del diffusion [unused]
Browse files- audiocraft/audiogen.py +0 -129
- audiocraft/builders.py +2 -24
- audiocraft/diffusion_schedule.py +0 -272
- audiocraft/loaders.py +0 -24
- audiocraft/rope.py +0 -125
- audiocraft/unet.py +0 -214
audiocraft/audiogen.py
DELETED
|
@@ -1,129 +0,0 @@
|
|
| 1 |
-
import typing as tp
|
| 2 |
-
import torch
|
| 3 |
-
from audiocraft.loaders import load_compression_model, load_lm_model
|
| 4 |
-
import typing as tp
|
| 5 |
-
import omegaconf
|
| 6 |
-
import torch
|
| 7 |
-
import numpy as np
|
| 8 |
-
from .lm import LMModel
|
| 9 |
-
from .conditioners import ConditioningAttributes
|
| 10 |
-
from .utils.autocast import TorchAutocast
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def _shift(x):
|
| 15 |
-
n = x.shape[2]
|
| 16 |
-
i = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD do we have very short segments
|
| 17 |
-
x = torch.roll(x, i, dims=2)
|
| 18 |
-
return x
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class AudioGen():
|
| 22 |
-
"""Base generative model with convenient generation API.
|
| 23 |
-
|
| 24 |
-
Args:
|
| 25 |
-
name (str)
|
| 26 |
-
compression_model (CompressionModel): Encodec with Seanet Decoder
|
| 27 |
-
lm
|
| 28 |
-
max_duration (float, optional): As is using top250 token draw() we can gen xN sequences
|
| 29 |
-
"""
|
| 30 |
-
def __init__(self,
|
| 31 |
-
name,
|
| 32 |
-
compression_model,
|
| 33 |
-
lm,
|
| 34 |
-
max_duration=None):
|
| 35 |
-
self.name = name
|
| 36 |
-
self.compression_model = compression_model
|
| 37 |
-
self.lm = lm
|
| 38 |
-
self.cfg: tp.Optional[omegaconf.DictConfig] = None
|
| 39 |
-
# Just to be safe, let's put everything in eval mode.
|
| 40 |
-
self.compression_model.eval()
|
| 41 |
-
self.lm.eval()
|
| 42 |
-
|
| 43 |
-
if hasattr(lm, 'cfg'):
|
| 44 |
-
cfg = lm.cfg
|
| 45 |
-
assert isinstance(cfg, omegaconf.DictConfig)
|
| 46 |
-
self.cfg = cfg
|
| 47 |
-
|
| 48 |
-
if max_duration is None:
|
| 49 |
-
if self.cfg is not None:
|
| 50 |
-
max_duration = lm.cfg.dataset.segment_duration # type: ignore
|
| 51 |
-
else:
|
| 52 |
-
raise ValueError("You must provide max_duration when building directly your GenModel")
|
| 53 |
-
assert max_duration is not None
|
| 54 |
-
|
| 55 |
-
self.max_duration: float = max_duration
|
| 56 |
-
self.duration = self.max_duration
|
| 57 |
-
self.device = next(iter(lm.parameters())).device
|
| 58 |
-
self.generation_params={}
|
| 59 |
-
|
| 60 |
-
if self.device.type == 'cpu':
|
| 61 |
-
self.autocast = TorchAutocast(enabled=False)
|
| 62 |
-
else:
|
| 63 |
-
self.autocast = TorchAutocast(
|
| 64 |
-
enabled=True,
|
| 65 |
-
device_type=self.device.type,
|
| 66 |
-
dtype=torch.float16)
|
| 67 |
-
|
| 68 |
-
@property
|
| 69 |
-
def frame_rate(self) -> float:
|
| 70 |
-
"""Roughly the number of AR steps per seconds."""
|
| 71 |
-
return self.compression_model.frame_rate
|
| 72 |
-
|
| 73 |
-
@property
|
| 74 |
-
def sample_rate(self) -> int:
|
| 75 |
-
"""Sample rate of the generated audio."""
|
| 76 |
-
return self.compression_model.sample_rate
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def generate(self, descriptions):
|
| 83 |
-
attributes = [
|
| 84 |
-
ConditioningAttributes(text={'description': d}) for d in descriptions]
|
| 85 |
-
tokens = self._generate_tokens(attributes)
|
| 86 |
-
print(f'\n{tokens.shape=}\n{tokens=} FINAL 5 AUD')
|
| 87 |
-
return self.generate_audio(tokens)
|
| 88 |
-
|
| 89 |
-
def _generate_tokens(self, attributes):
|
| 90 |
-
|
| 91 |
-
total_gen_len = int(self.duration * self.frame_rate)
|
| 92 |
-
|
| 93 |
-
if self.duration <= self.max_duration:
|
| 94 |
-
# generate by sampling from LM, simple case.
|
| 95 |
-
|
| 96 |
-
with self.autocast:
|
| 97 |
-
gen_tokens = self.lm.generate(conditions=attributes, max_gen_len=total_gen_len)
|
| 98 |
-
else:
|
| 99 |
-
print('<>Long gen ?<>')
|
| 100 |
-
# print(f'{gen_tokens.shape=}') # [5,4,35]
|
| 101 |
-
# FLATTEN BATCH AS EXTRA SEQUENCE (BATCH IS VIRTUAL JUST MULTINOMIAL SAMPLING OF N_DRAW TOKENS)
|
| 102 |
-
gen_tokens = gen_tokens.transpose(0, 1).reshape(4, -1)[None, :, :]
|
| 103 |
-
for _ in range(3):
|
| 104 |
-
print(gen_tokens.shape)
|
| 105 |
-
gen_tokens = _shift(gen_tokens)
|
| 106 |
-
return gen_tokens
|
| 107 |
-
|
| 108 |
-
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
|
| 109 |
-
"""Generate Audio from tokens."""
|
| 110 |
-
assert gen_tokens.dim() == 3
|
| 111 |
-
with torch.no_grad():
|
| 112 |
-
gen_audio = self.compression_model.decode(gen_tokens, None)
|
| 113 |
-
return gen_audio
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def get_pretrained(name='facebook/audiogen-medium',
|
| 117 |
-
device=None):
|
| 118 |
-
"""Return pretrained model, we provide a single model for now:
|
| 119 |
-
- facebook/audiogen-medium (1.5B), text to sound,
|
| 120 |
-
# see: https://huggingface.co/facebook/audiogen-medium
|
| 121 |
-
"""
|
| 122 |
-
compression_model = load_compression_model(name, device=device)
|
| 123 |
-
lm = load_lm_model(name, device=device)
|
| 124 |
-
assert 'self_wav' not in lm.condition_provider.conditioners, \
|
| 125 |
-
"AudioGen do not support waveform conditioning for now"
|
| 126 |
-
return AudioGen(name, compression_model, lm)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/builders.py
CHANGED
|
@@ -16,10 +16,10 @@ from .conditioners import (
|
|
| 16 |
ConditioningProvider,
|
| 17 |
T5Conditioner,
|
| 18 |
)
|
| 19 |
-
|
| 20 |
from .vq import ResidualVectorQuantizer
|
| 21 |
|
| 22 |
-
|
| 23 |
|
| 24 |
def dict_from_config(cfg):
|
| 25 |
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
|
@@ -155,25 +155,3 @@ def get_codebooks_pattern_provider(n_q, cfg):
|
|
| 155 |
|
| 156 |
klass = pattern_providers[name]
|
| 157 |
return klass(n_q, **kwargs)
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
def get_diffusion_model(cfg: omegaconf.DictConfig):
|
| 164 |
-
# TODO Find a way to infer the channels from dset
|
| 165 |
-
channels = cfg.channels
|
| 166 |
-
num_steps = cfg.schedule.num_steps
|
| 167 |
-
return DiffusionUnet(
|
| 168 |
-
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
def get_processor(cfg, sample_rate: int = 24000):
|
| 172 |
-
sample_processor = SampleProcessor()
|
| 173 |
-
if cfg.use:
|
| 174 |
-
kw = dict(cfg)
|
| 175 |
-
kw.pop('use')
|
| 176 |
-
kw.pop('name')
|
| 177 |
-
if cfg.name == "multi_band_processor":
|
| 178 |
-
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
|
| 179 |
-
return sample_processor
|
|
|
|
| 16 |
ConditioningProvider,
|
| 17 |
T5Conditioner,
|
| 18 |
)
|
| 19 |
+
|
| 20 |
from .vq import ResidualVectorQuantizer
|
| 21 |
|
| 22 |
+
|
| 23 |
|
| 24 |
def dict_from_config(cfg):
|
| 25 |
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
|
|
|
| 155 |
|
| 156 |
klass = pattern_providers[name]
|
| 157 |
return klass(n_q, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/diffusion_schedule.py
DELETED
|
@@ -1,272 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
from collections import namedtuple
|
| 12 |
-
import random
|
| 13 |
-
import typing as tp
|
| 14 |
-
import julius
|
| 15 |
-
import torch
|
| 16 |
-
|
| 17 |
-
TrainingItem = namedtuple("TrainingItem", "noisy noise step")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def betas_from_alpha_bar(alpha_bar):
|
| 21 |
-
alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
|
| 22 |
-
return 1 - alphas
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class SampleProcessor(torch.nn.Module):
|
| 26 |
-
def project_sample(self, x: torch.Tensor):
|
| 27 |
-
"""Project the original sample to the 'space' where the diffusion will happen."""
|
| 28 |
-
return x
|
| 29 |
-
|
| 30 |
-
def return_sample(self, z: torch.Tensor):
|
| 31 |
-
"""Project back from diffusion space to the actual sample space."""
|
| 32 |
-
return z
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
class MultiBandProcessor(SampleProcessor):
|
| 36 |
-
"""
|
| 37 |
-
MultiBand sample processor. The input audio is splitted across
|
| 38 |
-
frequency bands evenly distributed in mel-scale.
|
| 39 |
-
|
| 40 |
-
Each band will be rescaled to match the power distribution
|
| 41 |
-
of Gaussian noise in that band, using online metrics
|
| 42 |
-
computed on the first few samples.
|
| 43 |
-
|
| 44 |
-
Args:
|
| 45 |
-
n_bands (int): Number of mel-bands to split the signal over.
|
| 46 |
-
sample_rate (int): Sample rate of the audio.
|
| 47 |
-
num_samples (int): Number of samples to use to fit the rescaling
|
| 48 |
-
for each band. The processor won't be stable
|
| 49 |
-
until it has seen that many samples.
|
| 50 |
-
power_std (float or list/tensor): The rescaling factor computed to match the
|
| 51 |
-
power of Gaussian noise in each band is taken to
|
| 52 |
-
that power, i.e. `1.` means full correction of the energy
|
| 53 |
-
in each band, and values less than `1` means only partial
|
| 54 |
-
correction. Can be used to balance the relative importance
|
| 55 |
-
of low vs. high freq in typical audio signals.
|
| 56 |
-
"""
|
| 57 |
-
def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
|
| 58 |
-
num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
|
| 59 |
-
super().__init__()
|
| 60 |
-
self.n_bands = n_bands
|
| 61 |
-
self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
|
| 62 |
-
self.num_samples = num_samples
|
| 63 |
-
self.power_std = power_std
|
| 64 |
-
if isinstance(power_std, list):
|
| 65 |
-
assert len(power_std) == n_bands
|
| 66 |
-
power_std = torch.tensor(power_std)
|
| 67 |
-
self.register_buffer('counts', torch.zeros(1))
|
| 68 |
-
self.register_buffer('sum_x', torch.zeros(n_bands))
|
| 69 |
-
self.register_buffer('sum_x2', torch.zeros(n_bands))
|
| 70 |
-
self.register_buffer('sum_target_x2', torch.zeros(n_bands))
|
| 71 |
-
self.counts: torch.Tensor
|
| 72 |
-
self.sum_x: torch.Tensor
|
| 73 |
-
self.sum_x2: torch.Tensor
|
| 74 |
-
self.sum_target_x2: torch.Tensor
|
| 75 |
-
|
| 76 |
-
@property
|
| 77 |
-
def mean(self):
|
| 78 |
-
mean = self.sum_x / self.counts
|
| 79 |
-
return mean
|
| 80 |
-
|
| 81 |
-
@property
|
| 82 |
-
def std(self):
|
| 83 |
-
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
| 84 |
-
return std
|
| 85 |
-
|
| 86 |
-
@property
|
| 87 |
-
def target_std(self):
|
| 88 |
-
target_std = self.sum_target_x2 / self.counts
|
| 89 |
-
return target_std
|
| 90 |
-
|
| 91 |
-
def project_sample(self, x: torch.Tensor):
|
| 92 |
-
assert x.dim() == 3
|
| 93 |
-
bands = self.split_bands(x)
|
| 94 |
-
if self.counts.item() < self.num_samples:
|
| 95 |
-
ref_bands = self.split_bands(torch.randn_like(x))
|
| 96 |
-
self.counts += len(x)
|
| 97 |
-
self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
|
| 98 |
-
self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
|
| 99 |
-
self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
|
| 100 |
-
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
| 101 |
-
bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
|
| 102 |
-
return bands.sum(dim=0)
|
| 103 |
-
|
| 104 |
-
def return_sample(self, x: torch.Tensor):
|
| 105 |
-
assert x.dim() == 3
|
| 106 |
-
bands = self.split_bands(x)
|
| 107 |
-
rescale = (self.std / self.target_std) ** self.power_std
|
| 108 |
-
bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
|
| 109 |
-
return bands.sum(dim=0)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class NoiseSchedule:
|
| 113 |
-
"""Noise schedule for diffusion.
|
| 114 |
-
|
| 115 |
-
Args:
|
| 116 |
-
beta_t0 (float): Variance of the first diffusion step.
|
| 117 |
-
beta_t1 (float): Variance of the last diffusion step.
|
| 118 |
-
beta_exp (float): Power schedule exponent
|
| 119 |
-
num_steps (int): Number of diffusion step.
|
| 120 |
-
variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
|
| 121 |
-
clip (float): clipping value for the denoising steps
|
| 122 |
-
rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
|
| 123 |
-
repartition (str): shape of the schedule only power schedule is supported
|
| 124 |
-
sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
|
| 125 |
-
noise_scale (float): Scaling factor for the noise
|
| 126 |
-
"""
|
| 127 |
-
def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
|
| 128 |
-
clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
|
| 129 |
-
repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
|
| 130 |
-
sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
|
| 131 |
-
|
| 132 |
-
self.beta_t0 = beta_t0
|
| 133 |
-
self.beta_t1 = beta_t1
|
| 134 |
-
self.variance = variance
|
| 135 |
-
self.num_steps = num_steps
|
| 136 |
-
self.clip = clip
|
| 137 |
-
self.sample_processor = sample_processor
|
| 138 |
-
self.rescale = rescale
|
| 139 |
-
self.n_bands = n_bands
|
| 140 |
-
self.noise_scale = noise_scale
|
| 141 |
-
assert n_bands is None
|
| 142 |
-
if repartition == "power":
|
| 143 |
-
self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
|
| 144 |
-
device=device, dtype=torch.float) ** beta_exp
|
| 145 |
-
else:
|
| 146 |
-
raise RuntimeError('Not implemented')
|
| 147 |
-
self.rng = random.Random(1234)
|
| 148 |
-
|
| 149 |
-
def get_beta(self, step: tp.Union[int, torch.Tensor]):
|
| 150 |
-
if self.n_bands is None:
|
| 151 |
-
return self.betas[step]
|
| 152 |
-
else:
|
| 153 |
-
return self.betas[:, step] # [n_bands, len(step)]
|
| 154 |
-
|
| 155 |
-
def get_initial_noise(self, x: torch.Tensor):
|
| 156 |
-
if self.n_bands is None:
|
| 157 |
-
return torch.randn_like(x)
|
| 158 |
-
return torch.randn((x.size(0), self.n_bands, x.size(2)))
|
| 159 |
-
|
| 160 |
-
def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
|
| 161 |
-
"""Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
|
| 162 |
-
if step is None:
|
| 163 |
-
return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
|
| 164 |
-
if type(step) is int:
|
| 165 |
-
return (1 - self.betas[:step + 1]).prod()
|
| 166 |
-
else:
|
| 167 |
-
return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
|
| 168 |
-
|
| 169 |
-
def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
|
| 170 |
-
"""Create a noisy data item for diffusion model training:
|
| 171 |
-
|
| 172 |
-
Args:
|
| 173 |
-
x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
|
| 174 |
-
tensor_step (bool): If tensor_step = false, only one step t is sample,
|
| 175 |
-
the whole batch is diffused to the same step and t is int.
|
| 176 |
-
If tensor_step = true, t is a tensor of size (x.size(0),)
|
| 177 |
-
every element of the batch is diffused to a independently sampled.
|
| 178 |
-
"""
|
| 179 |
-
step: tp.Union[int, torch.Tensor]
|
| 180 |
-
if tensor_step:
|
| 181 |
-
bs = x.size(0)
|
| 182 |
-
step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
|
| 183 |
-
else:
|
| 184 |
-
step = self.rng.randrange(self.num_steps)
|
| 185 |
-
alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
|
| 186 |
-
|
| 187 |
-
x = self.sample_processor.project_sample(x)
|
| 188 |
-
noise = torch.randn_like(x)
|
| 189 |
-
noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
|
| 190 |
-
return TrainingItem(noisy, noise, step)
|
| 191 |
-
|
| 192 |
-
def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
|
| 193 |
-
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
|
| 194 |
-
"""Full ddpm reverse process.
|
| 195 |
-
|
| 196 |
-
Args:
|
| 197 |
-
model (nn.Module): Diffusion model.
|
| 198 |
-
initial (tensor): Initial Noise.
|
| 199 |
-
condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
|
| 200 |
-
return_list (bool): Whether to return the whole process or only the sampled point.
|
| 201 |
-
"""
|
| 202 |
-
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
|
| 203 |
-
current = initial
|
| 204 |
-
iterates = [initial]
|
| 205 |
-
for step in range(self.num_steps)[::-1]:
|
| 206 |
-
with torch.no_grad():
|
| 207 |
-
estimate = model(current, step, condition=condition).sample
|
| 208 |
-
alpha = 1 - self.betas[step]
|
| 209 |
-
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
|
| 210 |
-
previous_alpha_bar = self.get_alpha_bar(step=step - 1)
|
| 211 |
-
if step == 0:
|
| 212 |
-
sigma2 = 0
|
| 213 |
-
elif self.variance == 'beta':
|
| 214 |
-
sigma2 = 1 - alpha
|
| 215 |
-
elif self.variance == 'beta_tilde':
|
| 216 |
-
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
|
| 217 |
-
elif self.variance == 'none':
|
| 218 |
-
sigma2 = 0
|
| 219 |
-
else:
|
| 220 |
-
raise ValueError(f'Invalid variance type {self.variance}')
|
| 221 |
-
|
| 222 |
-
if sigma2 > 0:
|
| 223 |
-
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
|
| 224 |
-
if self.clip:
|
| 225 |
-
previous = previous.clamp(-self.clip, self.clip)
|
| 226 |
-
current = previous
|
| 227 |
-
alpha_bar = previous_alpha_bar
|
| 228 |
-
if step == 0:
|
| 229 |
-
previous *= self.rescale
|
| 230 |
-
if return_list:
|
| 231 |
-
iterates.append(previous.cpu())
|
| 232 |
-
|
| 233 |
-
if return_list:
|
| 234 |
-
return iterates
|
| 235 |
-
else:
|
| 236 |
-
return self.sample_processor.return_sample(previous)
|
| 237 |
-
|
| 238 |
-
def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
|
| 239 |
-
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
|
| 240 |
-
"""Reverse process that only goes through Markov chain states in step_list."""
|
| 241 |
-
if step_list is None:
|
| 242 |
-
step_list = list(range(1000))[::-50] + [0]
|
| 243 |
-
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
|
| 244 |
-
alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
|
| 245 |
-
betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
|
| 246 |
-
current = initial * self.noise_scale
|
| 247 |
-
iterates = [current]
|
| 248 |
-
for idx, step in enumerate(step_list[:-1]):
|
| 249 |
-
with torch.no_grad():
|
| 250 |
-
estimate = model(current, step, condition=condition).sample * self.noise_scale
|
| 251 |
-
alpha = 1 - betas_subsampled[-1 - idx]
|
| 252 |
-
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
|
| 253 |
-
previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
|
| 254 |
-
if step == step_list[-2]:
|
| 255 |
-
sigma2 = 0
|
| 256 |
-
previous_alpha_bar = torch.tensor(1.0)
|
| 257 |
-
else:
|
| 258 |
-
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
|
| 259 |
-
if sigma2 > 0:
|
| 260 |
-
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
|
| 261 |
-
if self.clip:
|
| 262 |
-
previous = previous.clamp(-self.clip, self.clip)
|
| 263 |
-
current = previous
|
| 264 |
-
alpha_bar = previous_alpha_bar
|
| 265 |
-
if step == 0:
|
| 266 |
-
previous *= self.rescale
|
| 267 |
-
if return_list:
|
| 268 |
-
iterates.append(previous.cpu())
|
| 269 |
-
if return_list:
|
| 270 |
-
return iterates
|
| 271 |
-
else:
|
| 272 |
-
return self.sample_processor.return_sample(previous)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/loaders.py
CHANGED
|
@@ -1,33 +1,9 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Utility functions to load from the checkpoints.
|
| 9 |
-
Each checkpoint is a torch.saved dict with the following keys:
|
| 10 |
-
- 'xp.cfg': the hydra config as dumped during training. This should be used
|
| 11 |
-
to rebuild the object using the audiocraft.models.builders functions,
|
| 12 |
-
- 'model_best_state': a readily loadable best state for the model, including
|
| 13 |
-
the conditioner. The model obtained from `xp.cfg` should be compatible
|
| 14 |
-
with this state dict. In the case of a LM, the encodec model would not be
|
| 15 |
-
bundled along but instead provided separately.
|
| 16 |
-
|
| 17 |
-
Those functions also support loading from a remote location with the Torch Hub API.
|
| 18 |
-
They also support overriding some parameters, in particular the device and dtype
|
| 19 |
-
of the returned model.
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
from pathlib import Path
|
| 23 |
from huggingface_hub import hf_hub_download
|
| 24 |
import typing as tp
|
| 25 |
import os
|
| 26 |
-
|
| 27 |
from omegaconf import OmegaConf, DictConfig
|
| 28 |
import torch
|
| 29 |
-
|
| 30 |
-
import audiocraft
|
| 31 |
from . import builders
|
| 32 |
from .encodec import EncodecModel
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
from huggingface_hub import hf_hub_download
|
| 3 |
import typing as tp
|
| 4 |
import os
|
|
|
|
| 5 |
from omegaconf import OmegaConf, DictConfig
|
| 6 |
import torch
|
|
|
|
|
|
|
| 7 |
from . import builders
|
| 8 |
from .encodec import EncodecModel
|
| 9 |
|
audiocraft/rope.py
DELETED
|
@@ -1,125 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import typing as tp
|
| 8 |
-
|
| 9 |
-
from torch import nn
|
| 10 |
-
import torch
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class XPos(nn.Module):
|
| 14 |
-
"""Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
|
| 15 |
-
This applies an exponential decay to the RoPE rotation matrix.
|
| 16 |
-
|
| 17 |
-
Args:
|
| 18 |
-
dim (int): Embedding dimension.
|
| 19 |
-
smoothing (float): Smoothing factor applied to the decay rates.
|
| 20 |
-
base_scale (int): Base decay rate, given in terms of scaling time.
|
| 21 |
-
device (torch.device, optional): Device on which to initialize the module.
|
| 22 |
-
dtype (torch.dtype): dtype to use to generate the embedding.
|
| 23 |
-
"""
|
| 24 |
-
def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
|
| 25 |
-
device=None, dtype: torch.dtype = torch.float32):
|
| 26 |
-
super().__init__()
|
| 27 |
-
assert dim % 2 == 0
|
| 28 |
-
assert dtype in [torch.float64, torch.float32]
|
| 29 |
-
self.dtype = dtype
|
| 30 |
-
self.base_scale = base_scale
|
| 31 |
-
|
| 32 |
-
half_dim = dim // 2
|
| 33 |
-
adim = torch.arange(half_dim, device=device, dtype=dtype)
|
| 34 |
-
decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
|
| 35 |
-
self.register_buffer("decay_rates", decay_rates)
|
| 36 |
-
self.decay: tp.Optional[torch.Tensor] = None
|
| 37 |
-
|
| 38 |
-
def get_decay(self, start: int, end: int):
|
| 39 |
-
"""Create complex decay tensor, cache values for fast computation."""
|
| 40 |
-
if self.decay is None or end > self.decay.shape[0]:
|
| 41 |
-
assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
|
| 42 |
-
idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
|
| 43 |
-
power = idx / self.base_scale
|
| 44 |
-
scale = self.decay_rates ** power.unsqueeze(-1)
|
| 45 |
-
self.decay = torch.polar(scale, torch.zeros_like(scale))
|
| 46 |
-
return self.decay[start:end] # [T, C/2]
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class RotaryEmbedding(nn.Module):
|
| 50 |
-
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
|
| 51 |
-
|
| 52 |
-
Args:
|
| 53 |
-
dim (int): Embedding dimension (twice the number of frequencies).
|
| 54 |
-
max_period (float): Maximum period of the rotation frequencies.
|
| 55 |
-
xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
|
| 56 |
-
scale (float): Scale of positional embedding, set to 0 to deactivate.
|
| 57 |
-
device (torch.device, optional): Device on which to initialize the module.
|
| 58 |
-
dtype (torch.dtype): dtype to use to generate the embedding.
|
| 59 |
-
"""
|
| 60 |
-
def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
|
| 61 |
-
scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
|
| 62 |
-
super().__init__()
|
| 63 |
-
assert dim % 2 == 0
|
| 64 |
-
self.scale = scale
|
| 65 |
-
assert dtype in [torch.float64, torch.float32]
|
| 66 |
-
self.dtype = dtype
|
| 67 |
-
|
| 68 |
-
adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
|
| 69 |
-
frequencies = 1.0 / (max_period ** (adim / dim))
|
| 70 |
-
self.register_buffer("frequencies", frequencies)
|
| 71 |
-
self.rotation: tp.Optional[torch.Tensor] = None
|
| 72 |
-
|
| 73 |
-
self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
|
| 74 |
-
|
| 75 |
-
def get_rotation(self, start: int, end: int):
|
| 76 |
-
"""Create complex rotation tensor, cache values for fast computation."""
|
| 77 |
-
if self.rotation is None or end > self.rotation.shape[0]:
|
| 78 |
-
assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
|
| 79 |
-
idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
|
| 80 |
-
angles = torch.outer(idx, self.frequencies)
|
| 81 |
-
self.rotation = torch.polar(torch.ones_like(angles), angles)
|
| 82 |
-
return self.rotation[start:end]
|
| 83 |
-
|
| 84 |
-
def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
|
| 85 |
-
"""Apply rope rotation to query or key tensor."""
|
| 86 |
-
T = x.shape[time_dim]
|
| 87 |
-
target_shape = [1] * x.dim()
|
| 88 |
-
target_shape[time_dim] = T
|
| 89 |
-
target_shape[-1] = -1
|
| 90 |
-
rotation = self.get_rotation(start, start + T).view(target_shape)
|
| 91 |
-
|
| 92 |
-
if self.xpos:
|
| 93 |
-
decay = self.xpos.get_decay(start, start + T).view(target_shape)
|
| 94 |
-
else:
|
| 95 |
-
decay = 1.0
|
| 96 |
-
|
| 97 |
-
if invert_decay:
|
| 98 |
-
decay = decay ** -1
|
| 99 |
-
|
| 100 |
-
x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
|
| 101 |
-
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
|
| 102 |
-
x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)
|
| 103 |
-
|
| 104 |
-
return x_out.type_as(x)
|
| 105 |
-
|
| 106 |
-
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
|
| 107 |
-
""" Apply rope rotation to both query and key tensors.
|
| 108 |
-
Supports streaming mode, in which query and key are not expected to have the same shape.
|
| 109 |
-
In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
|
| 110 |
-
query will be [C] (typically C == 1).
|
| 111 |
-
|
| 112 |
-
Args:
|
| 113 |
-
query (torch.Tensor): Query to rotate.
|
| 114 |
-
key (torch.Tensor): Key to rotate.
|
| 115 |
-
start (int): Start index of the sequence for time offset.
|
| 116 |
-
time_dim (int): which dimension represent the time steps.
|
| 117 |
-
"""
|
| 118 |
-
query_timesteps = query.shape[time_dim]
|
| 119 |
-
key_timesteps = key.shape[time_dim]
|
| 120 |
-
streaming_offset = key_timesteps - query_timesteps
|
| 121 |
-
|
| 122 |
-
query_out = self.rotate(query, start + streaming_offset, time_dim)
|
| 123 |
-
key_out = self.rotate(key, start, time_dim, invert_decay=True)
|
| 124 |
-
|
| 125 |
-
return query_out, key_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/unet.py
DELETED
|
@@ -1,214 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Pytorch Unet Module used for diffusion.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
from dataclasses import dataclass
|
| 12 |
-
import typing as tp
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
from torch import nn
|
| 16 |
-
from torch.nn import functional as F
|
| 17 |
-
from .transformer import StreamingTransformer, create_sin_embedding
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
@dataclass
|
| 21 |
-
class Output:
|
| 22 |
-
sample: torch.Tensor
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def get_model(cfg, channels: int, side: int, num_steps: int):
|
| 26 |
-
if cfg.model == 'unet':
|
| 27 |
-
return DiffusionUnet(
|
| 28 |
-
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
| 29 |
-
else:
|
| 30 |
-
raise RuntimeError('Not Implemented')
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class ResBlock(nn.Module):
|
| 34 |
-
def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
|
| 35 |
-
dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
| 36 |
-
dropout: float = 0.):
|
| 37 |
-
super().__init__()
|
| 38 |
-
stride = 1
|
| 39 |
-
padding = dilation * (kernel - stride) // 2
|
| 40 |
-
Conv = nn.Conv1d
|
| 41 |
-
Drop = nn.Dropout1d
|
| 42 |
-
self.norm1 = nn.GroupNorm(norm_groups, channels)
|
| 43 |
-
self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
|
| 44 |
-
self.activation1 = activation()
|
| 45 |
-
self.dropout1 = Drop(dropout)
|
| 46 |
-
|
| 47 |
-
self.norm2 = nn.GroupNorm(norm_groups, channels)
|
| 48 |
-
self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
|
| 49 |
-
self.activation2 = activation()
|
| 50 |
-
self.dropout2 = Drop(dropout)
|
| 51 |
-
|
| 52 |
-
def forward(self, x):
|
| 53 |
-
h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
|
| 54 |
-
h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
|
| 55 |
-
return x + h
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
class DecoderLayer(nn.Module):
|
| 59 |
-
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
|
| 60 |
-
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
| 61 |
-
dropout: float = 0.):
|
| 62 |
-
super().__init__()
|
| 63 |
-
padding = (kernel - stride) // 2
|
| 64 |
-
self.res_blocks = nn.Sequential(
|
| 65 |
-
*[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
|
| 66 |
-
for idx in range(res_blocks)])
|
| 67 |
-
self.norm = nn.GroupNorm(norm_groups, chin)
|
| 68 |
-
ConvTr = nn.ConvTranspose1d
|
| 69 |
-
self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
|
| 70 |
-
self.activation = activation()
|
| 71 |
-
|
| 72 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
-
x = self.res_blocks(x)
|
| 74 |
-
x = self.norm(x)
|
| 75 |
-
x = self.activation(x)
|
| 76 |
-
x = self.convtr(x)
|
| 77 |
-
return x
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class EncoderLayer(nn.Module):
|
| 81 |
-
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
|
| 82 |
-
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
| 83 |
-
dropout: float = 0.):
|
| 84 |
-
super().__init__()
|
| 85 |
-
padding = (kernel - stride) // 2
|
| 86 |
-
Conv = nn.Conv1d
|
| 87 |
-
self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
|
| 88 |
-
self.norm = nn.GroupNorm(norm_groups, chout)
|
| 89 |
-
self.activation = activation()
|
| 90 |
-
self.res_blocks = nn.Sequential(
|
| 91 |
-
*[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
|
| 92 |
-
for idx in range(res_blocks)])
|
| 93 |
-
|
| 94 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 95 |
-
B, C, T = x.shape
|
| 96 |
-
stride, = self.conv.stride
|
| 97 |
-
pad = (stride - (T % stride)) % stride
|
| 98 |
-
x = F.pad(x, (0, pad))
|
| 99 |
-
|
| 100 |
-
x = self.conv(x)
|
| 101 |
-
x = self.norm(x)
|
| 102 |
-
x = self.activation(x)
|
| 103 |
-
x = self.res_blocks(x)
|
| 104 |
-
return x
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
class BLSTM(nn.Module):
|
| 108 |
-
"""BiLSTM with same hidden units as input dim.
|
| 109 |
-
"""
|
| 110 |
-
def __init__(self, dim, layers=2):
|
| 111 |
-
super().__init__()
|
| 112 |
-
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
| 113 |
-
self.linear = nn.Linear(2 * dim, dim)
|
| 114 |
-
|
| 115 |
-
def forward(self, x):
|
| 116 |
-
x = x.permute(2, 0, 1)
|
| 117 |
-
x = self.lstm(x)[0]
|
| 118 |
-
x = self.linear(x)
|
| 119 |
-
x = x.permute(1, 2, 0)
|
| 120 |
-
return x
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
class DiffusionUnet(nn.Module):
|
| 124 |
-
def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
|
| 125 |
-
max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
|
| 126 |
-
bilstm: bool = False, transformer: bool = False,
|
| 127 |
-
codec_dim: tp.Optional[int] = None, **kwargs):
|
| 128 |
-
super().__init__()
|
| 129 |
-
self.encoders = nn.ModuleList()
|
| 130 |
-
self.decoders = nn.ModuleList()
|
| 131 |
-
self.embeddings: tp.Optional[nn.ModuleList] = None
|
| 132 |
-
self.embedding = nn.Embedding(num_steps, hidden)
|
| 133 |
-
if emb_all_layers:
|
| 134 |
-
self.embeddings = nn.ModuleList()
|
| 135 |
-
self.condition_embedding: tp.Optional[nn.Module] = None
|
| 136 |
-
for d in range(depth):
|
| 137 |
-
encoder = EncoderLayer(chin, hidden, **kwargs)
|
| 138 |
-
decoder = DecoderLayer(hidden, chin, **kwargs)
|
| 139 |
-
self.encoders.append(encoder)
|
| 140 |
-
self.decoders.insert(0, decoder)
|
| 141 |
-
if emb_all_layers and d > 0:
|
| 142 |
-
assert self.embeddings is not None
|
| 143 |
-
self.embeddings.append(nn.Embedding(num_steps, hidden))
|
| 144 |
-
chin = hidden
|
| 145 |
-
hidden = min(int(chin * growth), max_channels)
|
| 146 |
-
self.bilstm: tp.Optional[nn.Module]
|
| 147 |
-
if bilstm:
|
| 148 |
-
self.bilstm = BLSTM(chin)
|
| 149 |
-
else:
|
| 150 |
-
self.bilstm = None
|
| 151 |
-
self.use_transformer = transformer
|
| 152 |
-
self.cross_attention = False
|
| 153 |
-
if transformer:
|
| 154 |
-
self.cross_attention = cross_attention
|
| 155 |
-
self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
|
| 156 |
-
cross_attention=cross_attention)
|
| 157 |
-
|
| 158 |
-
self.use_codec = False
|
| 159 |
-
if codec_dim is not None:
|
| 160 |
-
self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
|
| 161 |
-
self.use_codec = True
|
| 162 |
-
|
| 163 |
-
def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
|
| 164 |
-
skips = []
|
| 165 |
-
bs = x.size(0)
|
| 166 |
-
z = x
|
| 167 |
-
view_args = [1]
|
| 168 |
-
if type(step) is torch.Tensor:
|
| 169 |
-
step_tensor = step
|
| 170 |
-
else:
|
| 171 |
-
step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
|
| 172 |
-
|
| 173 |
-
for idx, encoder in enumerate(self.encoders):
|
| 174 |
-
z = encoder(z)
|
| 175 |
-
if idx == 0:
|
| 176 |
-
z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
|
| 177 |
-
elif self.embeddings is not None:
|
| 178 |
-
z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
|
| 179 |
-
|
| 180 |
-
skips.append(z)
|
| 181 |
-
|
| 182 |
-
if self.use_codec: # insert condition in the bottleneck
|
| 183 |
-
assert condition is not None, "Model defined for conditionnal generation"
|
| 184 |
-
condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
|
| 185 |
-
assert condition_emb.size(-1) <= 2 * z.size(-1), \
|
| 186 |
-
f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
|
| 187 |
-
if not self.cross_attention:
|
| 188 |
-
|
| 189 |
-
condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
|
| 190 |
-
assert z.size() == condition_emb.size()
|
| 191 |
-
z += condition_emb
|
| 192 |
-
cross_attention_src = None
|
| 193 |
-
else:
|
| 194 |
-
cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
|
| 195 |
-
B, T, C = cross_attention_src.shape
|
| 196 |
-
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
| 197 |
-
pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
|
| 198 |
-
cross_attention_src = cross_attention_src + pos_emb
|
| 199 |
-
if self.use_transformer:
|
| 200 |
-
z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
|
| 201 |
-
else:
|
| 202 |
-
if self.bilstm is None:
|
| 203 |
-
z = torch.zeros_like(z)
|
| 204 |
-
else:
|
| 205 |
-
z = self.bilstm(z)
|
| 206 |
-
|
| 207 |
-
for decoder in self.decoders:
|
| 208 |
-
s = skips.pop(-1)
|
| 209 |
-
z = z[:, :, :s.shape[2]]
|
| 210 |
-
z = z + s
|
| 211 |
-
z = decoder(z)
|
| 212 |
-
|
| 213 |
-
z = z[:, :, :x.shape[2]]
|
| 214 |
-
return Output(z)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|