DEL DIFFUSION
Browse files- Modules/diffusion/sampler.py +0 -181
- models.py +5 -23
- msinference.py +3 -12
Modules/diffusion/sampler.py
DELETED
|
@@ -1,181 +0,0 @@
|
|
| 1 |
-
from math import sqrt
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
from einops import rearrange
|
| 4 |
-
from torch import Tensor
|
| 5 |
-
from functools import reduce
|
| 6 |
-
# from inspect import isfunction
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
def default(val, d):
|
| 11 |
-
if val is not None: #exists(val):
|
| 12 |
-
return val
|
| 13 |
-
return d #d() if isfunction(d) else d
|
| 14 |
-
|
| 15 |
-
class LogNormalDistribution():
|
| 16 |
-
def __init__(self, mean: float, std: float):
|
| 17 |
-
self.mean = mean
|
| 18 |
-
self.std = std
|
| 19 |
-
|
| 20 |
-
def __call__(
|
| 21 |
-
self, num_samples: int, device: torch.device = torch.device("cpu")
|
| 22 |
-
) -> Tensor:
|
| 23 |
-
normal = self.mean + self.std * torch.randn((num_samples,), device=device)
|
| 24 |
-
return normal.exp()
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class UniformDistribution():
|
| 28 |
-
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
|
| 29 |
-
return torch.rand(num_samples, device=device)
|
| 30 |
-
|
| 31 |
-
def to_batch(
|
| 32 |
-
batch_size: int,
|
| 33 |
-
device: torch.device,
|
| 34 |
-
x = None,
|
| 35 |
-
xs = None):
|
| 36 |
-
# assert exists(x) ^ exists(xs), "Either x or xs must be provided"
|
| 37 |
-
# If x provided use the same for all batch items
|
| 38 |
-
if x is not None: #exists(x):
|
| 39 |
-
xs = torch.full(size=(batch_size,), fill_value=x).to(device)
|
| 40 |
-
# assert exists(xs)
|
| 41 |
-
return xs
|
| 42 |
-
|
| 43 |
-
class KDiffusion(nn.Module):
|
| 44 |
-
"""Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
|
| 45 |
-
|
| 46 |
-
alias = "k"
|
| 47 |
-
|
| 48 |
-
def __init__(
|
| 49 |
-
self,
|
| 50 |
-
net: nn.Module,
|
| 51 |
-
*,
|
| 52 |
-
sigma_distribution,
|
| 53 |
-
sigma_data: float, # data distribution standard deviation
|
| 54 |
-
dynamic_threshold: float = 0.0,
|
| 55 |
-
):
|
| 56 |
-
super().__init__()
|
| 57 |
-
self.net = net
|
| 58 |
-
self.sigma_data = sigma_data
|
| 59 |
-
|
| 60 |
-
def get_scale_weights(self, sigmas):
|
| 61 |
-
sigma_data = self.sigma_data
|
| 62 |
-
c_noise = torch.log(sigmas) * 0.25
|
| 63 |
-
sigmas = rearrange(sigmas, "b -> b 1 1")
|
| 64 |
-
c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
|
| 65 |
-
c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
|
| 66 |
-
c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
|
| 67 |
-
return c_skip, c_out, c_in, c_noise
|
| 68 |
-
|
| 69 |
-
def denoise_fn(
|
| 70 |
-
self,
|
| 71 |
-
x_noisy,
|
| 72 |
-
sigmas = None,
|
| 73 |
-
sigma = None,
|
| 74 |
-
**kwargs,
|
| 75 |
-
):
|
| 76 |
-
# raise ValueError
|
| 77 |
-
batch_size, device = x_noisy.shape[0], x_noisy.device
|
| 78 |
-
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
|
| 79 |
-
|
| 80 |
-
# Predict network output and add skip connection
|
| 81 |
-
# print('\n\n\n\n', kwargs, '\nKWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWWAr\n\n\n\n') 'embedding tensor'
|
| 82 |
-
c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
|
| 83 |
-
x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
|
| 84 |
-
x_denoised = c_skip * x_noisy + c_out * x_pred
|
| 85 |
-
|
| 86 |
-
return x_denoised
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class KarrasSchedule(nn.Module):
|
| 90 |
-
"""https://arxiv.org/abs/2206.00364 equation 5"""
|
| 91 |
-
|
| 92 |
-
def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
|
| 93 |
-
super().__init__()
|
| 94 |
-
self.sigma_min = sigma_min
|
| 95 |
-
self.sigma_max = sigma_max
|
| 96 |
-
self.rho = rho
|
| 97 |
-
|
| 98 |
-
def forward(self, num_steps: int, device):
|
| 99 |
-
rho_inv = 1.0 / self.rho
|
| 100 |
-
steps = torch.arange(num_steps, device=device, dtype=torch.float32)
|
| 101 |
-
sigmas = (
|
| 102 |
-
self.sigma_max ** rho_inv
|
| 103 |
-
+ (steps / (num_steps - 1))
|
| 104 |
-
* (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
|
| 105 |
-
) ** self.rho
|
| 106 |
-
sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
|
| 107 |
-
return sigmas
|
| 108 |
-
|
| 109 |
-
class ADPM2Sampler(nn.Module):
|
| 110 |
-
"""https://www.desmos.com/calculator/jbxjlqd9mb"""
|
| 111 |
-
|
| 112 |
-
diffusion_types = [KDiffusion,] # VKDiffusion]
|
| 113 |
-
|
| 114 |
-
def __init__(self, rho: float = 1.0):
|
| 115 |
-
super().__init__()
|
| 116 |
-
self.rho = rho
|
| 117 |
-
|
| 118 |
-
def get_sigmas(self,
|
| 119 |
-
sigma,
|
| 120 |
-
sigma_next):
|
| 121 |
-
r = self.rho
|
| 122 |
-
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
|
| 123 |
-
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
|
| 124 |
-
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
|
| 125 |
-
return sigma_up, sigma_down, sigma_mid
|
| 126 |
-
|
| 127 |
-
def step(self, x, fn, sigma, sigma_next):
|
| 128 |
-
|
| 129 |
-
sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
|
| 130 |
-
# Derivative at sigma (∂x/∂sigma)
|
| 131 |
-
d = (x - fn(x, sigma=sigma)) / sigma
|
| 132 |
-
# Denoise to midpoint
|
| 133 |
-
x_mid = x + d * (sigma_mid - sigma)
|
| 134 |
-
# Derivative at sigma_mid (∂x_mid/∂sigma_mid)
|
| 135 |
-
d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
|
| 136 |
-
# Denoise to next
|
| 137 |
-
x = x + d_mid * (sigma_down - sigma)
|
| 138 |
-
# Add randomness
|
| 139 |
-
x_next = x + torch.randn_like(x) * sigma_up
|
| 140 |
-
return x_next
|
| 141 |
-
|
| 142 |
-
def forward(
|
| 143 |
-
self, noise, fn, sigmas, num_steps):
|
| 144 |
-
# raise ValueError
|
| 145 |
-
x = sigmas[0] * noise
|
| 146 |
-
# Denoise to sample
|
| 147 |
-
for i in range(num_steps - 1):
|
| 148 |
-
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
|
| 149 |
-
return x
|
| 150 |
-
|
| 151 |
-
class DiffusionSampler(nn.Module):
|
| 152 |
-
|
| 153 |
-
def __init__(
|
| 154 |
-
self,
|
| 155 |
-
diffusion=None,
|
| 156 |
-
num_steps=None,
|
| 157 |
-
clamp=True, # default=False
|
| 158 |
-
):
|
| 159 |
-
super().__init__()
|
| 160 |
-
self.denoise_fn = diffusion.denoise_fn
|
| 161 |
-
self.sampler = ADPM2Sampler()
|
| 162 |
-
self.sigma_schedule = KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0)
|
| 163 |
-
self.num_steps = num_steps
|
| 164 |
-
self.clamp = clamp
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
def forward(
|
| 168 |
-
self, noise, num_steps=None, **kwargs):
|
| 169 |
-
# raise ValueError
|
| 170 |
-
device = noise.device
|
| 171 |
-
num_steps = default(num_steps, self.num_steps) # type: ignore
|
| 172 |
-
|
| 173 |
-
# Compute sigmas using schedule
|
| 174 |
-
sigmas = self.sigma_schedule(num_steps, device)
|
| 175 |
-
|
| 176 |
-
# L242 KWARGS dict_keys(['embedding', 'features'])
|
| 177 |
-
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
|
| 178 |
-
# Sample using sampler
|
| 179 |
-
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
|
| 180 |
-
x = x.clamp(-1.0, 1.0) if self.clamp else x
|
| 181 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models.py
CHANGED
|
@@ -11,20 +11,19 @@ import torch.nn.functional as F
|
|
| 11 |
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 12 |
from Utils.ASR.models import ASRCNN
|
| 13 |
from Utils.JDC.model import JDCNet
|
| 14 |
-
|
| 15 |
from Modules.diffusion.modules import StyleTransformer1d
|
| 16 |
-
|
| 17 |
from munch import Munch
|
| 18 |
import yaml
|
| 19 |
from math import pi
|
| 20 |
from random import randint
|
| 21 |
-
|
| 22 |
import torch
|
| 23 |
from einops import rearrange
|
| 24 |
from torch import Tensor, nn
|
| 25 |
from tqdm import tqdm
|
| 26 |
-
|
| 27 |
-
# from Modules.diffusion.sampler import *
|
| 28 |
|
| 29 |
|
| 30 |
|
|
@@ -623,23 +622,7 @@ def build_model(args, text_aligner, pitch_extractor, bert):
|
|
| 623 |
else:
|
| 624 |
raise NotImplementedError
|
| 625 |
|
| 626 |
-
|
| 627 |
-
in_channels=1,
|
| 628 |
-
embedding_max_length=bert.config.max_position_embeddings,
|
| 629 |
-
embedding_features=bert.config.hidden_size,
|
| 630 |
-
embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
|
| 631 |
-
channels=args.style_dim*2,
|
| 632 |
-
context_features=args.style_dim*2,
|
| 633 |
-
)
|
| 634 |
-
# this initialises self.diffusion for AudioDiffusionConditional
|
| 635 |
-
diffusion.diffusion = KDiffusion(
|
| 636 |
-
net=diffusion.unet,
|
| 637 |
-
sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
|
| 638 |
-
sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
|
| 639 |
-
dynamic_threshold=0.0
|
| 640 |
-
)
|
| 641 |
-
diffusion.diffusion.net = transformer
|
| 642 |
-
diffusion.unet = transformer
|
| 643 |
|
| 644 |
|
| 645 |
nets = Munch(
|
|
@@ -652,7 +635,6 @@ def build_model(args, text_aligner, pitch_extractor, bert):
|
|
| 652 |
|
| 653 |
predictor_encoder=predictor_encoder,
|
| 654 |
style_encoder=style_encoder,
|
| 655 |
-
diffusion=diffusion,
|
| 656 |
|
| 657 |
text_aligner = text_aligner,
|
| 658 |
pitch_extractor=pitch_extractor
|
|
|
|
| 11 |
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 12 |
from Utils.ASR.models import ASRCNN
|
| 13 |
from Utils.JDC.model import JDCNet
|
| 14 |
+
|
| 15 |
from Modules.diffusion.modules import StyleTransformer1d
|
| 16 |
+
|
| 17 |
from munch import Munch
|
| 18 |
import yaml
|
| 19 |
from math import pi
|
| 20 |
from random import randint
|
| 21 |
+
|
| 22 |
import torch
|
| 23 |
from einops import rearrange
|
| 24 |
from torch import Tensor, nn
|
| 25 |
from tqdm import tqdm
|
| 26 |
+
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
|
|
|
|
| 622 |
else:
|
| 623 |
raise NotImplementedError
|
| 624 |
|
| 625 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
|
| 627 |
|
| 628 |
nets = Munch(
|
|
|
|
| 635 |
|
| 636 |
predictor_encoder=predictor_encoder,
|
| 637 |
style_encoder=style_encoder,
|
|
|
|
| 638 |
|
| 639 |
text_aligner = text_aligner,
|
| 640 |
pitch_extractor=pitch_extractor
|
msinference.py
CHANGED
|
@@ -160,9 +160,7 @@ for key in model:
|
|
| 160 |
# _load(params[key], model[key])
|
| 161 |
_ = [model[key].eval() for key in model]
|
| 162 |
|
| 163 |
-
from Modules.diffusion.sampler import DiffusionSampler
|
| 164 |
|
| 165 |
-
sampler = DiffusionSampler(diffusion=model.diffusion.diffusion)
|
| 166 |
|
| 167 |
def inference(text,
|
| 168 |
ref_s,
|
|
@@ -205,17 +203,10 @@ def inference(text,
|
|
| 205 |
# print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler
|
| 206 |
# BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11])
|
| 207 |
|
| 208 |
-
|
| 209 |
-
embedding=bert_dur,
|
| 210 |
-
features=ref_s, # reference from the same speaker as the embedding
|
| 211 |
-
num_steps=diffusion_steps).squeeze(1)
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
s = s_pred[:, 128:]
|
| 215 |
-
ref = s_pred[:, :128]
|
| 216 |
|
| 217 |
-
ref =
|
| 218 |
-
s =
|
| 219 |
|
| 220 |
d = model.predictor.text_encoder(d_en,
|
| 221 |
s, input_lengths, text_mask)
|
|
|
|
| 160 |
# _load(params[key], model[key])
|
| 161 |
_ = [model[key].eval() for key in model]
|
| 162 |
|
|
|
|
| 163 |
|
|
|
|
| 164 |
|
| 165 |
def inference(text,
|
| 166 |
ref_s,
|
|
|
|
| 203 |
# print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler
|
| 204 |
# BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11])
|
| 205 |
|
| 206 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
+
ref = ref_s[:, :128]
|
| 209 |
+
s = ref_s[:, 128:]
|
| 210 |
|
| 211 |
d = model.predictor.text_encoder(d_en,
|
| 212 |
s, input_lengths, text_mask)
|