|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Main model for using MelodyFlow. This will combine all the required components |
|
|
and provide easy access to the generation API. |
|
|
""" |
|
|
|
|
|
import typing as tp |
|
|
from audiocraft.utils.autocast import TorchAutocast |
|
|
import torch |
|
|
|
|
|
from .genmodel import BaseGenModel |
|
|
from ..modules.conditioners import ConditioningAttributes |
|
|
from ..utils.utils import vae_sample |
|
|
from .loaders import load_compression_model, load_dit_model_melodyflow |
|
|
|
|
|
|
|
|
class MelodyFlow(BaseGenModel): |
|
|
"""MelodyFlow main model with convenient generation API. |
|
|
Args: |
|
|
See MelodyFlow class. |
|
|
""" |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.set_generation_params() |
|
|
self.set_editing_params() |
|
|
if self.device.type == 'cpu' or self.device.type == 'mps': |
|
|
self.autocast = TorchAutocast(enabled=False) |
|
|
else: |
|
|
self.autocast = TorchAutocast( |
|
|
enabled=True, device_type=self.device.type, dtype=torch.bfloat16) |
|
|
|
|
|
@staticmethod |
|
|
def get_pretrained(name: str = 'facebook/melodyflow-t24-30secs', device=None): |
|
|
|
|
|
""" |
|
|
""" |
|
|
if device is None: |
|
|
if torch.cuda.device_count(): |
|
|
device = 'cuda' |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = 'mps' |
|
|
else: |
|
|
device = 'cpu' |
|
|
|
|
|
compression_model = load_compression_model(name, device=device) |
|
|
|
|
|
def _remove_weight_norm(module): |
|
|
if hasattr(module, "conv"): |
|
|
if hasattr(module.conv, "conv"): |
|
|
torch.nn.utils.parametrize.remove_parametrizations( |
|
|
module.conv.conv, "weight" |
|
|
) |
|
|
if hasattr(module, "convtr"): |
|
|
if hasattr(module.convtr, "convtr"): |
|
|
torch.nn.utils.parametrize.remove_parametrizations( |
|
|
module.convtr.convtr, "weight" |
|
|
) |
|
|
|
|
|
def _clear_weight_norm(module): |
|
|
_remove_weight_norm(module) |
|
|
for child in module.children(): |
|
|
_clear_weight_norm(child) |
|
|
|
|
|
compression_model.to('cpu') |
|
|
_clear_weight_norm(compression_model) |
|
|
compression_model.to(device) |
|
|
|
|
|
lm = load_dit_model_melodyflow(name, device=device) |
|
|
|
|
|
kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm} |
|
|
return MelodyFlow(**kwargs) |
|
|
|
|
|
def set_generation_params( |
|
|
self, |
|
|
solver: str = "midpoint", |
|
|
steps: int = 64, |
|
|
duration: float = 10.0, |
|
|
) -> tp.Dict[str, torch.Tensor]: |
|
|
"""Set regularized inversion parameters for MelodyFlow. |
|
|
|
|
|
Args: |
|
|
solver (str, optional): ODE solver, either euler or midpoint. |
|
|
steps (int, optional): number of inference steps. |
|
|
""" |
|
|
self.generation_params = { |
|
|
'solver': solver, |
|
|
'steps': steps, |
|
|
'duration': duration, |
|
|
} |
|
|
|
|
|
def set_editing_params( |
|
|
self, |
|
|
solver: str = "euler", |
|
|
steps: int = 25, |
|
|
target_flowstep: float = 0.0, |
|
|
regularize: bool = True, |
|
|
regularize_iters: int = 4, |
|
|
keep_last_k_iters: int = 2, |
|
|
lambda_kl: float = 0.2, |
|
|
) -> tp.Dict[str, torch.Tensor]: |
|
|
"""Set regularized inversion parameters for MelodyFlow. |
|
|
|
|
|
Args: |
|
|
solver (str, optional): ODE solver, either euler or midpoint. |
|
|
steps (int, optional): number of inference steps. |
|
|
target_flowstep (float): Target flow step. |
|
|
regularize (bool): Regularize each solver step. |
|
|
regularize_iters (int, optional): Number of regularization iterations. |
|
|
keep_last_k_iters (int, optional): Number of meaningful regularization iterations for moving average computation. |
|
|
lambda_kl (float, optional): KL regularization loss weight. |
|
|
""" |
|
|
self.editing_params = { |
|
|
'solver': solver, |
|
|
'steps': steps, |
|
|
'target_flowstep': target_flowstep, |
|
|
'regularize': regularize, |
|
|
'regularize_iters': regularize_iters, |
|
|
'keep_last_k_iters': keep_last_k_iters, |
|
|
'lambda_kl': lambda_kl, |
|
|
} |
|
|
|
|
|
def encode_audio(self, waveform: torch.Tensor) -> torch.Tensor: |
|
|
"""Generate Audio from tokens.""" |
|
|
assert waveform.dim() == 3 |
|
|
with torch.no_grad(): |
|
|
latent_sequence = self.compression_model.encode(waveform)[0].squeeze(1) |
|
|
return latent_sequence |
|
|
|
|
|
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor: |
|
|
"""Generate Audio from tokens.""" |
|
|
assert gen_tokens.dim() == 3 |
|
|
with torch.no_grad(): |
|
|
if self.lm.latent_mean.shape[1] != gen_tokens.shape[1]: |
|
|
|
|
|
mean, scale = gen_tokens.chunk(2, dim=1) |
|
|
gen_tokens = vae_sample(mean, scale) |
|
|
else: |
|
|
|
|
|
gen_tokens = gen_tokens * (self.lm.latent_std + 1e-5) + self.lm.latent_mean |
|
|
gen_audio = self.compression_model.decode(gen_tokens, None) |
|
|
return gen_audio |
|
|
|
|
|
def generate_unconditional(self, num_samples: int, progress: bool = False, |
|
|
return_tokens: bool = False) -> tp.Union[torch.Tensor, |
|
|
tp.Tuple[torch.Tensor, torch.Tensor]]: |
|
|
"""Generate samples in an unconditional manner. |
|
|
|
|
|
Args: |
|
|
num_samples (int): Number of samples to be generated. |
|
|
progress (bool, optional): Flag to display progress of the generation process. Defaults to False. |
|
|
""" |
|
|
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples |
|
|
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) |
|
|
assert prompt_tokens is None |
|
|
tokens = self._generate_tokens(attributes=attributes, |
|
|
prompt_tokens=prompt_tokens, |
|
|
progress=progress, |
|
|
**self.generation_params, |
|
|
) |
|
|
if return_tokens: |
|
|
return self.generate_audio(tokens), tokens |
|
|
return self.generate_audio(tokens) |
|
|
|
|
|
def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \ |
|
|
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: |
|
|
"""Generate samples conditioned on text. |
|
|
|
|
|
Args: |
|
|
descriptions (list of str): A list of strings used as text conditioning. |
|
|
progress (bool, optional): Flag to display progress of the generation process. Defaults to False. |
|
|
""" |
|
|
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) |
|
|
assert prompt_tokens is None |
|
|
tokens = self._generate_tokens(attributes=attributes, |
|
|
prompt_tokens=prompt_tokens, |
|
|
progress=progress, |
|
|
**self.generation_params, |
|
|
) |
|
|
if return_tokens: |
|
|
return self.generate_audio(tokens), tokens |
|
|
return self.generate_audio(tokens) |
|
|
|
|
|
def edit(self, |
|
|
prompt_tokens: torch.Tensor, |
|
|
descriptions: tp.List[str], |
|
|
src_descriptions: tp.Optional[tp.List[str]] = None, |
|
|
progress: bool = False, |
|
|
return_tokens: bool = False, |
|
|
) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: |
|
|
"""Generate samples conditioned on text. |
|
|
|
|
|
Args: |
|
|
prompt_tokens (torch.Tensor, optional): Audio prompt used as initial latent sequence. |
|
|
descriptions (list of str): A list of strings used as editing conditioning. |
|
|
inversion (str): Inversion method (either ddim or fm_renoise) |
|
|
target_flowstep (float): Target flow step pivot in [0, 1[. |
|
|
steps (int): number of solver steps. |
|
|
src_descriptions (list of str): A list of strings used as conditioning during latent inversion. |
|
|
progress (bool, optional): Flag to display progress of the generation process. Defaults to False. |
|
|
return_tokens (bool): Whether to return the generated tokens. |
|
|
""" |
|
|
empty_attributes, no_tokens = self._prepare_tokens_and_attributes( |
|
|
[""] if src_descriptions is None else src_descriptions, None) |
|
|
assert no_tokens is None |
|
|
edit_attributes, no_tokens = self._prepare_tokens_and_attributes(descriptions, None) |
|
|
assert no_tokens is None |
|
|
|
|
|
inversion_params = self.editing_params.copy() |
|
|
override_total_steps = inversion_params["steps"] * ( |
|
|
inversion_params["regularize_iters"] + 1) if inversion_params["regularize"] else inversion_params["steps"] * 2 |
|
|
current_step_offset: int = 0 |
|
|
|
|
|
def _progress_callback(elapsed_steps: int, total_steps: int): |
|
|
elapsed_steps += current_step_offset |
|
|
if self._progress_callback is not None: |
|
|
self._progress_callback(elapsed_steps, override_total_steps) |
|
|
else: |
|
|
print(f'{elapsed_steps: 6d} / {override_total_steps: 6d}', end='\r') |
|
|
|
|
|
intermediate_tokens = self._generate_tokens(attributes=empty_attributes, |
|
|
prompt_tokens=prompt_tokens, |
|
|
source_flowstep=1.0, |
|
|
progress=progress, |
|
|
callback=_progress_callback, |
|
|
**inversion_params, |
|
|
) |
|
|
if intermediate_tokens.shape[0] < len(descriptions): |
|
|
intermediate_tokens = intermediate_tokens.repeat(len(descriptions)//intermediate_tokens.shape[0], 1, 1) |
|
|
current_step_offset += inversion_params["steps"] * ( |
|
|
inversion_params["regularize_iters"]) if inversion_params["regularize"] else inversion_params["steps"] |
|
|
inversion_params.pop("regularize") |
|
|
final_tokens = self._generate_tokens(attributes=edit_attributes, |
|
|
prompt_tokens=intermediate_tokens, |
|
|
source_flowstep=inversion_params.pop("target_flowstep"), |
|
|
target_flowstep=1.0, |
|
|
progress=progress, |
|
|
callback=_progress_callback, |
|
|
**inversion_params,) |
|
|
if return_tokens: |
|
|
return self.generate_audio(final_tokens), final_tokens |
|
|
return self.generate_audio(final_tokens) |
|
|
|
|
|
def _generate_tokens(self, |
|
|
attributes: tp.List[ConditioningAttributes], |
|
|
prompt_tokens: tp.Optional[torch.Tensor], |
|
|
progress: bool = False, |
|
|
callback: tp.Optional[tp.Callable[[int, int], None]] = None, |
|
|
**kwargs) -> torch.Tensor: |
|
|
"""Generate continuous audio tokens given audio prompt and/or conditions. |
|
|
|
|
|
Args: |
|
|
attributes (list of ConditioningAttributes): Conditions used for generation (here text). |
|
|
prompt_tokens (torch.Tensor, optional): Audio prompt used as initial latent sequence. |
|
|
progress (bool, optional): Flag to display progress of the generation process. Defaults to False. |
|
|
Returns: |
|
|
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. |
|
|
""" |
|
|
generate_params = kwargs.copy() |
|
|
total_gen_len = prompt_tokens.shape[-1] if prompt_tokens is not None else int( |
|
|
generate_params.pop('duration') * self.frame_rate) |
|
|
current_step_offset: int = 0 |
|
|
|
|
|
def _progress_callback(elapsed_steps: int, total_steps: int): |
|
|
elapsed_steps += current_step_offset |
|
|
if self._progress_callback is not None: |
|
|
self._progress_callback(elapsed_steps, total_steps) |
|
|
else: |
|
|
print(f'{elapsed_steps: 6d} / {total_steps: 6d}', end='\r') |
|
|
|
|
|
if progress and callback is None: |
|
|
callback = _progress_callback |
|
|
|
|
|
assert total_gen_len <= int(self.max_duration * self.frame_rate) |
|
|
|
|
|
with self.autocast: |
|
|
gen_tokens = self.lm.generate( |
|
|
prompt=prompt_tokens, |
|
|
conditions=attributes, |
|
|
callback=callback, |
|
|
max_gen_len=total_gen_len, |
|
|
**generate_params, |
|
|
) |
|
|
|
|
|
return gen_tokens |
|
|
|