Spaces:
				
			
			
	
			
			
		Configuration error
		
	
	
	
			
			
	
	
	
	
		
		
		Configuration error
		
	| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Export a trained model from the full checkpoint (with optimizer etc.) to | |
| a final checkpoint, with only the model itself. The model is always stored as | |
| half float to gain space, and because this has zero impact on the final loss. | |
| When DiffQ was used for training, the model will actually be quantized and bitpacked.""" | |
| from argparse import ArgumentParser | |
| from fractions import Fraction | |
| import logging | |
| from pathlib import Path | |
| import sys | |
| import torch | |
| from demucs import train | |
| from demucs.states import serialize_model, save_with_checksum | |
| logger = logging.getLogger(__name__) | |
| def main(): | |
| logging.basicConfig(level=logging.INFO, stream=sys.stderr) | |
| parser = ArgumentParser("tools.export", description="Export trained models from XP sigs.") | |
| parser.add_argument('signatures', nargs='*', help='XP signatures.') | |
| parser.add_argument('-o', '--out', type=Path, default=Path("release_models"), | |
| help="Path where to store release models (default release_models)") | |
| parser.add_argument('-s', '--sign', action='store_true', | |
| help='Add sha256 prefix checksum to the filename.') | |
| args = parser.parse_args() | |
| args.out.mkdir(exist_ok=True, parents=True) | |
| for sig in args.signatures: | |
| xp = train.main.get_xp_from_sig(sig) | |
| name = train.main.get_name(xp) | |
| logger.info('Handling %s/%s', sig, name) | |
| out_path = args.out / (sig + ".th") | |
| solver = train.get_solver_from_sig(sig) | |
| if len(solver.history) < solver.args.epochs: | |
| logger.warning( | |
| 'Model %s has less epoch than expected (%d / %d)', | |
| sig, len(solver.history), solver.args.epochs) | |
| solver.model.load_state_dict(solver.best_state) | |
| pkg = serialize_model(solver.model, solver.args, solver.quantizer, half=True) | |
| if getattr(solver.model, 'use_train_segment', False): | |
| batch = solver.augment(next(iter(solver.loaders['train']))) | |
| pkg['kwargs']['segment'] = Fraction(batch.shape[-1], solver.model.samplerate) | |
| print("Override", pkg['kwargs']['segment']) | |
| valid, test = None, None | |
| for m in solver.history: | |
| if 'valid' in m: | |
| valid = m['valid'] | |
| if 'test' in m: | |
| test = m['test'] | |
| pkg['metrics'] = (valid, test) | |
| if args.sign: | |
| save_with_checksum(pkg, out_path) | |
| else: | |
| torch.save(pkg, out_path) | |
| if __name__ == '__main__': | |
| main() | |