Spaces:
Sleeping
Sleeping
| import sys | |
| from pprint import pformat | |
| from typing import Any | |
| import os | |
| import torch | |
| import ignite.distributed as idist | |
| import yaml | |
| from ignite.engine import Events | |
| from ignite.metrics import Accuracy, Loss | |
| from ignite.utils import manual_seed | |
| from torch import nn, optim | |
| from modelguidedattacks.data.setup import setup_data | |
| from modelguidedattacks.losses.boilerplate import BoilerplateLoss | |
| from modelguidedattacks.losses.energy import Energy, EnergyLoss | |
| from modelguidedattacks.metrics.topk_accuracy import TopKAccuracy | |
| from modelguidedattacks.models import setup_model | |
| from modelguidedattacks.trainers import setup_evaluator, setup_trainer | |
| from modelguidedattacks.utils import setup_parser, setup_output_dir | |
| from modelguidedattacks.utils import setup_logging, log_metrics, Engine | |
| def run(local_rank: int, config: Any): | |
| print ("Running ", local_rank) | |
| # make a certain seed | |
| rank = idist.get_rank() | |
| manual_seed(config.seed + rank) | |
| # create output folder | |
| config.output_dir = setup_output_dir(config, rank) | |
| # setup engines logger with python logging | |
| # print training configurations | |
| logger = setup_logging(config) | |
| logger.info("Configuration: \n%s", pformat(vars(config))) | |
| (config.output_dir / "config-lock.yaml").write_text(yaml.dump(config)) | |
| # donwload datasets and create dataloaders | |
| dataloader_train, dataloader_eval = setup_data(config, rank) | |
| # model, optimizer, loss function, device | |
| device = idist.device() | |
| model = idist.auto_model(setup_model(config, idist.device())) | |
| loss_fn = BoilerplateLoss().to(device=device) | |
| l2_energy_loss = Energy(p=2).to(device) | |
| l1_energy_loss = Energy(p=1).to(device) | |
| l_inf_energy_loss = Energy(p=torch.inf).to(device) | |
| evaluator = setup_evaluator(config, model, device) | |
| evaluator.logger = logger | |
| # attach metrics to evaluator | |
| accuracy = TopKAccuracy(device=device) | |
| metrics = { | |
| "ASR": accuracy, | |
| "L2 Energy": EnergyLoss(l2_energy_loss, device=device), | |
| "L1 Energy": EnergyLoss(l1_energy_loss, device=device), | |
| "L_inf Energy": EnergyLoss(l_inf_energy_loss, device=device), | |
| "L2 Energy Min": EnergyLoss(l2_energy_loss, reduction="min", device=device), | |
| "L1 Energy Min": EnergyLoss(l1_energy_loss, reduction="min", device=device), | |
| "L_inf Energy Min": EnergyLoss(l_inf_energy_loss, reduction="min", device=device), | |
| "L2 Energy Max": EnergyLoss(l2_energy_loss, reduction="max", device=device), | |
| "L1 Energy Max": EnergyLoss(l1_energy_loss, reduction="max", device=device), | |
| "L_inf Energy Max": EnergyLoss(l_inf_energy_loss, reduction="max", device=device) | |
| } | |
| for name, metric in metrics.items(): | |
| metric.attach(evaluator, name) | |
| if config.guide_model in ["unguided", "instance_guided"]: | |
| first_batch_passed = False | |
| early_stopped = False | |
| def compute_metrics(engine: Engine, tag: str): | |
| nonlocal first_batch_passed | |
| nonlocal early_stopped | |
| for name, metric in metrics.items(): | |
| metric.completed(engine, name) | |
| if not first_batch_passed: | |
| if engine.state.metrics["ASR"] < 1e-3: | |
| print ("Early stop, assuming no success throughout") | |
| early_stopped = True | |
| engine.terminate() | |
| else: | |
| first_batch_passed = True | |
| evaluator.add_event_handler( | |
| Events.ITERATION_COMPLETED(every=config.log_every_iters), | |
| compute_metrics, | |
| tag="eval", | |
| ) | |
| evaluator.add_event_handler( | |
| Events.ITERATION_COMPLETED(every=config.log_every_iters), | |
| log_metrics, | |
| tag="eval", | |
| ) | |
| evaluator.run(dataloader_eval, epoch_length=config.eval_epoch_length) | |
| log_metrics(evaluator, "eval") | |
| if len(config.out_dir) > 0: | |
| # Store results in out_dir | |
| os.makedirs(config.out_dir, exist_ok=True) | |
| metrics_dict = evaluator.state.metrics | |
| metrics_dict["config"] = config | |
| metrics_dict["early_stopped"] = early_stopped | |
| metrics_file_path = os.path.join(config.out_dir, "results.save") | |
| torch.save(metrics_dict, metrics_file_path) | |
| # No need to train with an unguided model | |
| return | |
| assert False, "This code path is for the future" | |
| # main entrypoint | |
| def launch(config=None): | |
| if config is None: | |
| config_path = sys.argv[1] | |
| config = setup_parser(config_path).parse_args(sys.argv[2:]) | |
| backend = config.backend | |
| nproc_per_node = config.nproc_per_node | |
| if nproc_per_node == 0 or backend is None: | |
| backend = None | |
| nproc_per_node = None | |
| with idist.Parallel(backend, nproc_per_node) as p: | |
| p.run(run, config=config) | |
| if __name__ == "__main__": | |
| launch() | |