Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| #coding:utf-8 | |
| import os | |
| import os.path as osp | |
| import re | |
| import sys | |
| import yaml | |
| import shutil | |
| import numpy as np | |
| import paddle | |
| import click | |
| import warnings | |
| warnings.simplefilter('ignore') | |
| from functools import reduce | |
| from munch import Munch | |
| from starganv2vc_paddle.meldataset import build_dataloader | |
| from starganv2vc_paddle.optimizers import build_optimizer | |
| from starganv2vc_paddle.models import build_model | |
| from starganv2vc_paddle.trainer import Trainer | |
| from visualdl import LogWriter | |
| from starganv2vc_paddle.Utils.ASR.models import ASRCNN | |
| from starganv2vc_paddle.Utils.JDC.model import JDCNet | |
| import logging | |
| from logging import StreamHandler | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.DEBUG) | |
| handler = StreamHandler() | |
| handler.setLevel(logging.DEBUG) | |
| logger.addHandler(handler) | |
| def main(config_path): | |
| config = yaml.safe_load(open(config_path)) | |
| log_dir = config['log_dir'] | |
| if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True) | |
| shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path))) | |
| writer = LogWriter(log_dir + "/visualdl") | |
| # write logs | |
| file_handler = logging.FileHandler(osp.join(log_dir, 'train.log')) | |
| file_handler.setLevel(logging.DEBUG) | |
| file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s')) | |
| logger.addHandler(file_handler) | |
| batch_size = config.get('batch_size', 10) | |
| epochs = config.get('epochs', 1000) | |
| save_freq = config.get('save_freq', 20) | |
| train_path = config.get('train_data', None) | |
| val_path = config.get('val_data', None) | |
| stage = config.get('stage', 'star') | |
| fp16_run = config.get('fp16_run', False) | |
| # load data | |
| train_list, val_list = get_data_path_list(train_path, val_path) | |
| train_dataloader = build_dataloader(train_list, | |
| batch_size=batch_size, | |
| num_workers=4) | |
| val_dataloader = build_dataloader(val_list, | |
| batch_size=batch_size, | |
| validation=True, | |
| num_workers=2) | |
| # load pretrained ASR model | |
| ASR_config = config.get('ASR_config', False) | |
| ASR_path = config.get('ASR_path', False) | |
| with open(ASR_config) as f: | |
| ASR_config = yaml.safe_load(f) | |
| ASR_model_config = ASR_config['model_params'] | |
| ASR_model = ASRCNN(**ASR_model_config) | |
| params = paddle.load(ASR_path)['model'] | |
| ASR_model.set_state_dict(params) | |
| _ = ASR_model.eval() | |
| # load pretrained F0 model | |
| F0_path = config.get('F0_path', False) | |
| F0_model = JDCNet(num_class=1, seq_len=192) | |
| params = paddle.load(F0_path)['net'] | |
| F0_model.set_state_dict(params) | |
| # build model | |
| model, model_ema = build_model(Munch(config['model_params']), F0_model, ASR_model) | |
| scheduler_params = { | |
| "max_lr": float(config['optimizer_params'].get('lr', 2e-4)), | |
| "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)), | |
| "epochs": epochs, | |
| "steps_per_epoch": len(train_dataloader), | |
| } | |
| scheduler_params_dict = {key: scheduler_params.copy() for key in model} | |
| scheduler_params_dict['mapping_network']['max_lr'] = 2e-6 | |
| optimizer = build_optimizer({key: model[key].parameters() for key in model}, | |
| scheduler_params_dict=scheduler_params_dict) | |
| trainer = Trainer(args=Munch(config['loss_params']), model=model, | |
| model_ema=model_ema, | |
| optimizer=optimizer, | |
| train_dataloader=train_dataloader, | |
| val_dataloader=val_dataloader, | |
| logger=logger, | |
| fp16_run=fp16_run) | |
| if config.get('pretrained_model', '') != '': | |
| trainer.load_checkpoint(config['pretrained_model'], | |
| load_only_params=config.get('load_only_params', True)) | |
| for _ in range(1, epochs+1): | |
| epoch = trainer.epochs | |
| train_results = trainer._train_epoch() | |
| eval_results = trainer._eval_epoch() | |
| results = train_results.copy() | |
| results.update(eval_results) | |
| logger.info('--- epoch %d ---' % epoch) | |
| for key, value in results.items(): | |
| if isinstance(value, float): | |
| logger.info('%-15s: %.4f' % (key, value)) | |
| writer.add_scalar(key, value, epoch) | |
| else: | |
| for v in value: | |
| writer.add_histogram('eval_spec', v, epoch) | |
| if (epoch % save_freq) == 0: | |
| trainer.save_checkpoint(osp.join(log_dir, 'epoch_%05d.pd' % epoch)) | |
| return 0 | |
| def get_data_path_list(train_path=None, val_path=None): | |
| if train_path is None: | |
| train_path = "Data/train_list.txt" | |
| if val_path is None: | |
| val_path = "Data/val_list.txt" | |
| with open(train_path, 'r') as f: | |
| train_list = f.readlines() | |
| with open(val_path, 'r') as f: | |
| val_list = f.readlines() | |
| return train_list, val_list | |
| if __name__=="__main__": | |
| main() | |