Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from torchvision.ops import MLP | |
| from .. import losses | |
| class InstanceGuide(nn.Module): | |
| def __init__(self, model: nn.Module, optimizer=torch.optim.AdamW, loss_fn=losses.CWExtensionLoss) -> None: | |
| super().__init__() | |
| self.guided = True | |
| self.model = model | |
| for p in self.model.parameters(): | |
| p.requires_grad_(False) | |
| self.loss = loss_fn() | |
| self.optimizer = optimizer | |
| self.epochs = 30 | |
| self.mlp_iterations = 5 | |
| self.perturbation_iterations = 5 | |
| def surject_perturbation(self, x): | |
| return x | |
| def forward(self, x, attack_targets): | |
| """ | |
| x: [B, channels, H, W] | |
| attack_targets: [B, K] | |
| """ | |
| B = x.shape[0] | |
| K = attack_targets.shape[-1] | |
| C = self.model.num_classes() | |
| with torch.no_grad(): | |
| pred_clean, feats = self.model(x, return_features=True) | |
| # We are assuming the clean predictions are ground truth since we make that | |
| # constraint on the dataset side | |
| attack_ground_truth = pred_clean.argmax(dim=-1) # [B] | |
| mlp = MLP(self.model.head_features(), | |
| [self.model.head_features()]*3 + [self.model.head_features()], | |
| activation_layer=nn.GELU, inplace=None).to(x.device) | |
| x_perturbation = nn.Parameter(torch.randn(x.shape, | |
| device=x.device)*1e-3) | |
| perturbation_optimizer = self.optimizer([x_perturbation], lr=1e-1) | |
| mlp_optimizer = self.optimizer(mlp.parameters(), lr=1e-3) | |
| logits_target_best = pred_clean | |
| feats_target_best = feats | |
| with torch.enable_grad(): | |
| for i in range(self.epochs): | |
| for _ in range(self.mlp_iterations): | |
| torch.cuda.synchronize() | |
| feature_offset = mlp(feats) | |
| feats_target_pred = feature_offset + feats | |
| logits_target_pred = self.model.head(feats_target_pred) | |
| # logits_target_pred = pred_logits | |
| pred_classes = logits_target_pred.argsort(dim=-1, descending=True) # [B, C] | |
| attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B] | |
| with torch.no_grad(): | |
| logits_target_best = torch.where( | |
| attack_successful[:, None].expand(-1, C), | |
| logits_target_pred, | |
| logits_target_best | |
| ) | |
| feats_target_best = torch.where( | |
| attack_successful[:, None].expand(-1, self.model.head_features()), | |
| feats_target_pred, | |
| feats_target_best | |
| ) | |
| mlp_loss = self.loss(logits_pred=logits_target_pred, | |
| prediction_feats=feats_target_pred, | |
| attack_targets=attack_targets, | |
| attack_ground_truth=attack_ground_truth, | |
| model=self.model) | |
| mlp_loss = mlp_loss.mean() + feature_offset.view(B, -1).norm(dim=-1, p=2)*1 | |
| mlp_optimizer.zero_grad() | |
| mlp_loss.backward() | |
| mlp_optimizer.step() | |
| feats_target_best = feats_target_best.detach() | |
| for _ in range(self.perturbation_iterations): | |
| x_perturbed = x + self.surject_perturbation(x_perturbation) | |
| prediction, perturbed_feats = self.model(x_perturbed, return_features=True) | |
| pred_classes = prediction.argsort(dim=-1, descending=True) # [B, C] | |
| attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B] | |
| perturbation_loss = (prediction - logits_target_best).view(B, -1).norm(dim=-1).mean() | |
| perturbation_optimizer.zero_grad() | |
| perturbation_loss.backward() | |
| perturbation_optimizer.step() | |
| return prediction |