Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import os | |
| import sys | |
| import datetime | |
| import time | |
| import math | |
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import torch.distributed as dist | |
| import torch.backends.cudnn as cudnn | |
| import torch.nn.functional as F | |
| from torchvision import datasets, transforms | |
| from torchvision import models as torchvision_models | |
| import utils | |
| import vision_transformer as vits | |
| from vision_transformer import DINOHead | |
| torchvision_archs = sorted(name for name in torchvision_models.__dict__ | |
| if name.islower() and not name.startswith("__") | |
| and callable(torchvision_models.__dict__[name])) | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser('DINO', add_help=False) | |
| # Model parameters | |
| parser.add_argument('--arch', default='vit_small', type=str, | |
| choices=['vit_tiny', 'vit_small', 'vit_base', 'xcit', 'deit_tiny', 'deit_small'] \ | |
| + torchvision_archs + torch.hub.list("facebookresearch/xcit"), | |
| help="""Name of architecture to train. For quick experiments with ViTs, | |
| we recommend using vit_tiny or vit_small.""") | |
| parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels | |
| of input square patches - default 16 (for 16x16 patches). Using smaller | |
| values leads to better performance but requires more memory. Applies only | |
| for ViTs (vit_tiny, vit_small and vit_base). If <16, we recommend disabling | |
| mixed precision training (--use_fp16 false) to avoid unstabilities.""") | |
| parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of | |
| the DINO head output. For complex and large datasets large values (like 65k) work well.""") | |
| parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag, | |
| help="""Whether or not to weight normalize the last layer of the DINO head. | |
| Not normalizing leads to better performance but can make the training unstable. | |
| In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""") | |
| parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA | |
| parameter for teacher update. The value is increased to 1 during training with cosine schedule. | |
| We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""") | |
| parser.add_argument('--use_bn_in_head', default=False, type=utils.bool_flag, | |
| help="Whether to use batch normalizations in projection head (Default: False)") | |
| # Temperature teacher parameters | |
| parser.add_argument('--warmup_teacher_temp', default=0.04, type=float, | |
| help="""Initial value for the teacher temperature: 0.04 works well in most cases. | |
| Try decreasing it if the training loss does not decrease.""") | |
| parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup) | |
| of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend | |
| starting with the default value of 0.04 and increase this slightly if needed.""") | |
| parser.add_argument('--warmup_teacher_temp_epochs', default=0, type=int, | |
| help='Number of warmup epochs for the teacher temperature (Default: 30).') | |
| # Training/Optimization parameters | |
| parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not | |
| to use half precision for training. Improves training time and memory requirements, | |
| but can provoke instability and slight decay of performance. We recommend disabling | |
| mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""") | |
| parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the | |
| weight decay. With ViT, a smaller value at the beginning of training works well.""") | |
| parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the | |
| weight decay. We use a cosine schedule for WD and using a larger decay by | |
| the end of training improves performance for ViTs.""") | |
| parser.add_argument('--clip_grad', type=float, default=3.0, help="""Maximal parameter | |
| gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can | |
| help optimization for larger ViT architectures. 0 for disabling.""") | |
| parser.add_argument('--batch_size_per_gpu', default=64, type=int, | |
| help='Per-GPU batch-size : number of distinct images loaded on one GPU.') | |
| parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.') | |
| parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs | |
| during which we keep the output layer fixed. Typically doing so during | |
| the first epoch helps training. Try increasing this value if the loss does not decrease.""") | |
| parser.add_argument("--lr", default=0.0005, type=float, help="""Learning rate at the end of | |
| linear warmup (highest LR used during training). The learning rate is linearly scaled | |
| with the batch size, and specified here for a reference batch size of 256.""") | |
| parser.add_argument("--warmup_epochs", default=10, type=int, | |
| help="Number of epochs for the linear learning-rate warm up.") | |
| parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the | |
| end of optimization. We use a cosine LR schedule with linear warmup.""") | |
| parser.add_argument('--optimizer', default='adamw', type=str, | |
| choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""") | |
| parser.add_argument('--drop_path_rate', type=float, default=0.1, help="stochastic depth rate") | |
| # Multi-crop parameters | |
| parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.), | |
| help="""Scale range of the cropped image before resizing, relatively to the origin image. | |
| Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we | |
| recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""") | |
| parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small | |
| local views to generate. Set this parameter to 0 to disable multi-crop training. | |
| When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """) | |
| parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4), | |
| help="""Scale range of the cropped image before resizing, relatively to the origin image. | |
| Used for small local view cropping of multi-crop.""") | |
| # Misc | |
| parser.add_argument('--data_path', default='/path/to/imagenet/train/', type=str, | |
| help='Please specify path to the ImageNet training data.') | |
| parser.add_argument('--output_dir', default=".", type=str, help='Path to save logs and checkpoints.') | |
| parser.add_argument('--saveckp_freq', default=20, type=int, help='Save checkpoint every x epochs.') | |
| parser.add_argument('--seed', default=0, type=int, help='Random seed.') | |
| parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') | |
| parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up | |
| distributed training; see https://pytorch.org/docs/stable/distributed.html""") | |
| parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") | |
| return parser | |
| def train_dino(args): | |
| utils.init_distributed_mode(args) | |
| utils.fix_random_seeds(args.seed) | |
| print("git:\n {}\n".format(utils.get_sha())) | |
| print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) | |
| cudnn.benchmark = True | |
| # ============ preparing data ... ============ | |
| transform = DataAugmentationDINO( | |
| args.global_crops_scale, | |
| args.local_crops_scale, | |
| args.local_crops_number, | |
| ) | |
| dataset = datasets.ImageFolder(args.data_path, transform=transform) | |
| sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, | |
| sampler=sampler, | |
| batch_size=args.batch_size_per_gpu, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| print(f"Data loaded: there are {len(dataset)} images.") | |
| # ============ building student and teacher networks ... ============ | |
| # we changed the name DeiT-S for ViT-S to avoid confusions | |
| args.arch = args.arch.replace("deit", "vit") | |
| # if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base) | |
| if args.arch in vits.__dict__.keys(): | |
| student = vits.__dict__[args.arch]( | |
| patch_size=args.patch_size, | |
| drop_path_rate=args.drop_path_rate, # stochastic depth | |
| ) | |
| teacher = vits.__dict__[args.arch](patch_size=args.patch_size) | |
| embed_dim = student.embed_dim | |
| # if the network is a XCiT | |
| elif args.arch in torch.hub.list("facebookresearch/xcit"): | |
| student = torch.hub.load('facebookresearch/xcit', args.arch, | |
| pretrained=False, drop_path_rate=args.drop_path_rate) | |
| teacher = torch.hub.load('facebookresearch/xcit', args.arch, pretrained=False) | |
| embed_dim = student.embed_dim | |
| # otherwise, we check if the architecture is in torchvision models | |
| elif args.arch in torchvision_models.__dict__.keys(): | |
| student = torchvision_models.__dict__[args.arch]() | |
| teacher = torchvision_models.__dict__[args.arch]() | |
| embed_dim = student.fc.weight.shape[1] | |
| else: | |
| print(f"Unknow architecture: {args.arch}") | |
| # multi-crop wrapper handles forward with inputs of different resolutions | |
| student = utils.MultiCropWrapper(student, DINOHead( | |
| embed_dim, | |
| args.out_dim, | |
| use_bn=args.use_bn_in_head, | |
| norm_last_layer=args.norm_last_layer, | |
| )) | |
| teacher = utils.MultiCropWrapper( | |
| teacher, | |
| DINOHead(embed_dim, args.out_dim, args.use_bn_in_head), | |
| ) | |
| # move networks to gpu | |
| student, teacher = student.cuda(), teacher.cuda() | |
| # synchronize batch norms (if any) | |
| if utils.has_batchnorms(student): | |
| student = nn.SyncBatchNorm.convert_sync_batchnorm(student) | |
| teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher) | |
| # we need DDP wrapper to have synchro batch norms working... | |
| teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu]) | |
| teacher_without_ddp = teacher.module | |
| else: | |
| # teacher_without_ddp and teacher are the same thing | |
| teacher_without_ddp = teacher | |
| student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu]) | |
| # teacher and student start with the same weights | |
| teacher_without_ddp.load_state_dict(student.module.state_dict()) | |
| # there is no backpropagation through the teacher, so no need for gradients | |
| for p in teacher.parameters(): | |
| p.requires_grad = False | |
| print(f"Student and Teacher are built: they are both {args.arch} network.") | |
| # ============ preparing loss ... ============ | |
| dino_loss = DINOLoss( | |
| args.out_dim, | |
| args.local_crops_number + 2, # total number of crops = 2 global crops + local_crops_number | |
| args.warmup_teacher_temp, | |
| args.teacher_temp, | |
| args.warmup_teacher_temp_epochs, | |
| args.epochs, | |
| ).cuda() | |
| # ============ preparing optimizer ... ============ | |
| params_groups = utils.get_params_groups(student) | |
| if args.optimizer == "adamw": | |
| optimizer = torch.optim.AdamW(params_groups) # to use with ViTs | |
| elif args.optimizer == "sgd": | |
| optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler | |
| elif args.optimizer == "lars": | |
| optimizer = utils.LARS(params_groups) # to use with convnet and large batches | |
| # for mixed precision training | |
| fp16_scaler = None | |
| if args.use_fp16: | |
| fp16_scaler = torch.cuda.amp.GradScaler() | |
| # ============ init schedulers ... ============ | |
| lr_schedule = utils.cosine_scheduler( | |
| args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule | |
| args.min_lr, | |
| args.epochs, len(data_loader), | |
| warmup_epochs=args.warmup_epochs, | |
| ) | |
| wd_schedule = utils.cosine_scheduler( | |
| args.weight_decay, | |
| args.weight_decay_end, | |
| args.epochs, len(data_loader), | |
| ) | |
| # momentum parameter is increased to 1. during training with a cosine schedule | |
| momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1, | |
| args.epochs, len(data_loader)) | |
| print(f"Loss, optimizer and schedulers ready.") | |
| # ============ optionally resume training ... ============ | |
| to_restore = {"epoch": 0} | |
| utils.restart_from_checkpoint( | |
| os.path.join(args.output_dir, "checkpoint.pth"), | |
| run_variables=to_restore, | |
| student=student, | |
| teacher=teacher, | |
| optimizer=optimizer, | |
| fp16_scaler=fp16_scaler, | |
| dino_loss=dino_loss, | |
| ) | |
| start_epoch = to_restore["epoch"] | |
| start_time = time.time() | |
| print("Starting DINO training !") | |
| for epoch in range(start_epoch, args.epochs): | |
| data_loader.sampler.set_epoch(epoch) | |
| # ============ training one epoch of DINO ... ============ | |
| train_stats = train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, | |
| data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule, | |
| epoch, fp16_scaler, args) | |
| # ============ writing logs ... ============ | |
| save_dict = { | |
| 'student': student.state_dict(), | |
| 'teacher': teacher.state_dict(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'epoch': epoch + 1, | |
| 'args': args, | |
| 'dino_loss': dino_loss.state_dict(), | |
| } | |
| if fp16_scaler is not None: | |
| save_dict['fp16_scaler'] = fp16_scaler.state_dict() | |
| utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth')) | |
| if args.saveckp_freq and epoch % args.saveckp_freq == 0: | |
| utils.save_on_master(save_dict, os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth')) | |
| log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, | |
| 'epoch': epoch} | |
| if utils.is_main_process(): | |
| with (Path(args.output_dir) / "log.txt").open("a") as f: | |
| f.write(json.dumps(log_stats) + "\n") | |
| total_time = time.time() - start_time | |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
| print('Training time {}'.format(total_time_str)) | |
| def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader, | |
| optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch, | |
| fp16_scaler, args): | |
| metric_logger = utils.MetricLogger(delimiter=" ") | |
| header = 'Epoch: [{}/{}]'.format(epoch, args.epochs) | |
| for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)): | |
| # update weight decay and learning rate according to their schedule | |
| it = len(data_loader) * epoch + it # global training iteration | |
| for i, param_group in enumerate(optimizer.param_groups): | |
| param_group["lr"] = lr_schedule[it] | |
| if i == 0: # only the first group is regularized | |
| param_group["weight_decay"] = wd_schedule[it] | |
| # move images to gpu | |
| images = [im.cuda(non_blocking=True) for im in images] | |
| # teacher and student forward passes + compute dino loss | |
| with torch.cuda.amp.autocast(fp16_scaler is not None): | |
| teacher_output = teacher(images[:2]) # only the 2 global views pass through the teacher | |
| student_output = student(images) | |
| loss = dino_loss(student_output, teacher_output, epoch) | |
| if not math.isfinite(loss.item()): | |
| print("Loss is {}, stopping training".format(loss.item()), force=True) | |
| sys.exit(1) | |
| # student update | |
| optimizer.zero_grad() | |
| param_norms = None | |
| if fp16_scaler is None: | |
| loss.backward() | |
| if args.clip_grad: | |
| param_norms = utils.clip_gradients(student, args.clip_grad) | |
| utils.cancel_gradients_last_layer(epoch, student, | |
| args.freeze_last_layer) | |
| optimizer.step() | |
| else: | |
| fp16_scaler.scale(loss).backward() | |
| if args.clip_grad: | |
| fp16_scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place | |
| param_norms = utils.clip_gradients(student, args.clip_grad) | |
| utils.cancel_gradients_last_layer(epoch, student, | |
| args.freeze_last_layer) | |
| fp16_scaler.step(optimizer) | |
| fp16_scaler.update() | |
| # EMA update for the teacher | |
| with torch.no_grad(): | |
| m = momentum_schedule[it] # momentum parameter | |
| for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()): | |
| param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) | |
| # logging | |
| torch.cuda.synchronize() | |
| metric_logger.update(loss=loss.item()) | |
| metric_logger.update(lr=optimizer.param_groups[0]["lr"]) | |
| metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"]) | |
| # gather the stats from all processes | |
| metric_logger.synchronize_between_processes() | |
| print("Averaged stats:", metric_logger) | |
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} | |
| class DINOLoss(nn.Module): | |
| def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, | |
| warmup_teacher_temp_epochs, nepochs, student_temp=0.1, | |
| center_momentum=0.9): | |
| super().__init__() | |
| self.student_temp = student_temp | |
| self.center_momentum = center_momentum | |
| self.ncrops = ncrops | |
| self.register_buffer("center", torch.zeros(1, out_dim)) | |
| # we apply a warm up for the teacher temperature because | |
| # a too high temperature makes the training instable at the beginning | |
| self.teacher_temp_schedule = np.concatenate(( | |
| np.linspace(warmup_teacher_temp, | |
| teacher_temp, warmup_teacher_temp_epochs), | |
| np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp | |
| )) | |
| def forward(self, student_output, teacher_output, epoch): | |
| """ | |
| Cross-entropy between softmax outputs of the teacher and student networks. | |
| """ | |
| student_out = student_output / self.student_temp | |
| student_out = student_out.chunk(self.ncrops) | |
| # teacher centering and sharpening | |
| temp = self.teacher_temp_schedule[epoch] | |
| teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) | |
| teacher_out = teacher_out.detach().chunk(2) | |
| total_loss = 0 | |
| n_loss_terms = 0 | |
| for iq, q in enumerate(teacher_out): | |
| for v in range(len(student_out)): | |
| if v == iq: | |
| # we skip cases where student and teacher operate on the same view | |
| continue | |
| loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) | |
| total_loss += loss.mean() | |
| n_loss_terms += 1 | |
| total_loss /= n_loss_terms | |
| self.update_center(teacher_output) | |
| return total_loss | |
| def update_center(self, teacher_output): | |
| """ | |
| Update center used for teacher output. | |
| """ | |
| batch_center = torch.sum(teacher_output, dim=0, keepdim=True) | |
| dist.all_reduce(batch_center) | |
| batch_center = batch_center / (len(teacher_output) * dist.get_world_size()) | |
| # ema update | |
| self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) | |
| class DataAugmentationDINO(object): | |
| def __init__(self, global_crops_scale, local_crops_scale, local_crops_number): | |
| flip_and_color_jitter = transforms.Compose([ | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomApply( | |
| [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], | |
| p=0.8 | |
| ), | |
| transforms.RandomGrayscale(p=0.2), | |
| ]) | |
| normalize = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| ]) | |
| # first global crop | |
| self.global_transfo1 = transforms.Compose([ | |
| transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC), | |
| flip_and_color_jitter, | |
| utils.GaussianBlur(1.0), | |
| normalize, | |
| ]) | |
| # second global crop | |
| self.global_transfo2 = transforms.Compose([ | |
| transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC), | |
| flip_and_color_jitter, | |
| utils.GaussianBlur(0.1), | |
| utils.Solarization(0.2), | |
| normalize, | |
| ]) | |
| # transformation for the local small crops | |
| self.local_crops_number = local_crops_number | |
| self.local_transfo = transforms.Compose([ | |
| transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC), | |
| flip_and_color_jitter, | |
| utils.GaussianBlur(p=0.5), | |
| normalize, | |
| ]) | |
| def __call__(self, image): | |
| crops = [] | |
| crops.append(self.global_transfo1(image)) | |
| crops.append(self.global_transfo2(image)) | |
| for _ in range(self.local_crops_number): | |
| crops.append(self.local_transfo(image)) | |
| return crops | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser('DINO', parents=[get_args_parser()]) | |
| args = parser.parse_args() | |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| train_dino(args) | |