|
|
|
|
|
from transformers import PretrainedConfig |
|
|
from typing import List, Optional |
|
|
|
|
|
class QuantizerConfig(PretrainedConfig): |
|
|
model_type = "prosody_quantizer" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
l_bins: int = 320, |
|
|
emb_width: int = 64, |
|
|
mu: float = 0.99, |
|
|
levels: int = 1, |
|
|
|
|
|
|
|
|
encoder_input_emb_width: int = 3, |
|
|
encoder_output_emb_width: int = 64, |
|
|
encoder_levels: int = 1, |
|
|
encoder_downs_t: List[int] = [4], |
|
|
encoder_strides_t: List[int] = [2], |
|
|
encoder_width: int = 32, |
|
|
encoder_depth: int = 4, |
|
|
encoder_m_conv: float = 1.0, |
|
|
encoder_dilation_growth_rate: int = 3, |
|
|
|
|
|
|
|
|
decoder_input_emb_width: int = 3, |
|
|
decoder_output_emb_width: int = 64, |
|
|
decoder_levels: int = 1, |
|
|
decoder_downs_t: List[int] = [4], |
|
|
decoder_strides_t: List[int] = [2], |
|
|
decoder_width: int = 32, |
|
|
decoder_depth: int = 4, |
|
|
decoder_m_conv: float = 1.0, |
|
|
decoder_dilation_growth_rate: int = 3, |
|
|
|
|
|
|
|
|
lambda_commit: float = 0.02, |
|
|
f0_normalize: bool = True, |
|
|
intensity_normalize: bool = True, |
|
|
multispkr: str = "single", |
|
|
f0_feats: bool = False, |
|
|
f0_median: bool = False, |
|
|
|
|
|
|
|
|
learning_rate: float = 0.0002, |
|
|
adam_b1: float = 0.8, |
|
|
adam_b2: float = 0.99, |
|
|
lr_decay: float = 0.999, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
self.l_bins = l_bins |
|
|
self.emb_width = emb_width |
|
|
self.mu = mu |
|
|
self.levels = levels |
|
|
|
|
|
|
|
|
self.encoder_input_emb_width = encoder_input_emb_width |
|
|
self.encoder_output_emb_width = encoder_output_emb_width |
|
|
self.encoder_levels = encoder_levels |
|
|
self.encoder_downs_t = encoder_downs_t |
|
|
self.encoder_strides_t = encoder_strides_t |
|
|
self.encoder_width = encoder_width |
|
|
self.encoder_depth = encoder_depth |
|
|
self.encoder_m_conv = encoder_m_conv |
|
|
self.encoder_dilation_growth_rate = encoder_dilation_growth_rate |
|
|
|
|
|
|
|
|
self.decoder_input_emb_width = decoder_input_emb_width |
|
|
self.decoder_output_emb_width = decoder_output_emb_width |
|
|
self.decoder_levels = decoder_levels |
|
|
self.decoder_downs_t = decoder_downs_t |
|
|
self.decoder_strides_t = decoder_strides_t |
|
|
self.decoder_width = decoder_width |
|
|
self.decoder_depth = decoder_depth |
|
|
self.decoder_m_conv = decoder_m_conv |
|
|
self.decoder_dilation_growth_rate = decoder_dilation_growth_rate |
|
|
|
|
|
|
|
|
self.lambda_commit = lambda_commit |
|
|
self.f0_normalize = f0_normalize |
|
|
self.intensity_normalize = intensity_normalize |
|
|
self.multispkr = multispkr |
|
|
self.f0_feats = f0_feats |
|
|
self.f0_median = f0_median |
|
|
|
|
|
|
|
|
self.learning_rate = learning_rate |
|
|
self.adam_b1 = adam_b1 |
|
|
self.adam_b2 = adam_b2 |
|
|
self.lr_decay = lr_decay |
|
|
|
|
|
@property |
|
|
def f0_vq_params(self): |
|
|
return { |
|
|
"l_bins": self.l_bins, |
|
|
"emb_width": self.emb_width, |
|
|
"mu": self.mu, |
|
|
"levels": self.levels |
|
|
} |
|
|
|
|
|
@property |
|
|
def f0_encoder_params(self): |
|
|
return { |
|
|
"input_emb_width": self.encoder_input_emb_width, |
|
|
"output_emb_width": self.encoder_output_emb_width, |
|
|
"levels": self.encoder_levels, |
|
|
"downs_t": self.encoder_downs_t, |
|
|
"strides_t": self.encoder_strides_t, |
|
|
"width": self.encoder_width, |
|
|
"depth": self.encoder_depth, |
|
|
"m_conv": self.encoder_m_conv, |
|
|
"dilation_growth_rate": self.encoder_dilation_growth_rate |
|
|
} |
|
|
|
|
|
@property |
|
|
def f0_decoder_params(self): |
|
|
return { |
|
|
"input_emb_width": self.decoder_input_emb_width, |
|
|
"output_emb_width": self.decoder_output_emb_width, |
|
|
"levels": self.decoder_levels, |
|
|
"downs_t": self.decoder_downs_t, |
|
|
"strides_t": self.decoder_strides_t, |
|
|
"width": self.decoder_width, |
|
|
"depth": self.decoder_depth, |
|
|
"m_conv": self.decoder_m_conv, |
|
|
"dilation_growth_rate": self.decoder_dilation_growth_rate |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_yaml(cls, yaml_path: str): |
|
|
"""Load config from yaml file""" |
|
|
import yaml |
|
|
with open(yaml_path, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
kwargs = { |
|
|
|
|
|
**{k: v for k, v in config['f0_vq_params'].items()}, |
|
|
|
|
|
|
|
|
**{f"encoder_{k}": v for k, v in config['f0_encoder_params'].items()}, |
|
|
|
|
|
|
|
|
**{f"decoder_{k}": v for k, v in config['f0_decoder_params'].items()}, |
|
|
|
|
|
|
|
|
"lambda_commit": config.get('lambda_commit', 0.02), |
|
|
"f0_normalize": config.get('f0_normalize', True), |
|
|
"intensity_normalize": config.get('intensity_normalize', True), |
|
|
"multispkr": config.get('multispkr', "single"), |
|
|
"f0_feats": config.get('f0_feats', False), |
|
|
"f0_median": config.get('f0_median', False), |
|
|
|
|
|
|
|
|
"learning_rate": config.get('learning_rate', 0.0002), |
|
|
"adam_b1": config.get('adam_b1', 0.8), |
|
|
"adam_b2": config.get('adam_b2', 0.99), |
|
|
"lr_decay": config.get('lr_decay', 0.999), |
|
|
} |
|
|
|
|
|
return cls(**kwargs) |
|
|
|