Spaces:
				
			
			
	
			
			
		Configuration error
		
	
	
	
			
			
	
	
	
	
		
		
		Configuration error
		
	| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # Script to evaluate pretrained models. | |
| from argparse import ArgumentParser | |
| import logging | |
| import sys | |
| import torch | |
| from demucs import train, pretrained, evaluate | |
| def main(): | |
| torch.set_num_threads(1) | |
| logging.basicConfig(stream=sys.stderr, level=logging.INFO) | |
| parser = ArgumentParser("tools.test_pretrained", | |
| description="Evaluate pre-trained models or bags of models " | |
| "on MusDB.") | |
| pretrained.add_model_flags(parser) | |
| parser.add_argument('overrides', nargs='*', | |
| help='Extra overrides, e.g. test.shifts=2.') | |
| args = parser.parse_args() | |
| xp = train.main.get_xp(args.overrides) | |
| with xp.enter(): | |
| solver = train.get_solver(xp.cfg) | |
| model = pretrained.get_model_from_args(args) | |
| solver.model = model.to(solver.device) | |
| solver.model.eval() | |
| with torch.no_grad(): | |
| results = evaluate.evaluate(solver, xp.cfg.test.sdr) | |
| print(results) | |
| if __name__ == '__main__': | |
| main() | |