fx diffusion
Browse files- Modules/diffusion/sampler.py +5 -28
- Utils/PLBERT/util.py +1 -1
- msinference.py +4 -9
Modules/diffusion/sampler.py
CHANGED
|
@@ -4,7 +4,6 @@ from einops import rearrange
|
|
| 4 |
from torch import Tensor
|
| 5 |
from functools import reduce
|
| 6 |
# from inspect import isfunction
|
| 7 |
-
# from math import ceil, floor, log2, pi
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
| 10 |
|
|
@@ -29,8 +28,6 @@ class UniformDistribution():
|
|
| 29 |
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
|
| 30 |
return torch.rand(num_samples, device=device)
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
def to_batch(
|
| 35 |
batch_size: int,
|
| 36 |
device: torch.device,
|
|
@@ -59,8 +56,6 @@ class KDiffusion(nn.Module):
|
|
| 59 |
super().__init__()
|
| 60 |
self.net = net
|
| 61 |
self.sigma_data = sigma_data
|
| 62 |
-
self.sigma_distribution = sigma_distribution
|
| 63 |
-
self.dynamic_threshold = dynamic_threshold
|
| 64 |
|
| 65 |
def get_scale_weights(self, sigmas):
|
| 66 |
sigma_data = self.sigma_data
|
|
@@ -91,17 +86,6 @@ class KDiffusion(nn.Module):
|
|
| 91 |
return x_denoised
|
| 92 |
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
class KarrasSchedule(nn.Module):
|
| 106 |
"""https://arxiv.org/abs/2206.00364 equation 5"""
|
| 107 |
|
|
@@ -165,27 +149,20 @@ class ADPM2Sampler(nn.Module):
|
|
| 165 |
return x
|
| 166 |
|
| 167 |
class DiffusionSampler(nn.Module):
|
|
|
|
| 168 |
def __init__(
|
| 169 |
self,
|
| 170 |
-
diffusion,
|
| 171 |
-
*,
|
| 172 |
-
sampler,
|
| 173 |
-
sigma_schedule,
|
| 174 |
num_steps=None,
|
| 175 |
-
clamp=True,
|
| 176 |
):
|
| 177 |
super().__init__()
|
| 178 |
self.denoise_fn = diffusion.denoise_fn
|
| 179 |
-
self.sampler =
|
| 180 |
-
self.sigma_schedule =
|
| 181 |
self.num_steps = num_steps
|
| 182 |
self.clamp = clamp
|
| 183 |
|
| 184 |
-
# Check sampler is compatible with diffusion type
|
| 185 |
-
sampler_class = sampler.__class__.__name__
|
| 186 |
-
diffusion_class = diffusion.__class__.__name__
|
| 187 |
-
message = f"{sampler_class} incompatible with {diffusion_class}"
|
| 188 |
-
assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
|
| 189 |
|
| 190 |
def forward(
|
| 191 |
self, noise, num_steps=None, **kwargs):
|
|
|
|
| 4 |
from torch import Tensor
|
| 5 |
from functools import reduce
|
| 6 |
# from inspect import isfunction
|
|
|
|
| 7 |
import torch
|
| 8 |
import torch.nn.functional as F
|
| 9 |
|
|
|
|
| 28 |
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
|
| 29 |
return torch.rand(num_samples, device=device)
|
| 30 |
|
|
|
|
|
|
|
| 31 |
def to_batch(
|
| 32 |
batch_size: int,
|
| 33 |
device: torch.device,
|
|
|
|
| 56 |
super().__init__()
|
| 57 |
self.net = net
|
| 58 |
self.sigma_data = sigma_data
|
|
|
|
|
|
|
| 59 |
|
| 60 |
def get_scale_weights(self, sigmas):
|
| 61 |
sigma_data = self.sigma_data
|
|
|
|
| 86 |
return x_denoised
|
| 87 |
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
class KarrasSchedule(nn.Module):
|
| 90 |
"""https://arxiv.org/abs/2206.00364 equation 5"""
|
| 91 |
|
|
|
|
| 149 |
return x
|
| 150 |
|
| 151 |
class DiffusionSampler(nn.Module):
|
| 152 |
+
|
| 153 |
def __init__(
|
| 154 |
self,
|
| 155 |
+
diffusion=None,
|
|
|
|
|
|
|
|
|
|
| 156 |
num_steps=None,
|
| 157 |
+
clamp=True, # default=False
|
| 158 |
):
|
| 159 |
super().__init__()
|
| 160 |
self.denoise_fn = diffusion.denoise_fn
|
| 161 |
+
self.sampler = ADPM2Sampler()
|
| 162 |
+
self.sigma_schedule = KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0)
|
| 163 |
self.num_steps = num_steps
|
| 164 |
self.clamp = clamp
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
def forward(
|
| 168 |
self, noise, num_steps=None, **kwargs):
|
Utils/PLBERT/util.py
CHANGED
|
@@ -37,6 +37,6 @@ def load_plbert(log_dir):
|
|
| 37 |
name = name[8:] # remove `encoder.`
|
| 38 |
new_state_dict[name] = v
|
| 39 |
del new_state_dict["embeddings.position_ids"]
|
| 40 |
-
bert.load_state_dict(new_state_dict, strict=
|
| 41 |
|
| 42 |
return bert
|
|
|
|
| 37 |
name = name[8:] # remove `encoder.`
|
| 38 |
new_state_dict[name] = v
|
| 39 |
del new_state_dict["embeddings.position_ids"]
|
| 40 |
+
bert.load_state_dict(new_state_dict, strict=True)
|
| 41 |
|
| 42 |
return bert
|
msinference.py
CHANGED
|
@@ -17,8 +17,8 @@ from torch import nn
|
|
| 17 |
from nltk.tokenize import word_tokenize
|
| 18 |
|
| 19 |
torch.manual_seed(0)
|
| 20 |
-
torch.backends.cudnn.benchmark = False
|
| 21 |
-
torch.backends.cudnn.deterministic = True
|
| 22 |
|
| 23 |
|
| 24 |
# IPA Phonemizer: https://github.com/bootphon/phonemizer
|
|
@@ -160,14 +160,9 @@ for key in model:
|
|
| 160 |
# _load(params[key], model[key])
|
| 161 |
_ = [model[key].eval() for key in model]
|
| 162 |
|
| 163 |
-
from Modules.diffusion.sampler import DiffusionSampler
|
| 164 |
|
| 165 |
-
sampler = DiffusionSampler(
|
| 166 |
-
model.diffusion.diffusion,
|
| 167 |
-
sampler=ADPM2Sampler(),
|
| 168 |
-
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
|
| 169 |
-
clamp=False
|
| 170 |
-
)
|
| 171 |
|
| 172 |
def inference(text,
|
| 173 |
ref_s,
|
|
|
|
| 17 |
from nltk.tokenize import word_tokenize
|
| 18 |
|
| 19 |
torch.manual_seed(0)
|
| 20 |
+
# torch.backends.cudnn.benchmark = False
|
| 21 |
+
# torch.backends.cudnn.deterministic = True
|
| 22 |
|
| 23 |
|
| 24 |
# IPA Phonemizer: https://github.com/bootphon/phonemizer
|
|
|
|
| 160 |
# _load(params[key], model[key])
|
| 161 |
_ = [model[key].eval() for key in model]
|
| 162 |
|
| 163 |
+
from Modules.diffusion.sampler import DiffusionSampler
|
| 164 |
|
| 165 |
+
sampler = DiffusionSampler(diffusion=model.diffusion.diffusion)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
def inference(text,
|
| 168 |
ref_s,
|