from transformers import PretrainedConfig from typing import List, Optional class QuantizerConfig(PretrainedConfig): model_type = "prosody_quantizer" def __init__( self, # VQ parameters l_bins: int = 320, emb_width: int = 64, mu: float = 0.99, levels: int = 1, # Encoder parameters encoder_input_emb_width: int = 3, encoder_output_emb_width: int = 64, encoder_levels: int = 1, encoder_downs_t: List[int] = [4], encoder_strides_t: List[int] = [2], encoder_width: int = 32, encoder_depth: int = 4, encoder_m_conv: float = 1.0, encoder_dilation_growth_rate: int = 3, # Decoder parameters decoder_input_emb_width: int = 3, decoder_output_emb_width: int = 64, decoder_levels: int = 1, decoder_downs_t: List[int] = [4], decoder_strides_t: List[int] = [2], decoder_width: int = 32, decoder_depth: int = 4, decoder_m_conv: float = 1.0, decoder_dilation_growth_rate: int = 3, # Training parameters lambda_commit: float = 0.02, f0_normalize: bool = True, intensity_normalize: bool = True, multispkr: str = "single", f0_feats: bool = False, f0_median: bool = False, # Optional training hyperparameters learning_rate: float = 0.0002, adam_b1: float = 0.8, adam_b2: float = 0.99, lr_decay: float = 0.999, **kwargs ): super().__init__(**kwargs) # VQ parameters self.l_bins = l_bins self.emb_width = emb_width self.mu = mu self.levels = levels # Encoder parameters self.encoder_input_emb_width = encoder_input_emb_width self.encoder_output_emb_width = encoder_output_emb_width self.encoder_levels = encoder_levels self.encoder_downs_t = encoder_downs_t self.encoder_strides_t = encoder_strides_t self.encoder_width = encoder_width self.encoder_depth = encoder_depth self.encoder_m_conv = encoder_m_conv self.encoder_dilation_growth_rate = encoder_dilation_growth_rate # Decoder parameters self.decoder_input_emb_width = decoder_input_emb_width self.decoder_output_emb_width = decoder_output_emb_width self.decoder_levels = decoder_levels self.decoder_downs_t = decoder_downs_t self.decoder_strides_t = decoder_strides_t self.decoder_width = decoder_width self.decoder_depth = decoder_depth self.decoder_m_conv = decoder_m_conv self.decoder_dilation_growth_rate = decoder_dilation_growth_rate # Training parameters self.lambda_commit = lambda_commit self.f0_normalize = f0_normalize self.intensity_normalize = intensity_normalize self.multispkr = multispkr self.f0_feats = f0_feats self.f0_median = f0_median # Training hyperparameters self.learning_rate = learning_rate self.adam_b1 = adam_b1 self.adam_b2 = adam_b2 self.lr_decay = lr_decay @property def f0_vq_params(self): return { "l_bins": self.l_bins, "emb_width": self.emb_width, "mu": self.mu, "levels": self.levels } @property def f0_encoder_params(self): return { "input_emb_width": self.encoder_input_emb_width, "output_emb_width": self.encoder_output_emb_width, "levels": self.encoder_levels, "downs_t": self.encoder_downs_t, "strides_t": self.encoder_strides_t, "width": self.encoder_width, "depth": self.encoder_depth, "m_conv": self.encoder_m_conv, "dilation_growth_rate": self.encoder_dilation_growth_rate } @property def f0_decoder_params(self): return { "input_emb_width": self.decoder_input_emb_width, "output_emb_width": self.decoder_output_emb_width, "levels": self.decoder_levels, "downs_t": self.decoder_downs_t, "strides_t": self.decoder_strides_t, "width": self.decoder_width, "depth": self.decoder_depth, "m_conv": self.decoder_m_conv, "dilation_growth_rate": self.decoder_dilation_growth_rate } @classmethod def from_yaml(cls, yaml_path: str): """Load config from yaml file""" import yaml with open(yaml_path, 'r') as f: config = yaml.safe_load(f) # Convert yaml config to kwargs kwargs = { # VQ params **{k: v for k, v in config['f0_vq_params'].items()}, # Encoder params **{f"encoder_{k}": v for k, v in config['f0_encoder_params'].items()}, # Decoder params **{f"decoder_{k}": v for k, v in config['f0_decoder_params'].items()}, # Training params "lambda_commit": config.get('lambda_commit', 0.02), "f0_normalize": config.get('f0_normalize', True), "intensity_normalize": config.get('intensity_normalize', True), "multispkr": config.get('multispkr', "single"), "f0_feats": config.get('f0_feats', False), "f0_median": config.get('f0_median', False), # Training hyperparams "learning_rate": config.get('learning_rate', 0.0002), "adam_b1": config.get('adam_b1', 0.8), "adam_b2": config.get('adam_b2', 0.99), "lr_decay": config.get('lr_decay', 0.999), } return cls(**kwargs)