Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib.cm | |
| from PIL import Image | |
| # Adapted from: https://colab.research.google.com/github/kevinzakka/clip_playground/blob/main/CLIP_GradCAM_Visualization.ipynb | |
| class Hook: | |
| """Attaches to a module and records its activations and gradients.""" | |
| def __init__(self, module: nn.Module): | |
| self.data = None | |
| self.hook = module.register_forward_hook(self.save_grad) | |
| def save_grad(self, module, input, output): | |
| self.data = output | |
| output.requires_grad_(True) | |
| output.retain_grad() | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_value, exc_traceback): | |
| self.hook.remove() | |
| def activation(self) -> torch.Tensor: | |
| return self.data | |
| def gradient(self) -> torch.Tensor: | |
| return self.data.grad | |
| # Reference: https://arxiv.org/abs/1610.02391 | |
| def gradCAM( | |
| model: nn.Module, | |
| input: torch.Tensor, | |
| target: torch.Tensor, | |
| layer: nn.Module | |
| ) -> torch.Tensor: | |
| # Zero out any gradients at the input. | |
| if input.grad is not None: | |
| input.grad.data.zero_() | |
| # Disable gradient settings. | |
| requires_grad = {} | |
| for name, param in model.named_parameters(): | |
| requires_grad[name] = param.requires_grad | |
| param.requires_grad_(False) | |
| # Attach a hook to the model at the desired layer. | |
| assert isinstance(layer, nn.Module) | |
| with Hook(layer) as hook: | |
| # Do a forward and backward pass. | |
| output = model(input) | |
| output.backward(target) | |
| grad = hook.gradient.float() | |
| act = hook.activation.float() | |
| # Global average pool gradient across spatial dimension | |
| # to obtain importance weights. | |
| alpha = grad.mean(dim=(2, 3), keepdim=True) | |
| # Weighted combination of activation maps over channel | |
| # dimension. | |
| gradcam = torch.sum(act * alpha, dim=1, keepdim=True) | |
| # We only want neurons with positive influence so we | |
| # clamp any negative ones. | |
| gradcam = torch.clamp(gradcam, min=0) | |
| # Resize gradcam to input resolution. | |
| gradcam = F.interpolate( | |
| gradcam, | |
| input.shape[2:], | |
| mode='bicubic', | |
| align_corners=False) | |
| # Restore gradient settings. | |
| for name, param in model.named_parameters(): | |
| param.requires_grad_(requires_grad[name]) | |
| return gradcam | |
| # Modified from: https://github.com/salesforce/ALBEF/blob/main/visualization.ipynb | |
| def getAttMap(img, attn_map): | |
| # Normalize attention map | |
| attn_map = attn_map - attn_map.min() | |
| if attn_map.max() > 0: | |
| attn_map = attn_map / attn_map.max() | |
| H = matplotlib.cm.jet(attn_map) | |
| H = (H * 255).astype(np.uint8)[:, :, :3] | |
| img_heatmap = Image.fromarray(H) | |
| img_heatmap = img_heatmap.resize((256, 256)) | |
| return Image.blend( | |
| img.resize((256, 256)), img_heatmap, 0.4) | |