implement clip-guided generation (and never use it...)
Browse files- api.py +30 -6
- eval_multiple.py +1 -1
- models/autoregressive.py +17 -6
- read.py +4 -3
- sweep.py +8 -9
api.py
CHANGED
|
@@ -76,7 +76,30 @@ def load_conditioning(clip, cond_length=132300):
|
|
| 76 |
return mel_clip.unsqueeze(0).cuda()
|
| 77 |
|
| 78 |
|
| 79 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
"""
|
| 81 |
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
|
| 82 |
trained on and what the autoregressive code generator creates (which has no padding or end).
|
|
@@ -89,7 +112,8 @@ def fix_autoregressive_output(codes, stop_token):
|
|
| 89 |
# Strip off the autoregressive stop token and add padding.
|
| 90 |
stop_token_indices = (codes == stop_token).nonzero()
|
| 91 |
if len(stop_token_indices) == 0:
|
| 92 |
-
|
|
|
|
| 93 |
return codes
|
| 94 |
else:
|
| 95 |
codes[stop_token_indices] = 83
|
|
@@ -136,14 +160,14 @@ class TextToSpeech:
|
|
| 136 |
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
|
| 137 |
train_solo_embeddings=False,
|
| 138 |
average_conditioning_embeddings=True).cpu().eval()
|
| 139 |
-
self.autoregressive.load_state_dict(torch.load('.models/
|
| 140 |
|
| 141 |
self.autoregressive_for_latents = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
| 142 |
model_dim=1024,
|
| 143 |
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
|
| 144 |
train_solo_embeddings=False,
|
| 145 |
average_conditioning_embeddings=True).cpu().eval()
|
| 146 |
-
self.autoregressive_for_latents.load_state_dict(torch.load('.models/
|
| 147 |
|
| 148 |
self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
|
| 149 |
text_seq_len=350, text_heads=8,
|
|
@@ -154,7 +178,7 @@ class TextToSpeech:
|
|
| 154 |
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
| 155 |
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
| 156 |
layer_drop=0, unconditioned_percentage=0).cpu().eval()
|
| 157 |
-
self.diffusion.load_state_dict(torch.load('.models/
|
| 158 |
|
| 159 |
self.vocoder = UnivNetGenerator().cpu()
|
| 160 |
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
|
|
@@ -170,7 +194,7 @@ class TextToSpeech:
|
|
| 170 |
presets = {
|
| 171 |
'intelligible': {'temperature': .5, 'length_penalty': 2.0, 'repetition_penalty': 2.0, 'top_p': .5, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': .7, 'diffusion_temperature': .7},
|
| 172 |
'mid': {'temperature': .7, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .7, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 1.5, 'diffusion_temperature': .8},
|
| 173 |
-
'realistic': {'temperature': .
|
| 174 |
}
|
| 175 |
kwargs.update(presets[preset])
|
| 176 |
return self.tts(text, voice_samples, **kwargs)
|
|
|
|
| 76 |
return mel_clip.unsqueeze(0).cuda()
|
| 77 |
|
| 78 |
|
| 79 |
+
def clip_guided_generation(autoregressive_model, clip_model, conditioning_input, text_input, num_batches, stop_mel_token,
|
| 80 |
+
tokens_per_clip_inference=10, clip_results_to_reduce_to=8, **generation_kwargs):
|
| 81 |
+
"""
|
| 82 |
+
Uses a CLVP model trained to associate full text with **partial** audio clips to pick the best generation candidates
|
| 83 |
+
every few iterations. The top results are then propagated forward through the generation process. Rinse and repeat.
|
| 84 |
+
This is a hybrid between beam search and sampling.
|
| 85 |
+
"""
|
| 86 |
+
token_goal = tokens_per_clip_inference
|
| 87 |
+
finished = False
|
| 88 |
+
while not finished and token_goal < autoregressive_model.max_mel_tokens:
|
| 89 |
+
samples = []
|
| 90 |
+
for b in tqdm(range(num_batches)):
|
| 91 |
+
codes = autoregressive_model.inference_speech(conditioning_input, text_input, **generation_kwargs)
|
| 92 |
+
samples.append(codes)
|
| 93 |
+
for batch in samples:
|
| 94 |
+
for i in range(batch.shape[0]):
|
| 95 |
+
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token, complain=False)
|
| 96 |
+
clip_results.append(clip_model(text_input.repeat(batch.shape[0], 1), batch, return_loss=False))
|
| 97 |
+
clip_results = torch.cat(clip_results, dim=0)
|
| 98 |
+
samples = torch.cat(samples, dim=0)
|
| 99 |
+
best_results = samples[torch.topk(clip_results, k=clip_results_to_reduce_to).indices]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def fix_autoregressive_output(codes, stop_token, complain=True):
|
| 103 |
"""
|
| 104 |
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
|
| 105 |
trained on and what the autoregressive code generator creates (which has no padding or end).
|
|
|
|
| 112 |
# Strip off the autoregressive stop token and add padding.
|
| 113 |
stop_token_indices = (codes == stop_token).nonzero()
|
| 114 |
if len(stop_token_indices) == 0:
|
| 115 |
+
if complain:
|
| 116 |
+
print("No stop tokens found, enjoy that output of yours!")
|
| 117 |
return codes
|
| 118 |
else:
|
| 119 |
codes[stop_token_indices] = 83
|
|
|
|
| 160 |
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
|
| 161 |
train_solo_embeddings=False,
|
| 162 |
average_conditioning_embeddings=True).cpu().eval()
|
| 163 |
+
self.autoregressive.load_state_dict(torch.load('.models/autoregressive_audiobooks.pth'))
|
| 164 |
|
| 165 |
self.autoregressive_for_latents = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
| 166 |
model_dim=1024,
|
| 167 |
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
|
| 168 |
train_solo_embeddings=False,
|
| 169 |
average_conditioning_embeddings=True).cpu().eval()
|
| 170 |
+
self.autoregressive_for_latents.load_state_dict(torch.load('.models/autoregressive_audiobooks.pth'))
|
| 171 |
|
| 172 |
self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
|
| 173 |
text_seq_len=350, text_heads=8,
|
|
|
|
| 178 |
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
| 179 |
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
| 180 |
layer_drop=0, unconditioned_percentage=0).cpu().eval()
|
| 181 |
+
self.diffusion.load_state_dict(torch.load('.models/diffusion_decoder_audiobooks.pth'))
|
| 182 |
|
| 183 |
self.vocoder = UnivNetGenerator().cpu()
|
| 184 |
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
|
|
|
|
| 194 |
presets = {
|
| 195 |
'intelligible': {'temperature': .5, 'length_penalty': 2.0, 'repetition_penalty': 2.0, 'top_p': .5, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': .7, 'diffusion_temperature': .7},
|
| 196 |
'mid': {'temperature': .7, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .7, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 1.5, 'diffusion_temperature': .8},
|
| 197 |
+
'realistic': {'temperature': 1.0, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .9, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 2, 'diffusion_temperature': 1},
|
| 198 |
}
|
| 199 |
kwargs.update(presets[preset])
|
| 200 |
return self.tts(text, voice_samples, **kwargs)
|
eval_multiple.py
CHANGED
|
@@ -8,7 +8,7 @@ from utils.audio import load_audio
|
|
| 8 |
if __name__ == '__main__':
|
| 9 |
fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
|
| 10 |
stop_after = 128
|
| 11 |
-
outpath_base = 'D:\\tmp\\tortoise-tts-eval\\
|
| 12 |
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
| 13 |
|
| 14 |
os.makedirs(outpath_real, exist_ok=True)
|
|
|
|
| 8 |
if __name__ == '__main__':
|
| 9 |
fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
|
| 10 |
stop_after = 128
|
| 11 |
+
outpath_base = 'D:\\tmp\\tortoise-tts-eval\\audiobooks'
|
| 12 |
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
| 13 |
|
| 14 |
os.makedirs(outpath_real, exist_ok=True)
|
models/autoregressive.py
CHANGED
|
@@ -511,7 +511,8 @@ class UnifiedVoice(nn.Module):
|
|
| 511 |
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
| 512 |
return loss_mel.mean()
|
| 513 |
|
| 514 |
-
def inference_speech(self, speech_conditioning_input, text_inputs,
|
|
|
|
| 515 |
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
| 516 |
if not hasattr(self, 'inference_model'):
|
| 517 |
# TODO: Decouple gpt_config from this inference model.
|
|
@@ -541,13 +542,23 @@ class UnifiedVoice(nn.Module):
|
|
| 541 |
emb = torch.cat([conds, text_emb], dim=1)
|
| 542 |
self.inference_model.store_mel_emb(emb)
|
| 543 |
|
| 544 |
-
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long,
|
| 545 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
|
| 547 |
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
|
|
|
| 551 |
|
| 552 |
|
| 553 |
if __name__ == '__main__':
|
|
|
|
| 511 |
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
| 512 |
return loss_mel.mean()
|
| 513 |
|
| 514 |
+
def inference_speech(self, speech_conditioning_input, text_inputs, input_tokens=None, num_return_sequences=1,
|
| 515 |
+
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
| 516 |
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
| 517 |
if not hasattr(self, 'inference_model'):
|
| 518 |
# TODO: Decouple gpt_config from this inference model.
|
|
|
|
| 542 |
emb = torch.cat([conds, text_emb], dim=1)
|
| 543 |
self.inference_model.store_mel_emb(emb)
|
| 544 |
|
| 545 |
+
fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long,
|
| 546 |
+
device=text_inputs.device)
|
| 547 |
+
fake_inputs[:, -1] = self.start_mel_token
|
| 548 |
+
trunc_index = fake_inputs.shape[1]
|
| 549 |
+
if input_tokens is None:
|
| 550 |
+
inputs = fake_inputs
|
| 551 |
+
else:
|
| 552 |
+
assert num_return_sequences % input_tokens.shape[0] == 0, "The number of return sequences must be divisible by the number of input sequences"
|
| 553 |
+
fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
|
| 554 |
+
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
|
| 555 |
+
inputs = torch.cat([fake_inputs, input_tokens], dim=1)
|
| 556 |
|
| 557 |
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
|
| 558 |
+
max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
|
| 559 |
+
gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
| 560 |
+
max_length=max_length, logits_processor=logits_processor, **hf_generate_kwargs)
|
| 561 |
+
return gen[:, trunc_index:]
|
| 562 |
|
| 563 |
|
| 564 |
if __name__ == '__main__':
|
read.py
CHANGED
|
@@ -32,15 +32,16 @@ if __name__ == '__main__':
|
|
| 32 |
preselected_cond_voices = {
|
| 33 |
'emma_stone': ['voices/emma_stone/1.wav','voices/emma_stone/2.wav','voices/emma_stone/3.wav'],
|
| 34 |
'tom_hanks': ['voices/tom_hanks/1.wav','voices/tom_hanks/2.wav','voices/tom_hanks/3.wav'],
|
|
|
|
| 35 |
}
|
| 36 |
|
| 37 |
parser = argparse.ArgumentParser()
|
| 38 |
parser.add_argument('-textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")
|
| 39 |
-
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='
|
| 40 |
-
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=
|
| 41 |
parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
|
| 42 |
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/longform/')
|
| 43 |
-
parser.add_argument('-generation_preset', type=str, help='Preset to use for generation', default='
|
| 44 |
args = parser.parse_args()
|
| 45 |
os.makedirs(args.output_path, exist_ok=True)
|
| 46 |
|
|
|
|
| 32 |
preselected_cond_voices = {
|
| 33 |
'emma_stone': ['voices/emma_stone/1.wav','voices/emma_stone/2.wav','voices/emma_stone/3.wav'],
|
| 34 |
'tom_hanks': ['voices/tom_hanks/1.wav','voices/tom_hanks/2.wav','voices/tom_hanks/3.wav'],
|
| 35 |
+
'patrick_stewart': ['voices/patrick_stewart/1.wav','voices/patrick_stewart/2.wav','voices/patrick_stewart/3.wav','voices/patrick_stewart/4.wav'],
|
| 36 |
}
|
| 37 |
|
| 38 |
parser = argparse.ArgumentParser()
|
| 39 |
parser.add_argument('-textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")
|
| 40 |
+
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='patrick_stewart')
|
| 41 |
+
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=128)
|
| 42 |
parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
|
| 43 |
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/longform/')
|
| 44 |
+
parser.add_argument('-generation_preset', type=str, help='Preset to use for generation', default='realistic')
|
| 45 |
args = parser.parse_args()
|
| 46 |
os.makedirs(args.output_path, exist_ok=True)
|
| 47 |
|
sweep.py
CHANGED
|
@@ -25,16 +25,15 @@ def permutations(args):
|
|
| 25 |
|
| 26 |
if __name__ == '__main__':
|
| 27 |
fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
|
| 28 |
-
stop_after =
|
| 29 |
-
outpath_base = 'D:\\tmp\\tortoise-tts-eval\\sweep'
|
| 30 |
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
| 31 |
|
| 32 |
arg_ranges = {
|
| 33 |
-
'top_p': [.
|
| 34 |
-
'temperature': [.
|
| 35 |
-
'diffusion_temperature': [.
|
| 36 |
-
'cond_free_k': [
|
| 37 |
-
'repetition_penalty': [1.0, 2.0]
|
| 38 |
}
|
| 39 |
cfgs = permutations(arg_ranges)
|
| 40 |
shuffle(cfgs)
|
|
@@ -56,8 +55,8 @@ if __name__ == '__main__':
|
|
| 56 |
path = os.path.join(os.path.dirname(fname), line[1])
|
| 57 |
cond_audio = load_audio(path, 22050)
|
| 58 |
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
|
| 59 |
-
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=
|
| 60 |
-
k=1, diffusion_iterations=
|
| 61 |
down = torchaudio.functional.resample(sample, 24000, 22050)
|
| 62 |
fout_path = os.path.join(outpath, os.path.basename(line[1]))
|
| 63 |
torchaudio.save(fout_path, down.squeeze(0), 22050)
|
|
|
|
| 25 |
|
| 26 |
if __name__ == '__main__':
|
| 27 |
fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
|
| 28 |
+
stop_after = 512
|
| 29 |
+
outpath_base = 'D:\\tmp\\tortoise-tts-eval\\sweep-2'
|
| 30 |
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
| 31 |
|
| 32 |
arg_ranges = {
|
| 33 |
+
'top_p': [.8,1],
|
| 34 |
+
'temperature': [.8,.9,1],
|
| 35 |
+
'diffusion_temperature': [.8,1],
|
| 36 |
+
'cond_free_k': [1,2,5,10],
|
|
|
|
| 37 |
}
|
| 38 |
cfgs = permutations(arg_ranges)
|
| 39 |
shuffle(cfgs)
|
|
|
|
| 55 |
path = os.path.join(os.path.dirname(fname), line[1])
|
| 56 |
cond_audio = load_audio(path, 22050)
|
| 57 |
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
|
| 58 |
+
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=32, repetition_penalty=2.0,
|
| 59 |
+
k=1, diffusion_iterations=32, length_penalty=1.0, **cfg)
|
| 60 |
down = torchaudio.functional.resample(sample, 24000, 22050)
|
| 61 |
fout_path = os.path.join(outpath, os.path.basename(line[1]))
|
| 62 |
torchaudio.save(fout_path, down.squeeze(0), 22050)
|