debug special token
Browse files- audiocraft/builders.py +26 -20
- audiocraft/conditioners.py +1 -8
- audiocraft/encodec.py +9 -51
- audiocraft/lm.py +34 -70
- audiocraft/loaders.py +1 -1
- audiocraft/seanet.py +14 -125
- audiocraft/utils/utils.py +11 -136
audiocraft/builders.py
CHANGED
|
@@ -17,7 +17,7 @@ import torch
|
|
| 17 |
|
| 18 |
from .encodec import CompressionModel, EncodecModel
|
| 19 |
from .lm import LMModel
|
| 20 |
-
from .seanet import
|
| 21 |
from .codebooks_patterns import (
|
| 22 |
CodebooksPatternProvider,
|
| 23 |
DelayedPatternProvider,
|
|
@@ -49,34 +49,40 @@ def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) ->
|
|
| 49 |
return klass(**kwargs)
|
| 50 |
|
| 51 |
|
| 52 |
-
def get_encodec_autoencoder(
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
decoder = SEANetDecoder(**decoder_kwargs)
|
| 61 |
-
return encoder, decoder
|
| 62 |
-
else:
|
| 63 |
-
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
| 64 |
|
| 65 |
|
| 66 |
-
def get_compression_model(cfg
|
| 67 |
"""Instantiate a compression model."""
|
| 68 |
if cfg.compression_model == 'encodec':
|
| 69 |
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
| 70 |
-
encoder_name = kwargs.pop('autoencoder')
|
| 71 |
quantizer_name = kwargs.pop('quantizer')
|
| 72 |
-
|
| 73 |
-
quantizer = get_quantizer(quantizer_name, cfg,
|
| 74 |
-
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
| 75 |
renormalize = kwargs.pop('renormalize', False)
|
| 76 |
# deprecated params
|
|
|
|
| 77 |
kwargs.pop('renorm', None)
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
else:
|
| 81 |
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
| 82 |
|
|
|
|
| 17 |
|
| 18 |
from .encodec import CompressionModel, EncodecModel
|
| 19 |
from .lm import LMModel
|
| 20 |
+
from .seanet import SEANetDecoder
|
| 21 |
from .codebooks_patterns import (
|
| 22 |
CodebooksPatternProvider,
|
| 23 |
DelayedPatternProvider,
|
|
|
|
| 49 |
return klass(**kwargs)
|
| 50 |
|
| 51 |
|
| 52 |
+
def get_encodec_autoencoder(cfg):
|
| 53 |
+
kwargs = dict_from_config(getattr(cfg, 'seanet'))
|
| 54 |
+
_ = kwargs.pop('encoder')
|
| 55 |
+
decoder_override_kwargs = kwargs.pop('decoder')
|
| 56 |
+
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
| 57 |
+
decoder = SEANetDecoder(**decoder_kwargs)
|
| 58 |
+
return decoder
|
| 59 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
+
def get_compression_model(cfg):
|
| 63 |
"""Instantiate a compression model."""
|
| 64 |
if cfg.compression_model == 'encodec':
|
| 65 |
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
|
|
|
| 66 |
quantizer_name = kwargs.pop('quantizer')
|
| 67 |
+
decoder = get_encodec_autoencoder(cfg)
|
| 68 |
+
quantizer = get_quantizer(quantizer_name, cfg, 128)
|
|
|
|
| 69 |
renormalize = kwargs.pop('renormalize', False)
|
| 70 |
# deprecated params
|
| 71 |
+
# print(f'{frame_rate=} {encoder.dimension=}') frame_rate=50 encoder.dimension=128
|
| 72 |
kwargs.pop('renorm', None)
|
| 73 |
+
# print('\n______!____________\n', kwargs, '\n______!____________\n')
|
| 74 |
+
# ______!____________
|
| 75 |
+
# {'autoencoder': 'seanet', 'sample_rate': 16000, 'channels': 1, 'causal': False}
|
| 76 |
+
# ______!____________
|
| 77 |
+
|
| 78 |
+
return EncodecModel(decoder=decoder,
|
| 79 |
+
quantizer=quantizer,
|
| 80 |
+
frame_rate=50,
|
| 81 |
+
renormalize=renormalize,
|
| 82 |
+
sample_rate=16000,
|
| 83 |
+
channels=1,
|
| 84 |
+
causal=False
|
| 85 |
+
).to(cfg.device)
|
| 86 |
else:
|
| 87 |
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
| 88 |
|
audiocraft/conditioners.py
CHANGED
|
@@ -1,11 +1,7 @@
|
|
| 1 |
from collections import defaultdict
|
| 2 |
from dataclasses import dataclass, field
|
| 3 |
-
from itertools import chain
|
| 4 |
import logging
|
| 5 |
-
import math
|
| 6 |
-
from pathlib import Path
|
| 7 |
import random
|
| 8 |
-
import re
|
| 9 |
import typing as tp
|
| 10 |
import warnings
|
| 11 |
import soundfile
|
|
@@ -14,11 +10,8 @@ import torch
|
|
| 14 |
from torch import nn
|
| 15 |
from .streaming import StreamingModule
|
| 16 |
|
| 17 |
-
|
| 18 |
-
from .quantization import ResidualVectorQuantizer
|
| 19 |
from .utils.autocast import TorchAutocast
|
| 20 |
-
|
| 21 |
-
from .utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
|
| 22 |
|
| 23 |
|
| 24 |
logger = logging.getLogger(__name__)
|
|
|
|
| 1 |
from collections import defaultdict
|
| 2 |
from dataclasses import dataclass, field
|
|
|
|
| 3 |
import logging
|
|
|
|
|
|
|
| 4 |
import random
|
|
|
|
| 5 |
import typing as tp
|
| 6 |
import warnings
|
| 7 |
import soundfile
|
|
|
|
| 10 |
from torch import nn
|
| 11 |
from .streaming import StreamingModule
|
| 12 |
|
|
|
|
|
|
|
| 13 |
from .utils.autocast import TorchAutocast
|
| 14 |
+
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
audiocraft/encodec.py
CHANGED
|
@@ -30,14 +30,7 @@ class CompressionModel(ABC, nn.Module):
|
|
| 30 |
with a language model.
|
| 31 |
"""
|
| 32 |
|
| 33 |
-
|
| 34 |
-
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
| 35 |
-
...
|
| 36 |
-
|
| 37 |
-
@abstractmethod
|
| 38 |
-
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
| 39 |
-
"""See `EncodecModel.encode`."""
|
| 40 |
-
...
|
| 41 |
|
| 42 |
@abstractmethod
|
| 43 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
|
@@ -142,16 +135,15 @@ class EncodecModel(CompressionModel):
|
|
| 142 |
channels: int = 0
|
| 143 |
|
| 144 |
def __init__(self,
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
renormalize: bool = False):
|
| 153 |
super().__init__()
|
| 154 |
-
|
| 155 |
self.decoder = decoder
|
| 156 |
self.quantizer = quantizer
|
| 157 |
self.frame_rate = frame_rate
|
|
@@ -203,40 +195,6 @@ class EncodecModel(CompressionModel):
|
|
| 203 |
x = x * scale.view(-1, 1, 1)
|
| 204 |
return x
|
| 205 |
|
| 206 |
-
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
| 207 |
-
assert x.dim() == 3
|
| 208 |
-
length = x.shape[-1]
|
| 209 |
-
x, scale = self.preprocess(x)
|
| 210 |
-
|
| 211 |
-
emb = self.encoder(x)
|
| 212 |
-
q_res = self.quantizer(emb, self.frame_rate)
|
| 213 |
-
out = self.decoder(q_res.x)
|
| 214 |
-
|
| 215 |
-
# remove extra padding added by the encoder and decoder
|
| 216 |
-
assert out.shape[-1] >= length, (out.shape[-1], length)
|
| 217 |
-
out = out[..., :length]
|
| 218 |
-
|
| 219 |
-
q_res.x = self.postprocess(out, scale)
|
| 220 |
-
|
| 221 |
-
return q_res
|
| 222 |
-
|
| 223 |
-
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
| 224 |
-
"""Encode the given input tensor to quantized representation along with scale parameter.
|
| 225 |
-
|
| 226 |
-
Args:
|
| 227 |
-
x (torch.Tensor): Float tensor of shape [B, C, T]
|
| 228 |
-
|
| 229 |
-
Returns:
|
| 230 |
-
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
|
| 231 |
-
codes: a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
| 232 |
-
scale: a float tensor containing the scale for audio renormalization.
|
| 233 |
-
"""
|
| 234 |
-
assert x.dim() == 3
|
| 235 |
-
x, scale = self.preprocess(x)
|
| 236 |
-
emb = self.encoder(x)
|
| 237 |
-
codes = self.quantizer.encode(emb)
|
| 238 |
-
return codes, scale
|
| 239 |
-
|
| 240 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
| 241 |
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
| 242 |
audio denormalization if needed.
|
|
|
|
| 30 |
with a language model.
|
| 31 |
"""
|
| 32 |
|
| 33 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
@abstractmethod
|
| 36 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
|
|
|
| 135 |
channels: int = 0
|
| 136 |
|
| 137 |
def __init__(self,
|
| 138 |
+
decoder=None,
|
| 139 |
+
quantizer=None,
|
| 140 |
+
frame_rate=None,
|
| 141 |
+
sample_rate=None,
|
| 142 |
+
channels=None,
|
| 143 |
+
causal=False,
|
| 144 |
+
renormalize=False):
|
|
|
|
| 145 |
super().__init__()
|
| 146 |
+
|
| 147 |
self.decoder = decoder
|
| 148 |
self.quantizer = quantizer
|
| 149 |
self.frame_rate = frame_rate
|
|
|
|
| 195 |
x = x * scale.view(-1, 1, 1)
|
| 196 |
return x
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
| 199 |
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
| 200 |
audio denormalization if needed.
|
audiocraft/lm.py
CHANGED
|
@@ -14,16 +14,14 @@ import warnings
|
|
| 14 |
import einops
|
| 15 |
from num2words import num2words
|
| 16 |
import spacy
|
| 17 |
-
from transformers import
|
| 18 |
import torch
|
| 19 |
import torch.nn.functional as F
|
| 20 |
from torch.nn.utils.rnn import pad_sequence
|
| 21 |
from audiocraft.streaming import StreamingModule
|
| 22 |
from audiocraft.transformer import create_sin_embedding
|
| 23 |
-
from audiocraft.utils.audio_utils import convert_audio
|
| 24 |
from audiocraft.utils.autocast import TorchAutocast
|
| 25 |
-
from audiocraft.utils.
|
| 26 |
-
from audiocraft.utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
|
| 27 |
from audiocraft.transformer import StreamingTransformer, create_norm_fn
|
| 28 |
from dataclasses import dataclass
|
| 29 |
from functools import partial
|
|
@@ -297,13 +295,7 @@ class BaseConditioner(nn.Module):
|
|
| 297 |
self.output_dim = output_dim
|
| 298 |
self.output_proj = nn.Linear(dim, output_dim)
|
| 299 |
|
| 300 |
-
|
| 301 |
-
"""Should be any part of the processing that will lead to a synchronization
|
| 302 |
-
point, e.g. BPE tokenization with transfer to the GPU.
|
| 303 |
-
|
| 304 |
-
The returned value will be saved and return later when calling forward().
|
| 305 |
-
"""
|
| 306 |
-
raise NotImplementedError()
|
| 307 |
|
| 308 |
def forward(self, inputs: tp.Any) -> ConditionType:
|
| 309 |
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
|
|
@@ -530,34 +522,6 @@ class ConditioningProvider(nn.Module):
|
|
| 530 |
def has_wav_condition(self):
|
| 531 |
return len(self.wav_conditions) > 0
|
| 532 |
|
| 533 |
-
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
| 534 |
-
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
|
| 535 |
-
This should be called before starting any real GPU work to avoid synchronization points.
|
| 536 |
-
This will return a dict matching conditioner names to their arbitrary tokenized representations.
|
| 537 |
-
|
| 538 |
-
Args:
|
| 539 |
-
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
|
| 540 |
-
text and wav conditions.
|
| 541 |
-
"""
|
| 542 |
-
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
|
| 543 |
-
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
|
| 544 |
-
f" but types were {set([type(x) for x in inputs])}"
|
| 545 |
-
)
|
| 546 |
-
|
| 547 |
-
output = {}
|
| 548 |
-
text = self._collate_text(inputs)
|
| 549 |
-
wavs = self._collate_wavs(inputs)
|
| 550 |
-
joint_embeds = self._collate_joint_embeds(inputs)
|
| 551 |
-
|
| 552 |
-
assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
|
| 553 |
-
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
| 554 |
-
f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
|
| 555 |
-
)
|
| 556 |
-
|
| 557 |
-
for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
|
| 558 |
-
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
| 559 |
-
return output
|
| 560 |
-
|
| 561 |
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
| 562 |
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
| 563 |
The output is for example:
|
|
@@ -780,6 +744,7 @@ class ConditionFuser(StreamingModule):
|
|
| 780 |
raise ValueError(f"unknown op ({op})")
|
| 781 |
|
| 782 |
if self.cross_attention_pos_emb and cross_attention_output is not None:
|
|
|
|
| 783 |
positions = torch.arange(
|
| 784 |
cross_attention_output.shape[1],
|
| 785 |
device=cross_attention_output.device
|
|
@@ -925,7 +890,7 @@ class LMModel(StreamingModule):
|
|
| 925 |
|
| 926 |
self.condition_provider = condition_provider
|
| 927 |
self.fuser = fuser
|
| 928 |
-
self.card = card
|
| 929 |
embed_dim = self.card + 1
|
| 930 |
self.n_q = n_q
|
| 931 |
self.dim = dim
|
|
@@ -1030,6 +995,7 @@ class LMModel(StreamingModule):
|
|
| 1030 |
# remove the prefix from the model outputs
|
| 1031 |
if len(self.fuser.fuse2cond['prepend']) > 0:
|
| 1032 |
logits = logits[:, :, -S:]
|
|
|
|
| 1033 |
|
| 1034 |
return logits # [B, K, S, card]
|
| 1035 |
|
|
@@ -1067,6 +1033,8 @@ class LMModel(StreamingModule):
|
|
| 1067 |
B, K, T = codes.shape
|
| 1068 |
codes = codes.contiguous()
|
| 1069 |
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
|
|
|
|
|
|
| 1070 |
pattern = self.pattern_provider.get_pattern(T)
|
| 1071 |
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
| 1072 |
codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
|
|
@@ -1118,35 +1086,33 @@ class LMModel(StreamingModule):
|
|
| 1118 |
model = self if self._fsdp is None else self._fsdp
|
| 1119 |
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
| 1120 |
if two_step_cfg and cfg_conditions != {}:
|
| 1121 |
-
|
| 1122 |
-
condition_tensors, null_condition_tensors = cfg_conditions
|
| 1123 |
-
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
|
| 1124 |
-
state = self.get_streaming_state()
|
| 1125 |
-
self.set_streaming_state(unconditional_state)
|
| 1126 |
-
uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
|
| 1127 |
-
unconditional_state.update(self.get_streaming_state())
|
| 1128 |
-
self.set_streaming_state(state)
|
| 1129 |
-
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
|
| 1130 |
else:
|
|
|
|
| 1131 |
assert isinstance(cfg_conditions, dict)
|
| 1132 |
condition_tensors = cfg_conditions
|
| 1133 |
if condition_tensors:
|
|
|
|
| 1134 |
# Preparing for CFG, predicting both conditional and unconditional logits.
|
| 1135 |
sequence = torch.cat([sequence, sequence], dim=0)
|
| 1136 |
all_logits = model(
|
| 1137 |
sequence,
|
| 1138 |
conditions=[], condition_tensors=condition_tensors)
|
| 1139 |
if condition_tensors:
|
| 1140 |
-
cond_logits, uncond_logits = all_logits.split(B, dim=0)
|
| 1141 |
-
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
|
|
|
|
|
|
|
| 1142 |
else:
|
| 1143 |
-
|
|
|
|
| 1144 |
|
| 1145 |
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
| 1146 |
logits = logits[..., -1] # [B x K x card]
|
| 1147 |
|
| 1148 |
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
| 1149 |
if use_sampling and temp > 0.0:
|
|
|
|
| 1150 |
probs = torch.softmax(logits / temp, dim=-1)
|
| 1151 |
if top_p > 0.0:
|
| 1152 |
next_token = utils.sample_top_p(probs, p=top_p)
|
|
@@ -1155,7 +1121,9 @@ class LMModel(StreamingModule):
|
|
| 1155 |
else:
|
| 1156 |
next_token = utils.multinomial(probs, num_samples=1)
|
| 1157 |
else:
|
| 1158 |
-
|
|
|
|
|
|
|
| 1159 |
|
| 1160 |
return next_token
|
| 1161 |
|
|
@@ -1249,9 +1217,9 @@ class LMModel(StreamingModule):
|
|
| 1249 |
# this token is used as default value for codes that are not generated yet
|
| 1250 |
unknown_token = -1
|
| 1251 |
|
| 1252 |
-
|
| 1253 |
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
|
| 1254 |
-
|
| 1255 |
gen_codes[..., :start_offset] = prompt
|
| 1256 |
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
| 1257 |
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
|
@@ -1280,9 +1248,17 @@ class LMModel(StreamingModule):
|
|
| 1280 |
# ensure the tokens that should be masked are properly set to special_token_id
|
| 1281 |
# as the model never output special_token_id
|
| 1282 |
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
| 1283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1284 |
# ensure we don't overwrite prompt tokens, we only write over unknown tokens
|
| 1285 |
-
|
| 1286 |
gen_sequence[..., offset:offset+1] = torch.where(
|
| 1287 |
gen_sequence[..., offset:offset+1] == unknown_token,
|
| 1288 |
next_token, gen_sequence[..., offset:offset+1]
|
|
@@ -1292,23 +1268,11 @@ class LMModel(StreamingModule):
|
|
| 1292 |
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
| 1293 |
unconditional_state.clear()
|
| 1294 |
|
| 1295 |
-
# ensure sequence has been entirely filled
|
| 1296 |
-
assert not (gen_sequence == unknown_token).any()
|
| 1297 |
-
# ensure gen_sequence pattern and mask are matching
|
| 1298 |
-
# which means the gen_sequence is valid according to the pattern
|
| 1299 |
-
assert (
|
| 1300 |
-
gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
|
| 1301 |
-
).all()
|
| 1302 |
-
# get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
|
| 1303 |
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
| 1304 |
|
| 1305 |
-
# sanity checks over the returned codes and corresponding masks
|
| 1306 |
-
assert (out_codes[..., :max_gen_len] != unknown_token).all()
|
| 1307 |
-
assert (out_mask[..., :max_gen_len] == 1).all()
|
| 1308 |
-
|
| 1309 |
out_start_offset = start_offset if remove_prompts else 0
|
| 1310 |
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
| 1311 |
|
| 1312 |
# ensure the returned codes are all valid
|
| 1313 |
-
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
| 1314 |
return out_codes
|
|
|
|
| 14 |
import einops
|
| 15 |
from num2words import num2words
|
| 16 |
import spacy
|
| 17 |
+
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
| 18 |
import torch
|
| 19 |
import torch.nn.functional as F
|
| 20 |
from torch.nn.utils.rnn import pad_sequence
|
| 21 |
from audiocraft.streaming import StreamingModule
|
| 22 |
from audiocraft.transformer import create_sin_embedding
|
|
|
|
| 23 |
from audiocraft.utils.autocast import TorchAutocast
|
| 24 |
+
from audiocraft.utils.utils import collate, length_to_mask
|
|
|
|
| 25 |
from audiocraft.transformer import StreamingTransformer, create_norm_fn
|
| 26 |
from dataclasses import dataclass
|
| 27 |
from functools import partial
|
|
|
|
| 295 |
self.output_dim = output_dim
|
| 296 |
self.output_proj = nn.Linear(dim, output_dim)
|
| 297 |
|
| 298 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
def forward(self, inputs: tp.Any) -> ConditionType:
|
| 301 |
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
|
|
|
|
| 522 |
def has_wav_condition(self):
|
| 523 |
return len(self.wav_conditions) > 0
|
| 524 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
| 526 |
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
| 527 |
The output is for example:
|
|
|
|
| 744 |
raise ValueError(f"unknown op ({op})")
|
| 745 |
|
| 746 |
if self.cross_attention_pos_emb and cross_attention_output is not None:
|
| 747 |
+
print('SIN EMBED')
|
| 748 |
positions = torch.arange(
|
| 749 |
cross_attention_output.shape[1],
|
| 750 |
device=cross_attention_output.device
|
|
|
|
| 890 |
|
| 891 |
self.condition_provider = condition_provider
|
| 892 |
self.fuser = fuser
|
| 893 |
+
self.card = card # 2048 ?
|
| 894 |
embed_dim = self.card + 1
|
| 895 |
self.n_q = n_q
|
| 896 |
self.dim = dim
|
|
|
|
| 995 |
# remove the prefix from the model outputs
|
| 996 |
if len(self.fuser.fuse2cond['prepend']) > 0:
|
| 997 |
logits = logits[:, :, -S:]
|
| 998 |
+
print('PRESFIX')
|
| 999 |
|
| 1000 |
return logits # [B, K, S, card]
|
| 1001 |
|
|
|
|
| 1033 |
B, K, T = codes.shape
|
| 1034 |
codes = codes.contiguous()
|
| 1035 |
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
| 1036 |
+
# what is the T is it 2048 ?
|
| 1037 |
+
# and then what is pattern -> another function?
|
| 1038 |
pattern = self.pattern_provider.get_pattern(T)
|
| 1039 |
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
| 1040 |
codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
|
|
|
|
| 1086 |
model = self if self._fsdp is None else self._fsdp
|
| 1087 |
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
| 1088 |
if two_step_cfg and cfg_conditions != {}:
|
| 1089 |
+
print('\nNOT HERE\n')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1090 |
else:
|
| 1091 |
+
print('C')
|
| 1092 |
assert isinstance(cfg_conditions, dict)
|
| 1093 |
condition_tensors = cfg_conditions
|
| 1094 |
if condition_tensors:
|
| 1095 |
+
# print('\nD\n')
|
| 1096 |
# Preparing for CFG, predicting both conditional and unconditional logits.
|
| 1097 |
sequence = torch.cat([sequence, sequence], dim=0)
|
| 1098 |
all_logits = model(
|
| 1099 |
sequence,
|
| 1100 |
conditions=[], condition_tensors=condition_tensors)
|
| 1101 |
if condition_tensors:
|
| 1102 |
+
cond_logits, uncond_logits = all_logits.split(B, dim=0) #torch.Size([2, 4, 1, 2048])
|
| 1103 |
+
# logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
|
| 1104 |
+
# logits = 3 * cond_logits - 2.4 * uncond_logits
|
| 1105 |
+
logits = 2 * cond_logits - 1.4 * uncond_logits
|
| 1106 |
else:
|
| 1107 |
+
print('\nF!\n')
|
| 1108 |
+
|
| 1109 |
|
| 1110 |
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
| 1111 |
logits = logits[..., -1] # [B x K x card]
|
| 1112 |
|
| 1113 |
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
| 1114 |
if use_sampling and temp > 0.0:
|
| 1115 |
+
# print(f'\nR {temp=} {top_p=} {top_k=}\n') -------------> R temp=1.0 top_p=0.0 top_k=250
|
| 1116 |
probs = torch.softmax(logits / temp, dim=-1)
|
| 1117 |
if top_p > 0.0:
|
| 1118 |
next_token = utils.sample_top_p(probs, p=top_p)
|
|
|
|
| 1121 |
else:
|
| 1122 |
next_token = utils.multinomial(probs, num_samples=1)
|
| 1123 |
else:
|
| 1124 |
+
#
|
| 1125 |
+
print('\nNeverHere\n')
|
| 1126 |
+
|
| 1127 |
|
| 1128 |
return next_token
|
| 1129 |
|
|
|
|
| 1217 |
# this token is used as default value for codes that are not generated yet
|
| 1218 |
unknown_token = -1
|
| 1219 |
|
| 1220 |
+
|
| 1221 |
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
|
| 1222 |
+
|
| 1223 |
gen_codes[..., :start_offset] = prompt
|
| 1224 |
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
| 1225 |
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
|
|
|
| 1248 |
# ensure the tokens that should be masked are properly set to special_token_id
|
| 1249 |
# as the model never output special_token_id
|
| 1250 |
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
| 1251 |
+
|
| 1252 |
+
# next_token[~valid_mask] = self.special_token_id
|
| 1253 |
+
|
| 1254 |
+
# print(f'{unconditional_state=} \n
|
| 1255 |
+
# print('Set All to Special')
|
| 1256 |
+
# next_token[:] = self.special_token_id
|
| 1257 |
+
|
| 1258 |
+
|
| 1259 |
+
|
| 1260 |
# ensure we don't overwrite prompt tokens, we only write over unknown tokens
|
| 1261 |
+
|
| 1262 |
gen_sequence[..., offset:offset+1] = torch.where(
|
| 1263 |
gen_sequence[..., offset:offset+1] == unknown_token,
|
| 1264 |
next_token, gen_sequence[..., offset:offset+1]
|
|
|
|
| 1268 |
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
| 1269 |
unconditional_state.clear()
|
| 1270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1271 |
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
| 1272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1273 |
out_start_offset = start_offset if remove_prompts else 0
|
| 1274 |
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
| 1275 |
|
| 1276 |
# ensure the returned codes are all valid
|
| 1277 |
+
# assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
| 1278 |
return out_codes
|
audiocraft/loaders.py
CHANGED
|
@@ -79,7 +79,7 @@ def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu',
|
|
| 79 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
| 80 |
cfg.device = str(device)
|
| 81 |
model = builders.get_compression_model(cfg)
|
| 82 |
-
model.load_state_dict(pkg['best_state'])
|
| 83 |
model.eval()
|
| 84 |
return model
|
| 85 |
|
|
|
|
| 79 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
| 80 |
cfg.device = str(device)
|
| 81 |
model = builders.get_compression_model(cfg)
|
| 82 |
+
model.load_state_dict(pkg['best_state'], strict=False) # ckpt contains uninstantiated encoder
|
| 83 |
model.eval()
|
| 84 |
return model
|
| 85 |
|
audiocraft/seanet.py
CHANGED
|
@@ -60,136 +60,25 @@ class SEANetResnetBlock(nn.Module):
|
|
| 60 |
return self.shortcut(x) + self.block(x)
|
| 61 |
|
| 62 |
|
| 63 |
-
class SEANetEncoder(nn.Module):
|
| 64 |
-
"""SEANet encoder.
|
| 65 |
|
| 66 |
-
Args:
|
| 67 |
-
channels (int): Audio channels.
|
| 68 |
-
dimension (int): Intermediate representation dimension.
|
| 69 |
-
n_filters (int): Base width for the model.
|
| 70 |
-
n_residual_layers (int): nb of residual layers.
|
| 71 |
-
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
| 72 |
-
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
| 73 |
-
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
|
| 74 |
-
activation (str): Activation function.
|
| 75 |
-
activation_params (dict): Parameters to provide to the activation function.
|
| 76 |
-
norm (str): Normalization method.
|
| 77 |
-
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 78 |
-
kernel_size (int): Kernel size for the initial convolution.
|
| 79 |
-
last_kernel_size (int): Kernel size for the initial convolution.
|
| 80 |
-
residual_kernel_size (int): Kernel size for the residual layers.
|
| 81 |
-
dilation_base (int): How much to increase the dilation with each layer.
|
| 82 |
-
causal (bool): Whether to use fully causal convolution.
|
| 83 |
-
pad_mode (str): Padding mode for the convolutions.
|
| 84 |
-
true_skip (bool): Whether to use true skip connection or a simple
|
| 85 |
-
(streamable) convolution as the skip connection in the residual network blocks.
|
| 86 |
-
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 87 |
-
lstm (int): Number of LSTM layers at the end of the encoder.
|
| 88 |
-
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
| 89 |
-
For the encoder, it corresponds to the N first blocks.
|
| 90 |
-
"""
|
| 91 |
-
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
| 92 |
-
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
| 93 |
-
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
| 94 |
-
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
| 95 |
-
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
| 96 |
-
disable_norm_outer_blocks: int = 0):
|
| 97 |
-
super().__init__()
|
| 98 |
-
self.channels = channels
|
| 99 |
-
self.dimension = dimension
|
| 100 |
-
self.n_filters = n_filters
|
| 101 |
-
self.ratios = list(reversed(ratios))
|
| 102 |
-
del ratios
|
| 103 |
-
self.n_residual_layers = n_residual_layers
|
| 104 |
-
self.hop_length = np.prod(self.ratios)
|
| 105 |
-
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
| 106 |
-
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
| 107 |
-
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
| 108 |
-
"Number of blocks for which to disable norm is invalid." \
|
| 109 |
-
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
| 110 |
-
|
| 111 |
-
act = getattr(nn, activation)
|
| 112 |
-
mult = 1
|
| 113 |
-
model: tp.List[nn.Module] = [
|
| 114 |
-
StreamableConv1d(channels, mult * n_filters, kernel_size,
|
| 115 |
-
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
| 116 |
-
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
| 117 |
-
]
|
| 118 |
-
# Downsample to raw audio scale
|
| 119 |
-
for i, ratio in enumerate(self.ratios):
|
| 120 |
-
block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
|
| 121 |
-
# Add residual layers
|
| 122 |
-
for j in range(n_residual_layers):
|
| 123 |
-
model += [
|
| 124 |
-
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
|
| 125 |
-
dilations=[dilation_base ** j, 1],
|
| 126 |
-
norm=block_norm, norm_params=norm_params,
|
| 127 |
-
activation=activation, activation_params=activation_params,
|
| 128 |
-
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
| 129 |
-
|
| 130 |
-
# Add downsampling layers
|
| 131 |
-
model += [
|
| 132 |
-
act(**activation_params),
|
| 133 |
-
StreamableConv1d(mult * n_filters, mult * n_filters * 2,
|
| 134 |
-
kernel_size=ratio * 2, stride=ratio,
|
| 135 |
-
norm=block_norm, norm_kwargs=norm_params,
|
| 136 |
-
causal=causal, pad_mode=pad_mode),
|
| 137 |
-
]
|
| 138 |
-
mult *= 2
|
| 139 |
-
|
| 140 |
-
if lstm:
|
| 141 |
-
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
| 142 |
-
|
| 143 |
-
model += [
|
| 144 |
-
act(**activation_params),
|
| 145 |
-
StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
|
| 146 |
-
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
| 147 |
-
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
| 148 |
-
]
|
| 149 |
-
|
| 150 |
-
self.model = nn.Sequential(*model)
|
| 151 |
-
|
| 152 |
-
def forward(self, x):
|
| 153 |
-
return self.model(x)
|
| 154 |
|
| 155 |
|
| 156 |
class SEANetDecoder(nn.Module):
|
| 157 |
-
"""SEANet decoder.
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
residual_kernel_size (int): Kernel size for the residual layers.
|
| 174 |
-
dilation_base (int): How much to increase the dilation with each layer.
|
| 175 |
-
causal (bool): Whether to use fully causal convolution.
|
| 176 |
-
pad_mode (str): Padding mode for the convolutions.
|
| 177 |
-
true_skip (bool): Whether to use true skip connection or a simple.
|
| 178 |
-
(streamable) convolution as the skip connection in the residual network blocks.
|
| 179 |
-
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 180 |
-
lstm (int): Number of LSTM layers at the end of the encoder.
|
| 181 |
-
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
| 182 |
-
For the decoder, it corresponds to the N last blocks.
|
| 183 |
-
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
| 184 |
-
If equal to 1.0, it means that all the trimming is done at the right.
|
| 185 |
-
"""
|
| 186 |
-
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
| 187 |
-
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
| 188 |
-
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
|
| 189 |
-
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
| 190 |
-
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
| 191 |
-
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
| 192 |
-
disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
|
| 193 |
super().__init__()
|
| 194 |
self.dimension = dimension
|
| 195 |
self.channels = channels
|
|
|
|
| 60 |
return self.shortcut(x) + self.block(x)
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
class SEANetDecoder(nn.Module):
|
|
|
|
| 67 |
|
| 68 |
+
def __init__(self, channels: int = 1,
|
| 69 |
+
dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
| 70 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU',
|
| 71 |
+
activation_params: dict = {'alpha': 1.0},
|
| 72 |
+
final_activation: tp.Optional[str] = None,
|
| 73 |
+
final_activation_params: tp.Optional[dict] = None,
|
| 74 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {},
|
| 75 |
+
kernel_size: int = 7,
|
| 76 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3,
|
| 77 |
+
dilation_base: int = 2, causal: bool = False,
|
| 78 |
+
pad_mode: str = 'reflect', true_skip: bool = True,
|
| 79 |
+
compress: int = 2, lstm: int = 0,
|
| 80 |
+
disable_norm_outer_blocks: int = 0,
|
| 81 |
+
trim_right_ratio: float = 1.0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
super().__init__()
|
| 83 |
self.dimension = dimension
|
| 84 |
self.channels = channels
|
audiocraft/utils/utils.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
# This source code is licensed under the license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
-
|
| 8 |
from contextlib import contextmanager
|
| 9 |
from functools import wraps, lru_cache
|
| 10 |
import hashlib
|
|
@@ -103,6 +103,9 @@ def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, gen
|
|
| 103 |
input_ = input.reshape(-1, input.shape[-1])
|
| 104 |
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
| 105 |
output = output_.reshape(*list(input.shape[:-1]), -1)
|
|
|
|
|
|
|
|
|
|
| 106 |
return output
|
| 107 |
|
| 108 |
|
|
@@ -115,61 +118,18 @@ def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
|
| 115 |
Returns:
|
| 116 |
torch.Tensor: Sampled tokens.
|
| 117 |
"""
|
| 118 |
-
top_k_value,
|
| 119 |
-
min_value_top_k = top_k_value[..., [-1]]
|
| 120 |
-
probs *= (probs >= min_value_top_k).float()
|
| 121 |
-
probs.div_(probs.sum(dim=-1, keepdim=True))
|
| 122 |
next_token = multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
|
| 123 |
return next_token
|
| 124 |
|
| 125 |
|
| 126 |
-
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
| 127 |
-
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
| 128 |
-
|
| 129 |
-
Args:
|
| 130 |
-
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
| 131 |
-
p (int): The p in “top-p”.
|
| 132 |
-
Returns:
|
| 133 |
-
torch.Tensor: Sampled tokens.
|
| 134 |
-
"""
|
| 135 |
-
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
| 136 |
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
| 137 |
-
mask = probs_sum - probs_sort > p
|
| 138 |
-
probs_sort *= (~mask).float()
|
| 139 |
-
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
| 140 |
-
next_token = multinomial(probs_sort, num_samples=1)
|
| 141 |
-
next_token = torch.gather(probs_idx, -1, next_token)
|
| 142 |
-
return next_token
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
class DummyPoolExecutor:
|
| 146 |
-
"""Dummy pool executor to use when we actually have only 1 worker.
|
| 147 |
-
(e.g. instead of ProcessPoolExecutor).
|
| 148 |
-
"""
|
| 149 |
-
class DummyResult:
|
| 150 |
-
def __init__(self, func, *args, **kwargs):
|
| 151 |
-
self.func = func
|
| 152 |
-
self.args = args
|
| 153 |
-
self.kwargs = kwargs
|
| 154 |
-
|
| 155 |
-
def result(self):
|
| 156 |
-
return self.func(*self.args, **self.kwargs)
|
| 157 |
-
|
| 158 |
-
def __init__(self, workers, mp_context=None):
|
| 159 |
-
pass
|
| 160 |
-
|
| 161 |
-
def submit(self, func, *args, **kwargs):
|
| 162 |
-
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
| 163 |
-
|
| 164 |
-
def __enter__(self):
|
| 165 |
-
return self
|
| 166 |
|
| 167 |
-
def __exit__(self, exc_type, exc_value, exc_tb):
|
| 168 |
-
return
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
def get_pool_executor(num_workers: int, mp_context=None):
|
| 172 |
-
return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
|
| 173 |
|
| 174 |
|
| 175 |
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
|
|
@@ -188,42 +148,6 @@ def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> t
|
|
| 188 |
return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None]
|
| 189 |
|
| 190 |
|
| 191 |
-
def hash_trick(word: str, vocab_size: int) -> int:
|
| 192 |
-
"""Hash trick to pair each word with an index
|
| 193 |
-
|
| 194 |
-
Args:
|
| 195 |
-
word (str): word we wish to convert to an index
|
| 196 |
-
vocab_size (int): size of the vocabulary
|
| 197 |
-
Returns:
|
| 198 |
-
int: index of the word in the embedding LUT
|
| 199 |
-
"""
|
| 200 |
-
hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
|
| 201 |
-
return hash % vocab_size
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
def with_rank_rng(base_seed: int = 1234):
|
| 205 |
-
"""Decorator for a function so that the function will use a Random Number Generator
|
| 206 |
-
whose state depend on the GPU rank. The original RNG state is restored upon returning.
|
| 207 |
-
|
| 208 |
-
Args:
|
| 209 |
-
base_seed (int): Random seed.
|
| 210 |
-
"""
|
| 211 |
-
def _decorator(fun: tp.Callable):
|
| 212 |
-
@wraps(fun)
|
| 213 |
-
def _decorated(*args, **kwargs):
|
| 214 |
-
state = torch.get_rng_state()
|
| 215 |
-
seed = base_seed ^ flashy.distrib.rank()
|
| 216 |
-
torch.manual_seed(seed)
|
| 217 |
-
logger.debug('Rank dependent seed set to %d', seed)
|
| 218 |
-
try:
|
| 219 |
-
return fun(*args, **kwargs)
|
| 220 |
-
finally:
|
| 221 |
-
torch.set_rng_state(state)
|
| 222 |
-
logger.debug('RNG state restored.')
|
| 223 |
-
return _decorated
|
| 224 |
-
return _decorator
|
| 225 |
-
|
| 226 |
-
|
| 227 |
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 228 |
"""Get a list of tensors and collate them to a single tensor. according to the following logic:
|
| 229 |
- `dim` specifies the time dimension which will be stacked and padded.
|
|
@@ -247,52 +171,3 @@ def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tens
|
|
| 247 |
return padded_tensors, lens
|
| 248 |
|
| 249 |
|
| 250 |
-
# TODO: Move to flashy?
|
| 251 |
-
def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
|
| 252 |
-
dtype: tp.Optional[torch.dtype] = None) -> tp.Any:
|
| 253 |
-
if isinstance(state, torch.Tensor):
|
| 254 |
-
if dtype is None or not state.is_floating_point():
|
| 255 |
-
dtype = state.dtype
|
| 256 |
-
return state.detach().to(device=device, dtype=dtype, copy=True)
|
| 257 |
-
elif isinstance(state, dict):
|
| 258 |
-
return {k: copy_state(v, device, dtype) for k, v in state.items()}
|
| 259 |
-
elif isinstance(state, list):
|
| 260 |
-
return [copy_state(v, device, dtype) for v in state]
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
# TODO: Move to flashy?
|
| 264 |
-
@contextmanager
|
| 265 |
-
def swap_state(model, state, **kwargs):
|
| 266 |
-
old_state = copy_state(model.state_dict())
|
| 267 |
-
model.load_state_dict(state, **kwargs)
|
| 268 |
-
try:
|
| 269 |
-
yield
|
| 270 |
-
finally:
|
| 271 |
-
model.load_state_dict(old_state)
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
@lru_cache(None)
|
| 275 |
-
def warn_once(logger, msg):
|
| 276 |
-
"""Warn about a given message only once."""
|
| 277 |
-
logger.warning(msg)
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
def is_jsonable(x: tp.Any):
|
| 281 |
-
"""Check if an object can be serialized into a json:"""
|
| 282 |
-
try:
|
| 283 |
-
json.dumps(x)
|
| 284 |
-
return True
|
| 285 |
-
except (TypeError, OverflowError):
|
| 286 |
-
return False
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
|
| 290 |
-
"""Wrapper around state dict loading of CLAP model
|
| 291 |
-
addressing compatibility issues between CLAP and AudioCraft
|
| 292 |
-
HuggingFace transformer version.
|
| 293 |
-
See: https://github.com/LAION-AI/CLAP/issues/118
|
| 294 |
-
"""
|
| 295 |
-
from clap_module.factory import load_state_dict # type: ignore
|
| 296 |
-
pkg = load_state_dict(path)
|
| 297 |
-
pkg.pop('text_branch.embeddings.position_ids', None)
|
| 298 |
-
clap_model.model.load_state_dict(pkg)
|
|
|
|
| 4 |
# This source code is licensed under the license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
+
|
| 8 |
from contextlib import contextmanager
|
| 9 |
from functools import wraps, lru_cache
|
| 10 |
import hashlib
|
|
|
|
| 103 |
input_ = input.reshape(-1, input.shape[-1])
|
| 104 |
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
| 105 |
output = output_.reshape(*list(input.shape[:-1]), -1)
|
| 106 |
+
|
| 107 |
+
# print('MULTINOmial', input.shape, output.shape) # MULTINOmial torch.Size([1, 4, 2048]) torch.Size([1, 4, 1])
|
| 108 |
+
# output = input[..., 0:1]
|
| 109 |
return output
|
| 110 |
|
| 111 |
|
|
|
|
| 118 |
Returns:
|
| 119 |
torch.Tensor: Sampled tokens.
|
| 120 |
"""
|
| 121 |
+
top_k_value, i250 = torch.topk(probs, k, dim=-1) # probs: [1, 4, 2048]
|
| 122 |
+
min_value_top_k = top_k_value[..., [-1]] #
|
| 123 |
+
probs *= (probs >= min_value_top_k).float() # multiply all being > of min_topk with 1 thus zeroing others
|
| 124 |
+
probs.div_(probs.sum(dim=-1, keepdim=True)) # why normalize by the sum ? oh in order to choose mult
|
| 125 |
next_token = multinomial(probs, num_samples=1)
|
| 126 |
+
# so instead of chooose multinomial what happens if we take all 250 topk tokens
|
| 127 |
+
# probs.shape=torch.Size([1, 4, 2048]) <, print(next_token,f'{probs.shape=}', 'h') # 1,4,1 next token is 4tok
|
| 128 |
+
# next_token = i250
|
| 129 |
return next_token
|
| 130 |
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
|
| 135 |
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
|
|
|
|
| 148 |
return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None]
|
| 149 |
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 152 |
"""Get a list of tensors and collate them to a single tensor. according to the following logic:
|
| 153 |
- `dim` specifies the time dimension which will be stacked and padded.
|
|
|
|
| 171 |
return padded_tensors, lens
|
| 172 |
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|