Spaces:
Configuration error
Configuration error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Distributed training utilities. | |
| """ | |
| import logging | |
| import pickle | |
| import numpy as np | |
| import torch | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.utils.data import DataLoader, Subset | |
| from torch.nn.parallel.distributed import DistributedDataParallel | |
| from dora import distrib as dora_distrib | |
| logger = logging.getLogger(__name__) | |
| rank = 0 | |
| world_size = 1 | |
| def init(): | |
| global rank, world_size | |
| if not torch.distributed.is_initialized(): | |
| dora_distrib.init() | |
| rank = dora_distrib.rank() | |
| world_size = dora_distrib.world_size() | |
| def average(metrics, count=1.): | |
| if isinstance(metrics, dict): | |
| keys, values = zip(*sorted(metrics.items())) | |
| values = average(values, count) | |
| return dict(zip(keys, values)) | |
| if world_size == 1: | |
| return metrics | |
| tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32) | |
| tensor *= count | |
| torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) | |
| return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist() | |
| def wrap(model): | |
| if world_size == 1: | |
| return model | |
| else: | |
| return DistributedDataParallel( | |
| model, | |
| # find_unused_parameters=True, | |
| device_ids=[torch.cuda.current_device()], | |
| output_device=torch.cuda.current_device()) | |
| def barrier(): | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| def share(obj=None, src=0): | |
| if world_size == 1: | |
| return obj | |
| size = torch.empty(1, device='cuda', dtype=torch.long) | |
| if rank == src: | |
| dump = pickle.dumps(obj) | |
| size[0] = len(dump) | |
| torch.distributed.broadcast(size, src=src) | |
| # size variable is now set to the length of pickled obj in all processes | |
| if rank == src: | |
| buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda() | |
| else: | |
| buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8) | |
| torch.distributed.broadcast(buffer, src=src) | |
| # buffer variable is now set to pickled obj in all processes | |
| if rank != src: | |
| obj = pickle.loads(buffer.cpu().numpy().tobytes()) | |
| logger.debug(f"Shared object of size {len(buffer)}") | |
| return obj | |
| def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): | |
| """ | |
| Create a dataloader properly in case of distributed training. | |
| If a gradient is going to be computed you must set `shuffle=True`. | |
| """ | |
| if world_size == 1: | |
| return klass(dataset, *args, shuffle=shuffle, **kwargs) | |
| if shuffle: | |
| # train means we will compute backward, we use DistributedSampler | |
| sampler = DistributedSampler(dataset) | |
| # We ignore shuffle, DistributedSampler already shuffles | |
| return klass(dataset, *args, **kwargs, sampler=sampler) | |
| else: | |
| # We make a manual shard, as DistributedSampler otherwise replicate some examples | |
| dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) | |
| return klass(dataset, *args, shuffle=shuffle, **kwargs) | |