fx sounds batch inference
Browse files- README.md +2 -2
- api.py +7 -15
- audiocraft/builders.py +10 -40
- audiocraft/conditioners.py +24 -198
- audiocraft/lm.py +7 -23
- audiocraft/transformer.py +1 -4
- models.py +5 -1
- requirements.txt +19 -0
README.md
CHANGED
|
@@ -62,7 +62,7 @@ pip install -r requirements.txt
|
|
| 62 |
Flask `tmux-session`
|
| 63 |
|
| 64 |
```
|
| 65 |
-
CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME
|
| 66 |
```
|
| 67 |
|
| 68 |
Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
|
|
@@ -127,5 +127,5 @@ Create audiobook from `.docx`. Listen to it - YouTube [male voice](https://www.y
|
|
| 127 |
|
| 128 |
```python
|
| 129 |
# audiobook will be saved in ./tts_audiobooks
|
| 130 |
-
python audiobook.py
|
| 131 |
```
|
|
|
|
| 62 |
Flask `tmux-session`
|
| 63 |
|
| 64 |
```
|
| 65 |
+
CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME=/data/dkounadis/.hf7/ CUDA_VISIBLE_DEVICES=0 python api.py
|
| 66 |
```
|
| 67 |
|
| 68 |
Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
|
|
|
|
| 127 |
|
| 128 |
```python
|
| 129 |
# audiobook will be saved in ./tts_audiobooks
|
| 130 |
+
python audiobook.py
|
| 131 |
```
|
api.py
CHANGED
|
@@ -9,12 +9,11 @@ import re
|
|
| 9 |
import srt
|
| 10 |
import subprocess
|
| 11 |
import cv2
|
| 12 |
-
import markdown
|
| 13 |
from pathlib import Path
|
| 14 |
from types import SimpleNamespace
|
| 15 |
from flask import Flask, request, send_from_directory
|
| 16 |
-
from
|
| 17 |
-
from moviepy.
|
| 18 |
from audiocraft.builders import AudioGen
|
| 19 |
CACHE_DIR = 'flask_cache/'
|
| 20 |
NUM_SOUND_GENERATIONS = 3 # batch size to generate same text (same soundscape for long video)
|
|
@@ -79,10 +78,10 @@ def overlay(x, soundscape=None):
|
|
| 79 |
background = sound_generator.generate(
|
| 80 |
[soundscape] * NUM_SOUND_GENERATIONS
|
| 81 |
).reshape(-1).detach().cpu().numpy() # bs, 11400 @.74s
|
| 82 |
-
|
| 83 |
-
# upsample 16 kHz AudioGen to 24kHZ
|
| 84 |
|
| 85 |
-
print('Resampling')
|
| 86 |
|
| 87 |
|
| 88 |
background = audresample.resample(
|
|
@@ -178,14 +177,6 @@ def tts_multi_sentence(precomputed_style_vector=None,
|
|
| 178 |
# global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
|
| 179 |
|
| 180 |
app = Flask(__name__)
|
| 181 |
-
cors = CORS(app)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
@app.route("/")
|
| 185 |
-
def index():
|
| 186 |
-
with open('README.md', 'r') as f:
|
| 187 |
-
return markdown.markdown(f.read())
|
| 188 |
-
|
| 189 |
|
| 190 |
@app.route("/", methods=['GET', 'POST', 'PUT'])
|
| 191 |
def serve_wav():
|
|
@@ -460,7 +451,8 @@ def serve_wav():
|
|
| 460 |
|
| 461 |
# SILENT CLIP
|
| 462 |
|
| 463 |
-
clip_silent = ImageClip(STATIC_FRAME
|
|
|
|
| 464 |
clip_silent.write_videofile(SILENT_VIDEO, fps=24)
|
| 465 |
|
| 466 |
x = tts_multi_sentence(text=text,
|
|
|
|
| 9 |
import srt
|
| 10 |
import subprocess
|
| 11 |
import cv2
|
|
|
|
| 12 |
from pathlib import Path
|
| 13 |
from types import SimpleNamespace
|
| 14 |
from flask import Flask, request, send_from_directory
|
| 15 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip
|
| 16 |
+
from moviepy.video.VideoClip import ImageClip
|
| 17 |
from audiocraft.builders import AudioGen
|
| 18 |
CACHE_DIR = 'flask_cache/'
|
| 19 |
NUM_SOUND_GENERATIONS = 3 # batch size to generate same text (same soundscape for long video)
|
|
|
|
| 78 |
background = sound_generator.generate(
|
| 79 |
[soundscape] * NUM_SOUND_GENERATIONS
|
| 80 |
).reshape(-1).detach().cpu().numpy() # bs, 11400 @.74s
|
| 81 |
+
|
| 82 |
+
# upsample 16 kHz AudioGen to 24kHZ of VITS/StyleTTS2
|
| 83 |
|
| 84 |
+
print('Resampling') # soundscape each generation in batch differs from the other generations thus clone/shift each element in batch, finally concat w/o shift
|
| 85 |
|
| 86 |
|
| 87 |
background = audresample.resample(
|
|
|
|
| 177 |
# global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
|
| 178 |
|
| 179 |
app = Flask(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
@app.route("/", methods=['GET', 'POST', 'PUT'])
|
| 182 |
def serve_wav():
|
|
|
|
| 451 |
|
| 452 |
# SILENT CLIP
|
| 453 |
|
| 454 |
+
clip_silent = ImageClip(img=STATIC_FRAME,
|
| 455 |
+
duration=5) # ffmpeg continues this silent video for duration of TTS
|
| 456 |
clip_silent.write_videofile(SILENT_VIDEO, fps=24)
|
| 457 |
|
| 458 |
x = tts_multi_sentence(text=text,
|
audiocraft/builders.py
CHANGED
|
@@ -10,11 +10,7 @@ from .encodec import EncodecModel
|
|
| 10 |
from .lm import LMModel
|
| 11 |
from .seanet import SEANetDecoder
|
| 12 |
from .codebooks_patterns import DelayedPatternProvider
|
| 13 |
-
from .conditioners import
|
| 14 |
-
ConditioningProvider,
|
| 15 |
-
T5Conditioner,
|
| 16 |
-
ConditioningAttributes
|
| 17 |
-
)
|
| 18 |
from .vq import ResidualVectorQuantizer
|
| 19 |
|
| 20 |
|
|
@@ -73,10 +69,8 @@ class AudioGen(nn.Module):
|
|
| 73 |
def generate(self,
|
| 74 |
descriptions):
|
| 75 |
with torch.no_grad():
|
| 76 |
-
attributes = [
|
| 77 |
-
ConditioningAttributes(text={'description': d}) for d in descriptions]
|
| 78 |
gen_tokens = self.lm.generate(
|
| 79 |
-
|
| 80 |
max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
|
| 81 |
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
| 82 |
# print('______________\nAudioGen Tokens', gen_tokens)
|
|
@@ -144,10 +138,8 @@ class AudioGen(nn.Module):
|
|
| 144 |
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
| 145 |
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
| 146 |
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
| 147 |
-
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
|
| 148 |
|
| 149 |
-
condition_provider = self.get_conditioner_provider(kwargs["dim"], cfg
|
| 150 |
-
).to(self.device)
|
| 151 |
|
| 152 |
|
| 153 |
# if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
|
|
@@ -163,7 +155,7 @@ class AudioGen(nn.Module):
|
|
| 163 |
pattern_provider = self.get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
| 164 |
return LMModel(
|
| 165 |
pattern_provider=pattern_provider,
|
| 166 |
-
condition_provider=
|
| 167 |
cfg_dropout=cfg_prob,
|
| 168 |
cfg_coef=cfg_coef,
|
| 169 |
attribute_dropout=attribute_dropout,
|
|
@@ -173,34 +165,8 @@ class AudioGen(nn.Module):
|
|
| 173 |
).to(cfg.device)
|
| 174 |
else:
|
| 175 |
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def get_conditioner_provider(self, output_dim,
|
| 179 |
-
cfg):
|
| 180 |
-
"""Instantiate T5 text"""
|
| 181 |
-
cfg = getattr(cfg, 'conditioners')
|
| 182 |
-
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
| 183 |
-
conditioners={}
|
| 184 |
-
condition_provider_args = dict_cfg.pop('args', {})
|
| 185 |
-
condition_provider_args.pop('merge_text_conditions_p', None)
|
| 186 |
-
condition_provider_args.pop('drop_desc_p', None)
|
| 187 |
-
|
| 188 |
-
for cond, cond_cfg in dict_cfg.items():
|
| 189 |
-
model_type = cond_cfg['model']
|
| 190 |
-
model_args = cond_cfg[model_type]
|
| 191 |
-
if model_type == 't5':
|
| 192 |
-
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim,
|
| 193 |
-
device=self.device,
|
| 194 |
-
**model_args)
|
| 195 |
-
else:
|
| 196 |
-
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
| 197 |
-
|
| 198 |
-
# print(f'{condition_provider_args=}')
|
| 199 |
-
return ConditioningProvider(conditioners)
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
def get_codebooks_pattern_provider(self, n_q, cfg):
|
| 205 |
pattern_providers = {
|
| 206 |
'delay': DelayedPatternProvider, # THIS
|
|
@@ -249,6 +215,10 @@ class AudioGen(nn.Module):
|
|
| 249 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
| 250 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
| 251 |
model = self.get_lm_model(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
model.load_state_dict(pkg['best_state'])
|
| 253 |
model.cfg = cfg
|
| 254 |
# return model
|
|
|
|
| 10 |
from .lm import LMModel
|
| 11 |
from .seanet import SEANetDecoder
|
| 12 |
from .codebooks_patterns import DelayedPatternProvider
|
| 13 |
+
from .conditioners import T5Conditioner
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from .vq import ResidualVectorQuantizer
|
| 15 |
|
| 16 |
|
|
|
|
| 69 |
def generate(self,
|
| 70 |
descriptions):
|
| 71 |
with torch.no_grad():
|
|
|
|
|
|
|
| 72 |
gen_tokens = self.lm.generate(
|
| 73 |
+
descriptions=descriptions,
|
| 74 |
max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
|
| 75 |
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
| 76 |
# print('______________\nAudioGen Tokens', gen_tokens)
|
|
|
|
| 138 |
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
| 139 |
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
| 140 |
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
| 141 |
+
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
|
| 142 |
|
|
|
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
# if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
|
|
|
|
| 155 |
pattern_provider = self.get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
| 156 |
return LMModel(
|
| 157 |
pattern_provider=pattern_provider,
|
| 158 |
+
condition_provider=T5Conditioner(name='t5-large', output_dim=kwargs["dim"], device=self.device),
|
| 159 |
cfg_dropout=cfg_prob,
|
| 160 |
cfg_coef=cfg_coef,
|
| 161 |
attribute_dropout=attribute_dropout,
|
|
|
|
| 165 |
).to(cfg.device)
|
| 166 |
else:
|
| 167 |
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
|
| 168 |
+
|
| 169 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
def get_codebooks_pattern_provider(self, n_q, cfg):
|
| 171 |
pattern_providers = {
|
| 172 |
'delay': DelayedPatternProvider, # THIS
|
|
|
|
| 215 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
| 216 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
| 217 |
model = self.get_lm_model(cfg)
|
| 218 |
+
|
| 219 |
+
_best = pkg['best_state']
|
| 220 |
+
_best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
|
| 221 |
+
_best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
|
| 222 |
model.load_state_dict(pkg['best_state'])
|
| 223 |
model.cfg = cfg
|
| 224 |
# return model
|
audiocraft/conditioners.py
CHANGED
|
@@ -1,82 +1,9 @@
|
|
| 1 |
-
from collections import defaultdict
|
| 2 |
-
from dataclasses import dataclass, field
|
| 3 |
-
import logging
|
| 4 |
-
import random
|
| 5 |
-
import typing as tp
|
| 6 |
import warnings
|
| 7 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
| 8 |
import torch
|
| 9 |
from torch import nn
|
| 10 |
-
logger = logging.getLogger(__name__)
|
| 11 |
-
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
| 12 |
-
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
| 13 |
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class JointEmbedCondition(tp.NamedTuple):
|
| 18 |
-
wav: torch.Tensor
|
| 19 |
-
text: tp.List[tp.Optional[str]]
|
| 20 |
-
length: torch.Tensor
|
| 21 |
-
sample_rate: tp.List[int]
|
| 22 |
-
path: tp.List[tp.Optional[str]] = []
|
| 23 |
-
seek_time: tp.List[tp.Optional[float]] = []
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
@dataclass
|
| 27 |
-
class ConditioningAttributes:
|
| 28 |
-
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
| 29 |
-
wav: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
| 30 |
-
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
| 31 |
-
|
| 32 |
-
def __getitem__(self, item):
|
| 33 |
-
return getattr(self, item)
|
| 34 |
-
|
| 35 |
-
@property
|
| 36 |
-
def text_attributes(self):
|
| 37 |
-
return self.text.keys()
|
| 38 |
-
|
| 39 |
-
@property
|
| 40 |
-
def wav_attributes(self):
|
| 41 |
-
return self.wav.keys()
|
| 42 |
-
|
| 43 |
-
@property
|
| 44 |
-
def joint_embed_attributes(self):
|
| 45 |
-
return self.joint_embed.keys()
|
| 46 |
-
|
| 47 |
-
@property
|
| 48 |
-
def attributes(self):
|
| 49 |
-
return {
|
| 50 |
-
"text": self.text_attributes,
|
| 51 |
-
"wav": self.wav_attributes,
|
| 52 |
-
"joint_embed": self.joint_embed_attributes,
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
def to_flat_dict(self):
|
| 56 |
-
return {
|
| 57 |
-
**{f"text.{k}": v for k, v in self.text.items()},
|
| 58 |
-
**{f"wav.{k}": v for k, v in self.wav.items()},
|
| 59 |
-
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
@classmethod
|
| 63 |
-
def from_flat_dict(cls, x):
|
| 64 |
-
out = cls()
|
| 65 |
-
for k, v in x.items():
|
| 66 |
-
kind, att = k.split(".")
|
| 67 |
-
out[kind][att] = v
|
| 68 |
-
return out
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
class Tokenizer:
|
| 72 |
-
"""Base tokenizer implementation
|
| 73 |
-
(in case we want to introduce more advances tokenizers in the future).
|
| 74 |
-
"""
|
| 75 |
-
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 76 |
-
raise NotImplementedError()
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
class T5Conditioner(nn.Module):
|
| 81 |
|
| 82 |
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
|
@@ -95,12 +22,10 @@ class T5Conditioner(nn.Module):
|
|
| 95 |
"google/flan-t5-11b": 1024,
|
| 96 |
}
|
| 97 |
|
| 98 |
-
def __init__(self,
|
| 99 |
-
name
|
| 100 |
-
output_dim
|
| 101 |
-
device
|
| 102 |
-
word_dropout: float = 0.,
|
| 103 |
-
normalize_text: bool = False,
|
| 104 |
finetune=False):
|
| 105 |
print(f'{finetune=}')
|
| 106 |
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
|
@@ -110,19 +35,9 @@ class T5Conditioner(nn.Module):
|
|
| 110 |
self.output_proj = nn.Linear(self.dim, output_dim)
|
| 111 |
self.device = device
|
| 112 |
self.name = name
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
# thanks https://gist.github.com/simon-weber/7853144
|
| 117 |
-
previous_level = logging.root.manager.disable
|
| 118 |
-
logging.disable(logging.ERROR)
|
| 119 |
-
with warnings.catch_warnings():
|
| 120 |
-
warnings.simplefilter("ignore")
|
| 121 |
-
try:
|
| 122 |
-
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
|
| 123 |
-
t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
|
| 124 |
-
finally:
|
| 125 |
-
logging.disable(previous_level)
|
| 126 |
if finetune:
|
| 127 |
self.t5 = t5
|
| 128 |
else:
|
|
@@ -130,116 +45,27 @@ class T5Conditioner(nn.Module):
|
|
| 130 |
# of the saved checkpoint
|
| 131 |
self.__dict__['t5'] = t5.to(device)
|
| 132 |
|
| 133 |
-
self.normalize_text = normalize_text
|
| 134 |
-
if normalize_text:
|
| 135 |
-
self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
|
| 136 |
|
| 137 |
-
def tokenize(self, x
|
| 138 |
-
|
| 139 |
-
entries
|
| 140 |
-
if self.normalize_text:
|
| 141 |
-
_, _, entries = self.text_normalizer(entries, return_text=True)
|
| 142 |
-
if self.word_dropout > 0. and self.training:
|
| 143 |
-
new_entries = []
|
| 144 |
-
for entry in entries:
|
| 145 |
-
words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
|
| 146 |
-
new_entries.append(" ".join(words))
|
| 147 |
-
entries = new_entries
|
| 148 |
|
| 149 |
-
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
|
| 150 |
|
| 151 |
-
inputs = self.t5_tokenizer(entries,
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
| 155 |
|
| 156 |
-
def forward(self,
|
| 157 |
-
|
|
|
|
|
|
|
| 158 |
with torch.no_grad():
|
| 159 |
-
embeds = self.t5(
|
|
|
|
|
|
|
| 160 |
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
| 161 |
-
embeds = (embeds *
|
| 162 |
-
|
| 163 |
-
# T5 torch.Size([2, 4, 1536]) dict_keys(['input_ids', 'attention_mask'])
|
| 164 |
-
print(f'{embeds.dtype=}') # inputs["input_ids"].shape=torch.Size([2, 4])
|
| 165 |
-
return embeds, mask
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
class ConditioningProvider(nn.Module):
|
| 175 |
-
|
| 176 |
-
def __init__(self,
|
| 177 |
-
conditioners):
|
| 178 |
-
super().__init__()
|
| 179 |
-
self.conditioners = nn.ModuleDict(conditioners)
|
| 180 |
-
|
| 181 |
-
@property
|
| 182 |
-
def text_conditions(self):
|
| 183 |
-
return [k for k, v in self.conditioners.items() if isinstance(v, T5Conditioner)]
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
| 188 |
-
output = {}
|
| 189 |
-
text = self._collate_text(inputs)
|
| 190 |
-
# wavs = self._collate_wavs(inputs)
|
| 191 |
-
# joint_embeds = self._collate_joint_embeds(inputs)
|
| 192 |
-
|
| 193 |
-
# assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
|
| 194 |
-
# f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
| 195 |
-
# f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
|
| 196 |
-
# )
|
| 197 |
-
for attribute, batch in text.items(): #, joint_embeds.items()):
|
| 198 |
-
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
| 199 |
-
print(f'COndProvToknz {output=}\n==')
|
| 200 |
-
return output
|
| 201 |
-
|
| 202 |
-
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
| 203 |
-
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
| 204 |
-
The output is for example:
|
| 205 |
-
{
|
| 206 |
-
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
|
| 207 |
-
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
|
| 208 |
-
...
|
| 209 |
-
}
|
| 210 |
-
|
| 211 |
-
Args:
|
| 212 |
-
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
|
| 213 |
-
"""
|
| 214 |
-
output = {}
|
| 215 |
-
for attribute, inputs in tokenized.items():
|
| 216 |
-
condition, mask = self.conditioners[attribute](inputs)
|
| 217 |
-
output[attribute] = (condition, mask)
|
| 218 |
-
return output
|
| 219 |
-
|
| 220 |
-
def _collate_text(self, samples):
|
| 221 |
-
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys
|
| 222 |
-
are the attributes and the values are the aggregated input per attribute.
|
| 223 |
-
For example:
|
| 224 |
-
Input:
|
| 225 |
-
[
|
| 226 |
-
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
|
| 227 |
-
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
|
| 228 |
-
]
|
| 229 |
-
Output:
|
| 230 |
-
{
|
| 231 |
-
"genre": ["Rock", "Hip-hop"],
|
| 232 |
-
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
|
| 233 |
-
}
|
| 234 |
|
| 235 |
-
|
| 236 |
-
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
| 237 |
-
Returns:
|
| 238 |
-
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
|
| 239 |
-
"""
|
| 240 |
-
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
|
| 241 |
-
texts = [x.text for x in samples]
|
| 242 |
-
for text in texts:
|
| 243 |
-
for condition in self.text_conditions:
|
| 244 |
-
out[condition].append(text[condition])
|
| 245 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import warnings
|
| 2 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
| 3 |
import torch
|
| 4 |
from torch import nn
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
class T5Conditioner(nn.Module):
|
| 8 |
|
| 9 |
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
|
|
|
| 22 |
"google/flan-t5-11b": 1024,
|
| 23 |
}
|
| 24 |
|
| 25 |
+
def __init__(self,
|
| 26 |
+
name,
|
| 27 |
+
output_dim,
|
| 28 |
+
device,
|
|
|
|
|
|
|
| 29 |
finetune=False):
|
| 30 |
print(f'{finetune=}')
|
| 31 |
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
|
|
|
| 35 |
self.output_proj = nn.Linear(self.dim, output_dim)
|
| 36 |
self.device = device
|
| 37 |
self.name = name
|
| 38 |
+
|
| 39 |
+
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
|
| 40 |
+
t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
if finetune:
|
| 42 |
self.t5 = t5
|
| 43 |
else:
|
|
|
|
| 45 |
# of the saved checkpoint
|
| 46 |
self.__dict__['t5'] = t5.to(device)
|
| 47 |
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
def tokenize(self, x):
|
| 50 |
+
|
| 51 |
+
entries = [xi if xi is not None else "" for xi in x]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
|
|
|
| 53 |
|
| 54 |
+
inputs = self.t5_tokenizer(entries,
|
| 55 |
+
return_tensors='pt',
|
| 56 |
+
padding=True).to(self.device)
|
| 57 |
+
|
| 58 |
+
return inputs # 'input_ids' 'attentio mask'
|
| 59 |
|
| 60 |
+
def forward(self, descriptions):
|
| 61 |
+
|
| 62 |
+
d = self.tokenize(descriptions)
|
| 63 |
+
|
| 64 |
with torch.no_grad():
|
| 65 |
+
embeds = self.t5(input_ids=d['input_ids'],
|
| 66 |
+
attention_mask=d['attention_mask']
|
| 67 |
+
).last_hidden_state # no kvcache for txt conditioning
|
| 68 |
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
| 69 |
+
embeds = (embeds * d['attention_mask'].unsqueeze(-1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
return embeds # , d['attention_mask']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/lm.py
CHANGED
|
@@ -23,17 +23,6 @@ def _shift(x):
|
|
| 23 |
|
| 24 |
|
| 25 |
|
| 26 |
-
# ============================================== From LM.py
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
logger = logging.getLogger(__name__)
|
| 30 |
-
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
| 31 |
-
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
| 32 |
-
|
| 33 |
-
ConditionTensors = tp.Dict[str, ConditionType]
|
| 34 |
-
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
| 35 |
-
|
| 36 |
-
|
| 37 |
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
|
| 38 |
"""LM layer initialization.
|
| 39 |
Inspired from xlformers: https://github.com/fairinternal/xlformers
|
|
@@ -280,19 +269,14 @@ class LMModel(nn.Module):
|
|
| 280 |
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
|
| 281 |
|
| 282 |
@torch.no_grad()
|
| 283 |
-
def generate(self,
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
tokenized = self.condition_provider.tokenize(conditions)
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
cfg_conditions = self.condition_provider(tokenized)
|
| 291 |
-
|
| 292 |
-
|
| 293 |
|
|
|
|
|
|
|
| 294 |
# NULL CONDITION
|
| 295 |
-
text_condition = cfg_conditions['description'][0]
|
| 296 |
bs, _, _ = text_condition.shape
|
| 297 |
text_condition = torch.cat(
|
| 298 |
[
|
|
@@ -330,7 +314,7 @@ class LMModel(nn.Module):
|
|
| 330 |
|
| 331 |
# forward duplicates the query to nullcond - then cfg & returns deduplicate token
|
| 332 |
next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset],
|
| 333 |
-
condition_tensors=text_condition,
|
| 334 |
token_count=offset-1) # [bs, 4, 1, 2048]
|
| 335 |
|
| 336 |
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
|
| 27 |
"""LM layer initialization.
|
| 28 |
Inspired from xlformers: https://github.com/fairinternal/xlformers
|
|
|
|
| 269 |
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
|
| 270 |
|
| 271 |
@torch.no_grad()
|
| 272 |
+
def generate(self,
|
| 273 |
+
descriptions = ['windy day', 'rain storm'],
|
| 274 |
+
max_gen_len = 256):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
+
text_condition = self.condition_provider(descriptions)
|
| 277 |
+
|
| 278 |
# NULL CONDITION
|
| 279 |
+
# text_condition = cfg_conditions['description'][0]
|
| 280 |
bs, _, _ = text_condition.shape
|
| 281 |
text_condition = torch.cat(
|
| 282 |
[
|
|
|
|
| 314 |
|
| 315 |
# forward duplicates the query to nullcond - then cfg & returns deduplicate token
|
| 316 |
next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset],
|
| 317 |
+
condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
|
| 318 |
token_count=offset-1) # [bs, 4, 1, 2048]
|
| 319 |
|
| 320 |
|
audiocraft/transformer.py
CHANGED
|
@@ -4,7 +4,6 @@ import torch
|
|
| 4 |
import torch.nn as nn
|
| 5 |
from torch.nn import functional as F
|
| 6 |
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
| 7 |
-
from xformers import ops
|
| 8 |
|
| 9 |
|
| 10 |
_efficient_attention_backend: str = 'torch'
|
|
@@ -12,7 +11,6 @@ _efficient_attention_backend: str = 'torch'
|
|
| 12 |
|
| 13 |
|
| 14 |
|
| 15 |
-
|
| 16 |
def _get_attention_time_dimension(memory_efficient: bool) -> int:
|
| 17 |
if _efficient_attention_backend == 'torch' and memory_efficient:
|
| 18 |
return 2
|
|
@@ -190,7 +188,7 @@ class StreamingMultiheadAttention(nn.Module):
|
|
| 190 |
# else:
|
| 191 |
# bound_layout = "b t p h d"
|
| 192 |
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
| 193 |
-
q, k, v =
|
| 194 |
|
| 195 |
|
| 196 |
if self.k_history is not None:
|
|
@@ -222,7 +220,6 @@ class StreamingMultiheadAttention(nn.Module):
|
|
| 222 |
|
| 223 |
p = self.dropout if self.training else 0
|
| 224 |
if _efficient_attention_backend == 'torch':
|
| 225 |
-
# print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(), 'CROSSopen')
|
| 226 |
x = torch.nn.functional.scaled_dot_product_attention(
|
| 227 |
q, k, v, is_causal=False, dropout_p=p
|
| 228 |
)
|
|
|
|
| 4 |
import torch.nn as nn
|
| 5 |
from torch.nn import functional as F
|
| 6 |
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
_efficient_attention_backend: str = 'torch'
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
|
|
|
|
| 14 |
def _get_attention_time_dimension(memory_efficient: bool) -> int:
|
| 15 |
if _efficient_attention_backend == 'torch' and memory_efficient:
|
| 16 |
return 2
|
|
|
|
| 188 |
# else:
|
| 189 |
# bound_layout = "b t p h d"
|
| 190 |
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
| 191 |
+
q, k, v = packed.unbind(dim=2)
|
| 192 |
|
| 193 |
|
| 194 |
if self.k_history is not None:
|
|
|
|
| 220 |
|
| 221 |
p = self.dropout if self.training else 0
|
| 222 |
if _efficient_attention_backend == 'torch':
|
|
|
|
| 223 |
x = torch.nn.functional.scaled_dot_product_attention(
|
| 224 |
q, k, v, is_causal=False, dropout_p=p
|
| 225 |
)
|
models.py
CHANGED
|
@@ -511,7 +511,11 @@ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
|
|
| 511 |
|
| 512 |
def _load_model(model_config, model_path):
|
| 513 |
model = ASRCNN(**model_config)
|
| 514 |
-
params = torch.load(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
model.load_state_dict(params)
|
| 516 |
return model
|
| 517 |
|
|
|
|
| 511 |
|
| 512 |
def _load_model(model_config, model_path):
|
| 513 |
model = ASRCNN(**model_config)
|
| 514 |
+
params = torch.load(
|
| 515 |
+
model_path,
|
| 516 |
+
map_location='cpu',
|
| 517 |
+
weights_only=False
|
| 518 |
+
)['model']
|
| 519 |
model.load_state_dict(params)
|
| 520 |
return model
|
| 521 |
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchaudio
|
| 3 |
+
numpy
|
| 4 |
+
audiofile
|
| 5 |
+
audresample
|
| 6 |
+
cached_path
|
| 7 |
+
einops
|
| 8 |
+
flask
|
| 9 |
+
librosa
|
| 10 |
+
moviepy
|
| 11 |
+
sentencepiece
|
| 12 |
+
omegaconf
|
| 13 |
+
opencv-python
|
| 14 |
+
soundfile
|
| 15 |
+
transformers
|
| 16 |
+
munch
|
| 17 |
+
srt
|
| 18 |
+
nltk
|
| 19 |
+
phonemizer
|