|
|
import warnings |
|
|
from typing import Optional |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
def make_enc_dec( |
|
|
fb_name, |
|
|
n_filters, |
|
|
kernel_size, |
|
|
stride=None, |
|
|
sample_rate=8000.0, |
|
|
who_is_pinv=None, |
|
|
padding=0, |
|
|
output_padding=0, |
|
|
**kwargs, |
|
|
): |
|
|
"""Creates congruent encoder and decoder from the same filterbank family. |
|
|
Args: |
|
|
fb_name (str, className): Filterbank family from which to make encoder |
|
|
and decoder. To choose among [``'free'``, ``'analytic_free'``, |
|
|
``'param_sinc'``, ``'stft'``]. Can also be a class defined in a |
|
|
submodule in this subpackade (e.g. :class:`~.FreeFB`). |
|
|
n_filters (int): Number of filters. |
|
|
kernel_size (int): Length of the filters. |
|
|
stride (int, optional): Stride of the convolution. |
|
|
If None (default), set to ``kernel_size // 2``. |
|
|
sample_rate (float): Sample rate of the expected audio. |
|
|
Defaults to 8000.0. |
|
|
who_is_pinv (str, optional): If `None`, no pseudo-inverse filters will |
|
|
be used. If string (among [``'encoder'``, ``'decoder'``]), decides |
|
|
which of ``Encoder`` or ``Decoder`` will be the pseudo inverse of |
|
|
the other one. |
|
|
padding (int): Zero-padding added to both sides of the input. |
|
|
Passed to Encoder and Decoder. |
|
|
output_padding (int): Additional size added to one side of the output shape. |
|
|
Passed to Decoder. |
|
|
**kwargs: Arguments which will be passed to the filterbank class |
|
|
additionally to the usual `n_filters`, `kernel_size` and `stride`. |
|
|
Depends on the filterbank family. |
|
|
Returns: |
|
|
:class:`.Encoder`, :class:`.Decoder` |
|
|
""" |
|
|
fb_class = get(fb_name) |
|
|
|
|
|
if who_is_pinv in ["dec", "decoder"]: |
|
|
fb = fb_class( |
|
|
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs |
|
|
) |
|
|
enc = Encoder(fb, padding=padding) |
|
|
|
|
|
dec = Decoder.pinv_of(fb) |
|
|
elif who_is_pinv in ["enc", "encoder"]: |
|
|
fb = fb_class( |
|
|
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs |
|
|
) |
|
|
dec = Decoder(fb, padding=padding, output_padding=output_padding) |
|
|
|
|
|
enc = Encoder.pinv_of(fb) |
|
|
else: |
|
|
fb = fb_class( |
|
|
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs |
|
|
) |
|
|
enc = Encoder(fb, padding=padding) |
|
|
|
|
|
fb = fb_class( |
|
|
n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs |
|
|
) |
|
|
dec = Decoder(fb, padding=padding, output_padding=output_padding) |
|
|
return enc, dec |
|
|
|
|
|
|
|
|
def register_filterbank(custom_fb): |
|
|
"""Register a custom filterbank, gettable with `filterbanks.get`. |
|
|
Args: |
|
|
custom_fb: Custom filterbank to register. |
|
|
""" |
|
|
if ( |
|
|
custom_fb.__name__ in globals().keys() |
|
|
or custom_fb.__name__.lower() in globals().keys() |
|
|
): |
|
|
raise ValueError( |
|
|
f"Filterbank {custom_fb.__name__} already exists. Choose another name." |
|
|
) |
|
|
globals().update({custom_fb.__name__: custom_fb}) |
|
|
|
|
|
|
|
|
def get(identifier): |
|
|
"""Returns a filterbank class from a string. Returns its input if it |
|
|
is callable (already a :class:`.Filterbank` for example). |
|
|
Args: |
|
|
identifier (str or Callable or None): the filterbank identifier. |
|
|
Returns: |
|
|
:class:`.Filterbank` or None |
|
|
""" |
|
|
if identifier is None: |
|
|
return None |
|
|
elif callable(identifier): |
|
|
return identifier |
|
|
elif isinstance(identifier, str): |
|
|
cls = globals().get(identifier) |
|
|
if cls is None: |
|
|
raise ValueError( |
|
|
"Could not interpret filterbank identifier: " + str(identifier) |
|
|
) |
|
|
return cls |
|
|
else: |
|
|
raise ValueError( |
|
|
"Could not interpret filterbank identifier: " + str(identifier) |
|
|
) |
|
|
|
|
|
|
|
|
class Filterbank(nn.Module): |
|
|
"""Base Filterbank class. |
|
|
Each subclass has to implement a ``filters`` method. |
|
|
Args: |
|
|
n_filters (int): Number of filters. |
|
|
kernel_size (int): Length of the filters. |
|
|
stride (int, optional): Stride of the conv or transposed conv. (Hop size). |
|
|
If None (default), set to ``kernel_size // 2``. |
|
|
sample_rate (float): Sample rate of the expected audio. |
|
|
Defaults to 8000. |
|
|
Attributes: |
|
|
n_feats_out (int): Number of output filters. |
|
|
""" |
|
|
|
|
|
def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0): |
|
|
super(Filterbank, self).__init__() |
|
|
self.n_filters = n_filters |
|
|
self.kernel_size = kernel_size |
|
|
self.stride = stride if stride else self.kernel_size // 2 |
|
|
|
|
|
|
|
|
self.n_feats_out = n_filters |
|
|
self.sample_rate = sample_rate |
|
|
|
|
|
def filters(self): |
|
|
"""Abstract method for filters.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def pre_analysis(self, wav: torch.Tensor): |
|
|
"""Apply transform before encoder convolution.""" |
|
|
return wav |
|
|
|
|
|
def post_analysis(self, spec: torch.Tensor): |
|
|
"""Apply transform to encoder convolution.""" |
|
|
return spec |
|
|
|
|
|
def pre_synthesis(self, spec: torch.Tensor): |
|
|
"""Apply transform before decoder transposed convolution.""" |
|
|
return spec |
|
|
|
|
|
def post_synthesis(self, wav: torch.Tensor): |
|
|
"""Apply transform after decoder transposed convolution.""" |
|
|
return wav |
|
|
|
|
|
def get_config(self): |
|
|
"""Returns dictionary of arguments to re-instantiate the class. |
|
|
Needs to be subclassed if the filterbanks takes additional arguments |
|
|
than ``n_filters`` ``kernel_size`` ``stride`` and ``sample_rate``. |
|
|
""" |
|
|
config = { |
|
|
"fb_name": self.__class__.__name__, |
|
|
"n_filters": self.n_filters, |
|
|
"kernel_size": self.kernel_size, |
|
|
"stride": self.stride, |
|
|
"sample_rate": self.sample_rate, |
|
|
} |
|
|
return config |
|
|
|
|
|
def forward(self, waveform): |
|
|
raise NotImplementedError( |
|
|
"Filterbanks must be wrapped with an Encoder or a Decoder." |
|
|
) |
|
|
|
|
|
|
|
|
class _EncDec(nn.Module): |
|
|
"""Base private class for Encoder and Decoder. |
|
|
Common parameters and methods. |
|
|
Args: |
|
|
filterbank (:class:`Filterbank`): Filterbank instance. The filterbank |
|
|
to use as an encoder or a decoder. |
|
|
is_pinv (bool): Whether to be the pseudo inverse of filterbank. |
|
|
Attributes: |
|
|
filterbank (:class:`Filterbank`) |
|
|
stride (int) |
|
|
is_pinv (bool) |
|
|
""" |
|
|
|
|
|
def __init__(self, filterbank, is_pinv=False): |
|
|
super(_EncDec, self).__init__() |
|
|
self.filterbank = filterbank |
|
|
self.sample_rate = getattr(filterbank, "sample_rate", None) |
|
|
self.stride = self.filterbank.stride |
|
|
self.is_pinv = is_pinv |
|
|
|
|
|
def filters(self): |
|
|
return self.filterbank.filters() |
|
|
|
|
|
def compute_filter_pinv(self, filters): |
|
|
"""Computes pseudo inverse filterbank of given filters.""" |
|
|
scale = self.filterbank.stride / self.filterbank.kernel_size |
|
|
shape = filters.shape |
|
|
ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape) |
|
|
|
|
|
return ifilt * scale |
|
|
|
|
|
def get_filters(self): |
|
|
"""Returns filters or pinv filters depending on `is_pinv` attribute""" |
|
|
if self.is_pinv: |
|
|
return self.compute_filter_pinv(self.filters()) |
|
|
else: |
|
|
return self.filters() |
|
|
|
|
|
def get_config(self): |
|
|
"""Returns dictionary of arguments to re-instantiate the class.""" |
|
|
config = {"is_pinv": self.is_pinv} |
|
|
base_config = self.filterbank.get_config() |
|
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
|
|
|
class Encoder(_EncDec): |
|
|
r"""Encoder class. |
|
|
Add encoding methods to Filterbank classes. |
|
|
Not intended to be subclassed. |
|
|
Args: |
|
|
filterbank (:class:`Filterbank`): The filterbank to use |
|
|
as an encoder. |
|
|
is_pinv (bool): Whether to be the pseudo inverse of filterbank. |
|
|
as_conv1d (bool): Whether to behave like nn.Conv1d. |
|
|
If True (default), forwarding input with shape :math:`(batch, 1, time)` |
|
|
will output a tensor of shape :math:`(batch, freq, conv\_time)`. |
|
|
If False, will output a tensor of shape :math:`(batch, 1, freq, conv\_time)`. |
|
|
padding (int): Zero-padding added to both sides of the input. |
|
|
""" |
|
|
|
|
|
def __init__(self, filterbank, is_pinv=False, as_conv1d=True, padding=0): |
|
|
super(Encoder, self).__init__(filterbank, is_pinv=is_pinv) |
|
|
self.as_conv1d = as_conv1d |
|
|
self.n_feats_out = self.filterbank.n_feats_out |
|
|
self.kernel_size = self.filterbank.kernel_size |
|
|
self.padding = padding |
|
|
|
|
|
@classmethod |
|
|
def pinv_of(cls, filterbank, **kwargs): |
|
|
"""Returns an :class:`~.Encoder`, pseudo inverse of a |
|
|
:class:`~.Filterbank` or :class:`~.Decoder`.""" |
|
|
if isinstance(filterbank, Filterbank): |
|
|
return cls(filterbank, is_pinv=True, **kwargs) |
|
|
elif isinstance(filterbank, Decoder): |
|
|
return cls(filterbank.filterbank, is_pinv=True, **kwargs) |
|
|
|
|
|
def forward(self, waveform): |
|
|
"""Convolve input waveform with the filters from a filterbank. |
|
|
Args: |
|
|
waveform (:class:`torch.Tensor`): any tensor with samples along the |
|
|
last dimension. The waveform representation with and |
|
|
batch/channel etc.. dimension. |
|
|
Returns: |
|
|
:class:`torch.Tensor`: The corresponding TF domain signal. |
|
|
Shapes |
|
|
>>> (time, ) -> (freq, conv_time) |
|
|
>>> (batch, time) -> (batch, freq, conv_time) # Avoid |
|
|
>>> if as_conv1d: |
|
|
>>> (batch, 1, time) -> (batch, freq, conv_time) |
|
|
>>> (batch, chan, time) -> (batch, chan, freq, conv_time) |
|
|
>>> else: |
|
|
>>> (batch, chan, time) -> (batch, chan, freq, conv_time) |
|
|
>>> (batch, any, dim, time) -> (batch, any, dim, freq, conv_time) |
|
|
""" |
|
|
filters = self.get_filters() |
|
|
waveform = self.filterbank.pre_analysis(waveform) |
|
|
spec = multishape_conv1d( |
|
|
waveform, |
|
|
filters=filters, |
|
|
stride=self.stride, |
|
|
padding=self.padding, |
|
|
as_conv1d=self.as_conv1d, |
|
|
) |
|
|
return self.filterbank.post_analysis(spec) |
|
|
|
|
|
|
|
|
def multishape_conv1d( |
|
|
waveform: torch.Tensor, |
|
|
filters: torch.Tensor, |
|
|
stride: int, |
|
|
padding: int = 0, |
|
|
as_conv1d: bool = True, |
|
|
) -> torch.Tensor: |
|
|
if waveform.ndim == 1: |
|
|
|
|
|
|
|
|
return F.conv1d( |
|
|
waveform[None, None], filters, stride=stride, padding=padding |
|
|
).squeeze() |
|
|
elif waveform.ndim == 2: |
|
|
|
|
|
|
|
|
warnings.warn( |
|
|
"Input tensor was 2D. Applying the corresponding " |
|
|
"Decoder to the current output will result in a 3D " |
|
|
"tensor. This behaviours was introduced to match " |
|
|
"Conv1D and ConvTranspose1D, please use 3D inputs " |
|
|
"to avoid it. For example, this can be done with " |
|
|
"input_tensor.unsqueeze(1)." |
|
|
) |
|
|
return F.conv1d(waveform.unsqueeze(1), filters, stride=stride, padding=padding) |
|
|
elif waveform.ndim == 3: |
|
|
batch, channels, time_len = waveform.shape |
|
|
if channels == 1 and as_conv1d: |
|
|
|
|
|
|
|
|
return F.conv1d(waveform, filters, stride=stride, padding=padding) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return batch_packed_1d_conv( |
|
|
waveform, filters, stride=stride, padding=padding |
|
|
) |
|
|
else: |
|
|
|
|
|
|
|
|
return batch_packed_1d_conv(waveform, filters, stride=stride, padding=padding) |
|
|
|
|
|
|
|
|
def batch_packed_1d_conv( |
|
|
inp: torch.Tensor, filters: torch.Tensor, stride: int = 1, padding: int = 0 |
|
|
): |
|
|
|
|
|
|
|
|
batched_conv = F.conv1d( |
|
|
inp.view(-1, 1, inp.shape[-1]), filters, stride=stride, padding=padding |
|
|
) |
|
|
output_shape = inp.shape[:-1] + batched_conv.shape[-2:] |
|
|
return batched_conv.view(output_shape) |
|
|
|
|
|
|
|
|
class Decoder(_EncDec): |
|
|
"""Decoder class. |
|
|
Add decoding methods to Filterbank classes. |
|
|
Not intended to be subclassed. |
|
|
Args: |
|
|
filterbank (:class:`Filterbank`): The filterbank to use as an decoder. |
|
|
is_pinv (bool): Whether to be the pseudo inverse of filterbank. |
|
|
padding (int): Zero-padding added to both sides of the input. |
|
|
output_padding (int): Additional size added to one side of the |
|
|
output shape. |
|
|
.. note:: |
|
|
``padding`` and ``output_padding`` arguments are directly passed to |
|
|
``F.conv_transpose1d``. |
|
|
""" |
|
|
|
|
|
def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0): |
|
|
super().__init__(filterbank, is_pinv=is_pinv) |
|
|
self.padding = padding |
|
|
self.output_padding = output_padding |
|
|
|
|
|
@classmethod |
|
|
def pinv_of(cls, filterbank): |
|
|
"""Returns an Decoder, pseudo inverse of a filterbank or Encoder.""" |
|
|
if isinstance(filterbank, Filterbank): |
|
|
return cls(filterbank, is_pinv=True) |
|
|
elif isinstance(filterbank, Encoder): |
|
|
return cls(filterbank.filterbank, is_pinv=True) |
|
|
|
|
|
def forward(self, spec, length: Optional[int] = None) -> torch.Tensor: |
|
|
"""Applies transposed convolution to a TF representation. |
|
|
This is equivalent to overlap-add. |
|
|
Args: |
|
|
spec (:class:`torch.Tensor`): 3D or 4D Tensor. The TF |
|
|
representation. (Output of :func:`Encoder.forward`). |
|
|
length: desired output length. |
|
|
Returns: |
|
|
:class:`torch.Tensor`: The corresponding time domain signal. |
|
|
""" |
|
|
filters = self.get_filters() |
|
|
spec = self.filterbank.pre_synthesis(spec) |
|
|
wav = multishape_conv_transpose1d( |
|
|
spec, |
|
|
filters, |
|
|
stride=self.stride, |
|
|
padding=self.padding, |
|
|
output_padding=self.output_padding, |
|
|
) |
|
|
wav = self.filterbank.post_synthesis(wav) |
|
|
if length is not None: |
|
|
length = min(length, wav.shape[-1]) |
|
|
return wav[..., :length] |
|
|
return wav |
|
|
|
|
|
|
|
|
def multishape_conv_transpose1d( |
|
|
spec: torch.Tensor, |
|
|
filters: torch.Tensor, |
|
|
stride: int = 1, |
|
|
padding: int = 0, |
|
|
output_padding: int = 0, |
|
|
) -> torch.Tensor: |
|
|
if spec.ndim == 2: |
|
|
|
|
|
return F.conv_transpose1d( |
|
|
spec.unsqueeze(0), |
|
|
filters, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
output_padding=output_padding, |
|
|
).squeeze() |
|
|
if spec.ndim == 3: |
|
|
|
|
|
return F.conv_transpose1d( |
|
|
spec, |
|
|
filters, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
output_padding=output_padding, |
|
|
) |
|
|
else: |
|
|
|
|
|
|
|
|
view_as = (-1,) + spec.shape[-2:] |
|
|
out = F.conv_transpose1d( |
|
|
spec.reshape(view_as), |
|
|
filters, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
output_padding=output_padding, |
|
|
) |
|
|
return out.view(spec.shape[:-2] + (-1,)) |
|
|
|
|
|
|
|
|
class FreeFB(Filterbank): |
|
|
"""Free filterbank without any constraints. Equivalent to |
|
|
:class:`nn.Conv1d`. |
|
|
Args: |
|
|
n_filters (int): Number of filters. |
|
|
kernel_size (int): Length of the filters. |
|
|
stride (int, optional): Stride of the convolution. |
|
|
If None (default), set to ``kernel_size // 2``. |
|
|
sample_rate (float): Sample rate of the expected audio. |
|
|
Defaults to 8000. |
|
|
Attributes: |
|
|
n_feats_out (int): Number of output filters. |
|
|
References |
|
|
[1] : "Filterbank design for end-to-end speech separation". ICASSP 2020. |
|
|
Manuel Pariente, Samuele Cornell, Antoine Deleforge, Emmanuel Vincent. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, n_filters, kernel_size, stride=None, sample_rate=8000.0, **kwargs |
|
|
): |
|
|
super().__init__(n_filters, kernel_size, stride=stride, sample_rate=sample_rate) |
|
|
self._filters = nn.Parameter(torch.ones(n_filters, 1, kernel_size)) |
|
|
for p in self.parameters(): |
|
|
nn.init.xavier_normal_(p) |
|
|
|
|
|
def filters(self): |
|
|
return self._filters |
|
|
|
|
|
|
|
|
free = FreeFB |
|
|
|