Spaces:
				
			
			
	
			
			
		Configuration error
		
	
	
	
			
			
	
	
	
	
		
		
		Configuration error
		
	| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| benchmarking script, useful to check for OOM, reasonable train time, | |
| and for the MDX competion, estimate if we will match the time limit.""" | |
| from contextlib import contextmanager | |
| import logging | |
| import sys | |
| import time | |
| import torch | |
| from demucs.train import get_solver, main | |
| from demucs.apply import apply_model | |
| logging.basicConfig(level=logging.INFO, stream=sys.stderr) | |
| class Result: | |
| pass | |
| def bench(): | |
| import gc | |
| gc.collect() | |
| torch.cuda.reset_max_memory_allocated() | |
| torch.cuda.empty_cache() | |
| result = Result() | |
| # before = torch.cuda.memory_allocated() | |
| before = 0 | |
| begin = time.time() | |
| try: | |
| yield result | |
| finally: | |
| torch.cuda.synchronize() | |
| mem = (torch.cuda.max_memory_allocated() - before) / 2 ** 20 | |
| tim = time.time() - begin | |
| result.mem = mem | |
| result.tim = tim | |
| xp = main.get_xp_from_sig(sys.argv[1]) | |
| xp = main.get_xp(xp.argv + sys.argv[2:]) | |
| with xp.enter(): | |
| solver = get_solver(xp.cfg) | |
| if getattr(solver.model, 'use_train_segment', False): | |
| batch = solver.augment(next(iter(solver.loaders['train']))) | |
| solver.model.segment = Fraction(batch.shape[-1], solver.model.samplerate) | |
| train_segment = solver.model.segment | |
| solver.model.eval() | |
| model = solver.model | |
| model.cuda() | |
| x = torch.randn(2, xp.cfg.dset.channels, int(10 * model.samplerate), device='cuda') | |
| with bench() as res: | |
| y = model(x) | |
| y.sum().backward() | |
| del y | |
| for p in model.parameters(): | |
| p.grad = None | |
| print(f"FB: {res.mem:.1f} MB, {res.tim * 1000:.1f} ms") | |
| x = torch.randn(1, xp.cfg.dset.channels, int(model.segment * model.samplerate), device='cuda') | |
| with bench() as res: | |
| with torch.no_grad(): | |
| y = model(x) | |
| del y | |
| print(f"FV: {res.mem:.1f} MB, {res.tim * 1000:.1f} ms") | |
| model.cpu() | |
| torch.set_num_threads(1) | |
| test = torch.randn(1, xp.cfg.dset.channels, model.samplerate * 40) | |
| b = time.time() | |
| apply_model(model, test, split=True, shifts=1) | |
| print("CPU 40 sec:", time.time() - b) | |