is this from tortoise?
Browse files- api.py +20 -1
- is_this_from_tortoise.py +14 -0
- models/classifier.py +14 -9
api.py
CHANGED
|
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|
| 8 |
import progressbar
|
| 9 |
import torchaudio
|
| 10 |
|
|
|
|
| 11 |
from models.cvvp import CVVP
|
| 12 |
from models.diffusion_decoder import DiffusionTts
|
| 13 |
from models.autoregressive import UnifiedVoice
|
|
@@ -24,7 +25,7 @@ from utils.tokenizer import VoiceBpeTokenizer, lev_distance
|
|
| 24 |
pbar = None
|
| 25 |
|
| 26 |
|
| 27 |
-
def download_models():
|
| 28 |
"""
|
| 29 |
Call to download all the models that Tortoise uses.
|
| 30 |
"""
|
|
@@ -50,6 +51,8 @@ def download_models():
|
|
| 50 |
pbar.finish()
|
| 51 |
pbar = None
|
| 52 |
for model_name, url in MODELS.items():
|
|
|
|
|
|
|
| 53 |
if os.path.exists(f'.models/{model_name}'):
|
| 54 |
continue
|
| 55 |
print(f'Downloading {model_name} from {url}...')
|
|
@@ -145,6 +148,22 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_sa
|
|
| 145 |
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
| 146 |
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
class TextToSpeech:
|
| 149 |
"""
|
| 150 |
Main entry point into Tortoise.
|
|
|
|
| 8 |
import progressbar
|
| 9 |
import torchaudio
|
| 10 |
|
| 11 |
+
from models.classifier import AudioMiniEncoderWithClassifierHead
|
| 12 |
from models.cvvp import CVVP
|
| 13 |
from models.diffusion_decoder import DiffusionTts
|
| 14 |
from models.autoregressive import UnifiedVoice
|
|
|
|
| 25 |
pbar = None
|
| 26 |
|
| 27 |
|
| 28 |
+
def download_models(specific_models=None):
|
| 29 |
"""
|
| 30 |
Call to download all the models that Tortoise uses.
|
| 31 |
"""
|
|
|
|
| 51 |
pbar.finish()
|
| 52 |
pbar = None
|
| 53 |
for model_name, url in MODELS.items():
|
| 54 |
+
if specific_models is not None and model_name not in specific_models:
|
| 55 |
+
continue
|
| 56 |
if os.path.exists(f'.models/{model_name}'):
|
| 57 |
continue
|
| 58 |
print(f'Downloading {model_name} from {url}...')
|
|
|
|
| 148 |
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
| 149 |
|
| 150 |
|
| 151 |
+
def classify_audio_clip(clip):
|
| 152 |
+
"""
|
| 153 |
+
Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
|
| 154 |
+
:param clip: torch tensor containing audio waveform data (get it from load_audio)
|
| 155 |
+
:return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
|
| 156 |
+
"""
|
| 157 |
+
download_models(['classifier'])
|
| 158 |
+
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
|
| 159 |
+
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
|
| 160 |
+
dropout=0, kernel_size=5, distribute_zero_label=False)
|
| 161 |
+
classifier.load_state_dict(torch.load('.models/classifier.pth', map_location=torch.device('cpu')))
|
| 162 |
+
clip = clip.cpu().unsqueeze(0)
|
| 163 |
+
results = F.softmax(classifier(clip), dim=-1)
|
| 164 |
+
return results[0][0]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
class TextToSpeech:
|
| 168 |
"""
|
| 169 |
Main entry point into Tortoise.
|
is_this_from_tortoise.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from api import classify_audio_clip
|
| 4 |
+
from utils.audio import load_audio
|
| 5 |
+
|
| 6 |
+
if __name__ == '__main__':
|
| 7 |
+
parser = argparse.ArgumentParser()
|
| 8 |
+
parser.add_argument('--clip', type=str, help='Path to an audio clip to classify.', default="results/favorite_riding_hood.mp3")
|
| 9 |
+
args = parser.parse_args()
|
| 10 |
+
|
| 11 |
+
clip = load_audio(args.clip, 24000)
|
| 12 |
+
clip = clip[:, :220000]
|
| 13 |
+
prob = classify_audio_clip(clip)
|
| 14 |
+
print(f"This classifier thinks there is a {prob*100}% chance that this clip was generated from Tortoise.")
|
models/classifier.py
CHANGED
|
@@ -1,4 +1,9 @@
|
|
| 1 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
class ResBlock(nn.Module):
|
|
@@ -27,7 +32,7 @@ class ResBlock(nn.Module):
|
|
| 27 |
self.in_layers = nn.Sequential(
|
| 28 |
normalization(channels),
|
| 29 |
nn.SiLU(),
|
| 30 |
-
|
| 31 |
)
|
| 32 |
|
| 33 |
self.updown = up or down
|
|
@@ -46,18 +51,18 @@ class ResBlock(nn.Module):
|
|
| 46 |
nn.SiLU(),
|
| 47 |
nn.Dropout(p=dropout),
|
| 48 |
zero_module(
|
| 49 |
-
|
| 50 |
),
|
| 51 |
)
|
| 52 |
|
| 53 |
if self.out_channels == channels:
|
| 54 |
self.skip_connection = nn.Identity()
|
| 55 |
elif use_conv:
|
| 56 |
-
self.skip_connection =
|
| 57 |
dims, channels, self.out_channels, kernel_size, padding=padding
|
| 58 |
)
|
| 59 |
else:
|
| 60 |
-
self.skip_connection =
|
| 61 |
|
| 62 |
def forward(self, x):
|
| 63 |
if self.do_checkpoint:
|
|
@@ -94,21 +99,21 @@ class AudioMiniEncoder(nn.Module):
|
|
| 94 |
kernel_size=3):
|
| 95 |
super().__init__()
|
| 96 |
self.init = nn.Sequential(
|
| 97 |
-
|
| 98 |
)
|
| 99 |
ch = base_channels
|
| 100 |
res = []
|
| 101 |
self.layers = depth
|
| 102 |
for l in range(depth):
|
| 103 |
for r in range(resnet_blocks):
|
| 104 |
-
res.append(ResBlock(ch, dropout,
|
| 105 |
-
res.append(Downsample(ch, use_conv=True,
|
| 106 |
ch *= 2
|
| 107 |
self.res = nn.Sequential(*res)
|
| 108 |
self.final = nn.Sequential(
|
| 109 |
normalization(ch),
|
| 110 |
nn.SiLU(),
|
| 111 |
-
|
| 112 |
)
|
| 113 |
attn = []
|
| 114 |
for a in range(attn_blocks):
|
|
@@ -118,7 +123,7 @@ class AudioMiniEncoder(nn.Module):
|
|
| 118 |
|
| 119 |
def forward(self, x):
|
| 120 |
h = self.init(x)
|
| 121 |
-
h =
|
| 122 |
h = self.final(h)
|
| 123 |
for blk in self.attn:
|
| 124 |
h = checkpoint(blk, h)
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.utils.checkpoint import checkpoint
|
| 5 |
+
|
| 6 |
+
from models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
|
| 7 |
|
| 8 |
|
| 9 |
class ResBlock(nn.Module):
|
|
|
|
| 32 |
self.in_layers = nn.Sequential(
|
| 33 |
normalization(channels),
|
| 34 |
nn.SiLU(),
|
| 35 |
+
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
|
| 36 |
)
|
| 37 |
|
| 38 |
self.updown = up or down
|
|
|
|
| 51 |
nn.SiLU(),
|
| 52 |
nn.Dropout(p=dropout),
|
| 53 |
zero_module(
|
| 54 |
+
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
|
| 55 |
),
|
| 56 |
)
|
| 57 |
|
| 58 |
if self.out_channels == channels:
|
| 59 |
self.skip_connection = nn.Identity()
|
| 60 |
elif use_conv:
|
| 61 |
+
self.skip_connection = nn.Conv1d(
|
| 62 |
dims, channels, self.out_channels, kernel_size, padding=padding
|
| 63 |
)
|
| 64 |
else:
|
| 65 |
+
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
|
| 66 |
|
| 67 |
def forward(self, x):
|
| 68 |
if self.do_checkpoint:
|
|
|
|
| 99 |
kernel_size=3):
|
| 100 |
super().__init__()
|
| 101 |
self.init = nn.Sequential(
|
| 102 |
+
nn.Conv1d(spec_dim, base_channels, 3, padding=1)
|
| 103 |
)
|
| 104 |
ch = base_channels
|
| 105 |
res = []
|
| 106 |
self.layers = depth
|
| 107 |
for l in range(depth):
|
| 108 |
for r in range(resnet_blocks):
|
| 109 |
+
res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size))
|
| 110 |
+
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
|
| 111 |
ch *= 2
|
| 112 |
self.res = nn.Sequential(*res)
|
| 113 |
self.final = nn.Sequential(
|
| 114 |
normalization(ch),
|
| 115 |
nn.SiLU(),
|
| 116 |
+
nn.Conv1d(ch, embedding_dim, 1)
|
| 117 |
)
|
| 118 |
attn = []
|
| 119 |
for a in range(attn_blocks):
|
|
|
|
| 123 |
|
| 124 |
def forward(self, x):
|
| 125 |
h = self.init(x)
|
| 126 |
+
h = self.res(h)
|
| 127 |
h = self.final(h)
|
| 128 |
for blk in self.attn:
|
| 129 |
h = checkpoint(blk, h)
|