Spaces:
Sleeping
Sleeping
| from torchvision import models | |
| from modelguidedattacks.guides.instance_guide import InstanceGuide | |
| from modelguidedattacks.guides.unguided import Unguided | |
| from modelguidedattacks import losses | |
| from .cls_models.registry import get_model | |
| guide_model_registry = { | |
| "instance_guided": InstanceGuide, | |
| "unguided": Unguided | |
| } | |
| loss_registry = { | |
| "cvxproj": losses.CVXProjLoss, | |
| "cwk": losses.CWExtensionLoss, | |
| "ad": losses.AdversarialDistillationLoss | |
| } | |
| def setup_model(config, device): | |
| model = get_model(config.dataset, config.model, device) | |
| kwargs = {} | |
| if config.guide_model == "unguided": | |
| kwargs["iterations"] = config.unguided_iterations | |
| kwargs["lr"] = config.unguided_lr | |
| kwargs["loss_fn"] = loss_registry[config.loss] | |
| kwargs["binary_search_steps"] = config.binary_search_steps | |
| kwargs["topk_loss_coef_upper"] = config.topk_loss_coef_upper | |
| return guide_model_registry[config.guide_model](model, config, **kwargs) |