time varying speaker style
Browse files- Modules/hifigan.py +13 -5
- models.py +53 -80
- msinference.py +178 -75
- requirements.txt +1 -1
Modules/hifigan.py
CHANGED
|
@@ -12,16 +12,24 @@ import numpy as np
|
|
| 12 |
LRELU_SLOPE = 0.1
|
| 13 |
|
| 14 |
class AdaIN1d(nn.Module):
|
|
|
|
|
|
|
|
|
|
| 15 |
def __init__(self, style_dim, num_features):
|
| 16 |
super().__init__()
|
| 17 |
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
| 18 |
self.fc = nn.Linear(style_dim, num_features*2)
|
| 19 |
|
| 20 |
def forward(self, x, s):
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
class AdaINResBlock1(torch.nn.Module):
|
| 27 |
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
|
@@ -443,7 +451,7 @@ class Decoder(nn.Module):
|
|
| 443 |
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
|
| 444 |
|
| 445 |
|
| 446 |
-
def forward(self, asr, F0_curve, N, s):
|
| 447 |
if self.training:
|
| 448 |
downlist = [0, 3, 7]
|
| 449 |
F0_down = downlist[random.randint(0, 2)]
|
|
|
|
| 12 |
LRELU_SLOPE = 0.1
|
| 13 |
|
| 14 |
class AdaIN1d(nn.Module):
|
| 15 |
+
|
| 16 |
+
# used by HiFiGan & ProsodyPredictor
|
| 17 |
+
|
| 18 |
def __init__(self, style_dim, num_features):
|
| 19 |
super().__init__()
|
| 20 |
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
| 21 |
self.fc = nn.Linear(style_dim, num_features*2)
|
| 22 |
|
| 23 |
def forward(self, x, s):
|
| 24 |
+
|
| 25 |
+
s = self.fc(s) # [bs, 1024, 130]
|
| 26 |
+
s = F.interpolate(s[:, :, 0, :].transpose(1,2), x.shape[2], mode='linear') # different time-resolution than Dur
|
| 27 |
+
|
| 28 |
+
gamma, beta = torch.chunk(s, chunks=2, dim=1) # channels vary in for loop
|
| 29 |
+
|
| 30 |
+
# affine (1 + lin(x)) * inst(x) + lin(x) is this a skip connection where the weight is a lin of itself
|
| 31 |
+
|
| 32 |
+
return (1 + gamma) * self.norm(x) + beta # norm(x) = PLBERT has norm / beta&gamma = style has no norm()
|
| 33 |
|
| 34 |
class AdaINResBlock1(torch.nn.Module):
|
| 35 |
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
|
|
|
| 451 |
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
|
| 452 |
|
| 453 |
|
| 454 |
+
def forward(self, asr=None, F0_curve=None, N=None, s=None):
|
| 455 |
if self.training:
|
| 456 |
downlist = [0, 3, 7]
|
| 457 |
F0_down = downlist[random.randint(0, 2)]
|
models.py
CHANGED
|
@@ -8,7 +8,7 @@ import torch.nn.functional as F
|
|
| 8 |
from torch.nn.utils import weight_norm, spectral_norm
|
| 9 |
from Utils.ASR.models import ASRCNN
|
| 10 |
from Utils.JDC.model import JDCNet
|
| 11 |
-
from
|
| 12 |
import yaml
|
| 13 |
|
| 14 |
|
|
@@ -110,7 +110,7 @@ class ResBlk(nn.Module):
|
|
| 110 |
|
| 111 |
class StyleEncoder(nn.Module):
|
| 112 |
|
| 113 |
-
#
|
| 114 |
|
| 115 |
def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
|
| 116 |
super().__init__()
|
|
@@ -125,15 +125,20 @@ class StyleEncoder(nn.Module):
|
|
| 125 |
|
| 126 |
blocks += [nn.LeakyReLU(0.2)]
|
| 127 |
blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
|
| 128 |
-
|
|
|
|
|
|
|
| 129 |
blocks += [nn.LeakyReLU(0.2)]
|
| 130 |
self.shared = nn.Sequential(*blocks)
|
| 131 |
|
| 132 |
self.unshared = nn.Linear(dim_out, style_dim)
|
| 133 |
|
| 134 |
def forward(self, x):
|
| 135 |
-
h = self.shared(x)
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
s = self.unshared(h)
|
| 138 |
return s
|
| 139 |
|
|
@@ -289,21 +294,6 @@ class TextEncoder(nn.Module):
|
|
| 289 |
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
| 290 |
return mask
|
| 291 |
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
class AdaIN1d(nn.Module):
|
| 295 |
-
def __init__(self, style_dim, num_features):
|
| 296 |
-
super().__init__()
|
| 297 |
-
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
| 298 |
-
self.fc = nn.Linear(style_dim, num_features*2)
|
| 299 |
-
|
| 300 |
-
def forward(self, x, s):
|
| 301 |
-
h = self.fc(s)
|
| 302 |
-
h = h.view(h.size(0), h.size(1), 1)
|
| 303 |
-
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
| 304 |
-
# affine (1 + lin(x)) * inst(x) + lin(x) is this a skip connection where the weight is a lin of itself
|
| 305 |
-
return (1 + gamma) * self.norm(x) + beta # norm(x) = PLBERT has norm / beta&gamma = style has no norm()
|
| 306 |
-
|
| 307 |
class UpSample1d(nn.Module):
|
| 308 |
def __init__(self, layer_type):
|
| 309 |
super().__init__()
|
|
@@ -316,8 +306,15 @@ class UpSample1d(nn.Module):
|
|
| 316 |
return F.interpolate(x, scale_factor=2, mode='nearest')
|
| 317 |
|
| 318 |
class AdainResBlk1d(nn.Module):
|
| 319 |
-
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
super().__init__()
|
| 322 |
self.actv = actv
|
| 323 |
self.upsample_type = upsample
|
|
@@ -362,26 +359,22 @@ class AdainResBlk1d(nn.Module):
|
|
| 362 |
return out
|
| 363 |
|
| 364 |
class AdaLayerNorm(nn.Module):
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
| 366 |
super().__init__()
|
| 367 |
-
self.channels = channels
|
| 368 |
self.eps = eps
|
| 369 |
-
|
| 370 |
-
self.fc = nn.Linear(style_dim, channels*2)
|
| 371 |
|
| 372 |
def forward(self, x, s):
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
h = self.fc(s)
|
| 377 |
-
h = h.view(h.size(0), h.size(1), 1)
|
| 378 |
-
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
| 379 |
-
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
| 380 |
|
| 381 |
-
|
| 382 |
-
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
| 383 |
x = (1 + gamma) * x + beta
|
| 384 |
-
return x
|
| 385 |
|
| 386 |
class ProsodyPredictor(nn.Module):
|
| 387 |
|
|
@@ -414,7 +407,12 @@ class ProsodyPredictor(nn.Module):
|
|
| 414 |
x, _ = self.shared(x.transpose(-1, -2))
|
| 415 |
|
| 416 |
F0 = x.transpose(-1, -2)
|
|
|
|
|
|
|
| 417 |
for block in self.F0:
|
|
|
|
|
|
|
|
|
|
| 418 |
F0 = block(F0, s)
|
| 419 |
F0 = self.F0_proj(F0)
|
| 420 |
|
|
@@ -452,21 +450,30 @@ class DurationEncoder(nn.Module):
|
|
| 452 |
def forward(self, x, style, text_lengths, m):
|
| 453 |
masks = m.to(text_lengths.device)
|
| 454 |
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
x
|
| 458 |
-
x.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
-
|
| 461 |
input_lengths = text_lengths.cpu().numpy()
|
| 462 |
-
x = x.transpose(-1, -2)
|
| 463 |
|
| 464 |
for block in self.lstms:
|
| 465 |
if isinstance(block, AdaLayerNorm):
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
x
|
|
|
|
|
|
|
| 469 |
else:
|
|
|
|
| 470 |
x = x.transpose(-1, -2)
|
| 471 |
x = nn.utils.rnn.pack_padded_sequence(
|
| 472 |
x, input_lengths, batch_first=True, enforce_sorted=False)
|
|
@@ -481,6 +488,7 @@ class DurationEncoder(nn.Module):
|
|
| 481 |
|
| 482 |
x_pad[:, :, :x.shape[-1]] = x
|
| 483 |
x = x_pad.to(x.device)
|
|
|
|
| 484 |
# print('Calling Duration Encoder\n\n\n\n',x.shape, x.min(), x.max())
|
| 485 |
# Calling Duration Encoder
|
| 486 |
# torch.Size([1, 640, 107]) tensor(-3.0903, device='cuda:0') tensor(2.3089, device='cuda:0')
|
|
@@ -493,7 +501,6 @@ def load_F0_models(path):
|
|
| 493 |
# load F0 model
|
| 494 |
|
| 495 |
F0_model = JDCNet(num_class=1, seq_len=192)
|
| 496 |
-
print(path, 'WHAT ARE YOU TRYING TO LOAD F0 L520')
|
| 497 |
path = path.replace('.t7', '.pth')
|
| 498 |
params = torch.load(path, map_location='cpu')['net']
|
| 499 |
F0_model.load_state_dict(params)
|
|
@@ -524,37 +531,3 @@ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
|
|
| 524 |
_ = asr_model.train()
|
| 525 |
|
| 526 |
return asr_model
|
| 527 |
-
|
| 528 |
-
def build_model(args, text_aligner, pitch_extractor, bert):
|
| 529 |
-
print(f'\n==============\n {args.decoder.type=}\n==============L584 models.py @ build_model()\n')
|
| 530 |
-
|
| 531 |
-
from Modules.hifigan import Decoder
|
| 532 |
-
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
| 533 |
-
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
| 534 |
-
upsample_rates = args.decoder.upsample_rates,
|
| 535 |
-
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
| 536 |
-
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
| 537 |
-
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
|
| 538 |
-
|
| 539 |
-
text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
|
| 540 |
-
|
| 541 |
-
predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
|
| 542 |
-
|
| 543 |
-
style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
|
| 544 |
-
predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder
|
| 545 |
-
nets = Munch(
|
| 546 |
-
bert=bert,
|
| 547 |
-
bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
|
| 548 |
-
|
| 549 |
-
predictor=predictor,
|
| 550 |
-
decoder=decoder,
|
| 551 |
-
text_encoder=text_encoder,
|
| 552 |
-
|
| 553 |
-
predictor_encoder=predictor_encoder,
|
| 554 |
-
style_encoder=style_encoder,
|
| 555 |
-
|
| 556 |
-
text_aligner = text_aligner,
|
| 557 |
-
pitch_extractor=pitch_extractor
|
| 558 |
-
)
|
| 559 |
-
|
| 560 |
-
return nets
|
|
|
|
| 8 |
from torch.nn.utils import weight_norm, spectral_norm
|
| 9 |
from Utils.ASR.models import ASRCNN
|
| 10 |
from Utils.JDC.model import JDCNet
|
| 11 |
+
from Modules.hifigan import AdaIN1d
|
| 12 |
import yaml
|
| 13 |
|
| 14 |
|
|
|
|
| 110 |
|
| 111 |
class StyleEncoder(nn.Module):
|
| 112 |
|
| 113 |
+
# for both acoustic & prosodic ref_s/p
|
| 114 |
|
| 115 |
def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
|
| 116 |
super().__init__()
|
|
|
|
| 125 |
|
| 126 |
blocks += [nn.LeakyReLU(0.2)]
|
| 127 |
blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
|
| 128 |
+
|
| 129 |
+
# blocks += [nn.AdaptiveAvgPool2d(1)] # THIS AVERAGES THE TIME-FRAMES OF SPEAKER STYLE
|
| 130 |
+
|
| 131 |
blocks += [nn.LeakyReLU(0.2)]
|
| 132 |
self.shared = nn.Sequential(*blocks)
|
| 133 |
|
| 134 |
self.unshared = nn.Linear(dim_out, style_dim)
|
| 135 |
|
| 136 |
def forward(self, x):
|
| 137 |
+
h = self.shared(x) # [bs, 512, 1, 11]
|
| 138 |
+
|
| 139 |
+
h = h.mean(3, keepdims=True) # UN COMMENT FOR TIME INVARIANT GLOBAL SPEAKER STYLE
|
| 140 |
+
|
| 141 |
+
h = h.transpose(1, 3)
|
| 142 |
s = self.unshared(h)
|
| 143 |
return s
|
| 144 |
|
|
|
|
| 294 |
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
| 295 |
return mask
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
class UpSample1d(nn.Module):
|
| 298 |
def __init__(self, layer_type):
|
| 299 |
super().__init__()
|
|
|
|
| 306 |
return F.interpolate(x, scale_factor=2, mode='nearest')
|
| 307 |
|
| 308 |
class AdainResBlk1d(nn.Module):
|
| 309 |
+
|
| 310 |
+
# only instantiated in ProsodyPredictor
|
| 311 |
+
|
| 312 |
+
def __init__(self, dim_in,
|
| 313 |
+
dim_out,
|
| 314 |
+
style_dim=64,
|
| 315 |
+
actv=nn.LeakyReLU(0.2),
|
| 316 |
+
upsample='none',
|
| 317 |
+
dropout_p=0.0):
|
| 318 |
super().__init__()
|
| 319 |
self.actv = actv
|
| 320 |
self.upsample_type = upsample
|
|
|
|
| 359 |
return out
|
| 360 |
|
| 361 |
class AdaLayerNorm(nn.Module):
|
| 362 |
+
|
| 363 |
+
# only instantianted in DurationPredictor()
|
| 364 |
+
|
| 365 |
+
def __init__(self, style_dim, channels=None, eps=1e-5):
|
| 366 |
super().__init__()
|
|
|
|
| 367 |
self.eps = eps
|
| 368 |
+
self.fc = nn.Linear(style_dim, 1024)
|
|
|
|
| 369 |
|
| 370 |
def forward(self, x, s):
|
| 371 |
+
h = self.fc(s.transpose(1, 2)) # has to be transposed due to interpolate needing the last dim to be frames
|
| 372 |
+
gamma = h[:, :, :512]
|
| 373 |
+
beta = h[:, :, 512:1024]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
+
x = F.layer_norm(x.transpose(1, 2), (512, ), eps=self.eps)
|
|
|
|
| 376 |
x = (1 + gamma) * x + beta
|
| 377 |
+
return x # [1, 75, 512]
|
| 378 |
|
| 379 |
class ProsodyPredictor(nn.Module):
|
| 380 |
|
|
|
|
| 407 |
x, _ = self.shared(x.transpose(-1, -2))
|
| 408 |
|
| 409 |
F0 = x.transpose(-1, -2)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
for block in self.F0:
|
| 413 |
+
print(f'F)N {F0.shape=} {s.shape=}\n')
|
| 414 |
+
# )N F0.shape=torch.Size([1, 512, 147]) s.shape=torch.Size([1, 128])
|
| 415 |
+
|
| 416 |
F0 = block(F0, s)
|
| 417 |
F0 = self.F0_proj(F0)
|
| 418 |
|
|
|
|
| 450 |
def forward(self, x, style, text_lengths, m):
|
| 451 |
masks = m.to(text_lengths.device)
|
| 452 |
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# x : [bs, 512, 987]
|
| 456 |
+
# print('DURATION ENCODER', x.shape, style.shape, masks.shape)
|
| 457 |
+
# s = style.expand(x.shape[0], x.shape[1], -1)
|
| 458 |
+
style = style[:, :, 0, :].transpose(2, 1) # [bs, 128, 11]
|
| 459 |
+
# print("S IN DURATION ENC", style.shape, x.shape)
|
| 460 |
+
style = F.interpolate(style, x.shape[2])
|
| 461 |
+
print(f'L468 IN DURATION ENC {x.shape=}, {style.shape=} {masks.shape=}') # mask = [1,75]
|
| 462 |
+
x = torch.cat([x, style], axis=1) # [bs, 640, 75]
|
| 463 |
+
x.masked_fill_(masks[:, None, :], 0.0)
|
| 464 |
|
| 465 |
+
|
| 466 |
input_lengths = text_lengths.cpu().numpy()
|
|
|
|
| 467 |
|
| 468 |
for block in self.lstms:
|
| 469 |
if isinstance(block, AdaLayerNorm):
|
| 470 |
+
|
| 471 |
+
print(f'\n=========ENTER ADALAYNORM L479 models.py {x.shape=}, {style.shape=}')
|
| 472 |
+
x = block(x, style) # [bs, 75, 512]
|
| 473 |
+
x = torch.cat([x.transpose(1, 2), style], axis=1) # [bs, 512, 75]
|
| 474 |
+
x.masked_fill_(masks[:, None, :], 0.0)
|
| 475 |
else:
|
| 476 |
+
# print(f'{x.shape=} ENTER LSTM') # [bs, 640, 75] LSTM reduce ch 640 -> 512
|
| 477 |
x = x.transpose(-1, -2)
|
| 478 |
x = nn.utils.rnn.pack_padded_sequence(
|
| 479 |
x, input_lengths, batch_first=True, enforce_sorted=False)
|
|
|
|
| 488 |
|
| 489 |
x_pad[:, :, :x.shape[-1]] = x
|
| 490 |
x = x_pad.to(x.device)
|
| 491 |
+
# print(f'{x.shape=} EXIR LSTM') # [bs, 512, 75]
|
| 492 |
# print('Calling Duration Encoder\n\n\n\n',x.shape, x.min(), x.max())
|
| 493 |
# Calling Duration Encoder
|
| 494 |
# torch.Size([1, 640, 107]) tensor(-3.0903, device='cuda:0') tensor(2.3089, device='cuda:0')
|
|
|
|
| 501 |
# load F0 model
|
| 502 |
|
| 503 |
F0_model = JDCNet(num_class=1, seq_len=192)
|
|
|
|
| 504 |
path = path.replace('.t7', '.pth')
|
| 505 |
params = torch.load(path, map_location='cpu')['net']
|
| 506 |
F0_model.load_state_dict(params)
|
|
|
|
| 531 |
_ = asr_model.train()
|
| 532 |
|
| 533 |
return asr_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
msinference.py
CHANGED
|
@@ -7,8 +7,7 @@ import numpy as np
|
|
| 7 |
import yaml
|
| 8 |
import torchaudio
|
| 9 |
import librosa
|
| 10 |
-
from models import
|
| 11 |
-
from munch import Munch
|
| 12 |
from nltk.tokenize import word_tokenize
|
| 13 |
|
| 14 |
torch.manual_seed(0)
|
|
@@ -62,17 +61,6 @@ def alpha_num(f):
|
|
| 62 |
return f
|
| 63 |
|
| 64 |
|
| 65 |
-
|
| 66 |
-
def recursive_munch(d):
|
| 67 |
-
if isinstance(d, dict):
|
| 68 |
-
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
| 69 |
-
elif isinstance(d, list):
|
| 70 |
-
return [recursive_munch(v) for v in d]
|
| 71 |
-
else:
|
| 72 |
-
return d
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
# ======== UTILS ABOVE
|
| 77 |
|
| 78 |
def length_to_mask(lengths):
|
|
@@ -94,10 +82,10 @@ def compute_style(path):
|
|
| 94 |
mel_tensor = preprocess(audio).to(device)
|
| 95 |
|
| 96 |
with torch.no_grad():
|
| 97 |
-
ref_s =
|
| 98 |
-
ref_p =
|
| 99 |
-
|
| 100 |
-
return torch.cat([ref_s, ref_p], dim=1
|
| 101 |
|
| 102 |
device = 'cpu'
|
| 103 |
if torch.cuda.is_available():
|
|
@@ -112,50 +100,151 @@ global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_
|
|
| 112 |
# phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
|
| 113 |
|
| 114 |
|
| 115 |
-
|
|
|
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
ASR_path = config.get('ASR_path', False)
|
| 120 |
-
text_aligner = load_ASR_models(ASR_path, ASR_config)
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
pitch_extractor = load_F0_models(F0_path)
|
| 125 |
|
| 126 |
-
# load BERT model
|
| 127 |
from Utils.PLBERT.util import load_plbert
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
# params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
|
| 137 |
# params_whole = torch.load('freevc2/yl4579_styletts2.pth' map_location='cpu')
|
| 138 |
params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
|
| 139 |
params = params_whole['net']
|
| 140 |
|
| 141 |
-
for key in model:
|
| 142 |
-
if key in params:
|
| 143 |
-
print('%s loaded' % key)
|
| 144 |
-
try:
|
| 145 |
-
model[key].load_state_dict(params[key])
|
| 146 |
-
except:
|
| 147 |
-
from collections import OrderedDict
|
| 148 |
-
state_dict = params[key]
|
| 149 |
-
new_state_dict = OrderedDict()
|
| 150 |
-
for k, v in state_dict.items():
|
| 151 |
-
name = k[7:] # remove `module.`
|
| 152 |
-
new_state_dict[name] = v
|
| 153 |
-
# load params
|
| 154 |
-
model[key].load_state_dict(new_state_dict, strict=False)
|
| 155 |
-
# except:
|
| 156 |
-
# _load(params[key], model[key])
|
| 157 |
-
_ = [model[key].eval() for key in model]
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
def inference(text,
|
|
@@ -193,24 +282,31 @@ def inference(text,
|
|
| 193 |
# 54, 156, 63, 158, 147, 83, 56, 16, 4]], device='cuda:0')
|
| 194 |
|
| 195 |
|
| 196 |
-
t_en =
|
| 197 |
-
bert_dur =
|
| 198 |
-
d_en =
|
| 199 |
# print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler
|
| 200 |
# BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11])
|
| 201 |
|
| 202 |
|
| 203 |
|
| 204 |
-
ref = ref_s[:, :128]
|
| 205 |
-
s = ref_s[:, 128:]
|
| 206 |
|
| 207 |
-
# s = .74 * s # prosody / arousal & fading unvoiced syllabes [x0.7 - x1.2]
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
-
x, _ =
|
| 213 |
-
duration =
|
| 214 |
|
| 215 |
duration = torch.sigmoid(duration).sum(axis=-1)
|
| 216 |
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
|
@@ -224,23 +320,25 @@ def inference(text,
|
|
| 224 |
|
| 225 |
# encode prosody
|
| 226 |
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
F0_pred, N_pred =
|
| 234 |
|
| 235 |
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
x =
|
| 243 |
-
|
|
|
|
|
|
|
| 244 |
|
| 245 |
x = x.cpu().numpy()[0, 0, :-400] # weird pulse at the end of sentences
|
| 246 |
|
|
@@ -299,6 +397,11 @@ import re
|
|
| 299 |
from num2words import num2words
|
| 300 |
|
| 301 |
PHONEME_MAP = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
'q': 'ku',
|
| 303 |
'w': 'aou',
|
| 304 |
'z': 's',
|
|
|
|
| 7 |
import yaml
|
| 8 |
import torchaudio
|
| 9 |
import librosa
|
| 10 |
+
from models import ProsodyPredictor, TextEncoder, StyleEncoder, load_ASR_models, load_F0_models
|
|
|
|
| 11 |
from nltk.tokenize import word_tokenize
|
| 12 |
|
| 13 |
torch.manual_seed(0)
|
|
|
|
| 61 |
return f
|
| 62 |
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
# ======== UTILS ABOVE
|
| 65 |
|
| 66 |
def length_to_mask(lengths):
|
|
|
|
| 82 |
mel_tensor = preprocess(audio).to(device)
|
| 83 |
|
| 84 |
with torch.no_grad():
|
| 85 |
+
ref_s = style_encoder(mel_tensor.unsqueeze(1))
|
| 86 |
+
ref_p = predictor_encoder(mel_tensor.unsqueeze(1)) # [bs, 11, 1, 128]
|
| 87 |
+
print(f'\n\n\n\nCOMPUTE STYLe {ref_s.shape=} {ref_p.shape=}')
|
| 88 |
+
return torch.cat([ref_s, ref_p], dim=3) # [bs, 11, 1, 256]
|
| 89 |
|
| 90 |
device = 'cpu'
|
| 91 |
if torch.cuda.is_available():
|
|
|
|
| 100 |
# phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
|
| 101 |
|
| 102 |
|
| 103 |
+
args = yaml.safe_load(open(str('Utils/config.yml')))
|
| 104 |
+
ASR_config = args['ASR_config']
|
| 105 |
|
| 106 |
+
ASR_path = args['ASR_path']
|
| 107 |
+
text_aligner = load_ASR_models(ASR_path, ASR_config).eval().to(device)
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
F0_path = args['F0_path']
|
| 110 |
+
pitch_extractor = load_F0_models(F0_path).eval().to(device)
|
|
|
|
| 111 |
|
|
|
|
| 112 |
from Utils.PLBERT.util import load_plbert
|
| 113 |
+
bert = load_plbert(args['PLBERT_dir']).eval().to(device)
|
| 114 |
+
# model_params = recursive_munch(config['model_params'])
|
| 115 |
+
# --
|
| 116 |
+
# def build_model(args, text_aligner, pitch_extractor, bert):
|
| 117 |
+
# print(f'\n==============\n {args.decoder.type=}\n==============L584 models.py @ build_model()\n')
|
| 118 |
+
# # ======================================
|
| 119 |
+
# In [4]: args['model_params']
|
| 120 |
+
# Out[4]:
|
| 121 |
+
# {'decoder': {'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 122 |
+
# 'resblock_kernel_sizes': [3, 7, 11],
|
| 123 |
+
# 'type': 'hifigan',
|
| 124 |
+
# 'upsample_initial_channel': 512,
|
| 125 |
+
# 'upsample_kernel_sizes': [20, 10, 6, 4],
|
| 126 |
+
# 'upsample_rates': [10, 5, 3, 2]},
|
| 127 |
+
# 'diffusion': {'dist': {'estimate_sigma_data': True,
|
| 128 |
+
# 'mean': -3.0,
|
| 129 |
+
# 'sigma_data': 0.19926648961191362,
|
| 130 |
+
# 'std': 1.0},
|
| 131 |
+
# 'embedding_mask_proba': 0.1,
|
| 132 |
+
# 'transformer': {'head_features': 64,
|
| 133 |
+
# 'multiplier': 2,
|
| 134 |
+
# 'num_heads': 8,
|
| 135 |
+
# 'num_layers': 3}},
|
| 136 |
+
# 'dim_in': 64,
|
| 137 |
+
# 'dropout': 0.2,
|
| 138 |
+
# 'hidden_dim': 512,
|
| 139 |
+
# 'max_conv_dim': 512,
|
| 140 |
+
# 'max_dur': 50,
|
| 141 |
+
# 'multispeaker': True,
|
| 142 |
+
# 'n_layer': 3,
|
| 143 |
+
# 'n_mels': 80,
|
| 144 |
+
# 'n_token': 178,
|
| 145 |
+
# 'slm': {'hidden': 768,
|
| 146 |
+
# 'initial_channel': 64,
|
| 147 |
+
# 'model': 'microsoft/wavlm-base-plus',
|
| 148 |
+
# 'nlayers': 13,
|
| 149 |
+
# 'sr': 16000},
|
| 150 |
+
# 'style_dim': 128}
|
| 151 |
+
# # ===============================================
|
| 152 |
+
from Modules.hifigan import Decoder
|
| 153 |
+
decoder = Decoder(dim_in=512,
|
| 154 |
+
style_dim=128,
|
| 155 |
+
dim_out=80, # n_mels
|
| 156 |
+
resblock_kernel_sizes = [3, 7, 11],
|
| 157 |
+
upsample_rates = [10, 5, 3, 2],
|
| 158 |
+
upsample_initial_channel=512,
|
| 159 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 160 |
+
upsample_kernel_sizes=[20, 10, 6, 4]).eval().to(device)
|
| 161 |
+
|
| 162 |
+
text_encoder = TextEncoder(channels=512,
|
| 163 |
+
kernel_size=5,
|
| 164 |
+
depth=3, #args['model_params']['n_layer'],
|
| 165 |
+
n_symbols=178, #args['model_params']['n_token']
|
| 166 |
+
).eval().to(device)
|
| 167 |
+
|
| 168 |
+
predictor = ProsodyPredictor(style_dim=128,
|
| 169 |
+
d_hid=512,
|
| 170 |
+
nlayers=3, # OFFICIAL config.nlayers=5;
|
| 171 |
+
max_dur=50,
|
| 172 |
+
dropout=.2).eval().to(device)
|
| 173 |
+
|
| 174 |
+
style_encoder = StyleEncoder(dim_in=64,
|
| 175 |
+
style_dim=128,
|
| 176 |
+
max_conv_dim=512).eval().to(device) # acoustic style encoder
|
| 177 |
+
predictor_encoder = StyleEncoder(dim_in=64,
|
| 178 |
+
style_dim=128,
|
| 179 |
+
max_conv_dim=512).eval().to(device) # prosodic style encoder
|
| 180 |
+
bert_encoder = torch.nn.Linear(bert.config.hidden_size, 512).eval().to(device)
|
| 181 |
+
# --
|
| 182 |
+
# model = build_model(model_params, text_aligner, pitch_extractor, plbert)
|
| 183 |
+
# _ = [model[key].eval() for key in model]
|
| 184 |
+
# _ = [model[key].to(device) for key in model]
|
| 185 |
|
| 186 |
# params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
|
| 187 |
# params_whole = torch.load('freevc2/yl4579_styletts2.pth' map_location='cpu')
|
| 188 |
params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
|
| 189 |
params = params_whole['net']
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
+
# 'bert',
|
| 193 |
+
# 'bert_encoder',
|
| 194 |
+
# 'predictor',
|
| 195 |
+
# 'decoder',
|
| 196 |
+
# 'text_encoder',
|
| 197 |
+
# 'predictor_encoder',
|
| 198 |
+
# 'style_encoder',
|
| 199 |
+
# 'text_aligner',
|
| 200 |
+
# 'pitch_extractor'
|
| 201 |
+
# --
|
| 202 |
+
from collections import OrderedDict
|
| 203 |
+
|
| 204 |
+
new_state_dict = OrderedDict()
|
| 205 |
+
for k, v in params['bert'].items():
|
| 206 |
+
new_state_dict[k[7:]] = v # del 'module.'
|
| 207 |
+
bert.load_state_dict(new_state_dict, strict=True)
|
| 208 |
+
# --
|
| 209 |
+
new_state_dict = OrderedDict()
|
| 210 |
+
for k, v in params['bert_encoder'].items():
|
| 211 |
+
new_state_dict[k[7:]] = v # del 'module.'
|
| 212 |
+
bert_encoder.load_state_dict(new_state_dict, strict=True)
|
| 213 |
+
# --
|
| 214 |
+
new_state_dict = OrderedDict()
|
| 215 |
+
for k, v in params['predictor'].items():
|
| 216 |
+
new_state_dict[k[7:]] = v # del 'module.'
|
| 217 |
+
predictor.load_state_dict(new_state_dict, strict=True) # XTRA non-ckpt LSTMs nlayers add slowiness to voice
|
| 218 |
+
# --
|
| 219 |
+
new_state_dict = OrderedDict()
|
| 220 |
+
for k, v in params['decoder'].items():
|
| 221 |
+
new_state_dict[k[7:]] = v
|
| 222 |
+
decoder.load_state_dict(new_state_dict, strict=True)
|
| 223 |
+
# --
|
| 224 |
+
new_state_dict = OrderedDict()
|
| 225 |
+
for k, v in params['text_encoder'].items():
|
| 226 |
+
new_state_dict[k[7:]] = v
|
| 227 |
+
text_encoder.load_state_dict(new_state_dict, strict=True)
|
| 228 |
+
# --
|
| 229 |
+
new_state_dict = OrderedDict()
|
| 230 |
+
for k, v in params['predictor_encoder'].items():
|
| 231 |
+
new_state_dict[k[7:]] = v
|
| 232 |
+
predictor_encoder.load_state_dict(new_state_dict, strict=True)
|
| 233 |
+
# --
|
| 234 |
+
new_state_dict = OrderedDict()
|
| 235 |
+
for k, v in params['style_encoder'].items():
|
| 236 |
+
new_state_dict[k[7:]] = v
|
| 237 |
+
style_encoder.load_state_dict(new_state_dict, strict=True)
|
| 238 |
+
# --
|
| 239 |
+
new_state_dict = OrderedDict()
|
| 240 |
+
for k, v in params['text_aligner'].items():
|
| 241 |
+
new_state_dict[k[7:]] = v # del 'module.'
|
| 242 |
+
text_aligner.load_state_dict(new_state_dict, strict=True)
|
| 243 |
+
# --
|
| 244 |
+
new_state_dict = OrderedDict()
|
| 245 |
+
for k, v in params['pitch_extractor'].items():
|
| 246 |
+
new_state_dict[k[7:]] = v
|
| 247 |
+
pitch_extractor.load_state_dict(new_state_dict, strict=True)
|
| 248 |
|
| 249 |
|
| 250 |
def inference(text,
|
|
|
|
| 282 |
# 54, 156, 63, 158, 147, 83, 56, 16, 4]], device='cuda:0')
|
| 283 |
|
| 284 |
|
| 285 |
+
t_en = text_encoder(tokens, input_lengths, text_mask)
|
| 286 |
+
bert_dur = bert(tokens, attention_mask=(~text_mask).int())
|
| 287 |
+
d_en = bert_encoder(bert_dur).transpose(-1, -2)
|
| 288 |
# print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler
|
| 289 |
# BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11])
|
| 290 |
|
| 291 |
|
| 292 |
|
| 293 |
+
ref = ref_s[:, :, :, :128] # [bs, 11, 1, 128]
|
| 294 |
+
s = ref_s[:, :, :, 128:] # have channels as last dim so it can go through nn.Linear layers
|
| 295 |
|
|
|
|
| 296 |
|
| 297 |
+
# ON compute style we dont know yet the size to interpolate
|
| 298 |
+
# Perhaps we can interpolate ref_s here as now we know how many bert time-frames the text needs
|
| 299 |
+
# s = .74 * s # prosody / arousal & fading unvoiced syllabes [x0.7 - x1.2]
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
print(f'{d_en.shape=} {s.shape=} {input_lengths.shape=} {text_mask.shape=}')
|
| 303 |
+
d = predictor.text_encoder(d_en,
|
| 304 |
+
s,
|
| 305 |
+
input_lengths,
|
| 306 |
+
text_mask)
|
| 307 |
|
| 308 |
+
x, _ = predictor.lstm(d)
|
| 309 |
+
duration = predictor.duration_proj(x)
|
| 310 |
|
| 311 |
duration = torch.sigmoid(duration).sum(axis=-1)
|
| 312 |
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
|
|
|
| 320 |
|
| 321 |
# encode prosody
|
| 322 |
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
| 323 |
+
|
| 324 |
+
asr_new = torch.zeros_like(en)
|
| 325 |
+
asr_new[:, :, 0] = en[:, :, 0]
|
| 326 |
+
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
| 327 |
+
en = asr_new
|
| 328 |
+
print('_________________________________________F0_____________________________')
|
| 329 |
+
F0_pred, N_pred = predictor.F0Ntrain(en, s)
|
| 330 |
|
| 331 |
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
| 332 |
+
|
| 333 |
+
asr_new = torch.zeros_like(asr)
|
| 334 |
+
asr_new[:, :, 0] = asr[:, :, 0]
|
| 335 |
+
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
| 336 |
+
asr = asr_new
|
| 337 |
+
print('_________________________________________HiFI_____________________________')
|
| 338 |
+
x = decoder(asr=asr,
|
| 339 |
+
F0_curve=F0_pred,
|
| 340 |
+
N=N_pred,
|
| 341 |
+
s=ref)
|
| 342 |
|
| 343 |
x = x.cpu().numpy()[0, 0, :-400] # weird pulse at the end of sentences
|
| 344 |
|
|
|
|
| 397 |
from num2words import num2words
|
| 398 |
|
| 399 |
PHONEME_MAP = {
|
| 400 |
+
'služ' : 'sloooozz', # 'službeno'
|
| 401 |
+
'suver': 'siuveeerra', # 'suverena'
|
| 402 |
+
'država': 'dirrezav', # 'država'
|
| 403 |
+
'iči': 'ici', # 'Graniči'
|
| 404 |
+
's ': 'se', # a s with space
|
| 405 |
'q': 'ku',
|
| 406 |
'w': 'aou',
|
| 407 |
'z': 's',
|
requirements.txt
CHANGED
|
@@ -13,7 +13,7 @@ omegaconf
|
|
| 13 |
opencv-python
|
| 14 |
soundfile
|
| 15 |
transformers
|
| 16 |
-
|
| 17 |
srt
|
| 18 |
nltk
|
| 19 |
phonemizer
|
|
|
|
| 13 |
opencv-python
|
| 14 |
soundfile
|
| 15 |
transformers
|
| 16 |
+
audresample
|
| 17 |
srt
|
| 18 |
nltk
|
| 19 |
phonemizer
|