Commit
·
e7d7e74
1
Parent(s):
367e473
Removed MMCV dependency
Browse files- .gitattributes +1 -0
- README.md +33 -50
- assets/pikachu_seg.png +3 -0
- hf_demo.ipynb +0 -0
- hf_model/__init__.py +0 -0
- hf_model/hooks.py +52 -0
- hf_model/masker.py +246 -0
- hf_model/model.py +757 -0
- hf_model/modules.py +243 -0
- hf_model/pamr.py +146 -0
- hf_model/talk2dino.py +432 -0
- hf_model/templates.py +148 -0
- hf_model/us.py +119 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -31,7 +31,7 @@ Talking to DINO: Bridging Self-Supervised Vision Backbones with Language for Ope
|
|
| 31 |
|
| 32 |
<div align="center">
|
| 33 |
<figure>
|
| 34 |
-
<img alt="Overview of Talk2DINO" src="./assets/overview.png" width="
|
| 35 |
</figure>
|
| 36 |
</div>
|
| 37 |
|
|
@@ -43,75 +43,58 @@ Open-Vocabulary Segmentation (OVS) aims at segmenting images from free-form text
|
|
| 43 |
### Mapping CLIP Text Embeddings to DINOv2 space with Talk2DINO
|
| 44 |
We can use Talk2DINO to map CLIP text embeddings into the DINOv2 patch embedding space.
|
| 45 |
```python
|
| 46 |
-
import
|
| 47 |
-
from
|
| 48 |
-
|
| 49 |
-
import os
|
| 50 |
# Device setup
|
| 51 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 52 |
-
# Configuration and weights
|
| 53 |
-
proj_name = 'vitb_mlp_infonce'
|
| 54 |
-
config_path = os.path.join("configs", f"{proj_name}.yaml")
|
| 55 |
-
weights_path = os.path.join("weights", f"{proj_name}.pth")
|
| 56 |
-
# Load Talk2DINO projection layer
|
| 57 |
-
talk2dino = ProjectionLayer.from_config(config_path)
|
| 58 |
-
talk2dino.load_state_dict(torch.load(weights_path, map_location=device))
|
| 59 |
-
talk2dino.to(device)
|
| 60 |
-
# Load CLIP model
|
| 61 |
-
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device, jit=False)
|
| 62 |
-
tokenizer = clip.tokenize
|
| 63 |
-
# Example: Tokenize and project text features
|
| 64 |
-
texts = ["a cat"]
|
| 65 |
-
text_tokens = tokenizer(texts).to(device)
|
| 66 |
-
text_features = clip_model.encode_text(text_tokens)
|
| 67 |
-
projected_text_features = talk2dino.project_clip_txt(text_features)
|
| 68 |
-
```
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
```
|
| 81 |
|
|
|
|
|
|
|
| 82 |
Result:
|
| 83 |
<div align="center">
|
| 84 |
<table><tr><td><figure>
|
| 85 |
<img alt="" src="./assets/pikachu.png" width=300>
|
| 86 |
</figure></td><td><figure>
|
| 87 |
-
<img alt="" src="./pikachu_seg.png" width=300>
|
| 88 |
</figure></td></tr></table>
|
| 89 |
</div>
|
| 90 |
|
| 91 |
## Installation
|
|
|
|
|
|
|
|
|
|
| 92 |
```bash
|
| 93 |
-
#
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
# Install CUDA toolkit and cuDNN
|
| 99 |
-
conda install -c nvidia/label/cuda-11.7.0 cuda
|
| 100 |
-
conda install -c nvidia/label/cuda-11.7.0 cuda-nvcc
|
| 101 |
-
conda install -c conda-forge cudnn cudatoolkit=11.7.0
|
| 102 |
-
# Install PyTorch 2.1 with CUDA 11.8 support
|
| 103 |
-
# Note: This is crucial, as it matches the requirements of mmcv-full 1.7.2
|
| 104 |
-
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
|
| 105 |
-
# Install other dependencies
|
| 106 |
pip install -r requirements.txt
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
pip install mmcv-full==1.7.2 -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.1.0/index.html
|
| 111 |
-
# Install mmsegmentation
|
| 112 |
-
pip install mmsegmentation==0.30.0
|
| 113 |
```
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
<details>
|
| 116 |
<summary>Qualitative Results</summary>
|
| 117 |
|
|
|
|
| 31 |
|
| 32 |
<div align="center">
|
| 33 |
<figure>
|
| 34 |
+
<img alt="Overview of Talk2DINO" src="./assets/overview.png" width="90%">
|
| 35 |
</figure>
|
| 36 |
</div>
|
| 37 |
|
|
|
|
| 43 |
### Mapping CLIP Text Embeddings to DINOv2 space with Talk2DINO
|
| 44 |
We can use Talk2DINO to map CLIP text embeddings into the DINOv2 patch embedding space.
|
| 45 |
```python
|
| 46 |
+
from hf_model.talk2dino import Talk2DINO
|
| 47 |
+
from torchvision.io import read_image
|
| 48 |
+
|
|
|
|
| 49 |
# Device setup
|
| 50 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
+
# Model Loading
|
| 53 |
+
model = Talk2DINO.from_pretrained("lorebianchi98/Talk2DINO-ViTL").to(device).eval()
|
| 54 |
|
| 55 |
+
# Embedding generation
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
text_embed = model.encode_text("a pikachu")
|
| 58 |
+
image_embed = model.encode_image(image)
|
| 59 |
|
| 60 |
+
# normalize the features to perform cosine similarity
|
| 61 |
+
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
| 62 |
+
image_embed = image_embed / image_embed.norm(dim=-1, keepdim=True)
|
| 63 |
+
|
| 64 |
+
similarity = (image_embed @ text_embed.T).squeeze(0, -1).cpu().numpy()
|
| 65 |
```
|
| 66 |
|
| 67 |
+
### Demo
|
| 68 |
+
In `demo.ipynb` we provide a simple example on how to use Talk2DINO for inference on a given image with custom textual categories.
|
| 69 |
Result:
|
| 70 |
<div align="center">
|
| 71 |
<table><tr><td><figure>
|
| 72 |
<img alt="" src="./assets/pikachu.png" width=300>
|
| 73 |
</figure></td><td><figure>
|
| 74 |
+
<img alt="" src="./assets/pikachu_seg.png" width=300>
|
| 75 |
</figure></td></tr></table>
|
| 76 |
</div>
|
| 77 |
|
| 78 |
## Installation
|
| 79 |
+
|
| 80 |
+
To use the **Hugging Face interface** for inference:
|
| 81 |
+
|
| 82 |
```bash
|
| 83 |
+
# Clone the repository
|
| 84 |
+
git clone https://huggingface.co/lorebianchi98/Talk2DINO-ViTL
|
| 85 |
+
cd Talk2DINO-ViTL
|
| 86 |
+
|
| 87 |
+
# Install dependencies
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
pip install -r requirements.txt
|
| 89 |
+
|
| 90 |
+
# Install PyTorch and torchvision with the appropriate CUDA version
|
| 91 |
+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
|
|
|
|
|
|
|
|
|
|
| 92 |
```
|
| 93 |
|
| 94 |
+
> For the **full MMCV interface** to perform evaluation on segmentation benchmarks, please refer to the [original Talk2DINO repository](https://github.com/lorebianchi98/Talk2DINO).
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
<details>
|
| 99 |
<summary>Qualitative Results</summary>
|
| 100 |
|
assets/pikachu_seg.png
ADDED
|
Git LFS Details
|
hf_demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
hf_model/__init__.py
ADDED
|
File without changes
|
hf_model/hooks.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
feats = {}
|
| 3 |
+
def get_self_attention(module, input, output):
|
| 4 |
+
feats['self_attn'] = output
|
| 5 |
+
|
| 6 |
+
def process_self_attention(output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False):
|
| 7 |
+
qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4)
|
| 8 |
+
q, k, v = qkv[0] * scale, qkv[1], qkv[2]
|
| 9 |
+
attn = q @ k.transpose(-2, -1)
|
| 10 |
+
self_attn_maps = attn[:, : , 0, num_global_tokens:]
|
| 11 |
+
self_attn = self_attn_maps.mean(dim=1)
|
| 12 |
+
self_attn = self_attn.softmax(dim=-1)
|
| 13 |
+
if ret_self_attn_maps:
|
| 14 |
+
return self_attn, self_attn_maps
|
| 15 |
+
else:
|
| 16 |
+
return self_attn
|
| 17 |
+
|
| 18 |
+
def get_vit_out(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 19 |
+
feats['vit_out'] = output
|
| 20 |
+
|
| 21 |
+
def get_second_last_out(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 22 |
+
feats['second_last_out'] = output
|
| 23 |
+
|
| 24 |
+
def get_all_out_tokens(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 25 |
+
feats['clip_txt_out_tokens'] = output
|
| 26 |
+
|
| 27 |
+
def get_clip_second_last_dense_out(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 28 |
+
feats['clip_second_last_out'] = output.permute(1,0,2)
|
| 29 |
+
|
| 30 |
+
def get_dinov1_patches(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 31 |
+
feats['dinov1_patches'] = output
|
| 32 |
+
|
| 33 |
+
def get_all_out_tokens(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 34 |
+
feats['clip_txt_out_tokens'] = output
|
| 35 |
+
|
| 36 |
+
def average_text_tokens(text_embeddings, mask, keep_cls=False, keep_end_seq=False):
|
| 37 |
+
if not keep_end_seq:
|
| 38 |
+
mask[torch.arange(mask.shape[0]), mask.sum(dim=1) - 1] = False # excluding end of sequence
|
| 39 |
+
if not keep_cls:
|
| 40 |
+
mask[:, 0] = False # excluding CLS token
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
masked_embeddings = text_embeddings * mask.unsqueeze(-1) # shape: [BS, SEQ_LEN, 512]
|
| 44 |
+
|
| 45 |
+
sum_embeddings = masked_embeddings.sum(dim=1) # shape: [BS, 512]
|
| 46 |
+
|
| 47 |
+
valid_elements = mask.sum(dim=1, keepdim=True) # shape: [BS, 1]
|
| 48 |
+
|
| 49 |
+
mean_embeddings = sum_embeddings / valid_elements # shape: [BS, 512]
|
| 50 |
+
|
| 51 |
+
return mean_embeddings
|
| 52 |
+
|
hf_model/masker.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Talk2DINO
|
| 3 |
+
# ------------------------------------------------------------------------------
|
| 4 |
+
import copy
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import hf_model.us as us
|
| 12 |
+
from einops import rearrange, repeat
|
| 13 |
+
|
| 14 |
+
# from models.dinotext.gumbel import gumbel_sigmoid
|
| 15 |
+
from hf_model.modules import FeatureEncoder
|
| 16 |
+
from omegaconf import OmegaConf
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_model(config):
|
| 20 |
+
model = OmegaConf.to_container(config, resolve=True)
|
| 21 |
+
return model
|
| 22 |
+
|
| 23 |
+
class Sim2Mask(nn.Module):
|
| 24 |
+
def __init__(self, init_w=1.0, init_b=0.0, gumbel_tau=1.0, learnable=True):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.init_w = init_w
|
| 27 |
+
self.init_b = init_b
|
| 28 |
+
self.gumbel_tau = gumbel_tau
|
| 29 |
+
self.learnable = learnable
|
| 30 |
+
|
| 31 |
+
assert not ((init_w is None) ^ (init_b is None))
|
| 32 |
+
if learnable:
|
| 33 |
+
self.w = nn.Parameter(torch.full([], float(init_w)))
|
| 34 |
+
self.b = nn.Parameter(torch.full([], float(init_b)))
|
| 35 |
+
else:
|
| 36 |
+
self.w = init_w
|
| 37 |
+
self.b = init_b
|
| 38 |
+
|
| 39 |
+
def forward(self, x, deterministic=False):
|
| 40 |
+
logits = x * self.w + self.b
|
| 41 |
+
|
| 42 |
+
soft_mask = torch.sigmoid(logits)
|
| 43 |
+
if deterministic:
|
| 44 |
+
hard_mask = soft_mask.gt(0.5).type(logits.dtype)
|
| 45 |
+
else:
|
| 46 |
+
hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau)
|
| 47 |
+
|
| 48 |
+
return hard_mask, soft_mask
|
| 49 |
+
|
| 50 |
+
def extra_repr(self):
|
| 51 |
+
return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}'
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MaskerBackbone(nn.Module):
|
| 55 |
+
"""Masker image encoder backbone.
|
| 56 |
+
"""
|
| 57 |
+
def __init__(self, clip_visual, freeze_idx):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.transformer = copy.deepcopy(clip_visual.transformer)
|
| 60 |
+
self.transformer.resblocks = self.transformer.resblocks[freeze_idx:]
|
| 61 |
+
|
| 62 |
+
for block in self.transformer.resblocks:
|
| 63 |
+
if hasattr(block, "hook_handler"):
|
| 64 |
+
block.hook_handler.remove()
|
| 65 |
+
|
| 66 |
+
self.ln_post = copy.deepcopy(clip_visual.ln_post)
|
| 67 |
+
self.proj = copy.deepcopy(clip_visual.proj)
|
| 68 |
+
|
| 69 |
+
self.layers = len(self.transformer.resblocks)
|
| 70 |
+
self.patch_size = clip_visual.patch_size
|
| 71 |
+
|
| 72 |
+
self.output_dim = clip_visual.output_dim if self.proj is not None else clip_visual.width
|
| 73 |
+
|
| 74 |
+
def forward(self, x, spatial=True, ignore_last_attn=True):
|
| 75 |
+
if self.layers:
|
| 76 |
+
x = self.transformer(x, ignore_last_attn=ignore_last_attn)
|
| 77 |
+
|
| 78 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 79 |
+
|
| 80 |
+
if spatial:
|
| 81 |
+
x = self.ln_post(x)
|
| 82 |
+
else:
|
| 83 |
+
x = self.ln_post(x[:, 0, :])
|
| 84 |
+
|
| 85 |
+
if self.proj is not None:
|
| 86 |
+
x = x @ self.proj
|
| 87 |
+
|
| 88 |
+
return x
|
| 89 |
+
|
| 90 |
+
class MaskerImageFeatureEncoder(FeatureEncoder):
|
| 91 |
+
def __init__(self, backbone: nn.Module, decoder: nn.Module, ignore_last_attn: bool = True):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.ignore_last_attn = ignore_last_attn
|
| 94 |
+
self.patch_size = backbone.patch_size
|
| 95 |
+
self.backbone = backbone
|
| 96 |
+
self.decoder = decoder
|
| 97 |
+
|
| 98 |
+
for resblock in self.backbone.transformer.resblocks:
|
| 99 |
+
resblock.hook_handler = resblock.register_forward_hook(self.hook)
|
| 100 |
+
|
| 101 |
+
def _encode(self, image, image_feat):
|
| 102 |
+
H, W = image.shape[-2:]
|
| 103 |
+
h = H // self.patch_size
|
| 104 |
+
w = W // self.patch_size
|
| 105 |
+
|
| 106 |
+
x = self.backbone(image_feat, spatial=True, ignore_last_attn=self.ignore_last_attn) # BLC
|
| 107 |
+
x = rearrange(x[:, 1:], "B (H W) C -> B C H W", H=h, W=w)
|
| 108 |
+
x = self.decoder(x)
|
| 109 |
+
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
class Masker(nn.Module):
|
| 113 |
+
def __init__(self, backbone, decoder, image_proj, sim2mask, ignore_last_attn, **kwargs):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.ignore_last_attn = ignore_last_attn
|
| 116 |
+
|
| 117 |
+
decoder["C"] = backbone.output_dim
|
| 118 |
+
decoder = MODELS.build(decoder)
|
| 119 |
+
decoder = nn.Sequential(OrderedDict([
|
| 120 |
+
("decoder", decoder),
|
| 121 |
+
("image_proj", image_proj)
|
| 122 |
+
]))
|
| 123 |
+
|
| 124 |
+
self.image_encoder = MaskerImageFeatureEncoder(backbone, decoder, ignore_last_attn=ignore_last_attn)
|
| 125 |
+
|
| 126 |
+
self.sim2mask = Sim2Mask(**sim2mask)
|
| 127 |
+
|
| 128 |
+
def forward(self, image, image_feat, text_emb, deterministic=False):
|
| 129 |
+
B = image.size(0)
|
| 130 |
+
image_emb, feats = self.image_encoder(image, image_feat, ret_feats=True) # [BCHW]
|
| 131 |
+
|
| 132 |
+
image_emb_norm = us.normalize(image_emb, dim=1)
|
| 133 |
+
text_emb_norm = us.normalize(text_emb, dim=-1)
|
| 134 |
+
|
| 135 |
+
H, W = image_emb.shape[2:]
|
| 136 |
+
D = dist.get_world_size()
|
| 137 |
+
|
| 138 |
+
# simmap [B, B*D, H, W] where D is #devices
|
| 139 |
+
all_text_emb_norm = us.gather_cat(text_emb_norm, grad=True, contiguous_grad=True)
|
| 140 |
+
simmap = torch.einsum("bchw,nc->bnhw", image_emb_norm, all_text_emb_norm)
|
| 141 |
+
mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
|
| 142 |
+
|
| 143 |
+
# mask [B, B*D, H, W] where D is #devices
|
| 144 |
+
# positive global label
|
| 145 |
+
pos_indices = torch.arange(B, dtype=torch.long, device=image_emb.device) + B * dist.get_rank()
|
| 146 |
+
pos_mask = mask[torch.arange(B), pos_indices].unsqueeze(1) # [B, 1, H, W]
|
| 147 |
+
|
| 148 |
+
offdiag = torch.ones(B, B*D, dtype=torch.bool, device=mask.device)
|
| 149 |
+
offdiag[torch.arange(B), pos_indices] = False
|
| 150 |
+
|
| 151 |
+
soft_pos_mask = soft_mask[torch.arange(B), pos_indices].unsqueeze(1)
|
| 152 |
+
soft_neg_mask = soft_mask.masked_select(offdiag[..., None, None]).view(B, B*D-1, H, W)
|
| 153 |
+
|
| 154 |
+
masks = {
|
| 155 |
+
"pos": pos_mask, # [B, 1, H, W]
|
| 156 |
+
|
| 157 |
+
"soft_pos": soft_pos_mask,
|
| 158 |
+
"soft_neg": soft_neg_mask,
|
| 159 |
+
"soft_all": soft_mask, # [B, N, H, W]
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
return masks, image_emb, text_emb, feats
|
| 163 |
+
|
| 164 |
+
@torch.no_grad()
|
| 165 |
+
def forward_seg(self, image, image_feat, text_emb, deterministic=True, hard=False):
|
| 166 |
+
"""Make mask by 1:N matching
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
image [B, 3, H, W]
|
| 170 |
+
image_feat [L, B, C]: CLIP features
|
| 171 |
+
text_emb [N, C]
|
| 172 |
+
deterministic (bool): deterministic inference flag for gumbel noise
|
| 173 |
+
hard (bool): decide hard or soft returning segmentation mask.
|
| 174 |
+
Note that soft mask is required for proper evaluation
|
| 175 |
+
|
| 176 |
+
Return:
|
| 177 |
+
mask [B, N, H', W'] (H' and W' are downsampled H/W)
|
| 178 |
+
"""
|
| 179 |
+
image_emb = self.image_encoder(image, image_feat) # [BCHW]
|
| 180 |
+
|
| 181 |
+
image_emb = us.normalize(image_emb, dim=1) # BCHW
|
| 182 |
+
text_emb = us.normalize(text_emb, dim=-1) # NC
|
| 183 |
+
|
| 184 |
+
simmap = torch.einsum("b c h w, n c -> b n h w", image_emb, text_emb)
|
| 185 |
+
|
| 186 |
+
hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
|
| 187 |
+
mask = hard_mask if hard else soft_mask
|
| 188 |
+
|
| 189 |
+
return mask, simmap
|
| 190 |
+
|
| 191 |
+
class DINOTextMasker(nn.Module):
|
| 192 |
+
def __init__(self, similarity_type="cosine"):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.sim2mask = DINOTextSim2Mask()
|
| 195 |
+
self.sim2mask = self.sim2mask.eval()
|
| 196 |
+
self.similarity_type = similarity_type
|
| 197 |
+
|
| 198 |
+
def forward(self, image, image_feat, text_emb, deterministic=False):
|
| 199 |
+
pass
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def forward_seg(self, image_feat, text_emb, deterministic=True, hard=False):
|
| 203 |
+
"""Make mask by 1:N matching
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
image [B, 3, H, W]
|
| 207 |
+
image_feat [L, B, C]: CLIP features
|
| 208 |
+
text_emb [N, K, C]
|
| 209 |
+
deterministic (bool): deterministic inference flag for gumbel noise
|
| 210 |
+
hard (bool): decide hard or soft returning segmentation mask.
|
| 211 |
+
Note that soft mask is required for proper evaluation
|
| 212 |
+
use_k_nn (bool): use kNN to segment
|
| 213 |
+
k_nn (int): number of nearest neighbors for kNN segmentation
|
| 214 |
+
|
| 215 |
+
Return:
|
| 216 |
+
mask [B, N, H', W'] (H' and W' are downsampled H/W)
|
| 217 |
+
"""
|
| 218 |
+
b, c, h, w = image_feat.shape
|
| 219 |
+
n, c = text_emb.shape
|
| 220 |
+
|
| 221 |
+
if self.similarity_type == "cosine":
|
| 222 |
+
image_feat = us.normalize(image_feat, dim=1) # BCHW
|
| 223 |
+
# text_emb = us.normalize(text_emb, dim=-1) # NKC
|
| 224 |
+
simmap = torch.einsum("b c h w, n c -> b n h w", image_feat, text_emb)
|
| 225 |
+
else:
|
| 226 |
+
raise NotImplementedError("similarity type {} not implemented".format(self.similarity_type))
|
| 227 |
+
|
| 228 |
+
hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
|
| 229 |
+
mask = hard_mask if hard else soft_mask
|
| 230 |
+
|
| 231 |
+
return mask, simmap
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class DINOTextSim2Mask(nn.Module):
|
| 235 |
+
def __init__(self, gumbel_tau=1.0):
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.gumbel_tau = gumbel_tau
|
| 238 |
+
|
| 239 |
+
def forward(self, x, deterministic=False):
|
| 240 |
+
soft_mask = torch.sigmoid(x)
|
| 241 |
+
if deterministic:
|
| 242 |
+
hard_mask = soft_mask.gt(0.5).type(x.dtype)
|
| 243 |
+
else:
|
| 244 |
+
hard_mask = gumbel_sigmoid(x, hard=True, tau=self.gumbel_tau)
|
| 245 |
+
|
| 246 |
+
return hard_mask, soft_mask
|
hf_model/model.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import clip
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from hf_model.hooks import get_self_attention, process_self_attention, feats
|
| 8 |
+
|
| 9 |
+
class VisualProjectionLayer(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Creates a projection layer on top of the DINO encoder.
|
| 12 |
+
The forward method calculate the similarity between the projected DINO token and the CLIP textual CLS token.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, act=nn.Tanh(), hidden_layer=False, cosine=True, hidden_embed_dim=None, dino_embed_dim=1024, clip_embed_dim=512):
|
| 15 |
+
# mlp_dims list of mlp dimensions
|
| 16 |
+
super().__init__()
|
| 17 |
+
if hidden_embed_dim is None:
|
| 18 |
+
hidden_embed_dim = clip_embed_dim
|
| 19 |
+
|
| 20 |
+
self.linear_layer = nn.Linear(dino_embed_dim, hidden_embed_dim)
|
| 21 |
+
if hidden_layer:
|
| 22 |
+
self.linear_layer2 = nn.Linear(hidden_embed_dim, clip_embed_dim)
|
| 23 |
+
self.act = act
|
| 24 |
+
self.cosine = cosine
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def from_config(cls, config):
|
| 28 |
+
if type(config) is str:
|
| 29 |
+
# if the configuration is a string, we treat it as a file path
|
| 30 |
+
with open(config, 'r') as f:
|
| 31 |
+
config = yaml.safe_load(f)['model']
|
| 32 |
+
|
| 33 |
+
# loading the activation function
|
| 34 |
+
act = config.get('act', None)
|
| 35 |
+
if act == 'tanh':
|
| 36 |
+
act = nn.Tanh()
|
| 37 |
+
elif act == 'relu':
|
| 38 |
+
act = nn.ReLU()
|
| 39 |
+
elif act == 'sigmoid':
|
| 40 |
+
act = nn.Sigmoid()
|
| 41 |
+
elif act is not None:
|
| 42 |
+
raise Exception("Unknown activation function")
|
| 43 |
+
|
| 44 |
+
model = cls(
|
| 45 |
+
act=act,
|
| 46 |
+
hidden_layer=config.get('hidden_layer', False),
|
| 47 |
+
cosine=config.get('cosine', True),
|
| 48 |
+
hidden_embed_dim=config.get('hidden_embed_dim', None) if config.get('hidden_layer', False) else None,
|
| 49 |
+
dino_embed_dim=config.get('dino_embed_dim', 1024),
|
| 50 |
+
clip_embed_dim=config.get('clip_embed_dim', 512)
|
| 51 |
+
|
| 52 |
+
)
|
| 53 |
+
return model
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def forward(self, visual_embedding, textual_embedding, ret_similarity_matrix=True, ret_embeds=False):
|
| 57 |
+
visual_embedding = self.project_dino(visual_embedding)
|
| 58 |
+
textual_embedding = textual_embedding.float()
|
| 59 |
+
|
| 60 |
+
if self.cosine:
|
| 61 |
+
textual_embedding = F.normalize(textual_embedding, p=2, dim=1)
|
| 62 |
+
visual_embedding = F.normalize(visual_embedding, p=2, dim=1)
|
| 63 |
+
if ret_embeds:
|
| 64 |
+
return textual_embedding, visual_embedding
|
| 65 |
+
x = textual_embedding @ visual_embedding.transpose(1, 0)
|
| 66 |
+
if not ret_similarity_matrix:
|
| 67 |
+
x = x[torch.eye(len(x)) > 0.5] # only diagonal elements
|
| 68 |
+
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
def project_dino(self, visual_embedding):
|
| 72 |
+
visual_embedding = visual_embedding.float()
|
| 73 |
+
|
| 74 |
+
x = self.linear_layer(visual_embedding)
|
| 75 |
+
if self.act:
|
| 76 |
+
x = self.act(x)
|
| 77 |
+
if hasattr(self, 'linear_layer2'):
|
| 78 |
+
x = self.linear_layer2(x)
|
| 79 |
+
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
def __len__(self):
|
| 83 |
+
return sum(p.numel() for p in self.parameters())
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ProjectionLayer(nn.Module):
|
| 88 |
+
"""
|
| 89 |
+
Creates a projection layer on top of the CLIP-text encoder.
|
| 90 |
+
The forward method calculate the similarity between the DINO CLS token and the projected CLIP textual CLS token.
|
| 91 |
+
"""
|
| 92 |
+
def __init__(self, act=nn.Tanh(), hidden_layer=False, cosine=True, dino_embed_dim=1024, clip_embed_dim=512, num_attn_head=16, weight_attn_heads=None,
|
| 93 |
+
alignment_strategy='max_score', alpha=0.6, keep_cls=False, keep_end_seq=False):
|
| 94 |
+
# mlp_dims list of mlp dimensions
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.num_attn_head = num_attn_head
|
| 97 |
+
|
| 98 |
+
self.linear_layer = nn.Linear(clip_embed_dim, dino_embed_dim)
|
| 99 |
+
if hidden_layer:
|
| 100 |
+
hidden_layer = 1 if hidden_layer is True else hidden_layer # ensuring compatibility with old code
|
| 101 |
+
# self.linear_layer2 = nn.Linear(dino_embed_dim, dino_embed_dim)
|
| 102 |
+
self.hidden_layers = nn.ModuleList([nn.Linear(dino_embed_dim, dino_embed_dim) for _ in range(hidden_layer)])
|
| 103 |
+
self.act = act
|
| 104 |
+
self.cosine = cosine
|
| 105 |
+
|
| 106 |
+
self.weight_attn_heads = weight_attn_heads
|
| 107 |
+
if weight_attn_heads == 'static':
|
| 108 |
+
self.attn_weights = nn.Parameter(torch.rand(self.num_attn_head))
|
| 109 |
+
elif weight_attn_heads == 'conditioned':
|
| 110 |
+
self.weight_layer1 = nn.Linear(dino_embed_dim, dino_embed_dim)
|
| 111 |
+
self.weight_layer2 = nn.Linear(dino_embed_dim, self.num_attn_head)
|
| 112 |
+
|
| 113 |
+
self.alignment_strategy = alignment_strategy # relevant only if we use disentangled_self_attn
|
| 114 |
+
self.keep_cls = keep_cls # relevant only if we use clip_txt_tokens_out
|
| 115 |
+
self.keep_end_seq = keep_end_seq # relevant only if we use clip_txt_tokens_out
|
| 116 |
+
self.alpha = alpha
|
| 117 |
+
|
| 118 |
+
@classmethod
|
| 119 |
+
def from_config(cls, config):
|
| 120 |
+
if type(config) is str:
|
| 121 |
+
# if the configuration is a string, we treat it as a file path
|
| 122 |
+
with open(config, 'r') as f:
|
| 123 |
+
config = yaml.safe_load(f)['model']
|
| 124 |
+
|
| 125 |
+
# loading the activation function
|
| 126 |
+
act = config.get('act', None)
|
| 127 |
+
if act == 'tanh':
|
| 128 |
+
act = nn.Tanh()
|
| 129 |
+
elif act == 'relu':
|
| 130 |
+
act = nn.ReLU()
|
| 131 |
+
elif act == 'sigmoid':
|
| 132 |
+
act = nn.Sigmoid()
|
| 133 |
+
elif act is not None:
|
| 134 |
+
raise Exception("Unknown activation function")
|
| 135 |
+
|
| 136 |
+
model = cls(
|
| 137 |
+
act=act,
|
| 138 |
+
hidden_layer=config.get('hidden_layer', False),
|
| 139 |
+
cosine=config.get('cosine', True),
|
| 140 |
+
dino_embed_dim=config.get('dino_embed_dim', 1024),
|
| 141 |
+
num_attn_head=config.get('num_attn_head', 16),
|
| 142 |
+
clip_embed_dim=config.get('clip_embed_dim', 512),
|
| 143 |
+
weight_attn_heads=config.get('weight_attn_heads', None),
|
| 144 |
+
alignment_strategy=config.get('alignment_strategy', 'max_score'),
|
| 145 |
+
alpha=config.get('alpha', 0.6),
|
| 146 |
+
keep_cls=config.get('keep_cls', None),
|
| 147 |
+
keep_end_seq=config.get('keep_end_seq', None),
|
| 148 |
+
)
|
| 149 |
+
if config.get('starting_checkpoint', None) is not None:
|
| 150 |
+
model.load_state_dict(torch.load(config['starting_checkpoint'], 'cpu'))
|
| 151 |
+
|
| 152 |
+
return model
|
| 153 |
+
|
| 154 |
+
def compute_similarity(self, visual_embedding, textual_embedding, text_input_mask=None, return_index=False):
|
| 155 |
+
if len(visual_embedding.shape) == 3 or len(textual_embedding.shape) == 3:
|
| 156 |
+
# at least one embedding is decomposed: either we have all textual tokens or we have all the attention head tokens
|
| 157 |
+
|
| 158 |
+
if self.alignment_strategy == 'weighted_avg':
|
| 159 |
+
if len(visual_embedding.shape) != 3 or len(textual_embedding.shape) != 2:
|
| 160 |
+
raise Exception("Alignment strategy not implemented for this type of embeddings!")
|
| 161 |
+
sims = torch.einsum('ik,ijk->ij', textual_embedding, visual_embedding)
|
| 162 |
+
sims = sims.softmax(dim=-1)
|
| 163 |
+
# in this case, we keep as visual_embedding the averaged token weighted by the text similarities
|
| 164 |
+
visual_embedding = (visual_embedding * sims.unsqueeze(dim=-1)).mean(dim=1)
|
| 165 |
+
sims = textual_embedding @ visual_embedding.transpose(1, 0)
|
| 166 |
+
|
| 167 |
+
# in this case we sample the visual embedding from the softmax similarities of attention heads tokens and the textual tokens
|
| 168 |
+
elif self.alignment_strategy == 'sampled_attn_map':
|
| 169 |
+
if len(visual_embedding.shape) != 3 or len(textual_embedding.shape) != 2:
|
| 170 |
+
raise Exception("Alignment strategy not implemented for this type of embeddings!")
|
| 171 |
+
sims = torch.einsum('ik,ijk->ij', textual_embedding, visual_embedding)
|
| 172 |
+
sims = sims.softmax(dim=-1)
|
| 173 |
+
# in this case, we sample from the distribution given byt text2attn-maps similarities the attention map to align
|
| 174 |
+
index = torch.multinomial(sims, 1).view(-1, 1, 1).expand(-1, 1, visual_embedding.shape[-1])
|
| 175 |
+
visual_embedding = torch.gather(visual_embedding, 1, index).squeeze(1)
|
| 176 |
+
sims = textual_embedding @ visual_embedding.transpose(1, 0)
|
| 177 |
+
|
| 178 |
+
elif self.alignment_strategy == 'max_score':
|
| 179 |
+
sims = torch.einsum('ik,ijk->ij', textual_embedding, visual_embedding)
|
| 180 |
+
sims = sims.softmax(dim=-1)
|
| 181 |
+
index = sims.argmax(dim=-1)
|
| 182 |
+
index_reshaped = sims.argmax(dim=-1).view(-1, 1, 1).expand(-1, 1, visual_embedding.shape[-1])
|
| 183 |
+
visual_embedding = torch.gather(visual_embedding, 1, index_reshaped).squeeze(1)
|
| 184 |
+
sims = textual_embedding @ visual_embedding.transpose(1, 0)
|
| 185 |
+
else:
|
| 186 |
+
# in this case we construct a similarity matrix between attention head tokens and textual tokens
|
| 187 |
+
|
| 188 |
+
# we ensure that both the batch embeddings have the same number of dimensions
|
| 189 |
+
textual_embedding = textual_embedding.unsqueeze(1) if len(textual_embedding.shape) == 2 else textual_embedding
|
| 190 |
+
visual_embedding = visual_embedding.unsqueeze(1) if len(visual_embedding.shape) == 2 else visual_embedding
|
| 191 |
+
if textual_embedding.shape[1] > 1:
|
| 192 |
+
assert text_input_mask is not None, "If we use all the textual embeddings, we need the input mask"
|
| 193 |
+
if not self.keep_end_seq:
|
| 194 |
+
# we take the last True value of the mask and we set it to False
|
| 195 |
+
text_input_mask[torch.arange(text_input_mask.shape[0]), torch.sum(text_input_mask, dim=1) - 1] = False
|
| 196 |
+
if not self.keep_cls:
|
| 197 |
+
text_input_mask[:, 0] = False
|
| 198 |
+
|
| 199 |
+
# do not consider cls and eos tokens
|
| 200 |
+
im_set = visual_embedding
|
| 201 |
+
s_seq = textual_embedding
|
| 202 |
+
|
| 203 |
+
im_set_batch = im_set.size(0)
|
| 204 |
+
im_set_len = im_set.size(1)
|
| 205 |
+
s_seq_batch = s_seq.size(0)
|
| 206 |
+
s_seq_len = s_seq.size(1)
|
| 207 |
+
|
| 208 |
+
im_set = im_set.unsqueeze(1).expand(-1, s_seq_batch, -1, -1) # B x B x S_im x dim
|
| 209 |
+
s_seq = s_seq.unsqueeze(0).expand(im_set_batch, -1, -1, -1) # B x B x S_s x dim
|
| 210 |
+
alignments = torch.matmul(im_set, s_seq.permute(0, 1, 3, 2)) # B x B x S_im x S_s
|
| 211 |
+
|
| 212 |
+
# compute mask for the alignments tensor
|
| 213 |
+
if text_input_mask is not None:
|
| 214 |
+
alignment_mask = text_input_mask.unsqueeze(1).unsqueeze(0).expand(im_set_batch, -1, im_set_len, -1).logical_not()
|
| 215 |
+
|
| 216 |
+
alignments.masked_fill_(alignment_mask, value=0)
|
| 217 |
+
# alignments = F.relu(alignments)
|
| 218 |
+
# alignments = F.normalize(alignments,p=2, dim=2)
|
| 219 |
+
|
| 220 |
+
if self.alignment_strategy == 'sum':
|
| 221 |
+
sims = alignments.sum(dim=(2,3))
|
| 222 |
+
elif self.alignment_strategy == 'mean':
|
| 223 |
+
sims = alignments.mean(dim=(2,3))
|
| 224 |
+
elif self.alignment_strategy == 'max-row_sum':
|
| 225 |
+
sims = alignments.max(2)[0].sum(2)
|
| 226 |
+
elif self.alignment_strategy == 'nucleus-sampling':
|
| 227 |
+
max_alignments = alignments.max(2)[0]
|
| 228 |
+
sorted_alignments = max_alignments.sort(dim=2, descending=True)[0]
|
| 229 |
+
# min-max normalization
|
| 230 |
+
mins = sorted_alignments.min(2)[0].unsqueeze(-1).expand(-1, -1, s_seq_len)
|
| 231 |
+
maxs = sorted_alignments.max(2)[0].unsqueeze(-1).expand(-1, -1, s_seq_len)
|
| 232 |
+
norm_alignments = ((sorted_alignments - mins) / (maxs - mins))
|
| 233 |
+
# transform values in percentage
|
| 234 |
+
sums = norm_alignments.sum(dim=-1).unsqueeze(-1).expand(-1, -1, s_seq_len)
|
| 235 |
+
norm_alignments = norm_alignments / sums
|
| 236 |
+
# finding the element indices which surpasses alpha
|
| 237 |
+
cumsums = norm_alignments.cumsum(2)
|
| 238 |
+
indices = torch.argmax((cumsums > self.alpha).int() + 1, dim=2)
|
| 239 |
+
|
| 240 |
+
mask = torch.arange(s_seq_len).unsqueeze(0).unsqueeze(0).expand(s_seq_batch, s_seq_batch, s_seq_len).to(indices.device) < indices.unsqueeze(-1).expand(-1, -1, s_seq_len) + 1
|
| 241 |
+
relevant_alignments = (sorted_alignments * mask)
|
| 242 |
+
sims = relevant_alignments.sum(dim=2)
|
| 243 |
+
else:
|
| 244 |
+
# default case: dot-product
|
| 245 |
+
sims = textual_embedding @ visual_embedding.transpose(1, 0)
|
| 246 |
+
|
| 247 |
+
if not return_index:
|
| 248 |
+
return sims
|
| 249 |
+
else:
|
| 250 |
+
return sims, index
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def forward(self, visual_embedding, textual_embedding, ret_similarity_matrix=True, ret_embeds=False, self_attn_maps=None, cls=None, text_input_mask=None, return_index=False):
|
| 255 |
+
if self.weight_attn_heads is not None:
|
| 256 |
+
assert self_attn_maps is not None, "In case we have attention maps weights, we have to weight patch tokens mean by the weighted self-attention maps"
|
| 257 |
+
visual_embedding = self.get_visual_embed(visual_embedding, self_attn_maps=self_attn_maps, cls=cls)
|
| 258 |
+
|
| 259 |
+
textual_embedding = self.project_clip_txt(textual_embedding)
|
| 260 |
+
|
| 261 |
+
if self.cosine:
|
| 262 |
+
textual_embedding = F.normalize(textual_embedding, p=2, dim=-1)
|
| 263 |
+
visual_embedding = F.normalize(visual_embedding, p=2, dim=-1)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
if ret_embeds:
|
| 267 |
+
return textual_embedding, visual_embedding
|
| 268 |
+
|
| 269 |
+
if not return_index:
|
| 270 |
+
x = self.compute_similarity(visual_embedding, textual_embedding, text_input_mask, return_index)
|
| 271 |
+
else:
|
| 272 |
+
x, index = self.compute_similarity(visual_embedding, textual_embedding, text_input_mask, return_index)
|
| 273 |
+
|
| 274 |
+
if not ret_similarity_matrix:
|
| 275 |
+
x = x[torch.eye(len(x)) > 0.5] # only diagonal elements
|
| 276 |
+
|
| 277 |
+
if not return_index:
|
| 278 |
+
return x
|
| 279 |
+
else:
|
| 280 |
+
return x, index
|
| 281 |
+
|
| 282 |
+
def get_visual_embed(self, visual_embedding, self_attn_maps=None, cls=None):
|
| 283 |
+
if self_attn_maps is not None:
|
| 284 |
+
# we weight each attention head to obtain a weighted self-attention map
|
| 285 |
+
assert len(visual_embedding.shape) == 3, "In case we have attention maps weights, the visual_embedding should contain patch embeddings, with shape BS x NUM_PATCHES x EMBED_DIM"
|
| 286 |
+
if self.weight_attn_heads == 'conditioned':
|
| 287 |
+
assert cls is not None, "cls must be setted in case of dinamic attention weighting"
|
| 288 |
+
x = self.weight_layer1(cls)
|
| 289 |
+
x = self.act(x)
|
| 290 |
+
x = self.weight_layer2(x)
|
| 291 |
+
normalized_attn_weights = x.softmax(dim=1)
|
| 292 |
+
self_attn = (self_attn_maps * normalized_attn_weights.unsqueeze(dim=-1)).mean(dim=1)
|
| 293 |
+
else:
|
| 294 |
+
normalized_attn_weights = self.attn_weights.softmax(dim=0)
|
| 295 |
+
self_attn = (self_attn_maps * normalized_attn_weights.view(1, normalized_attn_weights.shape[0], 1)).mean(dim=1)
|
| 296 |
+
self_attn = self_attn.softmax(dim=-1)
|
| 297 |
+
|
| 298 |
+
# then we perform the weighted mean of patches
|
| 299 |
+
visual_embedding = (self_attn.unsqueeze(-1) * visual_embedding).mean(dim=1)
|
| 300 |
+
return visual_embedding
|
| 301 |
+
|
| 302 |
+
def project_clip_txt(self, textual_embedding):
|
| 303 |
+
textual_embedding = textual_embedding.float()
|
| 304 |
+
x = self.linear_layer(textual_embedding)
|
| 305 |
+
|
| 306 |
+
if hasattr(self, 'hidden_layers'):
|
| 307 |
+
for hidden_layer in self.hidden_layers:
|
| 308 |
+
if self.act:
|
| 309 |
+
x = self.act(x)
|
| 310 |
+
x = hidden_layer(x)
|
| 311 |
+
|
| 312 |
+
return x
|
| 313 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 314 |
+
# compatibility with old code
|
| 315 |
+
if 'linear_layer2.weight' in state_dict:
|
| 316 |
+
state_dict['hidden_layers.0.weight'] = state_dict.pop('linear_layer2.weight')
|
| 317 |
+
state_dict['hidden_layers.0.bias'] = state_dict.pop('linear_layer2.bias')
|
| 318 |
+
# Call the parent class's load_state_dict with the modified state_dict
|
| 319 |
+
super(ProjectionLayer, self).load_state_dict(state_dict, strict)
|
| 320 |
+
|
| 321 |
+
def set_alignment_strategy(self, alignment_strategy):
|
| 322 |
+
self.alignment_strategy = alignment_strategy
|
| 323 |
+
return
|
| 324 |
+
|
| 325 |
+
def __len__(self):
|
| 326 |
+
return sum(p.numel() for p in self.parameters())
|
| 327 |
+
|
| 328 |
+
class DoubleMLP(nn.Module):
|
| 329 |
+
def __init__(self, act=nn.Tanh(), hidden_layer=False, cosine=True, dino_embed_dim=1024, clip_embed_dim=512, num_attn_head=16, weight_attn_heads=None,
|
| 330 |
+
alignment_strategy='max_score', alpha=0.6, keep_cls=False, keep_end_seq=False):
|
| 331 |
+
super().__init__()
|
| 332 |
+
self.num_attn_head = num_attn_head
|
| 333 |
+
|
| 334 |
+
self.linear_layer = nn.Linear(clip_embed_dim, dino_embed_dim)
|
| 335 |
+
if hidden_layer:
|
| 336 |
+
hidden_layer = 1 if hidden_layer is True else hidden_layer # ensuring compatibility with old code
|
| 337 |
+
# self.linear_layer2 = nn.Linear(dino_embed_dim, dino_embed_dim)
|
| 338 |
+
self.hidden_layers = nn.ModuleList([nn.Linear(dino_embed_dim, dino_embed_dim) for _ in range(hidden_layer)])
|
| 339 |
+
self.act = act
|
| 340 |
+
self.cosine = cosine
|
| 341 |
+
|
| 342 |
+
self.weight_attn_heads = weight_attn_heads
|
| 343 |
+
if weight_attn_heads == 'static':
|
| 344 |
+
self.attn_weights = nn.Parameter(torch.rand(self.num_attn_head))
|
| 345 |
+
elif weight_attn_heads == 'conditioned':
|
| 346 |
+
self.weight_layer1 = nn.Linear(dino_embed_dim, dino_embed_dim)
|
| 347 |
+
self.weight_layer2 = nn.Linear(dino_embed_dim, self.num_attn_head)
|
| 348 |
+
|
| 349 |
+
self.alignment_strategy = alignment_strategy # relevant only if we use disentangled_self_attn
|
| 350 |
+
self.keep_cls = keep_cls # relevant only if we use clip_txt_tokens_out
|
| 351 |
+
self.keep_end_seq = keep_end_seq # relevant only if we use clip_txt_tokens_out
|
| 352 |
+
self.alpha = alpha
|
| 353 |
+
|
| 354 |
+
self.visual_linear = nn.Linear(dino_embed_dim, dino_embed_dim)
|
| 355 |
+
if hidden_layer:
|
| 356 |
+
hidden_layer = 1 if hidden_layer is True else hidden_layer # ensuring compatibility with old code
|
| 357 |
+
self.visual_hidden_layers = nn.ModuleList([nn.Linear(dino_embed_dim, dino_embed_dim) for _ in range(hidden_layer)])
|
| 358 |
+
|
| 359 |
+
@classmethod
|
| 360 |
+
def from_config(cls, config):
|
| 361 |
+
if type(config) is str:
|
| 362 |
+
# if the configuration is a string, we treat it as a file path
|
| 363 |
+
with open(config, 'r') as f:
|
| 364 |
+
config = yaml.safe_load(f)['model']
|
| 365 |
+
|
| 366 |
+
# loading the activation function
|
| 367 |
+
act = config.get('act', None)
|
| 368 |
+
if act == 'tanh':
|
| 369 |
+
act = nn.Tanh()
|
| 370 |
+
elif act == 'relu':
|
| 371 |
+
act = nn.ReLU()
|
| 372 |
+
elif act == 'sigmoid':
|
| 373 |
+
act = nn.Sigmoid()
|
| 374 |
+
elif act is not None:
|
| 375 |
+
raise Exception("Unknown activation function")
|
| 376 |
+
|
| 377 |
+
model = cls(
|
| 378 |
+
act=act,
|
| 379 |
+
hidden_layer=config.get('hidden_layer', False),
|
| 380 |
+
cosine=config.get('cosine', True),
|
| 381 |
+
dino_embed_dim=config.get('dino_embed_dim', 1024),
|
| 382 |
+
num_attn_head=config.get('num_attn_head', 16),
|
| 383 |
+
clip_embed_dim=config.get('clip_embed_dim', 512),
|
| 384 |
+
weight_attn_heads=config.get('weight_attn_heads', None),
|
| 385 |
+
alignment_strategy=config.get('alignment_strategy', 'max_score'),
|
| 386 |
+
alpha=config.get('alpha', 0.6),
|
| 387 |
+
keep_cls=config.get('keep_cls', None),
|
| 388 |
+
keep_end_seq=config.get('keep_end_seq', None),
|
| 389 |
+
)
|
| 390 |
+
if config.get('starting_checkpoint', None) is not None:
|
| 391 |
+
model.load_state_dict(torch.load(config['starting_checkpoint'], 'cpu'))
|
| 392 |
+
|
| 393 |
+
return model
|
| 394 |
+
|
| 395 |
+
def compute_similarity(self, visual_embedding, textual_embedding, text_input_mask=None):
|
| 396 |
+
if len(visual_embedding.shape) == 3 or len(textual_embedding.shape) == 3:
|
| 397 |
+
# at least one embedding is decomposed: either we have all textual tokens or we have all the attention head tokens
|
| 398 |
+
|
| 399 |
+
if self.alignment_strategy == 'weighted_avg':
|
| 400 |
+
if len(visual_embedding.shape) != 3 or len(textual_embedding.shape) != 2:
|
| 401 |
+
raise Exception("Alignment strategy not implemented for this type of embeddings!")
|
| 402 |
+
sims = torch.einsum('ik,ijk->ij', textual_embedding, visual_embedding)
|
| 403 |
+
sims = sims.softmax(dim=-1)
|
| 404 |
+
# in this case, we keep as visual_embedding the averaged token weighted by the text similarities
|
| 405 |
+
visual_embedding = (visual_embedding * sims.unsqueeze(dim=-1)).mean(dim=1)
|
| 406 |
+
sims = textual_embedding @ visual_embedding.transpose(1, 0)
|
| 407 |
+
|
| 408 |
+
# in this case we sample the visual embedding from the softmax similarities of attention heads tokens and the textual tokens
|
| 409 |
+
elif self.alignment_strategy == 'sampled_attn_map':
|
| 410 |
+
if len(visual_embedding.shape) != 3 or len(textual_embedding.shape) != 2:
|
| 411 |
+
raise Exception("Alignment strategy not implemented for this type of embeddings!")
|
| 412 |
+
sims = torch.einsum('ik,ijk->ij', textual_embedding, visual_embedding)
|
| 413 |
+
sims = sims.softmax(dim=-1)
|
| 414 |
+
# in this case, we sample from the distribution given byt text2attn-maps similarities the attention map to align
|
| 415 |
+
index = torch.multinomial(sims, 1).view(-1, 1, 1).expand(-1, 1, visual_embedding.shape[-1])
|
| 416 |
+
visual_embedding = torch.gather(visual_embedding, 1, index).squeeze(1)
|
| 417 |
+
sims = textual_embedding @ visual_embedding.transpose(1, 0)
|
| 418 |
+
|
| 419 |
+
elif self.alignment_strategy == 'max_score':
|
| 420 |
+
sims = torch.einsum('ik,ijk->ij', textual_embedding, visual_embedding)
|
| 421 |
+
sims = sims.softmax(dim=-1)
|
| 422 |
+
index = sims.argmax(dim=-1).view(-1, 1, 1).expand(-1, 1, visual_embedding.shape[-1])
|
| 423 |
+
visual_embedding = torch.gather(visual_embedding, 1, index).squeeze(1)
|
| 424 |
+
sims = textual_embedding @ visual_embedding.transpose(1, 0)
|
| 425 |
+
else:
|
| 426 |
+
# in this case we construct a similarity matrix between attention head tokens and textual tokens
|
| 427 |
+
|
| 428 |
+
# we ensure that both the batch embeddings have the same number of dimensions
|
| 429 |
+
textual_embedding = textual_embedding.unsqueeze(1) if len(textual_embedding.shape) == 2 else textual_embedding
|
| 430 |
+
visual_embedding = visual_embedding.unsqueeze(1) if len(visual_embedding.shape) == 2 else visual_embedding
|
| 431 |
+
if textual_embedding.shape[1] > 1:
|
| 432 |
+
assert text_input_mask is not None, "If we use all the textual embeddings, we need the input mask"
|
| 433 |
+
if not self.keep_end_seq:
|
| 434 |
+
# we take the last True value of the mask and we set it to False
|
| 435 |
+
text_input_mask[torch.arange(text_input_mask.shape[0]), torch.sum(text_input_mask, dim=1) - 1] = False
|
| 436 |
+
if not self.keep_cls:
|
| 437 |
+
text_input_mask[:, 0] = False
|
| 438 |
+
|
| 439 |
+
# do not consider cls and eos tokens
|
| 440 |
+
im_set = visual_embedding
|
| 441 |
+
s_seq = textual_embedding
|
| 442 |
+
|
| 443 |
+
im_set_batch = im_set.size(0)
|
| 444 |
+
im_set_len = im_set.size(1)
|
| 445 |
+
s_seq_batch = s_seq.size(0)
|
| 446 |
+
s_seq_len = s_seq.size(1)
|
| 447 |
+
|
| 448 |
+
im_set = im_set.unsqueeze(1).expand(-1, s_seq_batch, -1, -1) # B x B x S_im x dim
|
| 449 |
+
s_seq = s_seq.unsqueeze(0).expand(im_set_batch, -1, -1, -1) # B x B x S_s x dim
|
| 450 |
+
alignments = torch.matmul(im_set, s_seq.permute(0, 1, 3, 2)) # B x B x S_im x S_s
|
| 451 |
+
|
| 452 |
+
# compute mask for the alignments tensor
|
| 453 |
+
if text_input_mask is not None:
|
| 454 |
+
alignment_mask = text_input_mask.unsqueeze(1).unsqueeze(0).expand(im_set_batch, -1, im_set_len, -1).logical_not()
|
| 455 |
+
|
| 456 |
+
alignments.masked_fill_(alignment_mask, value=0)
|
| 457 |
+
# alignments = F.relu(alignments)
|
| 458 |
+
# alignments = F.normalize(alignments,p=2, dim=2)
|
| 459 |
+
|
| 460 |
+
if self.alignment_strategy == 'sum':
|
| 461 |
+
sims = alignments.sum(dim=(2,3))
|
| 462 |
+
elif self.alignment_strategy == 'mean':
|
| 463 |
+
sims = alignments.mean(dim=(2,3))
|
| 464 |
+
elif self.alignment_strategy == 'max-row_sum':
|
| 465 |
+
sims = alignments.max(2)[0].sum(2)
|
| 466 |
+
elif self.alignment_strategy == 'nucleus-sampling':
|
| 467 |
+
max_alignments = alignments.max(2)[0]
|
| 468 |
+
sorted_alignments = max_alignments.sort(dim=2, descending=True)[0]
|
| 469 |
+
# min-max normalization
|
| 470 |
+
mins = sorted_alignments.min(2)[0].unsqueeze(-1).expand(-1, -1, s_seq_len)
|
| 471 |
+
maxs = sorted_alignments.max(2)[0].unsqueeze(-1).expand(-1, -1, s_seq_len)
|
| 472 |
+
norm_alignments = ((sorted_alignments - mins) / (maxs - mins))
|
| 473 |
+
# transform values in percentage
|
| 474 |
+
sums = norm_alignments.sum(dim=-1).unsqueeze(-1).expand(-1, -1, s_seq_len)
|
| 475 |
+
norm_alignments = norm_alignments / sums
|
| 476 |
+
# finding the element indices which surpasses alpha
|
| 477 |
+
cumsums = norm_alignments.cumsum(2)
|
| 478 |
+
indices = torch.argmax((cumsums > self.alpha).int() + 1, dim=2)
|
| 479 |
+
|
| 480 |
+
mask = torch.arange(s_seq_len).unsqueeze(0).unsqueeze(0).expand(s_seq_batch, s_seq_batch, s_seq_len).to(indices.device) < indices.unsqueeze(-1).expand(-1, -1, s_seq_len) + 1
|
| 481 |
+
relevant_alignments = (sorted_alignments * mask)
|
| 482 |
+
sims = relevant_alignments.sum(dim=2)
|
| 483 |
+
else:
|
| 484 |
+
# default case: dot-product
|
| 485 |
+
sims = textual_embedding @ visual_embedding.transpose(1, 0)
|
| 486 |
+
|
| 487 |
+
return sims
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def forward(self, visual_embedding, textual_embedding, ret_similarity_matrix=True, ret_embeds=False, self_attn_maps=None, cls=None, text_input_mask=None):
|
| 492 |
+
if self.weight_attn_heads is not None:
|
| 493 |
+
assert self_attn_maps is not None, "In case we have attention maps weights, we have to weight patch tokens mean by the weighted self-attention maps"
|
| 494 |
+
visual_embedding = self.get_visual_embed(visual_embedding, self_attn_maps=self_attn_maps, cls=cls)
|
| 495 |
+
|
| 496 |
+
visual_embedding = self.project_visual(visual_embedding)
|
| 497 |
+
|
| 498 |
+
textual_embedding = self.project_clip_txt(textual_embedding)
|
| 499 |
+
|
| 500 |
+
if self.cosine:
|
| 501 |
+
textual_embedding = F.normalize(textual_embedding, p=2, dim=-1)
|
| 502 |
+
visual_embedding = F.normalize(visual_embedding, p=2, dim=-1)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
if ret_embeds:
|
| 506 |
+
return textual_embedding, visual_embedding
|
| 507 |
+
|
| 508 |
+
x = self.compute_similarity(visual_embedding, textual_embedding, text_input_mask)
|
| 509 |
+
if not ret_similarity_matrix:
|
| 510 |
+
x = x[torch.eye(len(x)) > 0.5] # only diagonal elements
|
| 511 |
+
|
| 512 |
+
return x
|
| 513 |
+
|
| 514 |
+
def get_visual_embed(self, visual_embedding, self_attn_maps=None, cls=None):
|
| 515 |
+
if self_attn_maps is not None:
|
| 516 |
+
# we weight each attention head to obtain a weighted self-attention map
|
| 517 |
+
assert len(visual_embedding.shape) == 3, "In case we have attention maps weights, the visual_embedding should contain patch embeddings, with shape BS x NUM_PATCHES x EMBED_DIM"
|
| 518 |
+
if self.weight_attn_heads == 'conditioned':
|
| 519 |
+
assert cls is not None, "cls must be setted in case of dinamic attention weighting"
|
| 520 |
+
x = self.weight_layer1(cls)
|
| 521 |
+
x = self.act(x)
|
| 522 |
+
x = self.weight_layer2(x)
|
| 523 |
+
normalized_attn_weights = x.softmax(dim=1)
|
| 524 |
+
self_attn = (self_attn_maps * normalized_attn_weights.unsqueeze(dim=-1)).mean(dim=1)
|
| 525 |
+
else:
|
| 526 |
+
normalized_attn_weights = self.attn_weights.softmax(dim=0)
|
| 527 |
+
self_attn = (self_attn_maps * normalized_attn_weights.view(1, normalized_attn_weights.shape[0], 1)).mean(dim=1)
|
| 528 |
+
self_attn = self_attn.softmax(dim=-1)
|
| 529 |
+
|
| 530 |
+
# then we perform the weighted mean of patches
|
| 531 |
+
visual_embedding = (self_attn.unsqueeze(-1) * visual_embedding).mean(dim=1)
|
| 532 |
+
return visual_embedding
|
| 533 |
+
|
| 534 |
+
def project_clip_txt(self, textual_embedding):
|
| 535 |
+
textual_embedding = textual_embedding.float()
|
| 536 |
+
x = self.linear_layer(textual_embedding)
|
| 537 |
+
|
| 538 |
+
for hidden_layer in self.hidden_layers:
|
| 539 |
+
if self.act:
|
| 540 |
+
x = self.act(x)
|
| 541 |
+
x = hidden_layer(x)
|
| 542 |
+
|
| 543 |
+
return x
|
| 544 |
+
|
| 545 |
+
def project_visual(self, visual_embedding):
|
| 546 |
+
visual_embedding = visual_embedding.float()
|
| 547 |
+
x = self.visual_linear(visual_embedding)
|
| 548 |
+
|
| 549 |
+
for hidden_layer in self.visual_hidden_layers:
|
| 550 |
+
if self.act:
|
| 551 |
+
x = self.act(x)
|
| 552 |
+
x = hidden_layer(x)
|
| 553 |
+
|
| 554 |
+
return x
|
| 555 |
+
|
| 556 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 557 |
+
# compatibility with old code
|
| 558 |
+
if 'linear_layer2.weight' in state_dict:
|
| 559 |
+
state_dict['hidden_layers.0.weight'] = state_dict.pop('linear_layer2.weight')
|
| 560 |
+
state_dict['hidden_layers.0.bias'] = state_dict.pop('linear_layer2.bias')
|
| 561 |
+
# Call the parent class's load_state_dict with the modified state_dict
|
| 562 |
+
super(DoubleMLP, self).load_state_dict(state_dict, strict)
|
| 563 |
+
|
| 564 |
+
def set_alignment_strategy(self, alignment_strategy):
|
| 565 |
+
self.alignment_strategy = alignment_strategy
|
| 566 |
+
return
|
| 567 |
+
|
| 568 |
+
def __len__(self):
|
| 569 |
+
return sum(p.numel() for p in self.parameters())
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class CLIPLastLayer(nn.Module):
|
| 573 |
+
def __init__(self, act=nn.Tanh(), hidden_layer=False, cosine=True, dino_embed_dim=1024, clip_embed_dim=512, weight_attn_heads=None, alignment_strategy='max_score', clip_model='ViT-B/16', text_input_mask=None, projection_weights=None):
|
| 574 |
+
import clip
|
| 575 |
+
super().__init__()
|
| 576 |
+
self.clip_model, _ = clip.load(clip_model)
|
| 577 |
+
self.clip_model.to(dtype=torch.float32)
|
| 578 |
+
# self.last_resblock = copy.deepcopy(self.clip_model.transformer.resblocks[-1])
|
| 579 |
+
self.last_resblock = self.clip_model.transformer.resblocks[-1]
|
| 580 |
+
# self.last_resblock.requires_grad_(False)
|
| 581 |
+
# self.last_ln = copy.deepcopy(self.clip_model.ln_final)
|
| 582 |
+
self.last_ln = self.clip_model.ln_final
|
| 583 |
+
# self.last_ln.requires_grad_(False)
|
| 584 |
+
# self.clip_text_proj = copy.deepcopy(self.clip_model.text_projection)
|
| 585 |
+
self.clip_text_proj = self.clip_model.text_projection
|
| 586 |
+
# self.clip_text_proj.requires_grad_(False)
|
| 587 |
+
self.clip_dtype = self.clip_model.dtype
|
| 588 |
+
del self.clip_model
|
| 589 |
+
|
| 590 |
+
self.projection_layer = ProjectionLayer(act=act, hidden_layer=hidden_layer, cosine=cosine, dino_embed_dim=dino_embed_dim,
|
| 591 |
+
clip_embed_dim=clip_embed_dim, weight_attn_heads=weight_attn_heads, alignment_strategy=alignment_strategy)
|
| 592 |
+
|
| 593 |
+
if projection_weights is not None:
|
| 594 |
+
self.projection_layer.load_state_dict(torch.load(projection_weights, 'cpu'))
|
| 595 |
+
|
| 596 |
+
def forward(self, visual_embedding, textual_embedding, ret_similarity_matrix=True, ret_embeds=False, self_attn_maps=None, cls=None, text_argmax=None, text_input_mask=None):
|
| 597 |
+
x = self.last_resblock(textual_embedding.permute(1, 0, 2))
|
| 598 |
+
x = x.permute(1, 0, 2)
|
| 599 |
+
x = self.last_ln(x).type(self.clip_dtype)
|
| 600 |
+
x = x[torch.arange(x.shape[0]), text_argmax] @ self.clip_text_proj
|
| 601 |
+
if ret_embeds:
|
| 602 |
+
textual_embedding, visual_embedding = self.projection_layer(visual_embedding, x, ret_similarity_matrix=ret_similarity_matrix, ret_embeds=ret_embeds, self_attn_maps=self_attn_maps, cls=cls)
|
| 603 |
+
return textual_embedding, visual_embedding
|
| 604 |
+
x = self.projection_layer(visual_embedding, x, ret_similarity_matrix=ret_similarity_matrix, ret_embeds=ret_embeds, self_attn_maps=self_attn_maps, cls=cls)
|
| 605 |
+
return x
|
| 606 |
+
|
| 607 |
+
def project_clip_txt(self, textual_embedding, text_argmax):
|
| 608 |
+
x = self.last_resblock(textual_embedding.permute(1, 0, 2))
|
| 609 |
+
x = x.permute(1, 0, 2)
|
| 610 |
+
x = self.last_ln(x).type(self.clip_dtype)
|
| 611 |
+
x = x[torch.arange(x.shape[0]), text_argmax] @ self.clip_text_proj
|
| 612 |
+
x = self.projection_layer.project_clip_txt(x)
|
| 613 |
+
return x
|
| 614 |
+
|
| 615 |
+
@classmethod
|
| 616 |
+
def from_config(cls, config):
|
| 617 |
+
if type(config) is str:
|
| 618 |
+
# if the configuration is a string, we treat it as a file path
|
| 619 |
+
with open(config, 'r') as f:
|
| 620 |
+
config = yaml.safe_load(f)['model']
|
| 621 |
+
|
| 622 |
+
# loading the activation function
|
| 623 |
+
act = config.get('act', None)
|
| 624 |
+
if act == 'tanh':
|
| 625 |
+
act = nn.Tanh()
|
| 626 |
+
elif act == 'relu':
|
| 627 |
+
act = nn.ReLU()
|
| 628 |
+
elif act == 'sigmoid':
|
| 629 |
+
act = nn.Sigmoid()
|
| 630 |
+
elif act is not None:
|
| 631 |
+
raise Exception("Unknown activation function")
|
| 632 |
+
|
| 633 |
+
model = cls(
|
| 634 |
+
act=act,
|
| 635 |
+
hidden_layer=config.get('hidden_layer', False),
|
| 636 |
+
cosine=config.get('cosine', True),
|
| 637 |
+
dino_embed_dim=config.get('dino_embed_dim', 1024),
|
| 638 |
+
clip_embed_dim=config.get('clip_embed_dim', 512),
|
| 639 |
+
weight_attn_heads=config.get('weight_attn_heads', None),
|
| 640 |
+
alignment_strategy=config.get('alignment_strategy', 'max_score'),
|
| 641 |
+
clip_model=config.get('clip_model', 'ViT-B/16'),
|
| 642 |
+
projection_weights=config.get('projection_weights', None),
|
| 643 |
+
|
| 644 |
+
)
|
| 645 |
+
if config.get('starting_checkpoint', None) is not None:
|
| 646 |
+
model.load_state_dict(torch.load(config['starting_checkpoint'], 'cpu'))
|
| 647 |
+
|
| 648 |
+
return model
|
| 649 |
+
|
| 650 |
+
def __len__(self):
|
| 651 |
+
return sum(p.numel() for p in self.parameters())
|
| 652 |
+
|
| 653 |
+
class DinoText(nn.Module):
|
| 654 |
+
"""
|
| 655 |
+
Project images and texts into DINOv2 latent space.
|
| 656 |
+
"""
|
| 657 |
+
def __init__(self, dino_cfg="dinov2_vitl14_reg", clip_cfg="ViT-B/16", projection_cfg="configs/linear.yaml", projection_weights="weights/linear_avg_self_attn_out.pth", freeze_text_encoder=True, avg_self_attn_token=True, use_disentangled_self_attn=False):
|
| 658 |
+
super().__init__()
|
| 659 |
+
# DINO parameters
|
| 660 |
+
self.num_global_tokens = 1 if "reg" not in dino_cfg else 5
|
| 661 |
+
self.embed_dim = 1024 if "vitl" in dino_cfg else 768
|
| 662 |
+
self.num_attn_heads = 16
|
| 663 |
+
self.scale = 0.125
|
| 664 |
+
|
| 665 |
+
self.visual_backbone = torch.hub.load('facebookresearch/dinov2', dino_cfg)
|
| 666 |
+
self.text_backbone, _ = clip.load(clip_cfg)
|
| 667 |
+
self.clip2dino_proj = ProjectionLayer.from_config(projection_cfg)
|
| 668 |
+
if projection_weights is not None:
|
| 669 |
+
self.clip2dino_proj.load_state_dict(torch.load(projection_weights, 'cpu'))
|
| 670 |
+
self.use_avg_self_attn = avg_self_attn_token
|
| 671 |
+
self.use_disentangled_self_attn = use_disentangled_self_attn
|
| 672 |
+
if self.use_avg_self_attn or self.use_disentangled_self_attn:
|
| 673 |
+
self.visual_backbone.blocks[-1].attn.qkv.register_forward_hook(get_self_attention)
|
| 674 |
+
if self.use_disentangled_self_attn:
|
| 675 |
+
self.visual_backbone.blocks[-1].attn.qkv.register_forward_hook(get_self_attention)
|
| 676 |
+
if freeze_text_encoder:
|
| 677 |
+
self.text_backbone.eval()
|
| 678 |
+
self.text_backbone.requires_grad_(False)
|
| 679 |
+
self.avg_self_attn_token = avg_self_attn_token
|
| 680 |
+
if self.avg_self_attn_token or self.use_disentangled_self_attn:
|
| 681 |
+
self.visual_backbone.blocks[-1].attn.qkv.register_forward_hook(self.get_self_attention)
|
| 682 |
+
self.feats = {}
|
| 683 |
+
self.num_global_tokens = 1 if "reg" not in dino_cfg else 5
|
| 684 |
+
self.num_attn_heads = 16
|
| 685 |
+
self.scale = 0.125
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
@classmethod
|
| 689 |
+
def from_config(cls, cfg):
|
| 690 |
+
if type(cfg) is str:
|
| 691 |
+
# if the configuration is a string, we treat it as a file path
|
| 692 |
+
with open(cfg, 'r') as f:
|
| 693 |
+
cfg = yaml.safe_load(f)['model']
|
| 694 |
+
|
| 695 |
+
model = cls(
|
| 696 |
+
dino_cfg=cfg.get('dino_cfg', "dinov2_vitl14_reg"),
|
| 697 |
+
clip_cfg=cfg.get('clip_cfg', "ViT-B/16"),
|
| 698 |
+
projection_cfg=cfg.get('projection_cfg', "configs/linear.yaml"),
|
| 699 |
+
projection_weights=cfg.get('projection_weights', None),
|
| 700 |
+
avg_self_attn_token=cfg.get('use_avg_self_attn', False),
|
| 701 |
+
use_disentangled_self_attn=cfg.get('use_disentangled_self_attn', False),
|
| 702 |
+
)
|
| 703 |
+
return model
|
| 704 |
+
|
| 705 |
+
def encode_text(self, tokenized_texts):
|
| 706 |
+
x = self.text_backbone.encode_text(tokenized_texts)
|
| 707 |
+
x = self.clip2dino_proj.project_clip_txt(x)
|
| 708 |
+
return x
|
| 709 |
+
|
| 710 |
+
def encode_image(self, images):
|
| 711 |
+
batch_size, _, _, _ = images.shape
|
| 712 |
+
x = self.visual_backbone(images, is_training=self.avg_self_attn_token or self.use_disentangled_self_attn)
|
| 713 |
+
if self.avg_self_attn_token:
|
| 714 |
+
batch_size, num_tokens, embed_dim = x['x_norm_patchtokens'].shape
|
| 715 |
+
num_tokens = num_tokens + self.num_global_tokens
|
| 716 |
+
self_attn = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens)
|
| 717 |
+
x = (self_attn.unsqueeze(-1) * x['x_norm_patchtokens']).mean(dim=1)
|
| 718 |
+
if self.use_disentangled_self_attn:
|
| 719 |
+
batch_size, num_tokens, embed_dim = x['x_norm_patchtokens'].shape
|
| 720 |
+
num_tokens = num_tokens + self.num_global_tokens
|
| 721 |
+
self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True)
|
| 722 |
+
self_attn_maps = self_attn_maps.softmax(dim=-1)
|
| 723 |
+
x = (x['x_norm_patchtokens'].unsqueeze(1) * self_attn_maps.unsqueeze(-1)).mean(dim=2)
|
| 724 |
+
return x
|
| 725 |
+
|
| 726 |
+
def get_self_attention(self, module, input, output):
|
| 727 |
+
self.feats['self_attn'] = output
|
| 728 |
+
|
| 729 |
+
def process_self_attention(self, output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False):
|
| 730 |
+
qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4)
|
| 731 |
+
q, k, v = qkv[0] * scale, qkv[1], qkv[2]
|
| 732 |
+
attn = q @ k.transpose(-2, -1)
|
| 733 |
+
self_attn_maps = attn[:, : , 0, num_global_tokens:]
|
| 734 |
+
self_attn = self_attn_maps.mean(dim=1)
|
| 735 |
+
self_attn = self_attn.softmax(dim=-1)
|
| 736 |
+
if ret_self_attn_maps:
|
| 737 |
+
return self_attn, self_attn_maps
|
| 738 |
+
else:
|
| 739 |
+
return self_attn
|
| 740 |
+
|
| 741 |
+
def forward(self, images, tokenized_texts, cosine=True, ret_similarity_matrix=True):
|
| 742 |
+
img_embed = self.encode_image(images)
|
| 743 |
+
txt_embed = self.encode_text(tokenized_texts)
|
| 744 |
+
|
| 745 |
+
if cosine:
|
| 746 |
+
img_embed = F.normalize(img_embed, p=2, dim=1)
|
| 747 |
+
txt_embed = F.normalize(txt_embed, p=2, dim=1)
|
| 748 |
+
x = img_embed @ txt_embed.transpose(1, 0)
|
| 749 |
+
if not ret_similarity_matrix:
|
| 750 |
+
x = x[torch.eye(len(x)) > 0.5] # only diagonal elements
|
| 751 |
+
|
| 752 |
+
return x
|
| 753 |
+
|
| 754 |
+
def __len__(self):
|
| 755 |
+
def count_parameters(model):
|
| 756 |
+
return sum(p.numel() for p in model.parameters())
|
| 757 |
+
return count_parameters(self.visual_backbone) + count_parameters(self.clip2dino_proj) + count_parameters(self.text_backbone.transformer)
|
hf_model/modules.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# FreeDA
|
| 3 |
+
# ------------------------------------------------------------------------------
|
| 4 |
+
from functools import partial
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange, repeat
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BLCModuleCompatibleBCHW(nn.Module):
|
| 12 |
+
def forward_blc(self, x):
|
| 13 |
+
raise NotImplementedError()
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
is2d = x.ndim == 4
|
| 17 |
+
if is2d:
|
| 18 |
+
_, _, H, W = x.shape
|
| 19 |
+
x = rearrange(x, "B C H W -> B (H W) C")
|
| 20 |
+
|
| 21 |
+
x = self.forward_blc(x)
|
| 22 |
+
|
| 23 |
+
if is2d:
|
| 24 |
+
x = rearrange(x, "B (H W) C -> B C H W", H=H, W=W)
|
| 25 |
+
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class FeatureEncoder(nn.Module):
|
| 30 |
+
"""Encoder + Feature extractor
|
| 31 |
+
"""
|
| 32 |
+
def __init__(self, safe=True):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.safe = safe # clone return features to protect it from after-modification
|
| 35 |
+
self._features = []
|
| 36 |
+
|
| 37 |
+
def hook(self, module, input, output):
|
| 38 |
+
self._features.append(output)
|
| 39 |
+
|
| 40 |
+
def clear_features(self):
|
| 41 |
+
self._features.clear()
|
| 42 |
+
|
| 43 |
+
def _encode(self, x):
|
| 44 |
+
raise NotImplementedError()
|
| 45 |
+
|
| 46 |
+
def forward(self, *args, ret_feats=False, **kwargs):
|
| 47 |
+
self.clear_features()
|
| 48 |
+
|
| 49 |
+
x = self._encode(*args, **kwargs)
|
| 50 |
+
|
| 51 |
+
if ret_feats:
|
| 52 |
+
if self.safe:
|
| 53 |
+
features = [t.clone() for t in self._features]
|
| 54 |
+
self.clear_features()
|
| 55 |
+
else:
|
| 56 |
+
features = self._features
|
| 57 |
+
return x, features
|
| 58 |
+
else:
|
| 59 |
+
self.clear_features()
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Project2d(nn.Module):
|
| 64 |
+
"""2d projection by 1x1 conv
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
p: [C_in, C_out]
|
| 68 |
+
"""
|
| 69 |
+
def __init__(self, p):
|
| 70 |
+
# convert to 1x1 conv weight
|
| 71 |
+
super().__init__()
|
| 72 |
+
p = rearrange(p, "Cin Cout -> Cout Cin 1 1")
|
| 73 |
+
self.p = nn.Parameter(p.detach().clone())
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
return F.conv2d(x, self.p) # 1x1 conv
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def dispatcher(dispatch_fn):
|
| 80 |
+
def decorated(key, *args):
|
| 81 |
+
if callable(key):
|
| 82 |
+
return key
|
| 83 |
+
|
| 84 |
+
if key is None:
|
| 85 |
+
key = "none"
|
| 86 |
+
|
| 87 |
+
return dispatch_fn(key, *args)
|
| 88 |
+
|
| 89 |
+
return decorated
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dispatcher
|
| 93 |
+
def activ_dispatch(activ):
|
| 94 |
+
return {
|
| 95 |
+
"none": nn.Identity,
|
| 96 |
+
"relu": nn.ReLU,
|
| 97 |
+
"lrelu": partial(nn.LeakyReLU, negative_slope=0.2),
|
| 98 |
+
"gelu": nn.GELU,
|
| 99 |
+
}[activ.lower()]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_norm_fn(norm, C):
|
| 103 |
+
"""2d normalization layers
|
| 104 |
+
"""
|
| 105 |
+
if norm is None or norm == "none":
|
| 106 |
+
return nn.Identity()
|
| 107 |
+
|
| 108 |
+
return {
|
| 109 |
+
"bn": nn.BatchNorm2d(C),
|
| 110 |
+
"syncbn": nn.SyncBatchNorm(C),
|
| 111 |
+
"ln": LayerNorm2d(C),
|
| 112 |
+
"gn": nn.GroupNorm(32, C),
|
| 113 |
+
}[norm]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class LayerNorm2d(nn.LayerNorm):
|
| 117 |
+
def __init__(self, num_channels, eps=1e-5, affine=True):
|
| 118 |
+
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
return F.layer_norm(
|
| 122 |
+
x.permute(0, 2, 3, 1),
|
| 123 |
+
self.normalized_shape,
|
| 124 |
+
self.weight,
|
| 125 |
+
self.bias,
|
| 126 |
+
self.eps
|
| 127 |
+
).permute(0, 3, 1, 2)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class Gate(nn.Module):
|
| 131 |
+
"""Tanh gate"""
|
| 132 |
+
def __init__(self, init=0.0):
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.gate = nn.Parameter(torch.as_tensor(init))
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
return torch.tanh(self.gate) * x
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class ConvBlock(nn.Module):
|
| 141 |
+
def __init__(
|
| 142 |
+
self,
|
| 143 |
+
C_in,
|
| 144 |
+
C_out,
|
| 145 |
+
kernel_size=3,
|
| 146 |
+
stride=1,
|
| 147 |
+
padding=1,
|
| 148 |
+
norm="none",
|
| 149 |
+
activ="relu",
|
| 150 |
+
bias=True,
|
| 151 |
+
upsample=False,
|
| 152 |
+
downsample=False,
|
| 153 |
+
pad_type="zeros",
|
| 154 |
+
dropout=0.0,
|
| 155 |
+
gate=False,
|
| 156 |
+
):
|
| 157 |
+
super().__init__()
|
| 158 |
+
if kernel_size == 1:
|
| 159 |
+
assert padding == 0
|
| 160 |
+
self.C_in = C_in
|
| 161 |
+
self.C_out = C_out
|
| 162 |
+
|
| 163 |
+
activ = activ_dispatch(activ)
|
| 164 |
+
self.upsample = upsample
|
| 165 |
+
self.downsample = downsample
|
| 166 |
+
|
| 167 |
+
self.norm = get_norm_fn(norm, C_in)
|
| 168 |
+
self.activ = activ()
|
| 169 |
+
if dropout > 0.0:
|
| 170 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 171 |
+
self.conv = nn.Conv2d(
|
| 172 |
+
C_in, C_out, kernel_size, stride, padding,
|
| 173 |
+
bias=bias, padding_mode=pad_type
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self.gate = Gate() if gate else None
|
| 177 |
+
|
| 178 |
+
def forward(self, x):
|
| 179 |
+
# pre-act
|
| 180 |
+
x = self.norm(x)
|
| 181 |
+
x = self.activ(x)
|
| 182 |
+
if self.upsample:
|
| 183 |
+
x = F.interpolate(x, scale_factor=2)
|
| 184 |
+
if hasattr(self, "dropout"):
|
| 185 |
+
x = self.dropout(x)
|
| 186 |
+
x = self.conv(x)
|
| 187 |
+
if self.downsample:
|
| 188 |
+
x = F.avg_pool2d(x, 2)
|
| 189 |
+
|
| 190 |
+
if self.gate is not None:
|
| 191 |
+
x = self.gate(x)
|
| 192 |
+
|
| 193 |
+
return x
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class ResConv(nn.Module):
|
| 197 |
+
"""Pre-activate residual block with single or double conv block"""
|
| 198 |
+
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
C_in,
|
| 202 |
+
C_out,
|
| 203 |
+
kernel_size=3,
|
| 204 |
+
stride=1,
|
| 205 |
+
padding=1,
|
| 206 |
+
norm="none",
|
| 207 |
+
activ="relu",
|
| 208 |
+
upsample=False,
|
| 209 |
+
pad_type="zeros",
|
| 210 |
+
dropout=0.0,
|
| 211 |
+
gate=True, # if True, use zero-init gate
|
| 212 |
+
double=False,
|
| 213 |
+
# norm2 and activ2 are only used when double is True
|
| 214 |
+
norm2=None, # if given, apply it to second conv
|
| 215 |
+
activ2=None # if given, apply it to second conv
|
| 216 |
+
):
|
| 217 |
+
super().__init__()
|
| 218 |
+
|
| 219 |
+
self.C_in = C_in
|
| 220 |
+
self.C_out = C_out
|
| 221 |
+
self.upsample = upsample
|
| 222 |
+
self.double = double
|
| 223 |
+
self.conv = ConvBlock(
|
| 224 |
+
C_in, C_out, kernel_size, stride, padding, norm, activ,
|
| 225 |
+
pad_type=pad_type, dropout=dropout, gate=gate,
|
| 226 |
+
)
|
| 227 |
+
if double:
|
| 228 |
+
norm2 = norm2 or norm
|
| 229 |
+
activ2 = activ2 or activ
|
| 230 |
+
self.conv2 = ConvBlock(
|
| 231 |
+
C_out, C_out, kernel_size, stride, padding, norm2, activ2,
|
| 232 |
+
pad_type=pad_type, dropout=dropout, gate=gate
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def forward(self, x):
|
| 236 |
+
if self.upsample:
|
| 237 |
+
x = F.interpolate(x, scale_factor=2)
|
| 238 |
+
x = x + self.conv(x)
|
| 239 |
+
|
| 240 |
+
if self.double:
|
| 241 |
+
x = x + self.conv2(x)
|
| 242 |
+
|
| 243 |
+
return x
|
hf_model/pamr.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 TU Darmstadt
|
| 2 |
+
# Licnese: Apache 2.0 License.
|
| 3 |
+
# https://github.com/visinf/1-stage-wseg/blob/master/models/mods/pamr.py
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from functools import partial
|
| 9 |
+
|
| 10 |
+
#
|
| 11 |
+
# Helper modules
|
| 12 |
+
#
|
| 13 |
+
class LocalAffinity(nn.Module):
|
| 14 |
+
|
| 15 |
+
def __init__(self, dilations=[1]):
|
| 16 |
+
super(LocalAffinity, self).__init__()
|
| 17 |
+
self.dilations = dilations
|
| 18 |
+
weight = self._init_aff()
|
| 19 |
+
self.register_buffer('kernel', weight)
|
| 20 |
+
|
| 21 |
+
def _init_aff(self):
|
| 22 |
+
# initialising the shift kernel
|
| 23 |
+
weight = torch.zeros(8, 1, 3, 3)
|
| 24 |
+
|
| 25 |
+
for i in range(weight.size(0)):
|
| 26 |
+
weight[i, 0, 1, 1] = 1
|
| 27 |
+
|
| 28 |
+
weight[0, 0, 0, 0] = -1
|
| 29 |
+
weight[1, 0, 0, 1] = -1
|
| 30 |
+
weight[2, 0, 0, 2] = -1
|
| 31 |
+
|
| 32 |
+
weight[3, 0, 1, 0] = -1
|
| 33 |
+
weight[4, 0, 1, 2] = -1
|
| 34 |
+
|
| 35 |
+
weight[5, 0, 2, 0] = -1
|
| 36 |
+
weight[6, 0, 2, 1] = -1
|
| 37 |
+
weight[7, 0, 2, 2] = -1
|
| 38 |
+
|
| 39 |
+
self.weight_check = weight.clone()
|
| 40 |
+
|
| 41 |
+
return weight
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
|
| 45 |
+
self.weight_check = self.weight_check.type_as(x)
|
| 46 |
+
assert torch.all(self.weight_check.eq(self.kernel))
|
| 47 |
+
|
| 48 |
+
B,K,H,W = x.size()
|
| 49 |
+
x = x.view(B*K,1,H,W)
|
| 50 |
+
|
| 51 |
+
x_affs = []
|
| 52 |
+
for d in self.dilations:
|
| 53 |
+
x_pad = F.pad(x, [d]*4, mode='replicate')
|
| 54 |
+
x_aff = F.conv2d(x_pad, self.kernel, dilation=d)
|
| 55 |
+
x_affs.append(x_aff)
|
| 56 |
+
|
| 57 |
+
x_aff = torch.cat(x_affs, 1)
|
| 58 |
+
return x_aff.view(B,K,-1,H,W)
|
| 59 |
+
|
| 60 |
+
class LocalAffinityCopy(LocalAffinity):
|
| 61 |
+
|
| 62 |
+
def _init_aff(self):
|
| 63 |
+
# initialising the shift kernel
|
| 64 |
+
weight = torch.zeros(8, 1, 3, 3)
|
| 65 |
+
|
| 66 |
+
weight[0, 0, 0, 0] = 1
|
| 67 |
+
weight[1, 0, 0, 1] = 1
|
| 68 |
+
weight[2, 0, 0, 2] = 1
|
| 69 |
+
|
| 70 |
+
weight[3, 0, 1, 0] = 1
|
| 71 |
+
weight[4, 0, 1, 2] = 1
|
| 72 |
+
|
| 73 |
+
weight[5, 0, 2, 0] = 1
|
| 74 |
+
weight[6, 0, 2, 1] = 1
|
| 75 |
+
weight[7, 0, 2, 2] = 1
|
| 76 |
+
|
| 77 |
+
self.weight_check = weight.clone()
|
| 78 |
+
return weight
|
| 79 |
+
|
| 80 |
+
class LocalStDev(LocalAffinity):
|
| 81 |
+
|
| 82 |
+
def _init_aff(self):
|
| 83 |
+
weight = torch.zeros(9, 1, 3, 3)
|
| 84 |
+
weight.zero_()
|
| 85 |
+
|
| 86 |
+
weight[0, 0, 0, 0] = 1
|
| 87 |
+
weight[1, 0, 0, 1] = 1
|
| 88 |
+
weight[2, 0, 0, 2] = 1
|
| 89 |
+
|
| 90 |
+
weight[3, 0, 1, 0] = 1
|
| 91 |
+
weight[4, 0, 1, 1] = 1
|
| 92 |
+
weight[5, 0, 1, 2] = 1
|
| 93 |
+
|
| 94 |
+
weight[6, 0, 2, 0] = 1
|
| 95 |
+
weight[7, 0, 2, 1] = 1
|
| 96 |
+
weight[8, 0, 2, 2] = 1
|
| 97 |
+
|
| 98 |
+
self.weight_check = weight.clone()
|
| 99 |
+
return weight
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
# returns (B,K,P,H,W), where P is the number
|
| 103 |
+
# of locations
|
| 104 |
+
x = super(LocalStDev, self).forward(x)
|
| 105 |
+
|
| 106 |
+
return x.std(2, keepdim=True)
|
| 107 |
+
|
| 108 |
+
class LocalAffinityAbs(LocalAffinity):
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
x = super(LocalAffinityAbs, self).forward(x)
|
| 112 |
+
return torch.abs(x)
|
| 113 |
+
|
| 114 |
+
#
|
| 115 |
+
# PAMR module
|
| 116 |
+
#
|
| 117 |
+
class PAMR(nn.Module):
|
| 118 |
+
|
| 119 |
+
def __init__(self, num_iter=1, dilations=[1]):
|
| 120 |
+
super(PAMR, self).__init__()
|
| 121 |
+
|
| 122 |
+
self.num_iter = num_iter
|
| 123 |
+
self.aff_x = LocalAffinityAbs(dilations)
|
| 124 |
+
self.aff_m = LocalAffinityCopy(dilations)
|
| 125 |
+
self.aff_std = LocalStDev(dilations)
|
| 126 |
+
|
| 127 |
+
def forward(self, x, mask):
|
| 128 |
+
mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True)
|
| 129 |
+
|
| 130 |
+
# x: [BxKxHxW]
|
| 131 |
+
# mask: [BxCxHxW]
|
| 132 |
+
B,K,H,W = x.size()
|
| 133 |
+
_,C,_,_ = mask.size()
|
| 134 |
+
|
| 135 |
+
x_std = self.aff_std(x)
|
| 136 |
+
|
| 137 |
+
x = -self.aff_x(x) / (1e-8 + 0.1 * x_std)
|
| 138 |
+
x = x.mean(1, keepdim=True)
|
| 139 |
+
x = F.softmax(x, 2)
|
| 140 |
+
|
| 141 |
+
for _ in range(self.num_iter):
|
| 142 |
+
m = self.aff_m(mask) # [BxCxPxHxW]
|
| 143 |
+
mask = (m * x).sum(2)
|
| 144 |
+
|
| 145 |
+
# xvals: [BxCxHxW]
|
| 146 |
+
return mask
|
hf_model/talk2dino.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
from math import sqrt
|
| 5 |
+
import re
|
| 6 |
+
import yaml
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import timm
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torchvision
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
from transformers import BertModel, AutoTokenizer
|
| 16 |
+
import torchvision.transforms as T
|
| 17 |
+
import clip
|
| 18 |
+
import importlib
|
| 19 |
+
import hf_model.us as us
|
| 20 |
+
|
| 21 |
+
from hf_model.pamr import PAMR
|
| 22 |
+
from hf_model.masker import DINOTextMasker
|
| 23 |
+
from hf_model.templates import get_template
|
| 24 |
+
|
| 25 |
+
from hf_model.model import ProjectionLayer, VisualProjectionLayer, CLIPLastLayer, DoubleMLP
|
| 26 |
+
from hf_model.hooks import average_text_tokens, get_vit_out, feats
|
| 27 |
+
|
| 28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DINOText(nn.Module):
|
| 33 |
+
|
| 34 |
+
def get_self_attention(self, module, input, output):
|
| 35 |
+
self.feats['self_attn'] = output
|
| 36 |
+
|
| 37 |
+
def get_clip_second_last_dense_out(self, model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 38 |
+
self.feats['clip_second_last_out'] = output
|
| 39 |
+
self.feats['clip_second_last_out'].to(dtype=torch.float32)
|
| 40 |
+
|
| 41 |
+
def get_all_out_tokens(self, model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
| 42 |
+
self.feats['clip_txt_out_tokens'] = output
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self, model_name, resize_dim, clip_model_name, proj_class, proj_name, proj_model, avg_self_attn_token=False, disentangled_self_attn_token=True, loss=None, pre_trained=True,
|
| 46 |
+
unfreeze_last_text_layer=False, unfreeze_last_image_layer=False, is_eval=True, use_avg_text_token=False, keep_cls=False, keep_end_seq=False, with_bg_clean=False, **kwargs
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.feats = {}
|
| 50 |
+
self.model_name = model_name
|
| 51 |
+
# loading the model
|
| 52 |
+
|
| 53 |
+
if 'dinov2' in model_name:
|
| 54 |
+
self.model_family = 'facebookresearch/dinov2' if 'dinov2' in model_name else 'facebookresearch/dino:main'
|
| 55 |
+
self.model = torch.hub.load(self.model_family, model_name)
|
| 56 |
+
elif 'dinov3' in model_name:
|
| 57 |
+
def extract_dinov3_name(path, n_parts=2):
|
| 58 |
+
filename = os.path.basename(path)
|
| 59 |
+
parts = filename.split("_")
|
| 60 |
+
return "_".join(parts[:n_parts])
|
| 61 |
+
self.model = torch.hub.load('src/dinov3', extract_dinov3_name(model_name), source='local', weights=model_name)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
elif 'mae' in model_name or 'sam' in model_name or 'clip' in model_name or 'dino' in model_name:
|
| 65 |
+
self.model = timm.create_model(
|
| 66 |
+
model_name,
|
| 67 |
+
pretrained=True,
|
| 68 |
+
num_classes=0, # remove classifier nn.Linear
|
| 69 |
+
img_size=resize_dim
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if 'sam' in model_name:
|
| 73 |
+
self.model.blocks[-1].register_forward_hook(get_vit_out)
|
| 74 |
+
else:
|
| 75 |
+
raise Exception("Unknown ViT model")
|
| 76 |
+
# self.model.eval()
|
| 77 |
+
mean = (0.485, 0.456, 0.406) if not 'clip' in model_name else (0.4815, 0.4578, 0.4082)
|
| 78 |
+
std = (0.229, 0.224, 0.225) if not 'clip' in model_name else (0.2686, 0.2613, 0.2758)
|
| 79 |
+
self.image_transforms = T.Compose([
|
| 80 |
+
T.Resize((resize_dim, resize_dim)),
|
| 81 |
+
lambda x: T.ToTensor()(x) if not isinstance(x, torch.Tensor) else x / 255.0, # ensure tensor
|
| 82 |
+
T.Normalize(mean, std),
|
| 83 |
+
])
|
| 84 |
+
|
| 85 |
+
self.model.to(device)
|
| 86 |
+
self.model.requires_grad_(False)
|
| 87 |
+
|
| 88 |
+
self.clip_model_name = clip_model_name
|
| 89 |
+
if 'bert' in self.clip_model_name:
|
| 90 |
+
self.clip_model = BertModel.from_pretrained(self.clip_model_name, output_hidden_states = False)
|
| 91 |
+
# load the corresponding wordtokenizer
|
| 92 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.clip_model_name)
|
| 93 |
+
else:
|
| 94 |
+
self.clip_model, _ = clip.load(clip_model_name, device=device)
|
| 95 |
+
self.clip_model.eval()
|
| 96 |
+
self.clip_model.requires_grad_(False)
|
| 97 |
+
if unfreeze_last_text_layer:
|
| 98 |
+
for param in self.clip_model.transformer.resblocks[-1].parameters():
|
| 99 |
+
param.requires_grad = True
|
| 100 |
+
for param in self.clip_model.ln_final.parameters():
|
| 101 |
+
param.requires_grad = True
|
| 102 |
+
self.clip_model.text_projection.requires_grad = True
|
| 103 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 104 |
+
|
| 105 |
+
# with open(os.path.join('configs', f"{proj_class}.yaml"), 'r') as config_file:
|
| 106 |
+
# config = yaml.safe_load(config_file)['model']
|
| 107 |
+
if 'vitb_mlp_infonce' in proj_class:
|
| 108 |
+
config = {
|
| 109 |
+
'act': 'tanh', # None, tanh, relu or sigmoid
|
| 110 |
+
'hidden_layer': True,
|
| 111 |
+
'dino_embed_dim': 768
|
| 112 |
+
}
|
| 113 |
+
elif 'vitl_mlp_infonce' in proj_class:
|
| 114 |
+
config = {
|
| 115 |
+
'act': 'tanh', # None, tanh, relu or sigmoid
|
| 116 |
+
'hidden_layer': True,
|
| 117 |
+
'dino_embed_dim': 1024
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
self.proj = ProjectionLayer.from_config(config)
|
| 121 |
+
if type(self.proj) == CLIPLastLayer:
|
| 122 |
+
self.clip_model.transformer.resblocks[-2].register_forward_hook(self.get_clip_second_last_dense_out)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# if pre_trained:
|
| 126 |
+
# self.proj.load_state_dict(torch.load(os.path.join("weights", f"{proj_name}.pth"), 'cpu'))
|
| 127 |
+
self.proj.to(device)
|
| 128 |
+
|
| 129 |
+
self.masker = DINOTextMasker(similarity_type="cosine")
|
| 130 |
+
self.masker = self.masker.eval()
|
| 131 |
+
|
| 132 |
+
self.pamr = None
|
| 133 |
+
|
| 134 |
+
self.avg_self_attn_token = avg_self_attn_token
|
| 135 |
+
self.disentangled_self_attn_token = disentangled_self_attn_token
|
| 136 |
+
|
| 137 |
+
if self.avg_self_attn_token or self.disentangled_self_attn_token or is_eval:
|
| 138 |
+
self.model.blocks[-1].attn.qkv.register_forward_hook(self.get_self_attention)
|
| 139 |
+
self.num_global_tokens = 5 if 'reg' in model_name or 'dinov3' in model_name else 1
|
| 140 |
+
if 'sam' in self.model_name:
|
| 141 |
+
self.num_global_tokens = 0
|
| 142 |
+
self.num_attn_heads = self.model.num_heads
|
| 143 |
+
self.scale = 0.125
|
| 144 |
+
|
| 145 |
+
self.use_avg_text_token = use_avg_text_token
|
| 146 |
+
if self.use_avg_text_token:
|
| 147 |
+
self.feats = {}
|
| 148 |
+
# in this case we register a forward hook with the aim of getting all the tokens and not only the cls
|
| 149 |
+
self.clip_model.ln_final.register_forward_hook(self.get_all_out_tokens)
|
| 150 |
+
self.keep_cls = keep_cls
|
| 151 |
+
self.keep_end_seq = keep_end_seq
|
| 152 |
+
|
| 153 |
+
self.with_bg_clean = with_bg_clean
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def process_self_attention(self, output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False):
|
| 157 |
+
qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4)
|
| 158 |
+
q, k, v = qkv[0] * scale, qkv[1], qkv[2]
|
| 159 |
+
attn = q @ k.transpose(-2, -1)
|
| 160 |
+
self_attn_maps = attn[:, : , 0, num_global_tokens:]
|
| 161 |
+
self_attn = self_attn_maps.mean(dim=1)
|
| 162 |
+
self_attn = self_attn.softmax(dim=-1)
|
| 163 |
+
if ret_self_attn_maps:
|
| 164 |
+
return self_attn, self_attn_maps
|
| 165 |
+
else:
|
| 166 |
+
return self_attn
|
| 167 |
+
|
| 168 |
+
def encode_text(self, tokenized_texts):
|
| 169 |
+
if type(self.proj) == CLIPLastLayer:
|
| 170 |
+
self.clip_model.encode_text(tokenized_texts)
|
| 171 |
+
x = self.feats['clip_second_last_out']
|
| 172 |
+
x = x.to(dtype=torch.float32)
|
| 173 |
+
else:
|
| 174 |
+
x = self.clip_model.encode_text(tokenized_texts)
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
def encode_image(self, images):
|
| 178 |
+
batch_size, _, _, _ = images.shape
|
| 179 |
+
self_attn_maps = None
|
| 180 |
+
x = self.model(images, is_training=(self.avg_self_attn_token or self.disentangled_self_attn_token))
|
| 181 |
+
batch_size, num_tokens, embed_dim = x['x_norm_patchtokens'].shape
|
| 182 |
+
num_tokens = num_tokens + self.num_global_tokens
|
| 183 |
+
if self.avg_self_attn_token or self.disentangled_self_attn_token:
|
| 184 |
+
self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True)
|
| 185 |
+
if self.avg_self_attn_token:
|
| 186 |
+
x = (self_attn.unsqueeze(-1) * x['x_norm_patchtokens']).mean(dim=1)
|
| 187 |
+
elif self.disentangled_self_attn_token:
|
| 188 |
+
self_attn_maps = self_attn_maps.softmax(dim=-1)
|
| 189 |
+
x = (x['x_norm_patchtokens'].unsqueeze(1) * self_attn_maps.unsqueeze(-1)).mean(dim=2)
|
| 190 |
+
|
| 191 |
+
return x, self_attn_maps
|
| 192 |
+
|
| 193 |
+
def forward(self, image, text, return_logit_scale=False):
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
txt_embed = self.encode_text(text)
|
| 196 |
+
|
| 197 |
+
img_embed, self_attn_maps = self.encode_image(image)
|
| 198 |
+
|
| 199 |
+
if type(self.proj) == CLIPLastLayer:
|
| 200 |
+
img_embed, txt_embed = self.proj(img_embed, txt_embed, ret_embeds=True, self_attn_maps=self_attn_maps, text_argmax=text.argmax(dim=-1))
|
| 201 |
+
else:
|
| 202 |
+
img_embed, txt_embed = self.proj(img_embed, txt_embed, ret_embeds=True, self_attn_maps=self_attn_maps)
|
| 203 |
+
|
| 204 |
+
if return_logit_scale:
|
| 205 |
+
return txt_embed, img_embed, self.logit_scale
|
| 206 |
+
|
| 207 |
+
return txt_embed, img_embed
|
| 208 |
+
|
| 209 |
+
def compute_loss(self, image, text, cosine=True, ret_similarity_matrix=True):
|
| 210 |
+
ret = {}
|
| 211 |
+
if cosine:
|
| 212 |
+
img_embed = F.normalize(img_embed, p=2, dim=1)
|
| 213 |
+
txt_embed = F.normalize(txt_embed, p=2, dim=1)
|
| 214 |
+
sim = img_embed @ txt_embed.transpose(1, 0)
|
| 215 |
+
if not ret_similarity_matrix:
|
| 216 |
+
sim = sim[torch.eye(len(sim)) > 0.5] # only diagonal elements
|
| 217 |
+
|
| 218 |
+
ret['contrastive_loss'] = self.contrastive_loss.compute_contrastive_loss(sim)
|
| 219 |
+
|
| 220 |
+
return ret
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@torch.no_grad()
|
| 224 |
+
def build_dataset_class_tokens(self, template_set, classnames):
|
| 225 |
+
tokens = []
|
| 226 |
+
templates = get_template(template_set)
|
| 227 |
+
for classname in classnames:
|
| 228 |
+
if 'bert' not in self.clip_model_name:
|
| 229 |
+
tokens.append(
|
| 230 |
+
clip.tokenize([template.format(classname) for template in templates])
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
tokens.append(self.tokenizer([template.format(classname) for template in templates], return_tensors='pt', padding='max_length')['input_ids'])
|
| 234 |
+
# [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
|
| 235 |
+
tokens = torch.stack(tokens)
|
| 236 |
+
|
| 237 |
+
return tokens
|
| 238 |
+
|
| 239 |
+
@torch.no_grad()
|
| 240 |
+
def build_text_embedding(self, text):
|
| 241 |
+
"""
|
| 242 |
+
Args:
|
| 243 |
+
text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH] text tokens
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
text_embs
|
| 247 |
+
"""
|
| 248 |
+
text = text.to(next(self.parameters()).device)
|
| 249 |
+
num_classes, num_templates = text.shape[:2]
|
| 250 |
+
text_argmax = text.argmax(dim=-1)
|
| 251 |
+
text_argmax = rearrange(text_argmax, 'n t -> (n t)', n=num_classes, t=num_templates)
|
| 252 |
+
text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates)
|
| 253 |
+
# chunked inference for memory limitation
|
| 254 |
+
chunk_size = 32
|
| 255 |
+
N = text.size(0)
|
| 256 |
+
if type(self.proj) == CLIPLastLayer:
|
| 257 |
+
text_embs = torch.cat([
|
| 258 |
+
self.proj.project_clip_txt(self.encode_text(text[i:i + chunk_size]).permute(1, 0, 2), text_argmax=text_argmax[i:i + chunk_size])
|
| 259 |
+
for i in range(0, N, chunk_size)
|
| 260 |
+
])
|
| 261 |
+
else:
|
| 262 |
+
if not self.use_avg_text_token:
|
| 263 |
+
# performing classification using CLS textual token
|
| 264 |
+
if 'bert' not in self.clip_model_name:
|
| 265 |
+
text_embs = torch.cat([
|
| 266 |
+
self.clip_model.encode_text(text[i:i + chunk_size])
|
| 267 |
+
for i in range(0, N, chunk_size)
|
| 268 |
+
])
|
| 269 |
+
else:
|
| 270 |
+
# encoding with BERT
|
| 271 |
+
text_embs = []
|
| 272 |
+
for i in range(0, N, chunk_size):
|
| 273 |
+
outputs = self.clip_model(text[i:i + chunk_size])
|
| 274 |
+
text_embs.append(outputs['pooler_output'])
|
| 275 |
+
text_embs = torch.cat(text_embs)
|
| 276 |
+
else:
|
| 277 |
+
# using text token average
|
| 278 |
+
text_embs = []
|
| 279 |
+
for i in range(0, N, chunk_size):
|
| 280 |
+
self.clip_model.encode_text(text[i:i + chunk_size])
|
| 281 |
+
text_embs.append(average_text_tokens(self.feats['clip_txt_out_tokens'] @ self.clip_model.text_projection, text[i:i + chunk_size] > 0, self.keep_cls, self.keep_end_seq))
|
| 282 |
+
text_embs = torch.cat(text_embs)
|
| 283 |
+
# [N, T, C]
|
| 284 |
+
text_embs = rearrange(text_embs, '(n t) c -> n t c', n=num_classes, t=num_templates)
|
| 285 |
+
# [N, C]
|
| 286 |
+
text_embs = text_embs.mean(dim=1).float()
|
| 287 |
+
if type(self.proj) == ProjectionLayer or type(self.proj) == DoubleMLP:
|
| 288 |
+
text_embs = self.proj.project_clip_txt(text_embs)
|
| 289 |
+
text_embs = us.normalize(text_embs, dim=-1)
|
| 290 |
+
|
| 291 |
+
return text_embs
|
| 292 |
+
|
| 293 |
+
def apply_pamr(self, image, mask):
|
| 294 |
+
image = F.interpolate(image, mask.shape[-2:], mode="bilinear", align_corners=True)
|
| 295 |
+
if self.pamr is None:
|
| 296 |
+
pamr_iter = 10
|
| 297 |
+
pamr_kernel = [1, 2, 4, 8, 12, 24]
|
| 298 |
+
self.pamr = PAMR(pamr_iter, pamr_kernel)
|
| 299 |
+
self.pamr.eval()
|
| 300 |
+
self.pamr.to(next(self.parameters()).device)
|
| 301 |
+
|
| 302 |
+
mask = self.pamr(image, mask)
|
| 303 |
+
return mask
|
| 304 |
+
|
| 305 |
+
def compute_padsize(self, H: int, W: int, patch_size: int):
|
| 306 |
+
l, r, t, b = 0, 0, 0, 0
|
| 307 |
+
if W % patch_size:
|
| 308 |
+
lr = patch_size - (W % patch_size)
|
| 309 |
+
l = lr // 2
|
| 310 |
+
r = lr - l
|
| 311 |
+
|
| 312 |
+
if H % patch_size:
|
| 313 |
+
tb = patch_size - (H % patch_size)
|
| 314 |
+
t = tb // 2
|
| 315 |
+
b = tb - t
|
| 316 |
+
|
| 317 |
+
return l, r, t, b
|
| 318 |
+
|
| 319 |
+
@torch.no_grad()
|
| 320 |
+
def generate_masks(
|
| 321 |
+
self, image, img_metas, text_emb, classnames, text_is_token=False, apply_pamr=False, background_func="weighted_average_sigmoid", lambda_bg=0.2,
|
| 322 |
+
# kp_w=0.3,
|
| 323 |
+
):
|
| 324 |
+
"""Generate masks for each text embeddings
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
image [B, 3, H, W]
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
softmask [B, N, H, W]: softmasks for each text embeddings
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
H, W = image.shape[2:] # original image shape
|
| 334 |
+
|
| 335 |
+
# padded image size
|
| 336 |
+
pH, pW = image.shape[2:]
|
| 337 |
+
num_classes = text_emb.shape[0]
|
| 338 |
+
batch_size = image.shape[0]
|
| 339 |
+
|
| 340 |
+
image = image[:, [2, 1, 0], :, :] # BGR to RGB
|
| 341 |
+
ori_image = image.clone()
|
| 342 |
+
|
| 343 |
+
img_preprocessed = self.image_transforms(image).to(next(self.parameters()).device)
|
| 344 |
+
if 'dinov2' in self.model_name or 'dinov3' in self.model_name:
|
| 345 |
+
image_feat = self.model.forward_features(img_preprocessed)['x_norm_patchtokens']
|
| 346 |
+
elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name:
|
| 347 |
+
image_feat = self.model.forward_features(img_preprocessed)[:, 1:, :]
|
| 348 |
+
elif 'sam' in self.model_name:
|
| 349 |
+
self.model.forward_features(img_preprocessed)
|
| 350 |
+
image_feat = feats['vit_out'].reshape(feats['vit_out'].shape[0], feats['vit_out'].shape[1]**2, feats['vit_out'].shape[-1]) # BS x N_PATCHES x EMBED_DIM
|
| 351 |
+
|
| 352 |
+
batch_size, num_tokens, embed_dim = image_feat.shape
|
| 353 |
+
if type(self.proj) == VisualProjectionLayer:
|
| 354 |
+
image_feat = self.proj.project_dino(image_feat.float())
|
| 355 |
+
if type(self.proj) == DoubleMLP:
|
| 356 |
+
image_feat = self.proj.project_visual(image_feat.float())
|
| 357 |
+
b, np, c = image_feat.shape
|
| 358 |
+
np_h = np_w = int(sqrt(np))
|
| 359 |
+
image_feat = image_feat.reshape(b, np_h, np_w, c).permute(0, 3, 1, 2)
|
| 360 |
+
|
| 361 |
+
self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens + self.num_global_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True)
|
| 362 |
+
mask, simmap = self.masker.forward_seg(image_feat, text_emb, hard=False) # [B, N, H', W']
|
| 363 |
+
|
| 364 |
+
if self.with_bg_clean:
|
| 365 |
+
mask = self.similarity_assignment_weighted(mask, image_feat, self_attn_maps, text_emb, lambda_bg)
|
| 366 |
+
|
| 367 |
+
# resize
|
| 368 |
+
mask = F.interpolate(mask, (pH, pW), mode='bilinear', align_corners=True) # [B, N, H, W]
|
| 369 |
+
|
| 370 |
+
if apply_pamr:
|
| 371 |
+
for c in range(0, mask.shape[1], 30):
|
| 372 |
+
mask[:, c:c + 30] = self.apply_pamr(ori_image, mask[:, c:c + 30])
|
| 373 |
+
|
| 374 |
+
assert mask.shape[2] == H and mask.shape[3] == W, f"shape mismatch: ({H}, {W}) / {mask.shape}"
|
| 375 |
+
|
| 376 |
+
return mask, simmap
|
| 377 |
+
|
| 378 |
+
def similarity_assignment_weighted(self, mask, image_feat, self_attn_maps, text_emb, lambda_bg=0.2):
|
| 379 |
+
bs, c, h, w = image_feat.shape
|
| 380 |
+
bs, num_classes, h, w = mask.shape
|
| 381 |
+
bs, num_heads, hw = self_attn_maps.shape
|
| 382 |
+
image_feat = image_feat.reshape(bs, c, hw)
|
| 383 |
+
num_classes, c = text_emb.shape
|
| 384 |
+
avg_head_embed = (self_attn_maps.unsqueeze(2) * image_feat.unsqueeze(1)).mean(dim=-1)
|
| 385 |
+
avg_head_embed = avg_head_embed / avg_head_embed.norm(dim=-1, keepdim=True)
|
| 386 |
+
avg_head_embed = avg_head_embed.permute(0, 2, 1) # [B, C, M]
|
| 387 |
+
head_text_sim = text_emb.unsqueeze(0) @ avg_head_embed # [B, M, N]
|
| 388 |
+
head_text_sim = (head_text_sim).softmax(dim=-1)
|
| 389 |
+
head_text_sim_sum = head_text_sim.sum(dim=-1)
|
| 390 |
+
|
| 391 |
+
self_attn_maps_repeat = self_attn_maps.unsqueeze(1).repeat(1, num_classes, 1, 1)
|
| 392 |
+
head_text_sim_repeat = head_text_sim.unsqueeze(-1).repeat(1, 1, 1, hw)
|
| 393 |
+
avg_self_attn_per_class = (self_attn_maps_repeat * head_text_sim_repeat).sum(dim=2) / head_text_sim_sum.unsqueeze(-1).repeat(1, 1, hw)
|
| 394 |
+
avg_self_attn_per_class = avg_self_attn_per_class.softmax(dim=-1)
|
| 395 |
+
|
| 396 |
+
min_self_attn = avg_self_attn_per_class.min().item()
|
| 397 |
+
max_self_attn = avg_self_attn_per_class.max().item()
|
| 398 |
+
max_self_attn = max(max_self_attn, max_self_attn - min_self_attn)
|
| 399 |
+
avg_self_attn_per_class = avg_self_attn_per_class - min_self_attn
|
| 400 |
+
avg_self_attn_per_class = avg_self_attn_per_class / max_self_attn
|
| 401 |
+
avg_self_attn_per_class = avg_self_attn_per_class * (mask.max() - mask.min()) + mask.min()
|
| 402 |
+
mask = mask.reshape(num_classes, hw) # [N, P]
|
| 403 |
+
mask_output = (mask + lambda_bg * avg_self_attn_per_class).reshape(bs, num_classes, h, w) / (1 + lambda_bg)
|
| 404 |
+
return mask_output
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 408 |
+
|
| 409 |
+
class Talk2DINO(DINOText, PyTorchModelHubMixin):
|
| 410 |
+
def encode_text(self, texts):
|
| 411 |
+
""" texts: string or list of strings
|
| 412 |
+
returns: text embeddings (N, D) where N is the number of texts, D is the embedding dimension
|
| 413 |
+
"""
|
| 414 |
+
text_tokens = clip.tokenize(texts).to(self.parameters().__next__().device)
|
| 415 |
+
txt_embed = self.clip_model.encode_text(text_tokens)
|
| 416 |
+
txt_embed = self.proj.project_clip_txt(txt_embed)
|
| 417 |
+
return txt_embed
|
| 418 |
+
|
| 419 |
+
def encode_image(self, images):
|
| 420 |
+
""" images: PIL image or list of PIL images
|
| 421 |
+
returns: image embeddings (N, L, D) where N is the number of images, L is the number of patches, D is the embedding dimension
|
| 422 |
+
"""
|
| 423 |
+
if type(images) is not list:
|
| 424 |
+
images = [images]
|
| 425 |
+
img_preprocessed = [self.image_transforms(img).to(next(self.parameters()).device) for img in images]
|
| 426 |
+
img_preprocessed = torch.stack(img_preprocessed)
|
| 427 |
+
if 'dinov2' in self.model_name or 'dinov3' in self.model_name:
|
| 428 |
+
img_embed = self.model.forward_features(img_preprocessed)['x_norm_patchtokens']
|
| 429 |
+
elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name:
|
| 430 |
+
img_embed = self.model.forward_features(img_preprocessed)[:, 1:, :]
|
| 431 |
+
|
| 432 |
+
return img_embed
|
hf_model/templates.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# FreeDA
|
| 3 |
+
# ------------------------------------------------------------------------------
|
| 4 |
+
# Modified from CLIP (https://github.com/openai/CLIP)
|
| 5 |
+
# Copyright (c) 2021 OpenAI. All Rights Reserved.
|
| 6 |
+
# ------------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
full_imagenet_templates = [
|
| 9 |
+
"a bad photo of a {}.",
|
| 10 |
+
"a photo of many {}.",
|
| 11 |
+
"a sculpture of a {}.",
|
| 12 |
+
"a photo of the hard to see {}.",
|
| 13 |
+
"a low resolution photo of the {}.",
|
| 14 |
+
"a rendering of a {}.",
|
| 15 |
+
"graffiti of a {}.",
|
| 16 |
+
"a bad photo of the {}.",
|
| 17 |
+
"a cropped photo of the {}.",
|
| 18 |
+
"a tattoo of a {}.",
|
| 19 |
+
"the embroidered {}.",
|
| 20 |
+
"a photo of a hard to see {}.",
|
| 21 |
+
"a bright photo of a {}.",
|
| 22 |
+
"a photo of a clean {}.",
|
| 23 |
+
"a photo of a dirty {}.",
|
| 24 |
+
"a dark photo of the {}.",
|
| 25 |
+
"a drawing of a {}.",
|
| 26 |
+
"a photo of my {}.",
|
| 27 |
+
"the plastic {}.",
|
| 28 |
+
"a photo of the cool {}.",
|
| 29 |
+
"a close-up photo of a {}.",
|
| 30 |
+
"a black and white photo of the {}.",
|
| 31 |
+
"a painting of the {}.",
|
| 32 |
+
"a painting of a {}.",
|
| 33 |
+
"a pixelated photo of the {}.",
|
| 34 |
+
"a sculpture of the {}.",
|
| 35 |
+
"a bright photo of the {}.",
|
| 36 |
+
"a cropped photo of a {}.",
|
| 37 |
+
"a plastic {}.",
|
| 38 |
+
"a photo of the dirty {}.",
|
| 39 |
+
"a jpeg corrupted photo of a {}.",
|
| 40 |
+
"a blurry photo of the {}.",
|
| 41 |
+
"a photo of the {}.",
|
| 42 |
+
"a good photo of the {}.",
|
| 43 |
+
"a rendering of the {}.",
|
| 44 |
+
"a {} in a video game.",
|
| 45 |
+
"a photo of one {}.",
|
| 46 |
+
"a doodle of a {}.",
|
| 47 |
+
"a close-up photo of the {}.",
|
| 48 |
+
"a photo of a {}.",
|
| 49 |
+
"the origami {}.",
|
| 50 |
+
"the {} in a video game.",
|
| 51 |
+
"a sketch of a {}.",
|
| 52 |
+
"a doodle of the {}.",
|
| 53 |
+
"a origami {}.",
|
| 54 |
+
"a low resolution photo of a {}.",
|
| 55 |
+
"the toy {}.",
|
| 56 |
+
"a rendition of the {}.",
|
| 57 |
+
"a photo of the clean {}.",
|
| 58 |
+
"a photo of a large {}.",
|
| 59 |
+
"a rendition of a {}.",
|
| 60 |
+
"a photo of a nice {}.",
|
| 61 |
+
"a photo of a weird {}.",
|
| 62 |
+
"a blurry photo of a {}.",
|
| 63 |
+
"a cartoon {}.",
|
| 64 |
+
"art of a {}.",
|
| 65 |
+
"a sketch of the {}.",
|
| 66 |
+
"a embroidered {}.",
|
| 67 |
+
"a pixelated photo of a {}.",
|
| 68 |
+
"itap of the {}.",
|
| 69 |
+
"a jpeg corrupted photo of the {}.",
|
| 70 |
+
"a good photo of a {}.",
|
| 71 |
+
"a plushie {}.",
|
| 72 |
+
"a photo of the nice {}.",
|
| 73 |
+
"a photo of the small {}.",
|
| 74 |
+
"a photo of the weird {}.",
|
| 75 |
+
"the cartoon {}.",
|
| 76 |
+
"art of the {}.",
|
| 77 |
+
"a drawing of the {}.",
|
| 78 |
+
"a photo of the large {}.",
|
| 79 |
+
"a black and white photo of a {}.",
|
| 80 |
+
"the plushie {}.",
|
| 81 |
+
"a dark photo of a {}.",
|
| 82 |
+
"itap of a {}.",
|
| 83 |
+
"graffiti of the {}.",
|
| 84 |
+
"a toy {}.",
|
| 85 |
+
"itap of my {}.",
|
| 86 |
+
"a photo of a cool {}.",
|
| 87 |
+
"a photo of a small {}.",
|
| 88 |
+
"a tattoo of the {}.",
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
maskclip_templates = [
|
| 92 |
+
"there is a {} in the scene.",
|
| 93 |
+
"there is the {} in the scene.",
|
| 94 |
+
"this is a {} in the scene.",
|
| 95 |
+
"this is the {} in the scene.",
|
| 96 |
+
"this is one {} in the scene.", # maskclip
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
sub_imagenet_template = [
|
| 100 |
+
"itap of a {}.",
|
| 101 |
+
"a bad photo of a {}.",
|
| 102 |
+
"a origami {}.",
|
| 103 |
+
"a photo of the large {}.",
|
| 104 |
+
"a {} in a video game.",
|
| 105 |
+
"art of the {}.",
|
| 106 |
+
"a photo of the small {}.",
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
simple_imagenet_template = [
|
| 110 |
+
"a photo of a {}.",
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
plural_template = [
|
| 114 |
+
"a photo of {}s.",
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
identity_template = [
|
| 118 |
+
"{}",
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
template_meta = {
|
| 122 |
+
"full": full_imagenet_templates,
|
| 123 |
+
"full+maskclip": full_imagenet_templates + maskclip_templates, # templates used in maskclip paper
|
| 124 |
+
"subset": sub_imagenet_template,
|
| 125 |
+
"subset+maskclip": sub_imagenet_template + maskclip_templates,
|
| 126 |
+
"maskclip": maskclip_templates,
|
| 127 |
+
"simple": simple_imagenet_template,
|
| 128 |
+
"plural": plural_template,
|
| 129 |
+
"identity": identity_template,
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_template(key):
|
| 134 |
+
if key in template_meta:
|
| 135 |
+
return template_meta[key]
|
| 136 |
+
|
| 137 |
+
gdic = globals()
|
| 138 |
+
if key in gdic:
|
| 139 |
+
return gdic[key]
|
| 140 |
+
|
| 141 |
+
raise ValueError(key)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# custom template boosts performance a little.
|
| 145 |
+
custom = sub_imagenet_template + [
|
| 146 |
+
"a photo of many {}.",
|
| 147 |
+
"a photo of {}s.",
|
| 148 |
+
]
|
hf_model/us.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# FreeDA
|
| 3 |
+
# ------------------------------------------------------------------------------
|
| 4 |
+
from typing import Dict, List, Any
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from itertools import chain
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
# ImageNet mean/std (from timm)
|
| 15 |
+
|
| 16 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| 17 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
| 18 |
+
|
| 19 |
+
DEFAULT_MEAN = IMAGENET_DEFAULT_MEAN
|
| 20 |
+
DEFAULT_STD = IMAGENET_DEFAULT_STD
|
| 21 |
+
|
| 22 |
+
# NOTE Originally CLIP statistics should be used, but the legacy of ImageNet statistics
|
| 23 |
+
# from GroupViT is applied. Fortunately, CLIP is quite robust to slightly different
|
| 24 |
+
# normalization constants (https://github.com/openai/CLIP/issues/20#issuecomment-764985771).
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def unnorm(x):
|
| 28 |
+
mean = torch.as_tensor(DEFAULT_MEAN, device=x.device)[None, ..., None, None]
|
| 29 |
+
std = torch.as_tensor(DEFAULT_STD, device=x.device)[None, ..., None, None]
|
| 30 |
+
return x.mul(std).add(mean)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# DEBUG NaN
|
| 34 |
+
def check_nonfinite(x, name=""):
|
| 35 |
+
rank = dist.get_rank()
|
| 36 |
+
n_nan = x.isnan().sum()
|
| 37 |
+
n_inf = x.isinf().sum()
|
| 38 |
+
if n_nan or n_inf:
|
| 39 |
+
print(f"[RANK {rank}] {name} is not finite: #nan={n_nan}, #inf={n_inf}")
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
print(f"[RANK {rank}] {name} is OK ...")
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def normalize(t, dim, eps=1e-6):
|
| 47 |
+
"""Large default eps for fp16"""
|
| 48 |
+
return F.normalize(t, dim=dim, eps=eps)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def timestamp(fmt="%y%m%d-%H%M%S"):
|
| 52 |
+
return datetime.now().strftime(fmt)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def merge_dicts_by_key(dics: List[Dict]) -> Dict[Any, List]:
|
| 56 |
+
"""Merge dictionaries by key. All of dicts must have same keys."""
|
| 57 |
+
ret = {key: [] for key in dics[0].keys()}
|
| 58 |
+
for dic in dics:
|
| 59 |
+
for key, value in dic.items():
|
| 60 |
+
ret[key].append(value)
|
| 61 |
+
|
| 62 |
+
return ret
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def flatten_2d_list(list2d):
|
| 66 |
+
return list(chain.from_iterable(list2d))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def num_params(module):
|
| 70 |
+
return sum(p.numel() for p in module.parameters())
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def param_trace(name, module, depth=0, max_depth=999, threshold=0, printf=print):
|
| 74 |
+
if depth > max_depth:
|
| 75 |
+
return
|
| 76 |
+
prefix = " " * depth
|
| 77 |
+
n_params = num_params(module)
|
| 78 |
+
if n_params > threshold:
|
| 79 |
+
printf("{:60s}\t{:10.3f}M".format(prefix + name, n_params / 1024 / 1024))
|
| 80 |
+
for n, m in module.named_children():
|
| 81 |
+
if depth == 0:
|
| 82 |
+
child_name = n
|
| 83 |
+
else:
|
| 84 |
+
child_name = "{}.{}".format(name, n)
|
| 85 |
+
param_trace(child_name, m, depth + 1, max_depth, threshold, printf)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@torch.no_grad()
|
| 89 |
+
def hash_bn(module):
|
| 90 |
+
summary = []
|
| 91 |
+
for m in module.modules():
|
| 92 |
+
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
|
| 93 |
+
w = m.weight.detach().mean().item()
|
| 94 |
+
b = m.bias.detach().mean().item()
|
| 95 |
+
rm = m.running_mean.detach().mean().item()
|
| 96 |
+
rv = m.running_var.detach().mean().item()
|
| 97 |
+
summary.append((w, b, rm, rv))
|
| 98 |
+
|
| 99 |
+
if not summary:
|
| 100 |
+
return 0.0, 0.0
|
| 101 |
+
|
| 102 |
+
w, b, rm, rv = [np.mean(col) for col in zip(*summary)]
|
| 103 |
+
p = np.mean([w, b])
|
| 104 |
+
s = np.mean([rm, rv])
|
| 105 |
+
|
| 106 |
+
return p, s
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@torch.no_grad()
|
| 110 |
+
def hash_params(module):
|
| 111 |
+
return torch.as_tensor([p.mean() for p in module.parameters()]).mean().item()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@torch.no_grad()
|
| 115 |
+
def hashm(module):
|
| 116 |
+
p = hash_params(module)
|
| 117 |
+
_, s = hash_bn(module)
|
| 118 |
+
|
| 119 |
+
return p, s
|