F0_Energy_joint_VQVAE_embeddings / quantizer_config.py
Daporte's picture
Create quantizer_config.py
76c0443 verified
raw
history blame
5.87 kB
from transformers import PretrainedConfig
from typing import List, Optional
class QuantizerConfig(PretrainedConfig):
model_type = "prosody_quantizer"
def __init__(
self,
# VQ parameters
l_bins: int = 320,
emb_width: int = 64,
mu: float = 0.99,
levels: int = 1,
# Encoder parameters
encoder_input_emb_width: int = 3,
encoder_output_emb_width: int = 64,
encoder_levels: int = 1,
encoder_downs_t: List[int] = [4],
encoder_strides_t: List[int] = [2],
encoder_width: int = 32,
encoder_depth: int = 4,
encoder_m_conv: float = 1.0,
encoder_dilation_growth_rate: int = 3,
# Decoder parameters
decoder_input_emb_width: int = 3,
decoder_output_emb_width: int = 64,
decoder_levels: int = 1,
decoder_downs_t: List[int] = [4],
decoder_strides_t: List[int] = [2],
decoder_width: int = 32,
decoder_depth: int = 4,
decoder_m_conv: float = 1.0,
decoder_dilation_growth_rate: int = 3,
# Training parameters
lambda_commit: float = 0.02,
f0_normalize: bool = True,
intensity_normalize: bool = True,
multispkr: str = "single",
f0_feats: bool = False,
f0_median: bool = False,
# Optional training hyperparameters
learning_rate: float = 0.0002,
adam_b1: float = 0.8,
adam_b2: float = 0.99,
lr_decay: float = 0.999,
**kwargs
):
super().__init__(**kwargs)
# VQ parameters
self.l_bins = l_bins
self.emb_width = emb_width
self.mu = mu
self.levels = levels
# Encoder parameters
self.encoder_input_emb_width = encoder_input_emb_width
self.encoder_output_emb_width = encoder_output_emb_width
self.encoder_levels = encoder_levels
self.encoder_downs_t = encoder_downs_t
self.encoder_strides_t = encoder_strides_t
self.encoder_width = encoder_width
self.encoder_depth = encoder_depth
self.encoder_m_conv = encoder_m_conv
self.encoder_dilation_growth_rate = encoder_dilation_growth_rate
# Decoder parameters
self.decoder_input_emb_width = decoder_input_emb_width
self.decoder_output_emb_width = decoder_output_emb_width
self.decoder_levels = decoder_levels
self.decoder_downs_t = decoder_downs_t
self.decoder_strides_t = decoder_strides_t
self.decoder_width = decoder_width
self.decoder_depth = decoder_depth
self.decoder_m_conv = decoder_m_conv
self.decoder_dilation_growth_rate = decoder_dilation_growth_rate
# Training parameters
self.lambda_commit = lambda_commit
self.f0_normalize = f0_normalize
self.intensity_normalize = intensity_normalize
self.multispkr = multispkr
self.f0_feats = f0_feats
self.f0_median = f0_median
# Training hyperparameters
self.learning_rate = learning_rate
self.adam_b1 = adam_b1
self.adam_b2 = adam_b2
self.lr_decay = lr_decay
@property
def f0_vq_params(self):
return {
"l_bins": self.l_bins,
"emb_width": self.emb_width,
"mu": self.mu,
"levels": self.levels
}
@property
def f0_encoder_params(self):
return {
"input_emb_width": self.encoder_input_emb_width,
"output_emb_width": self.encoder_output_emb_width,
"levels": self.encoder_levels,
"downs_t": self.encoder_downs_t,
"strides_t": self.encoder_strides_t,
"width": self.encoder_width,
"depth": self.encoder_depth,
"m_conv": self.encoder_m_conv,
"dilation_growth_rate": self.encoder_dilation_growth_rate
}
@property
def f0_decoder_params(self):
return {
"input_emb_width": self.decoder_input_emb_width,
"output_emb_width": self.decoder_output_emb_width,
"levels": self.decoder_levels,
"downs_t": self.decoder_downs_t,
"strides_t": self.decoder_strides_t,
"width": self.decoder_width,
"depth": self.decoder_depth,
"m_conv": self.decoder_m_conv,
"dilation_growth_rate": self.decoder_dilation_growth_rate
}
@classmethod
def from_yaml(cls, yaml_path: str):
"""Load config from yaml file"""
import yaml
with open(yaml_path, 'r') as f:
config = yaml.safe_load(f)
# Convert yaml config to kwargs
kwargs = {
# VQ params
**{k: v for k, v in config['f0_vq_params'].items()},
# Encoder params
**{f"encoder_{k}": v for k, v in config['f0_encoder_params'].items()},
# Decoder params
**{f"decoder_{k}": v for k, v in config['f0_decoder_params'].items()},
# Training params
"lambda_commit": config.get('lambda_commit', 0.02),
"f0_normalize": config.get('f0_normalize', True),
"intensity_normalize": config.get('intensity_normalize', True),
"multispkr": config.get('multispkr', "single"),
"f0_feats": config.get('f0_feats', False),
"f0_median": config.get('f0_median', False),
# Training hyperparams
"learning_rate": config.get('learning_rate', 0.0002),
"adam_b1": config.get('adam_b1', 0.8),
"adam_b2": config.get('adam_b2', 0.99),
"lr_decay": config.get('lr_decay', 0.999),
}
return cls(**kwargs)