Spaces:
Running
Running
| import argparse | |
| from concurrent.futures import ThreadPoolExecutor | |
| import warnings | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| import utils | |
| from common.log import logger | |
| from common.stdout_wrapper import SAFE_STDOUT | |
| from config import config | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| from pyannote.audio import Inference, Model | |
| model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") | |
| inference = Inference(model, window="whole") | |
| device = torch.device(config.style_gen_config.device) | |
| inference.to(device) | |
| class NaNValueError(ValueError): | |
| """カスタム例外クラス。NaN値が見つかった場合に使用されます。""" | |
| pass | |
| # 推論時にインポートするために短いが関数を書く | |
| def get_style_vector(wav_path): | |
| return inference(wav_path) | |
| def save_style_vector(wav_path): | |
| try: | |
| style_vec = get_style_vector(wav_path) | |
| except Exception as e: | |
| print("\n") | |
| logger.error(f"Error occurred with file: {wav_path}, Details:\n{e}\n") | |
| raise | |
| # 値にNaNが含まれていると悪影響なのでチェックする | |
| if np.isnan(style_vec).any(): | |
| print("\n") | |
| logger.warning(f"NaN value found in style vector: {wav_path}") | |
| raise NaNValueError(f"NaN value found in style vector: {wav_path}") | |
| np.save(f"{wav_path}.npy", style_vec) # `test.wav` -> `test.wav.npy` | |
| def process_line(line): | |
| wavname = line.split("|")[0] | |
| try: | |
| save_style_vector(wavname) | |
| return line, None | |
| except NaNValueError: | |
| return line, "nan_error" | |
| def save_average_style_vector(style_vectors, filename="style_vectors.npy"): | |
| average_vector = np.mean(style_vectors, axis=0) | |
| np.save(filename, average_vector) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-c", "--config", type=str, default=config.style_gen_config.config_path | |
| ) | |
| parser.add_argument( | |
| "--num_processes", type=int, default=config.style_gen_config.num_processes | |
| ) | |
| args, _ = parser.parse_known_args() | |
| config_path = args.config | |
| num_processes = args.num_processes | |
| hps = utils.get_hparams_from_file(config_path) | |
| device = config.style_gen_config.device | |
| training_lines = [] | |
| with open(hps.data.training_files, encoding="utf-8") as f: | |
| training_lines.extend(f.readlines()) | |
| with ThreadPoolExecutor(max_workers=num_processes) as executor: | |
| training_results = list( | |
| tqdm( | |
| executor.map(process_line, training_lines), | |
| total=len(training_lines), | |
| file=SAFE_STDOUT, | |
| ) | |
| ) | |
| ok_training_lines = [line for line, error in training_results if error is None] | |
| nan_training_lines = [ | |
| line for line, error in training_results if error == "nan_error" | |
| ] | |
| if nan_training_lines: | |
| nan_files = [line.split("|")[0] for line in nan_training_lines] | |
| logger.warning( | |
| f"Found NaN value in {len(nan_training_lines)} files: {nan_files}, so they will be deleted from training data." | |
| ) | |
| val_lines = [] | |
| with open(hps.data.validation_files, encoding="utf-8") as f: | |
| val_lines.extend(f.readlines()) | |
| with ThreadPoolExecutor(max_workers=num_processes) as executor: | |
| val_results = list( | |
| tqdm( | |
| executor.map(process_line, val_lines), | |
| total=len(val_lines), | |
| file=SAFE_STDOUT, | |
| ) | |
| ) | |
| ok_val_lines = [line for line, error in val_results if error is None] | |
| nan_val_lines = [line for line, error in val_results if error == "nan_error"] | |
| if nan_val_lines: | |
| nan_files = [line.split("|")[0] for line in nan_val_lines] | |
| logger.warning( | |
| f"Found NaN value in {len(nan_val_lines)} files: {nan_files}, so they will be deleted from validation data." | |
| ) | |
| with open(hps.data.training_files, "w", encoding="utf-8") as f: | |
| f.writelines(ok_training_lines) | |
| with open(hps.data.validation_files, "w", encoding="utf-8") as f: | |
| f.writelines(ok_val_lines) | |
| ok_num = len(ok_training_lines) + len(ok_val_lines) | |
| logger.info(f"Finished generating style vectors! total: {ok_num} npy files.") | |