|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import typing as tp |
|
|
|
|
|
import julius |
|
|
import torch |
|
|
import torchaudio |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
from torchaudio.functional.filtering import highpass_biquad, treble_biquad |
|
|
|
|
|
|
|
|
def basic_loudness(waveform: torch.Tensor, sample_rate: int) -> torch.Tensor: |
|
|
"""This is a simpler loudness function that is more stable. |
|
|
Args: |
|
|
waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)` |
|
|
sample_rate (int): sampling rate of the waveform |
|
|
Returns: |
|
|
loudness loss as a scalar |
|
|
""" |
|
|
|
|
|
if waveform.size(-2) > 5: |
|
|
raise ValueError("Only up to 5 channels are supported.") |
|
|
eps = torch.finfo(torch.float32).eps |
|
|
gate_duration = 0.4 |
|
|
overlap = 0.75 |
|
|
gate_samples = int(round(gate_duration * sample_rate)) |
|
|
step = int(round(gate_samples * (1 - overlap))) |
|
|
|
|
|
|
|
|
waveform = treble_biquad(waveform, sample_rate, 4.0, 1500.0, 1 / math.sqrt(2)) |
|
|
waveform = highpass_biquad(waveform, sample_rate, 38.0, 0.5) |
|
|
|
|
|
|
|
|
energy = torch.square(waveform).unfold(-1, gate_samples, step) |
|
|
energy = torch.mean(energy, dim=-1) |
|
|
|
|
|
|
|
|
g = torch.tensor([1.0, 1.0, 1.0, 1.41, 1.41], dtype=waveform.dtype, device=waveform.device) |
|
|
g = g[: energy.size(-2)] |
|
|
|
|
|
energy_weighted = torch.sum(g.unsqueeze(-1) * energy, dim=-2) |
|
|
|
|
|
loudness = -0.691 + 10 * torch.log10(energy_weighted + eps) |
|
|
return loudness |
|
|
|
|
|
|
|
|
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: |
|
|
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K] |
|
|
with K the kernel size, by extracting frames with the given stride. |
|
|
This will pad the input so that `F = ceil(T / K)`. |
|
|
see https://github.com/pytorch/pytorch/issues/60466 |
|
|
""" |
|
|
*shape, length = a.shape |
|
|
n_frames = math.ceil(length / stride) |
|
|
tgt_length = (n_frames - 1) * stride + kernel_size |
|
|
a = F.pad(a, (0, tgt_length - length)) |
|
|
strides = list(a.stride()) |
|
|
assert strides[-1] == 1, "data should be contiguous" |
|
|
strides = strides[:-1] + [stride, 1] |
|
|
return a.as_strided([*shape, n_frames, kernel_size], strides) |
|
|
|
|
|
|
|
|
class FLoudnessRatio(nn.Module): |
|
|
"""FSNR loss. |
|
|
|
|
|
Input should be [B, C, T], output is scalar. |
|
|
|
|
|
Args: |
|
|
sample_rate (int): Sample rate. |
|
|
segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on |
|
|
entire audio only. |
|
|
overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap. |
|
|
epsilon (float): Epsilon value for numerical stability. |
|
|
n_bands (int): number of mel scale bands that we include |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
sample_rate: int = 16000, |
|
|
segment: tp.Optional[float] = 20, |
|
|
overlap: float = 0.5, |
|
|
epsilon: float = torch.finfo(torch.float32).eps, |
|
|
n_bands: int = 0, |
|
|
): |
|
|
super().__init__() |
|
|
self.sample_rate = sample_rate |
|
|
self.segment = segment |
|
|
self.overlap = overlap |
|
|
self.epsilon = epsilon |
|
|
if n_bands == 0: |
|
|
self.filter = None |
|
|
else: |
|
|
self.filter = julius.SplitBands(sample_rate=sample_rate, n_bands=n_bands) |
|
|
self.loudness = torchaudio.transforms.Loudness(sample_rate) |
|
|
|
|
|
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor: |
|
|
B, C, T = ref_sig.shape |
|
|
assert ref_sig.shape == out_sig.shape |
|
|
assert self.filter is not None |
|
|
bands_ref = self.filter(ref_sig) |
|
|
bands_out = self.filter(out_sig) |
|
|
l_noise = self.loudness(bands_ref - bands_out) |
|
|
l_ref = self.loudness(bands_ref) |
|
|
l_ratio = (l_noise - l_ref).view(-1, B) |
|
|
loss = torch.nn.functional.softmax(l_ratio, dim=0) * l_ratio |
|
|
return loss.sum() |
|
|
|
|
|
|
|
|
class TLoudnessRatio(nn.Module): |
|
|
"""TSNR loss. |
|
|
|
|
|
Input should be [B, C, T], output is scalar. |
|
|
|
|
|
Args: |
|
|
sample_rate (int): Sample rate. |
|
|
segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on |
|
|
entire audio only. |
|
|
overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
sample_rate: int = 16000, |
|
|
segment: float = 0.5, |
|
|
overlap: float = 0.5, |
|
|
): |
|
|
super().__init__() |
|
|
self.sample_rate = sample_rate |
|
|
self.segment = segment |
|
|
self.overlap = overlap |
|
|
self.loudness = torchaudio.transforms.Loudness(sample_rate) |
|
|
|
|
|
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor: |
|
|
B, C, T = ref_sig.shape |
|
|
assert ref_sig.shape == out_sig.shape |
|
|
assert C == 1 |
|
|
|
|
|
frame = int(self.segment * self.sample_rate) |
|
|
stride = int(frame * (1 - self.overlap)) |
|
|
gt = _unfold(ref_sig, frame, stride).view(-1, 1, frame) |
|
|
est = _unfold(out_sig, frame, stride).view(-1, 1, frame) |
|
|
l_noise = self.loudness(gt - est) |
|
|
l_ref = self.loudness(gt) |
|
|
l_ratio = (l_noise - l_ref).view(-1, B) |
|
|
loss = torch.nn.functional.softmax(l_ratio, dim=0) * l_ratio |
|
|
return loss.sum() |
|
|
|
|
|
|
|
|
class TFLoudnessRatio(nn.Module): |
|
|
"""TF-loudness ratio loss. |
|
|
|
|
|
Input should be [B, C, T], output is scalar. |
|
|
|
|
|
Args: |
|
|
sample_rate (int): Sample rate. |
|
|
segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on |
|
|
entire audio only. |
|
|
overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap. |
|
|
n_bands (int): number of bands to separate |
|
|
temperature (float): temperature of the softmax step |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
sample_rate: int = 16000, |
|
|
segment: float = 0.5, |
|
|
overlap: float = 0.5, |
|
|
n_bands: int = 0, |
|
|
clip_min: float = -100, |
|
|
temperature: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
self.sample_rate = sample_rate |
|
|
self.segment = segment |
|
|
self.overlap = overlap |
|
|
self.clip_min = clip_min |
|
|
self.temperature = temperature |
|
|
if n_bands == 0: |
|
|
self.filter = None |
|
|
else: |
|
|
self.n_bands = n_bands |
|
|
self.filter = julius.SplitBands(sample_rate=sample_rate, n_bands=n_bands) |
|
|
|
|
|
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor: |
|
|
B, C, T = ref_sig.shape |
|
|
|
|
|
assert ref_sig.shape == out_sig.shape |
|
|
assert C == 1 |
|
|
assert self.filter is not None |
|
|
|
|
|
bands_ref = self.filter(ref_sig).view(B * self.n_bands, 1, -1) |
|
|
bands_out = self.filter(out_sig).view(B * self.n_bands, 1, -1) |
|
|
frame = int(self.segment * self.sample_rate) |
|
|
stride = int(frame * (1 - self.overlap)) |
|
|
gt = _unfold(bands_ref, frame, stride).squeeze(1).contiguous().view(-1, 1, frame) |
|
|
est = _unfold(bands_out, frame, stride).squeeze(1).contiguous().view(-1, 1, frame) |
|
|
l_noise = basic_loudness(est - gt, sample_rate=self.sample_rate) |
|
|
l_ref = basic_loudness(gt, sample_rate=self.sample_rate) |
|
|
l_ratio = (l_noise - l_ref).view(-1, B) |
|
|
loss = torch.nn.functional.softmax(l_ratio / self.temperature, dim=0) * l_ratio |
|
|
return loss.mean() |
|
|
|