lorebianchi98 commited on
Commit
e7d7e74
·
1 Parent(s): 367e473

Removed MMCV dependency

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -31,7 +31,7 @@ Talking to DINO: Bridging Self-Supervised Vision Backbones with Language for Ope
31
 
32
  <div align="center">
33
  <figure>
34
- <img alt="Overview of Talk2DINO" src="./assets/overview.png" width="40%">
35
  </figure>
36
  </div>
37
 
@@ -43,75 +43,58 @@ Open-Vocabulary Segmentation (OVS) aims at segmenting images from free-form text
43
  ### Mapping CLIP Text Embeddings to DINOv2 space with Talk2DINO
44
  We can use Talk2DINO to map CLIP text embeddings into the DINOv2 patch embedding space.
45
  ```python
46
- import clip
47
- from src.model import ProjectionLayer
48
- import torch
49
- import os
50
  # Device setup
51
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
52
- # Configuration and weights
53
- proj_name = 'vitb_mlp_infonce'
54
- config_path = os.path.join("configs", f"{proj_name}.yaml")
55
- weights_path = os.path.join("weights", f"{proj_name}.pth")
56
- # Load Talk2DINO projection layer
57
- talk2dino = ProjectionLayer.from_config(config_path)
58
- talk2dino.load_state_dict(torch.load(weights_path, map_location=device))
59
- talk2dino.to(device)
60
- # Load CLIP model
61
- clip_model, clip_preprocess = clip.load("ViT-L/14", device=device, jit=False)
62
- tokenizer = clip.tokenize
63
- # Example: Tokenize and project text features
64
- texts = ["a cat"]
65
- text_tokens = tokenizer(texts).to(device)
66
- text_features = clip_model.encode_text(text_tokens)
67
- projected_text_features = talk2dino.project_clip_txt(text_features)
68
- ```
69
 
70
- ### Demo
71
- In `demo.py` we provide a simple example on how to use Talk2DINO for inference on a given image with custom textual categories. Run
72
 
73
- ```bash
74
- python demo.py --input custom_input_image --output custom_output_seg [--with_background] --textual_categories category_1,category_2,..
75
- ```
 
76
 
77
- Example:
78
- ```bash
79
- python demo.py --input assets/pikachu.png --output pikachu_seg.png --textual_categories pikachu,traffic_sign,forest,route
 
 
80
  ```
81
 
 
 
82
  Result:
83
  <div align="center">
84
  <table><tr><td><figure>
85
  <img alt="" src="./assets/pikachu.png" width=300>
86
  </figure></td><td><figure>
87
- <img alt="" src="./pikachu_seg.png" width=300>
88
  </figure></td></tr></table>
89
  </div>
90
 
91
  ## Installation
 
 
 
92
  ```bash
93
- # Create a new environment with Python 3.10
94
- conda create --name talk2dino python=3.10 -c conda-forge
95
- conda activate talk2dino
96
- # Install compilers for C++/CUDA extensions
97
- conda install -c conda-forge "gxx_linux-64=11.*" "gcc_linux-64=11.*"
98
- # Install CUDA toolkit and cuDNN
99
- conda install -c nvidia/label/cuda-11.7.0 cuda
100
- conda install -c nvidia/label/cuda-11.7.0 cuda-nvcc
101
- conda install -c conda-forge cudnn cudatoolkit=11.7.0
102
- # Install PyTorch 2.1 with CUDA 11.8 support
103
- # Note: This is crucial, as it matches the requirements of mmcv-full 1.7.2
104
- pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
105
- # Install other dependencies
106
  pip install -r requirements.txt
107
- pip install -U openmim
108
- mim install mmengine
109
- # Install a compatible version of mmcv-full (1.7.2) for PyTorch 2.1
110
- pip install mmcv-full==1.7.2 -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.1.0/index.html
111
- # Install mmsegmentation
112
- pip install mmsegmentation==0.30.0
113
  ```
114
 
 
 
 
 
115
  <details>
116
  <summary>Qualitative Results</summary>
117
 
 
31
 
32
  <div align="center">
33
  <figure>
34
+ <img alt="Overview of Talk2DINO" src="./assets/overview.png" width="90%">
35
  </figure>
36
  </div>
37
 
 
43
  ### Mapping CLIP Text Embeddings to DINOv2 space with Talk2DINO
44
  We can use Talk2DINO to map CLIP text embeddings into the DINOv2 patch embedding space.
45
  ```python
46
+ from hf_model.talk2dino import Talk2DINO
47
+ from torchvision.io import read_image
48
+
 
49
  # Device setup
50
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Model Loading
53
+ model = Talk2DINO.from_pretrained("lorebianchi98/Talk2DINO-ViTL").to(device).eval()
54
 
55
+ # Embedding generation
56
+ with torch.no_grad():
57
+ text_embed = model.encode_text("a pikachu")
58
+ image_embed = model.encode_image(image)
59
 
60
+ # normalize the features to perform cosine similarity
61
+ text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
62
+ image_embed = image_embed / image_embed.norm(dim=-1, keepdim=True)
63
+
64
+ similarity = (image_embed @ text_embed.T).squeeze(0, -1).cpu().numpy()
65
  ```
66
 
67
+ ### Demo
68
+ In `demo.ipynb` we provide a simple example on how to use Talk2DINO for inference on a given image with custom textual categories.
69
  Result:
70
  <div align="center">
71
  <table><tr><td><figure>
72
  <img alt="" src="./assets/pikachu.png" width=300>
73
  </figure></td><td><figure>
74
+ <img alt="" src="./assets/pikachu_seg.png" width=300>
75
  </figure></td></tr></table>
76
  </div>
77
 
78
  ## Installation
79
+
80
+ To use the **Hugging Face interface** for inference:
81
+
82
  ```bash
83
+ # Clone the repository
84
+ git clone https://huggingface.co/lorebianchi98/Talk2DINO-ViTL
85
+ cd Talk2DINO-ViTL
86
+
87
+ # Install dependencies
 
 
 
 
 
 
 
 
88
  pip install -r requirements.txt
89
+
90
+ # Install PyTorch and torchvision with the appropriate CUDA version
91
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
 
 
 
92
  ```
93
 
94
+ > For the **full MMCV interface** to perform evaluation on segmentation benchmarks, please refer to the [original Talk2DINO repository](https://github.com/lorebianchi98/Talk2DINO).
95
+
96
+
97
+
98
  <details>
99
  <summary>Qualitative Results</summary>
100
 
assets/pikachu_seg.png ADDED

Git LFS Details

  • SHA256: 53eb872bee3c849aeca202853fc8d38019916f2a465f8620542647c4a8baa852
  • Pointer size: 131 Bytes
  • Size of remote file: 220 kB
hf_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
hf_model/__init__.py ADDED
File without changes
hf_model/hooks.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ feats = {}
3
+ def get_self_attention(module, input, output):
4
+ feats['self_attn'] = output
5
+
6
+ def process_self_attention(output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False):
7
+ qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4)
8
+ q, k, v = qkv[0] * scale, qkv[1], qkv[2]
9
+ attn = q @ k.transpose(-2, -1)
10
+ self_attn_maps = attn[:, : , 0, num_global_tokens:]
11
+ self_attn = self_attn_maps.mean(dim=1)
12
+ self_attn = self_attn.softmax(dim=-1)
13
+ if ret_self_attn_maps:
14
+ return self_attn, self_attn_maps
15
+ else:
16
+ return self_attn
17
+
18
+ def get_vit_out(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
19
+ feats['vit_out'] = output
20
+
21
+ def get_second_last_out(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
22
+ feats['second_last_out'] = output
23
+
24
+ def get_all_out_tokens(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
25
+ feats['clip_txt_out_tokens'] = output
26
+
27
+ def get_clip_second_last_dense_out(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
28
+ feats['clip_second_last_out'] = output.permute(1,0,2)
29
+
30
+ def get_dinov1_patches(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
31
+ feats['dinov1_patches'] = output
32
+
33
+ def get_all_out_tokens(model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
34
+ feats['clip_txt_out_tokens'] = output
35
+
36
+ def average_text_tokens(text_embeddings, mask, keep_cls=False, keep_end_seq=False):
37
+ if not keep_end_seq:
38
+ mask[torch.arange(mask.shape[0]), mask.sum(dim=1) - 1] = False # excluding end of sequence
39
+ if not keep_cls:
40
+ mask[:, 0] = False # excluding CLS token
41
+
42
+
43
+ masked_embeddings = text_embeddings * mask.unsqueeze(-1) # shape: [BS, SEQ_LEN, 512]
44
+
45
+ sum_embeddings = masked_embeddings.sum(dim=1) # shape: [BS, 512]
46
+
47
+ valid_elements = mask.sum(dim=1, keepdim=True) # shape: [BS, 1]
48
+
49
+ mean_embeddings = sum_embeddings / valid_elements # shape: [BS, 512]
50
+
51
+ return mean_embeddings
52
+
hf_model/masker.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Talk2DINO
3
+ # ------------------------------------------------------------------------------
4
+ import copy
5
+ from collections import OrderedDict
6
+ import numpy as np
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import hf_model.us as us
12
+ from einops import rearrange, repeat
13
+
14
+ # from models.dinotext.gumbel import gumbel_sigmoid
15
+ from hf_model.modules import FeatureEncoder
16
+ from omegaconf import OmegaConf
17
+
18
+
19
+ def build_model(config):
20
+ model = OmegaConf.to_container(config, resolve=True)
21
+ return model
22
+
23
+ class Sim2Mask(nn.Module):
24
+ def __init__(self, init_w=1.0, init_b=0.0, gumbel_tau=1.0, learnable=True):
25
+ super().__init__()
26
+ self.init_w = init_w
27
+ self.init_b = init_b
28
+ self.gumbel_tau = gumbel_tau
29
+ self.learnable = learnable
30
+
31
+ assert not ((init_w is None) ^ (init_b is None))
32
+ if learnable:
33
+ self.w = nn.Parameter(torch.full([], float(init_w)))
34
+ self.b = nn.Parameter(torch.full([], float(init_b)))
35
+ else:
36
+ self.w = init_w
37
+ self.b = init_b
38
+
39
+ def forward(self, x, deterministic=False):
40
+ logits = x * self.w + self.b
41
+
42
+ soft_mask = torch.sigmoid(logits)
43
+ if deterministic:
44
+ hard_mask = soft_mask.gt(0.5).type(logits.dtype)
45
+ else:
46
+ hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau)
47
+
48
+ return hard_mask, soft_mask
49
+
50
+ def extra_repr(self):
51
+ return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}'
52
+
53
+
54
+ class MaskerBackbone(nn.Module):
55
+ """Masker image encoder backbone.
56
+ """
57
+ def __init__(self, clip_visual, freeze_idx):
58
+ super().__init__()
59
+ self.transformer = copy.deepcopy(clip_visual.transformer)
60
+ self.transformer.resblocks = self.transformer.resblocks[freeze_idx:]
61
+
62
+ for block in self.transformer.resblocks:
63
+ if hasattr(block, "hook_handler"):
64
+ block.hook_handler.remove()
65
+
66
+ self.ln_post = copy.deepcopy(clip_visual.ln_post)
67
+ self.proj = copy.deepcopy(clip_visual.proj)
68
+
69
+ self.layers = len(self.transformer.resblocks)
70
+ self.patch_size = clip_visual.patch_size
71
+
72
+ self.output_dim = clip_visual.output_dim if self.proj is not None else clip_visual.width
73
+
74
+ def forward(self, x, spatial=True, ignore_last_attn=True):
75
+ if self.layers:
76
+ x = self.transformer(x, ignore_last_attn=ignore_last_attn)
77
+
78
+ x = x.permute(1, 0, 2) # LND -> NLD
79
+
80
+ if spatial:
81
+ x = self.ln_post(x)
82
+ else:
83
+ x = self.ln_post(x[:, 0, :])
84
+
85
+ if self.proj is not None:
86
+ x = x @ self.proj
87
+
88
+ return x
89
+
90
+ class MaskerImageFeatureEncoder(FeatureEncoder):
91
+ def __init__(self, backbone: nn.Module, decoder: nn.Module, ignore_last_attn: bool = True):
92
+ super().__init__()
93
+ self.ignore_last_attn = ignore_last_attn
94
+ self.patch_size = backbone.patch_size
95
+ self.backbone = backbone
96
+ self.decoder = decoder
97
+
98
+ for resblock in self.backbone.transformer.resblocks:
99
+ resblock.hook_handler = resblock.register_forward_hook(self.hook)
100
+
101
+ def _encode(self, image, image_feat):
102
+ H, W = image.shape[-2:]
103
+ h = H // self.patch_size
104
+ w = W // self.patch_size
105
+
106
+ x = self.backbone(image_feat, spatial=True, ignore_last_attn=self.ignore_last_attn) # BLC
107
+ x = rearrange(x[:, 1:], "B (H W) C -> B C H W", H=h, W=w)
108
+ x = self.decoder(x)
109
+
110
+ return x
111
+
112
+ class Masker(nn.Module):
113
+ def __init__(self, backbone, decoder, image_proj, sim2mask, ignore_last_attn, **kwargs):
114
+ super().__init__()
115
+ self.ignore_last_attn = ignore_last_attn
116
+
117
+ decoder["C"] = backbone.output_dim
118
+ decoder = MODELS.build(decoder)
119
+ decoder = nn.Sequential(OrderedDict([
120
+ ("decoder", decoder),
121
+ ("image_proj", image_proj)
122
+ ]))
123
+
124
+ self.image_encoder = MaskerImageFeatureEncoder(backbone, decoder, ignore_last_attn=ignore_last_attn)
125
+
126
+ self.sim2mask = Sim2Mask(**sim2mask)
127
+
128
+ def forward(self, image, image_feat, text_emb, deterministic=False):
129
+ B = image.size(0)
130
+ image_emb, feats = self.image_encoder(image, image_feat, ret_feats=True) # [BCHW]
131
+
132
+ image_emb_norm = us.normalize(image_emb, dim=1)
133
+ text_emb_norm = us.normalize(text_emb, dim=-1)
134
+
135
+ H, W = image_emb.shape[2:]
136
+ D = dist.get_world_size()
137
+
138
+ # simmap [B, B*D, H, W] where D is #devices
139
+ all_text_emb_norm = us.gather_cat(text_emb_norm, grad=True, contiguous_grad=True)
140
+ simmap = torch.einsum("bchw,nc->bnhw", image_emb_norm, all_text_emb_norm)
141
+ mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
142
+
143
+ # mask [B, B*D, H, W] where D is #devices
144
+ # positive global label
145
+ pos_indices = torch.arange(B, dtype=torch.long, device=image_emb.device) + B * dist.get_rank()
146
+ pos_mask = mask[torch.arange(B), pos_indices].unsqueeze(1) # [B, 1, H, W]
147
+
148
+ offdiag = torch.ones(B, B*D, dtype=torch.bool, device=mask.device)
149
+ offdiag[torch.arange(B), pos_indices] = False
150
+
151
+ soft_pos_mask = soft_mask[torch.arange(B), pos_indices].unsqueeze(1)
152
+ soft_neg_mask = soft_mask.masked_select(offdiag[..., None, None]).view(B, B*D-1, H, W)
153
+
154
+ masks = {
155
+ "pos": pos_mask, # [B, 1, H, W]
156
+
157
+ "soft_pos": soft_pos_mask,
158
+ "soft_neg": soft_neg_mask,
159
+ "soft_all": soft_mask, # [B, N, H, W]
160
+ }
161
+
162
+ return masks, image_emb, text_emb, feats
163
+
164
+ @torch.no_grad()
165
+ def forward_seg(self, image, image_feat, text_emb, deterministic=True, hard=False):
166
+ """Make mask by 1:N matching
167
+
168
+ Args:
169
+ image [B, 3, H, W]
170
+ image_feat [L, B, C]: CLIP features
171
+ text_emb [N, C]
172
+ deterministic (bool): deterministic inference flag for gumbel noise
173
+ hard (bool): decide hard or soft returning segmentation mask.
174
+ Note that soft mask is required for proper evaluation
175
+
176
+ Return:
177
+ mask [B, N, H', W'] (H' and W' are downsampled H/W)
178
+ """
179
+ image_emb = self.image_encoder(image, image_feat) # [BCHW]
180
+
181
+ image_emb = us.normalize(image_emb, dim=1) # BCHW
182
+ text_emb = us.normalize(text_emb, dim=-1) # NC
183
+
184
+ simmap = torch.einsum("b c h w, n c -> b n h w", image_emb, text_emb)
185
+
186
+ hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
187
+ mask = hard_mask if hard else soft_mask
188
+
189
+ return mask, simmap
190
+
191
+ class DINOTextMasker(nn.Module):
192
+ def __init__(self, similarity_type="cosine"):
193
+ super().__init__()
194
+ self.sim2mask = DINOTextSim2Mask()
195
+ self.sim2mask = self.sim2mask.eval()
196
+ self.similarity_type = similarity_type
197
+
198
+ def forward(self, image, image_feat, text_emb, deterministic=False):
199
+ pass
200
+
201
+ @torch.no_grad()
202
+ def forward_seg(self, image_feat, text_emb, deterministic=True, hard=False):
203
+ """Make mask by 1:N matching
204
+
205
+ Args:
206
+ image [B, 3, H, W]
207
+ image_feat [L, B, C]: CLIP features
208
+ text_emb [N, K, C]
209
+ deterministic (bool): deterministic inference flag for gumbel noise
210
+ hard (bool): decide hard or soft returning segmentation mask.
211
+ Note that soft mask is required for proper evaluation
212
+ use_k_nn (bool): use kNN to segment
213
+ k_nn (int): number of nearest neighbors for kNN segmentation
214
+
215
+ Return:
216
+ mask [B, N, H', W'] (H' and W' are downsampled H/W)
217
+ """
218
+ b, c, h, w = image_feat.shape
219
+ n, c = text_emb.shape
220
+
221
+ if self.similarity_type == "cosine":
222
+ image_feat = us.normalize(image_feat, dim=1) # BCHW
223
+ # text_emb = us.normalize(text_emb, dim=-1) # NKC
224
+ simmap = torch.einsum("b c h w, n c -> b n h w", image_feat, text_emb)
225
+ else:
226
+ raise NotImplementedError("similarity type {} not implemented".format(self.similarity_type))
227
+
228
+ hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
229
+ mask = hard_mask if hard else soft_mask
230
+
231
+ return mask, simmap
232
+
233
+
234
+ class DINOTextSim2Mask(nn.Module):
235
+ def __init__(self, gumbel_tau=1.0):
236
+ super().__init__()
237
+ self.gumbel_tau = gumbel_tau
238
+
239
+ def forward(self, x, deterministic=False):
240
+ soft_mask = torch.sigmoid(x)
241
+ if deterministic:
242
+ hard_mask = soft_mask.gt(0.5).type(x.dtype)
243
+ else:
244
+ hard_mask = gumbel_sigmoid(x, hard=True, tau=self.gumbel_tau)
245
+
246
+ return hard_mask, soft_mask
hf_model/model.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import clip
2
+ import yaml
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from hf_model.hooks import get_self_attention, process_self_attention, feats
8
+
9
+ class VisualProjectionLayer(nn.Module):
10
+ """
11
+ Creates a projection layer on top of the DINO encoder.
12
+ The forward method calculate the similarity between the projected DINO token and the CLIP textual CLS token.
13
+ """
14
+ def __init__(self, act=nn.Tanh(), hidden_layer=False, cosine=True, hidden_embed_dim=None, dino_embed_dim=1024, clip_embed_dim=512):
15
+ # mlp_dims list of mlp dimensions
16
+ super().__init__()
17
+ if hidden_embed_dim is None:
18
+ hidden_embed_dim = clip_embed_dim
19
+
20
+ self.linear_layer = nn.Linear(dino_embed_dim, hidden_embed_dim)
21
+ if hidden_layer:
22
+ self.linear_layer2 = nn.Linear(hidden_embed_dim, clip_embed_dim)
23
+ self.act = act
24
+ self.cosine = cosine
25
+
26
+ @classmethod
27
+ def from_config(cls, config):
28
+ if type(config) is str:
29
+ # if the configuration is a string, we treat it as a file path
30
+ with open(config, 'r') as f:
31
+ config = yaml.safe_load(f)['model']
32
+
33
+ # loading the activation function
34
+ act = config.get('act', None)
35
+ if act == 'tanh':
36
+ act = nn.Tanh()
37
+ elif act == 'relu':
38
+ act = nn.ReLU()
39
+ elif act == 'sigmoid':
40
+ act = nn.Sigmoid()
41
+ elif act is not None:
42
+ raise Exception("Unknown activation function")
43
+
44
+ model = cls(
45
+ act=act,
46
+ hidden_layer=config.get('hidden_layer', False),
47
+ cosine=config.get('cosine', True),
48
+ hidden_embed_dim=config.get('hidden_embed_dim', None) if config.get('hidden_layer', False) else None,
49
+ dino_embed_dim=config.get('dino_embed_dim', 1024),
50
+ clip_embed_dim=config.get('clip_embed_dim', 512)
51
+
52
+ )
53
+ return model
54
+
55
+
56
+ def forward(self, visual_embedding, textual_embedding, ret_similarity_matrix=True, ret_embeds=False):
57
+ visual_embedding = self.project_dino(visual_embedding)
58
+ textual_embedding = textual_embedding.float()
59
+
60
+ if self.cosine:
61
+ textual_embedding = F.normalize(textual_embedding, p=2, dim=1)
62
+ visual_embedding = F.normalize(visual_embedding, p=2, dim=1)
63
+ if ret_embeds:
64
+ return textual_embedding, visual_embedding
65
+ x = textual_embedding @ visual_embedding.transpose(1, 0)
66
+ if not ret_similarity_matrix:
67
+ x = x[torch.eye(len(x)) > 0.5] # only diagonal elements
68
+
69
+ return x
70
+
71
+ def project_dino(self, visual_embedding):
72
+ visual_embedding = visual_embedding.float()
73
+
74
+ x = self.linear_layer(visual_embedding)
75
+ if self.act:
76
+ x = self.act(x)
77
+ if hasattr(self, 'linear_layer2'):
78
+ x = self.linear_layer2(x)
79
+
80
+ return x
81
+
82
+ def __len__(self):
83
+ return sum(p.numel() for p in self.parameters())
84
+
85
+
86
+
87
+ class ProjectionLayer(nn.Module):
88
+ """
89
+ Creates a projection layer on top of the CLIP-text encoder.
90
+ The forward method calculate the similarity between the DINO CLS token and the projected CLIP textual CLS token.
91
+ """
92
+ def __init__(self, act=nn.Tanh(), hidden_layer=False, cosine=True, dino_embed_dim=1024, clip_embed_dim=512, num_attn_head=16, weight_attn_heads=None,
93
+ alignment_strategy='max_score', alpha=0.6, keep_cls=False, keep_end_seq=False):
94
+ # mlp_dims list of mlp dimensions
95
+ super().__init__()
96
+ self.num_attn_head = num_attn_head
97
+
98
+ self.linear_layer = nn.Linear(clip_embed_dim, dino_embed_dim)
99
+ if hidden_layer:
100
+ hidden_layer = 1 if hidden_layer is True else hidden_layer # ensuring compatibility with old code
101
+ # self.linear_layer2 = nn.Linear(dino_embed_dim, dino_embed_dim)
102
+ self.hidden_layers = nn.ModuleList([nn.Linear(dino_embed_dim, dino_embed_dim) for _ in range(hidden_layer)])
103
+ self.act = act
104
+ self.cosine = cosine
105
+
106
+ self.weight_attn_heads = weight_attn_heads
107
+ if weight_attn_heads == 'static':
108
+ self.attn_weights = nn.Parameter(torch.rand(self.num_attn_head))
109
+ elif weight_attn_heads == 'conditioned':
110
+ self.weight_layer1 = nn.Linear(dino_embed_dim, dino_embed_dim)
111
+ self.weight_layer2 = nn.Linear(dino_embed_dim, self.num_attn_head)
112
+
113
+ self.alignment_strategy = alignment_strategy # relevant only if we use disentangled_self_attn
114
+ self.keep_cls = keep_cls # relevant only if we use clip_txt_tokens_out
115
+ self.keep_end_seq = keep_end_seq # relevant only if we use clip_txt_tokens_out
116
+ self.alpha = alpha
117
+
118
+ @classmethod
119
+ def from_config(cls, config):
120
+ if type(config) is str:
121
+ # if the configuration is a string, we treat it as a file path
122
+ with open(config, 'r') as f:
123
+ config = yaml.safe_load(f)['model']
124
+
125
+ # loading the activation function
126
+ act = config.get('act', None)
127
+ if act == 'tanh':
128
+ act = nn.Tanh()
129
+ elif act == 'relu':
130
+ act = nn.ReLU()
131
+ elif act == 'sigmoid':
132
+ act = nn.Sigmoid()
133
+ elif act is not None:
134
+ raise Exception("Unknown activation function")
135
+
136
+ model = cls(
137
+ act=act,
138
+ hidden_layer=config.get('hidden_layer', False),
139
+ cosine=config.get('cosine', True),
140
+ dino_embed_dim=config.get('dino_embed_dim', 1024),
141
+ num_attn_head=config.get('num_attn_head', 16),
142
+ clip_embed_dim=config.get('clip_embed_dim', 512),
143
+ weight_attn_heads=config.get('weight_attn_heads', None),
144
+ alignment_strategy=config.get('alignment_strategy', 'max_score'),
145
+ alpha=config.get('alpha', 0.6),
146
+ keep_cls=config.get('keep_cls', None),
147
+ keep_end_seq=config.get('keep_end_seq', None),
148
+ )
149
+ if config.get('starting_checkpoint', None) is not None:
150
+ model.load_state_dict(torch.load(config['starting_checkpoint'], 'cpu'))
151
+
152
+ return model
153
+
154
+ def compute_similarity(self, visual_embedding, textual_embedding, text_input_mask=None, return_index=False):
155
+ if len(visual_embedding.shape) == 3 or len(textual_embedding.shape) == 3:
156
+ # at least one embedding is decomposed: either we have all textual tokens or we have all the attention head tokens
157
+
158
+ if self.alignment_strategy == 'weighted_avg':
159
+ if len(visual_embedding.shape) != 3 or len(textual_embedding.shape) != 2:
160
+ raise Exception("Alignment strategy not implemented for this type of embeddings!")
161
+ sims = torch.einsum('ik,ijk->ij', textual_embedding, visual_embedding)
162
+ sims = sims.softmax(dim=-1)
163
+ # in this case, we keep as visual_embedding the averaged token weighted by the text similarities
164
+ visual_embedding = (visual_embedding * sims.unsqueeze(dim=-1)).mean(dim=1)
165
+ sims = textual_embedding @ visual_embedding.transpose(1, 0)
166
+
167
+ # in this case we sample the visual embedding from the softmax similarities of attention heads tokens and the textual tokens
168
+ elif self.alignment_strategy == 'sampled_attn_map':
169
+ if len(visual_embedding.shape) != 3 or len(textual_embedding.shape) != 2:
170
+ raise Exception("Alignment strategy not implemented for this type of embeddings!")
171
+ sims = torch.einsum('ik,ijk->ij', textual_embedding, visual_embedding)
172
+ sims = sims.softmax(dim=-1)
173
+ # in this case, we sample from the distribution given byt text2attn-maps similarities the attention map to align
174
+ index = torch.multinomial(sims, 1).view(-1, 1, 1).expand(-1, 1, visual_embedding.shape[-1])
175
+ visual_embedding = torch.gather(visual_embedding, 1, index).squeeze(1)
176
+ sims = textual_embedding @ visual_embedding.transpose(1, 0)
177
+
178
+ elif self.alignment_strategy == 'max_score':
179
+ sims = torch.einsum('ik,ijk->ij', textual_embedding, visual_embedding)
180
+ sims = sims.softmax(dim=-1)
181
+ index = sims.argmax(dim=-1)
182
+ index_reshaped = sims.argmax(dim=-1).view(-1, 1, 1).expand(-1, 1, visual_embedding.shape[-1])
183
+ visual_embedding = torch.gather(visual_embedding, 1, index_reshaped).squeeze(1)
184
+ sims = textual_embedding @ visual_embedding.transpose(1, 0)
185
+ else:
186
+ # in this case we construct a similarity matrix between attention head tokens and textual tokens
187
+
188
+ # we ensure that both the batch embeddings have the same number of dimensions
189
+ textual_embedding = textual_embedding.unsqueeze(1) if len(textual_embedding.shape) == 2 else textual_embedding
190
+ visual_embedding = visual_embedding.unsqueeze(1) if len(visual_embedding.shape) == 2 else visual_embedding
191
+ if textual_embedding.shape[1] > 1:
192
+ assert text_input_mask is not None, "If we use all the textual embeddings, we need the input mask"
193
+ if not self.keep_end_seq:
194
+ # we take the last True value of the mask and we set it to False
195
+ text_input_mask[torch.arange(text_input_mask.shape[0]), torch.sum(text_input_mask, dim=1) - 1] = False
196
+ if not self.keep_cls:
197
+ text_input_mask[:, 0] = False
198
+
199
+ # do not consider cls and eos tokens
200
+ im_set = visual_embedding
201
+ s_seq = textual_embedding
202
+
203
+ im_set_batch = im_set.size(0)
204
+ im_set_len = im_set.size(1)
205
+ s_seq_batch = s_seq.size(0)
206
+ s_seq_len = s_seq.size(1)
207
+
208
+ im_set = im_set.unsqueeze(1).expand(-1, s_seq_batch, -1, -1) # B x B x S_im x dim
209
+ s_seq = s_seq.unsqueeze(0).expand(im_set_batch, -1, -1, -1) # B x B x S_s x dim
210
+ alignments = torch.matmul(im_set, s_seq.permute(0, 1, 3, 2)) # B x B x S_im x S_s
211
+
212
+ # compute mask for the alignments tensor
213
+ if text_input_mask is not None:
214
+ alignment_mask = text_input_mask.unsqueeze(1).unsqueeze(0).expand(im_set_batch, -1, im_set_len, -1).logical_not()
215
+
216
+ alignments.masked_fill_(alignment_mask, value=0)
217
+ # alignments = F.relu(alignments)
218
+ # alignments = F.normalize(alignments,p=2, dim=2)
219
+
220
+ if self.alignment_strategy == 'sum':
221
+ sims = alignments.sum(dim=(2,3))
222
+ elif self.alignment_strategy == 'mean':
223
+ sims = alignments.mean(dim=(2,3))
224
+ elif self.alignment_strategy == 'max-row_sum':
225
+ sims = alignments.max(2)[0].sum(2)
226
+ elif self.alignment_strategy == 'nucleus-sampling':
227
+ max_alignments = alignments.max(2)[0]
228
+ sorted_alignments = max_alignments.sort(dim=2, descending=True)[0]
229
+ # min-max normalization
230
+ mins = sorted_alignments.min(2)[0].unsqueeze(-1).expand(-1, -1, s_seq_len)
231
+ maxs = sorted_alignments.max(2)[0].unsqueeze(-1).expand(-1, -1, s_seq_len)
232
+ norm_alignments = ((sorted_alignments - mins) / (maxs - mins))
233
+ # transform values in percentage
234
+ sums = norm_alignments.sum(dim=-1).unsqueeze(-1).expand(-1, -1, s_seq_len)
235
+ norm_alignments = norm_alignments / sums
236
+ # finding the element indices which surpasses alpha
237
+ cumsums = norm_alignments.cumsum(2)
238
+ indices = torch.argmax((cumsums > self.alpha).int() + 1, dim=2)
239
+
240
+ mask = torch.arange(s_seq_len).unsqueeze(0).unsqueeze(0).expand(s_seq_batch, s_seq_batch, s_seq_len).to(indices.device) < indices.unsqueeze(-1).expand(-1, -1, s_seq_len) + 1
241
+ relevant_alignments = (sorted_alignments * mask)
242
+ sims = relevant_alignments.sum(dim=2)
243
+ else:
244
+ # default case: dot-product
245
+ sims = textual_embedding @ visual_embedding.transpose(1, 0)
246
+
247
+ if not return_index:
248
+ return sims
249
+ else:
250
+ return sims, index
251
+
252
+
253
+
254
+ def forward(self, visual_embedding, textual_embedding, ret_similarity_matrix=True, ret_embeds=False, self_attn_maps=None, cls=None, text_input_mask=None, return_index=False):
255
+ if self.weight_attn_heads is not None:
256
+ assert self_attn_maps is not None, "In case we have attention maps weights, we have to weight patch tokens mean by the weighted self-attention maps"
257
+ visual_embedding = self.get_visual_embed(visual_embedding, self_attn_maps=self_attn_maps, cls=cls)
258
+
259
+ textual_embedding = self.project_clip_txt(textual_embedding)
260
+
261
+ if self.cosine:
262
+ textual_embedding = F.normalize(textual_embedding, p=2, dim=-1)
263
+ visual_embedding = F.normalize(visual_embedding, p=2, dim=-1)
264
+
265
+
266
+ if ret_embeds:
267
+ return textual_embedding, visual_embedding
268
+
269
+ if not return_index:
270
+ x = self.compute_similarity(visual_embedding, textual_embedding, text_input_mask, return_index)
271
+ else:
272
+ x, index = self.compute_similarity(visual_embedding, textual_embedding, text_input_mask, return_index)
273
+
274
+ if not ret_similarity_matrix:
275
+ x = x[torch.eye(len(x)) > 0.5] # only diagonal elements
276
+
277
+ if not return_index:
278
+ return x
279
+ else:
280
+ return x, index
281
+
282
+ def get_visual_embed(self, visual_embedding, self_attn_maps=None, cls=None):
283
+ if self_attn_maps is not None:
284
+ # we weight each attention head to obtain a weighted self-attention map
285
+ assert len(visual_embedding.shape) == 3, "In case we have attention maps weights, the visual_embedding should contain patch embeddings, with shape BS x NUM_PATCHES x EMBED_DIM"
286
+ if self.weight_attn_heads == 'conditioned':
287
+ assert cls is not None, "cls must be setted in case of dinamic attention weighting"
288
+ x = self.weight_layer1(cls)
289
+ x = self.act(x)
290
+ x = self.weight_layer2(x)
291
+ normalized_attn_weights = x.softmax(dim=1)
292
+ self_attn = (self_attn_maps * normalized_attn_weights.unsqueeze(dim=-1)).mean(dim=1)
293
+ else:
294
+ normalized_attn_weights = self.attn_weights.softmax(dim=0)
295
+ self_attn = (self_attn_maps * normalized_attn_weights.view(1, normalized_attn_weights.shape[0], 1)).mean(dim=1)
296
+ self_attn = self_attn.softmax(dim=-1)
297
+
298
+ # then we perform the weighted mean of patches
299
+ visual_embedding = (self_attn.unsqueeze(-1) * visual_embedding).mean(dim=1)
300
+ return visual_embedding
301
+
302
+ def project_clip_txt(self, textual_embedding):
303
+ textual_embedding = textual_embedding.float()
304
+ x = self.linear_layer(textual_embedding)
305
+
306
+ if hasattr(self, 'hidden_layers'):
307
+ for hidden_layer in self.hidden_layers:
308
+ if self.act:
309
+ x = self.act(x)
310
+ x = hidden_layer(x)
311
+
312
+ return x
313
+ def load_state_dict(self, state_dict, strict=True):
314
+ # compatibility with old code
315
+ if 'linear_layer2.weight' in state_dict:
316
+ state_dict['hidden_layers.0.weight'] = state_dict.pop('linear_layer2.weight')
317
+ state_dict['hidden_layers.0.bias'] = state_dict.pop('linear_layer2.bias')
318
+ # Call the parent class's load_state_dict with the modified state_dict
319
+ super(ProjectionLayer, self).load_state_dict(state_dict, strict)
320
+
321
+ def set_alignment_strategy(self, alignment_strategy):
322
+ self.alignment_strategy = alignment_strategy
323
+ return
324
+
325
+ def __len__(self):
326
+ return sum(p.numel() for p in self.parameters())
327
+
328
+ class DoubleMLP(nn.Module):
329
+ def __init__(self, act=nn.Tanh(), hidden_layer=False, cosine=True, dino_embed_dim=1024, clip_embed_dim=512, num_attn_head=16, weight_attn_heads=None,
330
+ alignment_strategy='max_score', alpha=0.6, keep_cls=False, keep_end_seq=False):
331
+ super().__init__()
332
+ self.num_attn_head = num_attn_head
333
+
334
+ self.linear_layer = nn.Linear(clip_embed_dim, dino_embed_dim)
335
+ if hidden_layer:
336
+ hidden_layer = 1 if hidden_layer is True else hidden_layer # ensuring compatibility with old code
337
+ # self.linear_layer2 = nn.Linear(dino_embed_dim, dino_embed_dim)
338
+ self.hidden_layers = nn.ModuleList([nn.Linear(dino_embed_dim, dino_embed_dim) for _ in range(hidden_layer)])
339
+ self.act = act
340
+ self.cosine = cosine
341
+
342
+ self.weight_attn_heads = weight_attn_heads
343
+ if weight_attn_heads == 'static':
344
+ self.attn_weights = nn.Parameter(torch.rand(self.num_attn_head))
345
+ elif weight_attn_heads == 'conditioned':
346
+ self.weight_layer1 = nn.Linear(dino_embed_dim, dino_embed_dim)
347
+ self.weight_layer2 = nn.Linear(dino_embed_dim, self.num_attn_head)
348
+
349
+ self.alignment_strategy = alignment_strategy # relevant only if we use disentangled_self_attn
350
+ self.keep_cls = keep_cls # relevant only if we use clip_txt_tokens_out
351
+ self.keep_end_seq = keep_end_seq # relevant only if we use clip_txt_tokens_out
352
+ self.alpha = alpha
353
+
354
+ self.visual_linear = nn.Linear(dino_embed_dim, dino_embed_dim)
355
+ if hidden_layer:
356
+ hidden_layer = 1 if hidden_layer is True else hidden_layer # ensuring compatibility with old code
357
+ self.visual_hidden_layers = nn.ModuleList([nn.Linear(dino_embed_dim, dino_embed_dim) for _ in range(hidden_layer)])
358
+
359
+ @classmethod
360
+ def from_config(cls, config):
361
+ if type(config) is str:
362
+ # if the configuration is a string, we treat it as a file path
363
+ with open(config, 'r') as f:
364
+ config = yaml.safe_load(f)['model']
365
+
366
+ # loading the activation function
367
+ act = config.get('act', None)
368
+ if act == 'tanh':
369
+ act = nn.Tanh()
370
+ elif act == 'relu':
371
+ act = nn.ReLU()
372
+ elif act == 'sigmoid':
373
+ act = nn.Sigmoid()
374
+ elif act is not None:
375
+ raise Exception("Unknown activation function")
376
+
377
+ model = cls(
378
+ act=act,
379
+ hidden_layer=config.get('hidden_layer', False),
380
+ cosine=config.get('cosine', True),
381
+ dino_embed_dim=config.get('dino_embed_dim', 1024),
382
+ num_attn_head=config.get('num_attn_head', 16),
383
+ clip_embed_dim=config.get('clip_embed_dim', 512),
384
+ weight_attn_heads=config.get('weight_attn_heads', None),
385
+ alignment_strategy=config.get('alignment_strategy', 'max_score'),
386
+ alpha=config.get('alpha', 0.6),
387
+ keep_cls=config.get('keep_cls', None),
388
+ keep_end_seq=config.get('keep_end_seq', None),
389
+ )
390
+ if config.get('starting_checkpoint', None) is not None:
391
+ model.load_state_dict(torch.load(config['starting_checkpoint'], 'cpu'))
392
+
393
+ return model
394
+
395
+ def compute_similarity(self, visual_embedding, textual_embedding, text_input_mask=None):
396
+ if len(visual_embedding.shape) == 3 or len(textual_embedding.shape) == 3:
397
+ # at least one embedding is decomposed: either we have all textual tokens or we have all the attention head tokens
398
+
399
+ if self.alignment_strategy == 'weighted_avg':
400
+ if len(visual_embedding.shape) != 3 or len(textual_embedding.shape) != 2:
401
+ raise Exception("Alignment strategy not implemented for this type of embeddings!")
402
+ sims = torch.einsum('ik,ijk->ij', textual_embedding, visual_embedding)
403
+ sims = sims.softmax(dim=-1)
404
+ # in this case, we keep as visual_embedding the averaged token weighted by the text similarities
405
+ visual_embedding = (visual_embedding * sims.unsqueeze(dim=-1)).mean(dim=1)
406
+ sims = textual_embedding @ visual_embedding.transpose(1, 0)
407
+
408
+ # in this case we sample the visual embedding from the softmax similarities of attention heads tokens and the textual tokens
409
+ elif self.alignment_strategy == 'sampled_attn_map':
410
+ if len(visual_embedding.shape) != 3 or len(textual_embedding.shape) != 2:
411
+ raise Exception("Alignment strategy not implemented for this type of embeddings!")
412
+ sims = torch.einsum('ik,ijk->ij', textual_embedding, visual_embedding)
413
+ sims = sims.softmax(dim=-1)
414
+ # in this case, we sample from the distribution given byt text2attn-maps similarities the attention map to align
415
+ index = torch.multinomial(sims, 1).view(-1, 1, 1).expand(-1, 1, visual_embedding.shape[-1])
416
+ visual_embedding = torch.gather(visual_embedding, 1, index).squeeze(1)
417
+ sims = textual_embedding @ visual_embedding.transpose(1, 0)
418
+
419
+ elif self.alignment_strategy == 'max_score':
420
+ sims = torch.einsum('ik,ijk->ij', textual_embedding, visual_embedding)
421
+ sims = sims.softmax(dim=-1)
422
+ index = sims.argmax(dim=-1).view(-1, 1, 1).expand(-1, 1, visual_embedding.shape[-1])
423
+ visual_embedding = torch.gather(visual_embedding, 1, index).squeeze(1)
424
+ sims = textual_embedding @ visual_embedding.transpose(1, 0)
425
+ else:
426
+ # in this case we construct a similarity matrix between attention head tokens and textual tokens
427
+
428
+ # we ensure that both the batch embeddings have the same number of dimensions
429
+ textual_embedding = textual_embedding.unsqueeze(1) if len(textual_embedding.shape) == 2 else textual_embedding
430
+ visual_embedding = visual_embedding.unsqueeze(1) if len(visual_embedding.shape) == 2 else visual_embedding
431
+ if textual_embedding.shape[1] > 1:
432
+ assert text_input_mask is not None, "If we use all the textual embeddings, we need the input mask"
433
+ if not self.keep_end_seq:
434
+ # we take the last True value of the mask and we set it to False
435
+ text_input_mask[torch.arange(text_input_mask.shape[0]), torch.sum(text_input_mask, dim=1) - 1] = False
436
+ if not self.keep_cls:
437
+ text_input_mask[:, 0] = False
438
+
439
+ # do not consider cls and eos tokens
440
+ im_set = visual_embedding
441
+ s_seq = textual_embedding
442
+
443
+ im_set_batch = im_set.size(0)
444
+ im_set_len = im_set.size(1)
445
+ s_seq_batch = s_seq.size(0)
446
+ s_seq_len = s_seq.size(1)
447
+
448
+ im_set = im_set.unsqueeze(1).expand(-1, s_seq_batch, -1, -1) # B x B x S_im x dim
449
+ s_seq = s_seq.unsqueeze(0).expand(im_set_batch, -1, -1, -1) # B x B x S_s x dim
450
+ alignments = torch.matmul(im_set, s_seq.permute(0, 1, 3, 2)) # B x B x S_im x S_s
451
+
452
+ # compute mask for the alignments tensor
453
+ if text_input_mask is not None:
454
+ alignment_mask = text_input_mask.unsqueeze(1).unsqueeze(0).expand(im_set_batch, -1, im_set_len, -1).logical_not()
455
+
456
+ alignments.masked_fill_(alignment_mask, value=0)
457
+ # alignments = F.relu(alignments)
458
+ # alignments = F.normalize(alignments,p=2, dim=2)
459
+
460
+ if self.alignment_strategy == 'sum':
461
+ sims = alignments.sum(dim=(2,3))
462
+ elif self.alignment_strategy == 'mean':
463
+ sims = alignments.mean(dim=(2,3))
464
+ elif self.alignment_strategy == 'max-row_sum':
465
+ sims = alignments.max(2)[0].sum(2)
466
+ elif self.alignment_strategy == 'nucleus-sampling':
467
+ max_alignments = alignments.max(2)[0]
468
+ sorted_alignments = max_alignments.sort(dim=2, descending=True)[0]
469
+ # min-max normalization
470
+ mins = sorted_alignments.min(2)[0].unsqueeze(-1).expand(-1, -1, s_seq_len)
471
+ maxs = sorted_alignments.max(2)[0].unsqueeze(-1).expand(-1, -1, s_seq_len)
472
+ norm_alignments = ((sorted_alignments - mins) / (maxs - mins))
473
+ # transform values in percentage
474
+ sums = norm_alignments.sum(dim=-1).unsqueeze(-1).expand(-1, -1, s_seq_len)
475
+ norm_alignments = norm_alignments / sums
476
+ # finding the element indices which surpasses alpha
477
+ cumsums = norm_alignments.cumsum(2)
478
+ indices = torch.argmax((cumsums > self.alpha).int() + 1, dim=2)
479
+
480
+ mask = torch.arange(s_seq_len).unsqueeze(0).unsqueeze(0).expand(s_seq_batch, s_seq_batch, s_seq_len).to(indices.device) < indices.unsqueeze(-1).expand(-1, -1, s_seq_len) + 1
481
+ relevant_alignments = (sorted_alignments * mask)
482
+ sims = relevant_alignments.sum(dim=2)
483
+ else:
484
+ # default case: dot-product
485
+ sims = textual_embedding @ visual_embedding.transpose(1, 0)
486
+
487
+ return sims
488
+
489
+
490
+
491
+ def forward(self, visual_embedding, textual_embedding, ret_similarity_matrix=True, ret_embeds=False, self_attn_maps=None, cls=None, text_input_mask=None):
492
+ if self.weight_attn_heads is not None:
493
+ assert self_attn_maps is not None, "In case we have attention maps weights, we have to weight patch tokens mean by the weighted self-attention maps"
494
+ visual_embedding = self.get_visual_embed(visual_embedding, self_attn_maps=self_attn_maps, cls=cls)
495
+
496
+ visual_embedding = self.project_visual(visual_embedding)
497
+
498
+ textual_embedding = self.project_clip_txt(textual_embedding)
499
+
500
+ if self.cosine:
501
+ textual_embedding = F.normalize(textual_embedding, p=2, dim=-1)
502
+ visual_embedding = F.normalize(visual_embedding, p=2, dim=-1)
503
+
504
+
505
+ if ret_embeds:
506
+ return textual_embedding, visual_embedding
507
+
508
+ x = self.compute_similarity(visual_embedding, textual_embedding, text_input_mask)
509
+ if not ret_similarity_matrix:
510
+ x = x[torch.eye(len(x)) > 0.5] # only diagonal elements
511
+
512
+ return x
513
+
514
+ def get_visual_embed(self, visual_embedding, self_attn_maps=None, cls=None):
515
+ if self_attn_maps is not None:
516
+ # we weight each attention head to obtain a weighted self-attention map
517
+ assert len(visual_embedding.shape) == 3, "In case we have attention maps weights, the visual_embedding should contain patch embeddings, with shape BS x NUM_PATCHES x EMBED_DIM"
518
+ if self.weight_attn_heads == 'conditioned':
519
+ assert cls is not None, "cls must be setted in case of dinamic attention weighting"
520
+ x = self.weight_layer1(cls)
521
+ x = self.act(x)
522
+ x = self.weight_layer2(x)
523
+ normalized_attn_weights = x.softmax(dim=1)
524
+ self_attn = (self_attn_maps * normalized_attn_weights.unsqueeze(dim=-1)).mean(dim=1)
525
+ else:
526
+ normalized_attn_weights = self.attn_weights.softmax(dim=0)
527
+ self_attn = (self_attn_maps * normalized_attn_weights.view(1, normalized_attn_weights.shape[0], 1)).mean(dim=1)
528
+ self_attn = self_attn.softmax(dim=-1)
529
+
530
+ # then we perform the weighted mean of patches
531
+ visual_embedding = (self_attn.unsqueeze(-1) * visual_embedding).mean(dim=1)
532
+ return visual_embedding
533
+
534
+ def project_clip_txt(self, textual_embedding):
535
+ textual_embedding = textual_embedding.float()
536
+ x = self.linear_layer(textual_embedding)
537
+
538
+ for hidden_layer in self.hidden_layers:
539
+ if self.act:
540
+ x = self.act(x)
541
+ x = hidden_layer(x)
542
+
543
+ return x
544
+
545
+ def project_visual(self, visual_embedding):
546
+ visual_embedding = visual_embedding.float()
547
+ x = self.visual_linear(visual_embedding)
548
+
549
+ for hidden_layer in self.visual_hidden_layers:
550
+ if self.act:
551
+ x = self.act(x)
552
+ x = hidden_layer(x)
553
+
554
+ return x
555
+
556
+ def load_state_dict(self, state_dict, strict=True):
557
+ # compatibility with old code
558
+ if 'linear_layer2.weight' in state_dict:
559
+ state_dict['hidden_layers.0.weight'] = state_dict.pop('linear_layer2.weight')
560
+ state_dict['hidden_layers.0.bias'] = state_dict.pop('linear_layer2.bias')
561
+ # Call the parent class's load_state_dict with the modified state_dict
562
+ super(DoubleMLP, self).load_state_dict(state_dict, strict)
563
+
564
+ def set_alignment_strategy(self, alignment_strategy):
565
+ self.alignment_strategy = alignment_strategy
566
+ return
567
+
568
+ def __len__(self):
569
+ return sum(p.numel() for p in self.parameters())
570
+
571
+
572
+ class CLIPLastLayer(nn.Module):
573
+ def __init__(self, act=nn.Tanh(), hidden_layer=False, cosine=True, dino_embed_dim=1024, clip_embed_dim=512, weight_attn_heads=None, alignment_strategy='max_score', clip_model='ViT-B/16', text_input_mask=None, projection_weights=None):
574
+ import clip
575
+ super().__init__()
576
+ self.clip_model, _ = clip.load(clip_model)
577
+ self.clip_model.to(dtype=torch.float32)
578
+ # self.last_resblock = copy.deepcopy(self.clip_model.transformer.resblocks[-1])
579
+ self.last_resblock = self.clip_model.transformer.resblocks[-1]
580
+ # self.last_resblock.requires_grad_(False)
581
+ # self.last_ln = copy.deepcopy(self.clip_model.ln_final)
582
+ self.last_ln = self.clip_model.ln_final
583
+ # self.last_ln.requires_grad_(False)
584
+ # self.clip_text_proj = copy.deepcopy(self.clip_model.text_projection)
585
+ self.clip_text_proj = self.clip_model.text_projection
586
+ # self.clip_text_proj.requires_grad_(False)
587
+ self.clip_dtype = self.clip_model.dtype
588
+ del self.clip_model
589
+
590
+ self.projection_layer = ProjectionLayer(act=act, hidden_layer=hidden_layer, cosine=cosine, dino_embed_dim=dino_embed_dim,
591
+ clip_embed_dim=clip_embed_dim, weight_attn_heads=weight_attn_heads, alignment_strategy=alignment_strategy)
592
+
593
+ if projection_weights is not None:
594
+ self.projection_layer.load_state_dict(torch.load(projection_weights, 'cpu'))
595
+
596
+ def forward(self, visual_embedding, textual_embedding, ret_similarity_matrix=True, ret_embeds=False, self_attn_maps=None, cls=None, text_argmax=None, text_input_mask=None):
597
+ x = self.last_resblock(textual_embedding.permute(1, 0, 2))
598
+ x = x.permute(1, 0, 2)
599
+ x = self.last_ln(x).type(self.clip_dtype)
600
+ x = x[torch.arange(x.shape[0]), text_argmax] @ self.clip_text_proj
601
+ if ret_embeds:
602
+ textual_embedding, visual_embedding = self.projection_layer(visual_embedding, x, ret_similarity_matrix=ret_similarity_matrix, ret_embeds=ret_embeds, self_attn_maps=self_attn_maps, cls=cls)
603
+ return textual_embedding, visual_embedding
604
+ x = self.projection_layer(visual_embedding, x, ret_similarity_matrix=ret_similarity_matrix, ret_embeds=ret_embeds, self_attn_maps=self_attn_maps, cls=cls)
605
+ return x
606
+
607
+ def project_clip_txt(self, textual_embedding, text_argmax):
608
+ x = self.last_resblock(textual_embedding.permute(1, 0, 2))
609
+ x = x.permute(1, 0, 2)
610
+ x = self.last_ln(x).type(self.clip_dtype)
611
+ x = x[torch.arange(x.shape[0]), text_argmax] @ self.clip_text_proj
612
+ x = self.projection_layer.project_clip_txt(x)
613
+ return x
614
+
615
+ @classmethod
616
+ def from_config(cls, config):
617
+ if type(config) is str:
618
+ # if the configuration is a string, we treat it as a file path
619
+ with open(config, 'r') as f:
620
+ config = yaml.safe_load(f)['model']
621
+
622
+ # loading the activation function
623
+ act = config.get('act', None)
624
+ if act == 'tanh':
625
+ act = nn.Tanh()
626
+ elif act == 'relu':
627
+ act = nn.ReLU()
628
+ elif act == 'sigmoid':
629
+ act = nn.Sigmoid()
630
+ elif act is not None:
631
+ raise Exception("Unknown activation function")
632
+
633
+ model = cls(
634
+ act=act,
635
+ hidden_layer=config.get('hidden_layer', False),
636
+ cosine=config.get('cosine', True),
637
+ dino_embed_dim=config.get('dino_embed_dim', 1024),
638
+ clip_embed_dim=config.get('clip_embed_dim', 512),
639
+ weight_attn_heads=config.get('weight_attn_heads', None),
640
+ alignment_strategy=config.get('alignment_strategy', 'max_score'),
641
+ clip_model=config.get('clip_model', 'ViT-B/16'),
642
+ projection_weights=config.get('projection_weights', None),
643
+
644
+ )
645
+ if config.get('starting_checkpoint', None) is not None:
646
+ model.load_state_dict(torch.load(config['starting_checkpoint'], 'cpu'))
647
+
648
+ return model
649
+
650
+ def __len__(self):
651
+ return sum(p.numel() for p in self.parameters())
652
+
653
+ class DinoText(nn.Module):
654
+ """
655
+ Project images and texts into DINOv2 latent space.
656
+ """
657
+ def __init__(self, dino_cfg="dinov2_vitl14_reg", clip_cfg="ViT-B/16", projection_cfg="configs/linear.yaml", projection_weights="weights/linear_avg_self_attn_out.pth", freeze_text_encoder=True, avg_self_attn_token=True, use_disentangled_self_attn=False):
658
+ super().__init__()
659
+ # DINO parameters
660
+ self.num_global_tokens = 1 if "reg" not in dino_cfg else 5
661
+ self.embed_dim = 1024 if "vitl" in dino_cfg else 768
662
+ self.num_attn_heads = 16
663
+ self.scale = 0.125
664
+
665
+ self.visual_backbone = torch.hub.load('facebookresearch/dinov2', dino_cfg)
666
+ self.text_backbone, _ = clip.load(clip_cfg)
667
+ self.clip2dino_proj = ProjectionLayer.from_config(projection_cfg)
668
+ if projection_weights is not None:
669
+ self.clip2dino_proj.load_state_dict(torch.load(projection_weights, 'cpu'))
670
+ self.use_avg_self_attn = avg_self_attn_token
671
+ self.use_disentangled_self_attn = use_disentangled_self_attn
672
+ if self.use_avg_self_attn or self.use_disentangled_self_attn:
673
+ self.visual_backbone.blocks[-1].attn.qkv.register_forward_hook(get_self_attention)
674
+ if self.use_disentangled_self_attn:
675
+ self.visual_backbone.blocks[-1].attn.qkv.register_forward_hook(get_self_attention)
676
+ if freeze_text_encoder:
677
+ self.text_backbone.eval()
678
+ self.text_backbone.requires_grad_(False)
679
+ self.avg_self_attn_token = avg_self_attn_token
680
+ if self.avg_self_attn_token or self.use_disentangled_self_attn:
681
+ self.visual_backbone.blocks[-1].attn.qkv.register_forward_hook(self.get_self_attention)
682
+ self.feats = {}
683
+ self.num_global_tokens = 1 if "reg" not in dino_cfg else 5
684
+ self.num_attn_heads = 16
685
+ self.scale = 0.125
686
+
687
+
688
+ @classmethod
689
+ def from_config(cls, cfg):
690
+ if type(cfg) is str:
691
+ # if the configuration is a string, we treat it as a file path
692
+ with open(cfg, 'r') as f:
693
+ cfg = yaml.safe_load(f)['model']
694
+
695
+ model = cls(
696
+ dino_cfg=cfg.get('dino_cfg', "dinov2_vitl14_reg"),
697
+ clip_cfg=cfg.get('clip_cfg', "ViT-B/16"),
698
+ projection_cfg=cfg.get('projection_cfg', "configs/linear.yaml"),
699
+ projection_weights=cfg.get('projection_weights', None),
700
+ avg_self_attn_token=cfg.get('use_avg_self_attn', False),
701
+ use_disentangled_self_attn=cfg.get('use_disentangled_self_attn', False),
702
+ )
703
+ return model
704
+
705
+ def encode_text(self, tokenized_texts):
706
+ x = self.text_backbone.encode_text(tokenized_texts)
707
+ x = self.clip2dino_proj.project_clip_txt(x)
708
+ return x
709
+
710
+ def encode_image(self, images):
711
+ batch_size, _, _, _ = images.shape
712
+ x = self.visual_backbone(images, is_training=self.avg_self_attn_token or self.use_disentangled_self_attn)
713
+ if self.avg_self_attn_token:
714
+ batch_size, num_tokens, embed_dim = x['x_norm_patchtokens'].shape
715
+ num_tokens = num_tokens + self.num_global_tokens
716
+ self_attn = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens)
717
+ x = (self_attn.unsqueeze(-1) * x['x_norm_patchtokens']).mean(dim=1)
718
+ if self.use_disentangled_self_attn:
719
+ batch_size, num_tokens, embed_dim = x['x_norm_patchtokens'].shape
720
+ num_tokens = num_tokens + self.num_global_tokens
721
+ self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True)
722
+ self_attn_maps = self_attn_maps.softmax(dim=-1)
723
+ x = (x['x_norm_patchtokens'].unsqueeze(1) * self_attn_maps.unsqueeze(-1)).mean(dim=2)
724
+ return x
725
+
726
+ def get_self_attention(self, module, input, output):
727
+ self.feats['self_attn'] = output
728
+
729
+ def process_self_attention(self, output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False):
730
+ qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4)
731
+ q, k, v = qkv[0] * scale, qkv[1], qkv[2]
732
+ attn = q @ k.transpose(-2, -1)
733
+ self_attn_maps = attn[:, : , 0, num_global_tokens:]
734
+ self_attn = self_attn_maps.mean(dim=1)
735
+ self_attn = self_attn.softmax(dim=-1)
736
+ if ret_self_attn_maps:
737
+ return self_attn, self_attn_maps
738
+ else:
739
+ return self_attn
740
+
741
+ def forward(self, images, tokenized_texts, cosine=True, ret_similarity_matrix=True):
742
+ img_embed = self.encode_image(images)
743
+ txt_embed = self.encode_text(tokenized_texts)
744
+
745
+ if cosine:
746
+ img_embed = F.normalize(img_embed, p=2, dim=1)
747
+ txt_embed = F.normalize(txt_embed, p=2, dim=1)
748
+ x = img_embed @ txt_embed.transpose(1, 0)
749
+ if not ret_similarity_matrix:
750
+ x = x[torch.eye(len(x)) > 0.5] # only diagonal elements
751
+
752
+ return x
753
+
754
+ def __len__(self):
755
+ def count_parameters(model):
756
+ return sum(p.numel() for p in model.parameters())
757
+ return count_parameters(self.visual_backbone) + count_parameters(self.clip2dino_proj) + count_parameters(self.text_backbone.transformer)
hf_model/modules.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # FreeDA
3
+ # ------------------------------------------------------------------------------
4
+ from functools import partial
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+
10
+
11
+ class BLCModuleCompatibleBCHW(nn.Module):
12
+ def forward_blc(self, x):
13
+ raise NotImplementedError()
14
+
15
+ def forward(self, x):
16
+ is2d = x.ndim == 4
17
+ if is2d:
18
+ _, _, H, W = x.shape
19
+ x = rearrange(x, "B C H W -> B (H W) C")
20
+
21
+ x = self.forward_blc(x)
22
+
23
+ if is2d:
24
+ x = rearrange(x, "B (H W) C -> B C H W", H=H, W=W)
25
+
26
+ return x
27
+
28
+
29
+ class FeatureEncoder(nn.Module):
30
+ """Encoder + Feature extractor
31
+ """
32
+ def __init__(self, safe=True):
33
+ super().__init__()
34
+ self.safe = safe # clone return features to protect it from after-modification
35
+ self._features = []
36
+
37
+ def hook(self, module, input, output):
38
+ self._features.append(output)
39
+
40
+ def clear_features(self):
41
+ self._features.clear()
42
+
43
+ def _encode(self, x):
44
+ raise NotImplementedError()
45
+
46
+ def forward(self, *args, ret_feats=False, **kwargs):
47
+ self.clear_features()
48
+
49
+ x = self._encode(*args, **kwargs)
50
+
51
+ if ret_feats:
52
+ if self.safe:
53
+ features = [t.clone() for t in self._features]
54
+ self.clear_features()
55
+ else:
56
+ features = self._features
57
+ return x, features
58
+ else:
59
+ self.clear_features()
60
+ return x
61
+
62
+
63
+ class Project2d(nn.Module):
64
+ """2d projection by 1x1 conv
65
+
66
+ Args:
67
+ p: [C_in, C_out]
68
+ """
69
+ def __init__(self, p):
70
+ # convert to 1x1 conv weight
71
+ super().__init__()
72
+ p = rearrange(p, "Cin Cout -> Cout Cin 1 1")
73
+ self.p = nn.Parameter(p.detach().clone())
74
+
75
+ def forward(self, x):
76
+ return F.conv2d(x, self.p) # 1x1 conv
77
+
78
+
79
+ def dispatcher(dispatch_fn):
80
+ def decorated(key, *args):
81
+ if callable(key):
82
+ return key
83
+
84
+ if key is None:
85
+ key = "none"
86
+
87
+ return dispatch_fn(key, *args)
88
+
89
+ return decorated
90
+
91
+
92
+ @dispatcher
93
+ def activ_dispatch(activ):
94
+ return {
95
+ "none": nn.Identity,
96
+ "relu": nn.ReLU,
97
+ "lrelu": partial(nn.LeakyReLU, negative_slope=0.2),
98
+ "gelu": nn.GELU,
99
+ }[activ.lower()]
100
+
101
+
102
+ def get_norm_fn(norm, C):
103
+ """2d normalization layers
104
+ """
105
+ if norm is None or norm == "none":
106
+ return nn.Identity()
107
+
108
+ return {
109
+ "bn": nn.BatchNorm2d(C),
110
+ "syncbn": nn.SyncBatchNorm(C),
111
+ "ln": LayerNorm2d(C),
112
+ "gn": nn.GroupNorm(32, C),
113
+ }[norm]
114
+
115
+
116
+ class LayerNorm2d(nn.LayerNorm):
117
+ def __init__(self, num_channels, eps=1e-5, affine=True):
118
+ super().__init__(num_channels, eps=eps, elementwise_affine=affine)
119
+
120
+ def forward(self, x):
121
+ return F.layer_norm(
122
+ x.permute(0, 2, 3, 1),
123
+ self.normalized_shape,
124
+ self.weight,
125
+ self.bias,
126
+ self.eps
127
+ ).permute(0, 3, 1, 2)
128
+
129
+
130
+ class Gate(nn.Module):
131
+ """Tanh gate"""
132
+ def __init__(self, init=0.0):
133
+ super().__init__()
134
+ self.gate = nn.Parameter(torch.as_tensor(init))
135
+
136
+ def forward(self, x):
137
+ return torch.tanh(self.gate) * x
138
+
139
+
140
+ class ConvBlock(nn.Module):
141
+ def __init__(
142
+ self,
143
+ C_in,
144
+ C_out,
145
+ kernel_size=3,
146
+ stride=1,
147
+ padding=1,
148
+ norm="none",
149
+ activ="relu",
150
+ bias=True,
151
+ upsample=False,
152
+ downsample=False,
153
+ pad_type="zeros",
154
+ dropout=0.0,
155
+ gate=False,
156
+ ):
157
+ super().__init__()
158
+ if kernel_size == 1:
159
+ assert padding == 0
160
+ self.C_in = C_in
161
+ self.C_out = C_out
162
+
163
+ activ = activ_dispatch(activ)
164
+ self.upsample = upsample
165
+ self.downsample = downsample
166
+
167
+ self.norm = get_norm_fn(norm, C_in)
168
+ self.activ = activ()
169
+ if dropout > 0.0:
170
+ self.dropout = nn.Dropout2d(p=dropout)
171
+ self.conv = nn.Conv2d(
172
+ C_in, C_out, kernel_size, stride, padding,
173
+ bias=bias, padding_mode=pad_type
174
+ )
175
+
176
+ self.gate = Gate() if gate else None
177
+
178
+ def forward(self, x):
179
+ # pre-act
180
+ x = self.norm(x)
181
+ x = self.activ(x)
182
+ if self.upsample:
183
+ x = F.interpolate(x, scale_factor=2)
184
+ if hasattr(self, "dropout"):
185
+ x = self.dropout(x)
186
+ x = self.conv(x)
187
+ if self.downsample:
188
+ x = F.avg_pool2d(x, 2)
189
+
190
+ if self.gate is not None:
191
+ x = self.gate(x)
192
+
193
+ return x
194
+
195
+
196
+ class ResConv(nn.Module):
197
+ """Pre-activate residual block with single or double conv block"""
198
+
199
+ def __init__(
200
+ self,
201
+ C_in,
202
+ C_out,
203
+ kernel_size=3,
204
+ stride=1,
205
+ padding=1,
206
+ norm="none",
207
+ activ="relu",
208
+ upsample=False,
209
+ pad_type="zeros",
210
+ dropout=0.0,
211
+ gate=True, # if True, use zero-init gate
212
+ double=False,
213
+ # norm2 and activ2 are only used when double is True
214
+ norm2=None, # if given, apply it to second conv
215
+ activ2=None # if given, apply it to second conv
216
+ ):
217
+ super().__init__()
218
+
219
+ self.C_in = C_in
220
+ self.C_out = C_out
221
+ self.upsample = upsample
222
+ self.double = double
223
+ self.conv = ConvBlock(
224
+ C_in, C_out, kernel_size, stride, padding, norm, activ,
225
+ pad_type=pad_type, dropout=dropout, gate=gate,
226
+ )
227
+ if double:
228
+ norm2 = norm2 or norm
229
+ activ2 = activ2 or activ
230
+ self.conv2 = ConvBlock(
231
+ C_out, C_out, kernel_size, stride, padding, norm2, activ2,
232
+ pad_type=pad_type, dropout=dropout, gate=gate
233
+ )
234
+
235
+ def forward(self, x):
236
+ if self.upsample:
237
+ x = F.interpolate(x, scale_factor=2)
238
+ x = x + self.conv(x)
239
+
240
+ if self.double:
241
+ x = x + self.conv2(x)
242
+
243
+ return x
hf_model/pamr.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 TU Darmstadt
2
+ # Licnese: Apache 2.0 License.
3
+ # https://github.com/visinf/1-stage-wseg/blob/master/models/mods/pamr.py
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.nn as nn
7
+
8
+ from functools import partial
9
+
10
+ #
11
+ # Helper modules
12
+ #
13
+ class LocalAffinity(nn.Module):
14
+
15
+ def __init__(self, dilations=[1]):
16
+ super(LocalAffinity, self).__init__()
17
+ self.dilations = dilations
18
+ weight = self._init_aff()
19
+ self.register_buffer('kernel', weight)
20
+
21
+ def _init_aff(self):
22
+ # initialising the shift kernel
23
+ weight = torch.zeros(8, 1, 3, 3)
24
+
25
+ for i in range(weight.size(0)):
26
+ weight[i, 0, 1, 1] = 1
27
+
28
+ weight[0, 0, 0, 0] = -1
29
+ weight[1, 0, 0, 1] = -1
30
+ weight[2, 0, 0, 2] = -1
31
+
32
+ weight[3, 0, 1, 0] = -1
33
+ weight[4, 0, 1, 2] = -1
34
+
35
+ weight[5, 0, 2, 0] = -1
36
+ weight[6, 0, 2, 1] = -1
37
+ weight[7, 0, 2, 2] = -1
38
+
39
+ self.weight_check = weight.clone()
40
+
41
+ return weight
42
+
43
+ def forward(self, x):
44
+
45
+ self.weight_check = self.weight_check.type_as(x)
46
+ assert torch.all(self.weight_check.eq(self.kernel))
47
+
48
+ B,K,H,W = x.size()
49
+ x = x.view(B*K,1,H,W)
50
+
51
+ x_affs = []
52
+ for d in self.dilations:
53
+ x_pad = F.pad(x, [d]*4, mode='replicate')
54
+ x_aff = F.conv2d(x_pad, self.kernel, dilation=d)
55
+ x_affs.append(x_aff)
56
+
57
+ x_aff = torch.cat(x_affs, 1)
58
+ return x_aff.view(B,K,-1,H,W)
59
+
60
+ class LocalAffinityCopy(LocalAffinity):
61
+
62
+ def _init_aff(self):
63
+ # initialising the shift kernel
64
+ weight = torch.zeros(8, 1, 3, 3)
65
+
66
+ weight[0, 0, 0, 0] = 1
67
+ weight[1, 0, 0, 1] = 1
68
+ weight[2, 0, 0, 2] = 1
69
+
70
+ weight[3, 0, 1, 0] = 1
71
+ weight[4, 0, 1, 2] = 1
72
+
73
+ weight[5, 0, 2, 0] = 1
74
+ weight[6, 0, 2, 1] = 1
75
+ weight[7, 0, 2, 2] = 1
76
+
77
+ self.weight_check = weight.clone()
78
+ return weight
79
+
80
+ class LocalStDev(LocalAffinity):
81
+
82
+ def _init_aff(self):
83
+ weight = torch.zeros(9, 1, 3, 3)
84
+ weight.zero_()
85
+
86
+ weight[0, 0, 0, 0] = 1
87
+ weight[1, 0, 0, 1] = 1
88
+ weight[2, 0, 0, 2] = 1
89
+
90
+ weight[3, 0, 1, 0] = 1
91
+ weight[4, 0, 1, 1] = 1
92
+ weight[5, 0, 1, 2] = 1
93
+
94
+ weight[6, 0, 2, 0] = 1
95
+ weight[7, 0, 2, 1] = 1
96
+ weight[8, 0, 2, 2] = 1
97
+
98
+ self.weight_check = weight.clone()
99
+ return weight
100
+
101
+ def forward(self, x):
102
+ # returns (B,K,P,H,W), where P is the number
103
+ # of locations
104
+ x = super(LocalStDev, self).forward(x)
105
+
106
+ return x.std(2, keepdim=True)
107
+
108
+ class LocalAffinityAbs(LocalAffinity):
109
+
110
+ def forward(self, x):
111
+ x = super(LocalAffinityAbs, self).forward(x)
112
+ return torch.abs(x)
113
+
114
+ #
115
+ # PAMR module
116
+ #
117
+ class PAMR(nn.Module):
118
+
119
+ def __init__(self, num_iter=1, dilations=[1]):
120
+ super(PAMR, self).__init__()
121
+
122
+ self.num_iter = num_iter
123
+ self.aff_x = LocalAffinityAbs(dilations)
124
+ self.aff_m = LocalAffinityCopy(dilations)
125
+ self.aff_std = LocalStDev(dilations)
126
+
127
+ def forward(self, x, mask):
128
+ mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True)
129
+
130
+ # x: [BxKxHxW]
131
+ # mask: [BxCxHxW]
132
+ B,K,H,W = x.size()
133
+ _,C,_,_ = mask.size()
134
+
135
+ x_std = self.aff_std(x)
136
+
137
+ x = -self.aff_x(x) / (1e-8 + 0.1 * x_std)
138
+ x = x.mean(1, keepdim=True)
139
+ x = F.softmax(x, 2)
140
+
141
+ for _ in range(self.num_iter):
142
+ m = self.aff_m(mask) # [BxCxPxHxW]
143
+ mask = (m * x).sum(2)
144
+
145
+ # xvals: [BxCxHxW]
146
+ return mask
hf_model/talk2dino.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ import pickle
4
+ from math import sqrt
5
+ import re
6
+ import yaml
7
+
8
+ import numpy as np
9
+ import timm
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torchvision
14
+ from einops import rearrange
15
+ from transformers import BertModel, AutoTokenizer
16
+ import torchvision.transforms as T
17
+ import clip
18
+ import importlib
19
+ import hf_model.us as us
20
+
21
+ from hf_model.pamr import PAMR
22
+ from hf_model.masker import DINOTextMasker
23
+ from hf_model.templates import get_template
24
+
25
+ from hf_model.model import ProjectionLayer, VisualProjectionLayer, CLIPLastLayer, DoubleMLP
26
+ from hf_model.hooks import average_text_tokens, get_vit_out, feats
27
+
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+
31
+
32
+ class DINOText(nn.Module):
33
+
34
+ def get_self_attention(self, module, input, output):
35
+ self.feats['self_attn'] = output
36
+
37
+ def get_clip_second_last_dense_out(self, model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
38
+ self.feats['clip_second_last_out'] = output
39
+ self.feats['clip_second_last_out'].to(dtype=torch.float32)
40
+
41
+ def get_all_out_tokens(self, model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
42
+ self.feats['clip_txt_out_tokens'] = output
43
+
44
+ def __init__(
45
+ self, model_name, resize_dim, clip_model_name, proj_class, proj_name, proj_model, avg_self_attn_token=False, disentangled_self_attn_token=True, loss=None, pre_trained=True,
46
+ unfreeze_last_text_layer=False, unfreeze_last_image_layer=False, is_eval=True, use_avg_text_token=False, keep_cls=False, keep_end_seq=False, with_bg_clean=False, **kwargs
47
+ ):
48
+ super().__init__()
49
+ self.feats = {}
50
+ self.model_name = model_name
51
+ # loading the model
52
+
53
+ if 'dinov2' in model_name:
54
+ self.model_family = 'facebookresearch/dinov2' if 'dinov2' in model_name else 'facebookresearch/dino:main'
55
+ self.model = torch.hub.load(self.model_family, model_name)
56
+ elif 'dinov3' in model_name:
57
+ def extract_dinov3_name(path, n_parts=2):
58
+ filename = os.path.basename(path)
59
+ parts = filename.split("_")
60
+ return "_".join(parts[:n_parts])
61
+ self.model = torch.hub.load('src/dinov3', extract_dinov3_name(model_name), source='local', weights=model_name)
62
+
63
+
64
+ elif 'mae' in model_name or 'sam' in model_name or 'clip' in model_name or 'dino' in model_name:
65
+ self.model = timm.create_model(
66
+ model_name,
67
+ pretrained=True,
68
+ num_classes=0, # remove classifier nn.Linear
69
+ img_size=resize_dim
70
+ )
71
+
72
+ if 'sam' in model_name:
73
+ self.model.blocks[-1].register_forward_hook(get_vit_out)
74
+ else:
75
+ raise Exception("Unknown ViT model")
76
+ # self.model.eval()
77
+ mean = (0.485, 0.456, 0.406) if not 'clip' in model_name else (0.4815, 0.4578, 0.4082)
78
+ std = (0.229, 0.224, 0.225) if not 'clip' in model_name else (0.2686, 0.2613, 0.2758)
79
+ self.image_transforms = T.Compose([
80
+ T.Resize((resize_dim, resize_dim)),
81
+ lambda x: T.ToTensor()(x) if not isinstance(x, torch.Tensor) else x / 255.0, # ensure tensor
82
+ T.Normalize(mean, std),
83
+ ])
84
+
85
+ self.model.to(device)
86
+ self.model.requires_grad_(False)
87
+
88
+ self.clip_model_name = clip_model_name
89
+ if 'bert' in self.clip_model_name:
90
+ self.clip_model = BertModel.from_pretrained(self.clip_model_name, output_hidden_states = False)
91
+ # load the corresponding wordtokenizer
92
+ self.tokenizer = AutoTokenizer.from_pretrained(self.clip_model_name)
93
+ else:
94
+ self.clip_model, _ = clip.load(clip_model_name, device=device)
95
+ self.clip_model.eval()
96
+ self.clip_model.requires_grad_(False)
97
+ if unfreeze_last_text_layer:
98
+ for param in self.clip_model.transformer.resblocks[-1].parameters():
99
+ param.requires_grad = True
100
+ for param in self.clip_model.ln_final.parameters():
101
+ param.requires_grad = True
102
+ self.clip_model.text_projection.requires_grad = True
103
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
104
+
105
+ # with open(os.path.join('configs', f"{proj_class}.yaml"), 'r') as config_file:
106
+ # config = yaml.safe_load(config_file)['model']
107
+ if 'vitb_mlp_infonce' in proj_class:
108
+ config = {
109
+ 'act': 'tanh', # None, tanh, relu or sigmoid
110
+ 'hidden_layer': True,
111
+ 'dino_embed_dim': 768
112
+ }
113
+ elif 'vitl_mlp_infonce' in proj_class:
114
+ config = {
115
+ 'act': 'tanh', # None, tanh, relu or sigmoid
116
+ 'hidden_layer': True,
117
+ 'dino_embed_dim': 1024
118
+ }
119
+
120
+ self.proj = ProjectionLayer.from_config(config)
121
+ if type(self.proj) == CLIPLastLayer:
122
+ self.clip_model.transformer.resblocks[-2].register_forward_hook(self.get_clip_second_last_dense_out)
123
+
124
+
125
+ # if pre_trained:
126
+ # self.proj.load_state_dict(torch.load(os.path.join("weights", f"{proj_name}.pth"), 'cpu'))
127
+ self.proj.to(device)
128
+
129
+ self.masker = DINOTextMasker(similarity_type="cosine")
130
+ self.masker = self.masker.eval()
131
+
132
+ self.pamr = None
133
+
134
+ self.avg_self_attn_token = avg_self_attn_token
135
+ self.disentangled_self_attn_token = disentangled_self_attn_token
136
+
137
+ if self.avg_self_attn_token or self.disentangled_self_attn_token or is_eval:
138
+ self.model.blocks[-1].attn.qkv.register_forward_hook(self.get_self_attention)
139
+ self.num_global_tokens = 5 if 'reg' in model_name or 'dinov3' in model_name else 1
140
+ if 'sam' in self.model_name:
141
+ self.num_global_tokens = 0
142
+ self.num_attn_heads = self.model.num_heads
143
+ self.scale = 0.125
144
+
145
+ self.use_avg_text_token = use_avg_text_token
146
+ if self.use_avg_text_token:
147
+ self.feats = {}
148
+ # in this case we register a forward hook with the aim of getting all the tokens and not only the cls
149
+ self.clip_model.ln_final.register_forward_hook(self.get_all_out_tokens)
150
+ self.keep_cls = keep_cls
151
+ self.keep_end_seq = keep_end_seq
152
+
153
+ self.with_bg_clean = with_bg_clean
154
+
155
+
156
+ def process_self_attention(self, output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False):
157
+ qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4)
158
+ q, k, v = qkv[0] * scale, qkv[1], qkv[2]
159
+ attn = q @ k.transpose(-2, -1)
160
+ self_attn_maps = attn[:, : , 0, num_global_tokens:]
161
+ self_attn = self_attn_maps.mean(dim=1)
162
+ self_attn = self_attn.softmax(dim=-1)
163
+ if ret_self_attn_maps:
164
+ return self_attn, self_attn_maps
165
+ else:
166
+ return self_attn
167
+
168
+ def encode_text(self, tokenized_texts):
169
+ if type(self.proj) == CLIPLastLayer:
170
+ self.clip_model.encode_text(tokenized_texts)
171
+ x = self.feats['clip_second_last_out']
172
+ x = x.to(dtype=torch.float32)
173
+ else:
174
+ x = self.clip_model.encode_text(tokenized_texts)
175
+ return x
176
+
177
+ def encode_image(self, images):
178
+ batch_size, _, _, _ = images.shape
179
+ self_attn_maps = None
180
+ x = self.model(images, is_training=(self.avg_self_attn_token or self.disentangled_self_attn_token))
181
+ batch_size, num_tokens, embed_dim = x['x_norm_patchtokens'].shape
182
+ num_tokens = num_tokens + self.num_global_tokens
183
+ if self.avg_self_attn_token or self.disentangled_self_attn_token:
184
+ self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True)
185
+ if self.avg_self_attn_token:
186
+ x = (self_attn.unsqueeze(-1) * x['x_norm_patchtokens']).mean(dim=1)
187
+ elif self.disentangled_self_attn_token:
188
+ self_attn_maps = self_attn_maps.softmax(dim=-1)
189
+ x = (x['x_norm_patchtokens'].unsqueeze(1) * self_attn_maps.unsqueeze(-1)).mean(dim=2)
190
+
191
+ return x, self_attn_maps
192
+
193
+ def forward(self, image, text, return_logit_scale=False):
194
+ with torch.no_grad():
195
+ txt_embed = self.encode_text(text)
196
+
197
+ img_embed, self_attn_maps = self.encode_image(image)
198
+
199
+ if type(self.proj) == CLIPLastLayer:
200
+ img_embed, txt_embed = self.proj(img_embed, txt_embed, ret_embeds=True, self_attn_maps=self_attn_maps, text_argmax=text.argmax(dim=-1))
201
+ else:
202
+ img_embed, txt_embed = self.proj(img_embed, txt_embed, ret_embeds=True, self_attn_maps=self_attn_maps)
203
+
204
+ if return_logit_scale:
205
+ return txt_embed, img_embed, self.logit_scale
206
+
207
+ return txt_embed, img_embed
208
+
209
+ def compute_loss(self, image, text, cosine=True, ret_similarity_matrix=True):
210
+ ret = {}
211
+ if cosine:
212
+ img_embed = F.normalize(img_embed, p=2, dim=1)
213
+ txt_embed = F.normalize(txt_embed, p=2, dim=1)
214
+ sim = img_embed @ txt_embed.transpose(1, 0)
215
+ if not ret_similarity_matrix:
216
+ sim = sim[torch.eye(len(sim)) > 0.5] # only diagonal elements
217
+
218
+ ret['contrastive_loss'] = self.contrastive_loss.compute_contrastive_loss(sim)
219
+
220
+ return ret
221
+
222
+
223
+ @torch.no_grad()
224
+ def build_dataset_class_tokens(self, template_set, classnames):
225
+ tokens = []
226
+ templates = get_template(template_set)
227
+ for classname in classnames:
228
+ if 'bert' not in self.clip_model_name:
229
+ tokens.append(
230
+ clip.tokenize([template.format(classname) for template in templates])
231
+ )
232
+ else:
233
+ tokens.append(self.tokenizer([template.format(classname) for template in templates], return_tensors='pt', padding='max_length')['input_ids'])
234
+ # [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
235
+ tokens = torch.stack(tokens)
236
+
237
+ return tokens
238
+
239
+ @torch.no_grad()
240
+ def build_text_embedding(self, text):
241
+ """
242
+ Args:
243
+ text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH] text tokens
244
+
245
+ Returns:
246
+ text_embs
247
+ """
248
+ text = text.to(next(self.parameters()).device)
249
+ num_classes, num_templates = text.shape[:2]
250
+ text_argmax = text.argmax(dim=-1)
251
+ text_argmax = rearrange(text_argmax, 'n t -> (n t)', n=num_classes, t=num_templates)
252
+ text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates)
253
+ # chunked inference for memory limitation
254
+ chunk_size = 32
255
+ N = text.size(0)
256
+ if type(self.proj) == CLIPLastLayer:
257
+ text_embs = torch.cat([
258
+ self.proj.project_clip_txt(self.encode_text(text[i:i + chunk_size]).permute(1, 0, 2), text_argmax=text_argmax[i:i + chunk_size])
259
+ for i in range(0, N, chunk_size)
260
+ ])
261
+ else:
262
+ if not self.use_avg_text_token:
263
+ # performing classification using CLS textual token
264
+ if 'bert' not in self.clip_model_name:
265
+ text_embs = torch.cat([
266
+ self.clip_model.encode_text(text[i:i + chunk_size])
267
+ for i in range(0, N, chunk_size)
268
+ ])
269
+ else:
270
+ # encoding with BERT
271
+ text_embs = []
272
+ for i in range(0, N, chunk_size):
273
+ outputs = self.clip_model(text[i:i + chunk_size])
274
+ text_embs.append(outputs['pooler_output'])
275
+ text_embs = torch.cat(text_embs)
276
+ else:
277
+ # using text token average
278
+ text_embs = []
279
+ for i in range(0, N, chunk_size):
280
+ self.clip_model.encode_text(text[i:i + chunk_size])
281
+ text_embs.append(average_text_tokens(self.feats['clip_txt_out_tokens'] @ self.clip_model.text_projection, text[i:i + chunk_size] > 0, self.keep_cls, self.keep_end_seq))
282
+ text_embs = torch.cat(text_embs)
283
+ # [N, T, C]
284
+ text_embs = rearrange(text_embs, '(n t) c -> n t c', n=num_classes, t=num_templates)
285
+ # [N, C]
286
+ text_embs = text_embs.mean(dim=1).float()
287
+ if type(self.proj) == ProjectionLayer or type(self.proj) == DoubleMLP:
288
+ text_embs = self.proj.project_clip_txt(text_embs)
289
+ text_embs = us.normalize(text_embs, dim=-1)
290
+
291
+ return text_embs
292
+
293
+ def apply_pamr(self, image, mask):
294
+ image = F.interpolate(image, mask.shape[-2:], mode="bilinear", align_corners=True)
295
+ if self.pamr is None:
296
+ pamr_iter = 10
297
+ pamr_kernel = [1, 2, 4, 8, 12, 24]
298
+ self.pamr = PAMR(pamr_iter, pamr_kernel)
299
+ self.pamr.eval()
300
+ self.pamr.to(next(self.parameters()).device)
301
+
302
+ mask = self.pamr(image, mask)
303
+ return mask
304
+
305
+ def compute_padsize(self, H: int, W: int, patch_size: int):
306
+ l, r, t, b = 0, 0, 0, 0
307
+ if W % patch_size:
308
+ lr = patch_size - (W % patch_size)
309
+ l = lr // 2
310
+ r = lr - l
311
+
312
+ if H % patch_size:
313
+ tb = patch_size - (H % patch_size)
314
+ t = tb // 2
315
+ b = tb - t
316
+
317
+ return l, r, t, b
318
+
319
+ @torch.no_grad()
320
+ def generate_masks(
321
+ self, image, img_metas, text_emb, classnames, text_is_token=False, apply_pamr=False, background_func="weighted_average_sigmoid", lambda_bg=0.2,
322
+ # kp_w=0.3,
323
+ ):
324
+ """Generate masks for each text embeddings
325
+
326
+ Args:
327
+ image [B, 3, H, W]
328
+
329
+ Returns:
330
+ softmask [B, N, H, W]: softmasks for each text embeddings
331
+ """
332
+
333
+ H, W = image.shape[2:] # original image shape
334
+
335
+ # padded image size
336
+ pH, pW = image.shape[2:]
337
+ num_classes = text_emb.shape[0]
338
+ batch_size = image.shape[0]
339
+
340
+ image = image[:, [2, 1, 0], :, :] # BGR to RGB
341
+ ori_image = image.clone()
342
+
343
+ img_preprocessed = self.image_transforms(image).to(next(self.parameters()).device)
344
+ if 'dinov2' in self.model_name or 'dinov3' in self.model_name:
345
+ image_feat = self.model.forward_features(img_preprocessed)['x_norm_patchtokens']
346
+ elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name:
347
+ image_feat = self.model.forward_features(img_preprocessed)[:, 1:, :]
348
+ elif 'sam' in self.model_name:
349
+ self.model.forward_features(img_preprocessed)
350
+ image_feat = feats['vit_out'].reshape(feats['vit_out'].shape[0], feats['vit_out'].shape[1]**2, feats['vit_out'].shape[-1]) # BS x N_PATCHES x EMBED_DIM
351
+
352
+ batch_size, num_tokens, embed_dim = image_feat.shape
353
+ if type(self.proj) == VisualProjectionLayer:
354
+ image_feat = self.proj.project_dino(image_feat.float())
355
+ if type(self.proj) == DoubleMLP:
356
+ image_feat = self.proj.project_visual(image_feat.float())
357
+ b, np, c = image_feat.shape
358
+ np_h = np_w = int(sqrt(np))
359
+ image_feat = image_feat.reshape(b, np_h, np_w, c).permute(0, 3, 1, 2)
360
+
361
+ self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens + self.num_global_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True)
362
+ mask, simmap = self.masker.forward_seg(image_feat, text_emb, hard=False) # [B, N, H', W']
363
+
364
+ if self.with_bg_clean:
365
+ mask = self.similarity_assignment_weighted(mask, image_feat, self_attn_maps, text_emb, lambda_bg)
366
+
367
+ # resize
368
+ mask = F.interpolate(mask, (pH, pW), mode='bilinear', align_corners=True) # [B, N, H, W]
369
+
370
+ if apply_pamr:
371
+ for c in range(0, mask.shape[1], 30):
372
+ mask[:, c:c + 30] = self.apply_pamr(ori_image, mask[:, c:c + 30])
373
+
374
+ assert mask.shape[2] == H and mask.shape[3] == W, f"shape mismatch: ({H}, {W}) / {mask.shape}"
375
+
376
+ return mask, simmap
377
+
378
+ def similarity_assignment_weighted(self, mask, image_feat, self_attn_maps, text_emb, lambda_bg=0.2):
379
+ bs, c, h, w = image_feat.shape
380
+ bs, num_classes, h, w = mask.shape
381
+ bs, num_heads, hw = self_attn_maps.shape
382
+ image_feat = image_feat.reshape(bs, c, hw)
383
+ num_classes, c = text_emb.shape
384
+ avg_head_embed = (self_attn_maps.unsqueeze(2) * image_feat.unsqueeze(1)).mean(dim=-1)
385
+ avg_head_embed = avg_head_embed / avg_head_embed.norm(dim=-1, keepdim=True)
386
+ avg_head_embed = avg_head_embed.permute(0, 2, 1) # [B, C, M]
387
+ head_text_sim = text_emb.unsqueeze(0) @ avg_head_embed # [B, M, N]
388
+ head_text_sim = (head_text_sim).softmax(dim=-1)
389
+ head_text_sim_sum = head_text_sim.sum(dim=-1)
390
+
391
+ self_attn_maps_repeat = self_attn_maps.unsqueeze(1).repeat(1, num_classes, 1, 1)
392
+ head_text_sim_repeat = head_text_sim.unsqueeze(-1).repeat(1, 1, 1, hw)
393
+ avg_self_attn_per_class = (self_attn_maps_repeat * head_text_sim_repeat).sum(dim=2) / head_text_sim_sum.unsqueeze(-1).repeat(1, 1, hw)
394
+ avg_self_attn_per_class = avg_self_attn_per_class.softmax(dim=-1)
395
+
396
+ min_self_attn = avg_self_attn_per_class.min().item()
397
+ max_self_attn = avg_self_attn_per_class.max().item()
398
+ max_self_attn = max(max_self_attn, max_self_attn - min_self_attn)
399
+ avg_self_attn_per_class = avg_self_attn_per_class - min_self_attn
400
+ avg_self_attn_per_class = avg_self_attn_per_class / max_self_attn
401
+ avg_self_attn_per_class = avg_self_attn_per_class * (mask.max() - mask.min()) + mask.min()
402
+ mask = mask.reshape(num_classes, hw) # [N, P]
403
+ mask_output = (mask + lambda_bg * avg_self_attn_per_class).reshape(bs, num_classes, h, w) / (1 + lambda_bg)
404
+ return mask_output
405
+
406
+
407
+ from huggingface_hub import PyTorchModelHubMixin
408
+
409
+ class Talk2DINO(DINOText, PyTorchModelHubMixin):
410
+ def encode_text(self, texts):
411
+ """ texts: string or list of strings
412
+ returns: text embeddings (N, D) where N is the number of texts, D is the embedding dimension
413
+ """
414
+ text_tokens = clip.tokenize(texts).to(self.parameters().__next__().device)
415
+ txt_embed = self.clip_model.encode_text(text_tokens)
416
+ txt_embed = self.proj.project_clip_txt(txt_embed)
417
+ return txt_embed
418
+
419
+ def encode_image(self, images):
420
+ """ images: PIL image or list of PIL images
421
+ returns: image embeddings (N, L, D) where N is the number of images, L is the number of patches, D is the embedding dimension
422
+ """
423
+ if type(images) is not list:
424
+ images = [images]
425
+ img_preprocessed = [self.image_transforms(img).to(next(self.parameters()).device) for img in images]
426
+ img_preprocessed = torch.stack(img_preprocessed)
427
+ if 'dinov2' in self.model_name or 'dinov3' in self.model_name:
428
+ img_embed = self.model.forward_features(img_preprocessed)['x_norm_patchtokens']
429
+ elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name:
430
+ img_embed = self.model.forward_features(img_preprocessed)[:, 1:, :]
431
+
432
+ return img_embed
hf_model/templates.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # FreeDA
3
+ # ------------------------------------------------------------------------------
4
+ # Modified from CLIP (https://github.com/openai/CLIP)
5
+ # Copyright (c) 2021 OpenAI. All Rights Reserved.
6
+ # ------------------------------------------------------------------------------
7
+
8
+ full_imagenet_templates = [
9
+ "a bad photo of a {}.",
10
+ "a photo of many {}.",
11
+ "a sculpture of a {}.",
12
+ "a photo of the hard to see {}.",
13
+ "a low resolution photo of the {}.",
14
+ "a rendering of a {}.",
15
+ "graffiti of a {}.",
16
+ "a bad photo of the {}.",
17
+ "a cropped photo of the {}.",
18
+ "a tattoo of a {}.",
19
+ "the embroidered {}.",
20
+ "a photo of a hard to see {}.",
21
+ "a bright photo of a {}.",
22
+ "a photo of a clean {}.",
23
+ "a photo of a dirty {}.",
24
+ "a dark photo of the {}.",
25
+ "a drawing of a {}.",
26
+ "a photo of my {}.",
27
+ "the plastic {}.",
28
+ "a photo of the cool {}.",
29
+ "a close-up photo of a {}.",
30
+ "a black and white photo of the {}.",
31
+ "a painting of the {}.",
32
+ "a painting of a {}.",
33
+ "a pixelated photo of the {}.",
34
+ "a sculpture of the {}.",
35
+ "a bright photo of the {}.",
36
+ "a cropped photo of a {}.",
37
+ "a plastic {}.",
38
+ "a photo of the dirty {}.",
39
+ "a jpeg corrupted photo of a {}.",
40
+ "a blurry photo of the {}.",
41
+ "a photo of the {}.",
42
+ "a good photo of the {}.",
43
+ "a rendering of the {}.",
44
+ "a {} in a video game.",
45
+ "a photo of one {}.",
46
+ "a doodle of a {}.",
47
+ "a close-up photo of the {}.",
48
+ "a photo of a {}.",
49
+ "the origami {}.",
50
+ "the {} in a video game.",
51
+ "a sketch of a {}.",
52
+ "a doodle of the {}.",
53
+ "a origami {}.",
54
+ "a low resolution photo of a {}.",
55
+ "the toy {}.",
56
+ "a rendition of the {}.",
57
+ "a photo of the clean {}.",
58
+ "a photo of a large {}.",
59
+ "a rendition of a {}.",
60
+ "a photo of a nice {}.",
61
+ "a photo of a weird {}.",
62
+ "a blurry photo of a {}.",
63
+ "a cartoon {}.",
64
+ "art of a {}.",
65
+ "a sketch of the {}.",
66
+ "a embroidered {}.",
67
+ "a pixelated photo of a {}.",
68
+ "itap of the {}.",
69
+ "a jpeg corrupted photo of the {}.",
70
+ "a good photo of a {}.",
71
+ "a plushie {}.",
72
+ "a photo of the nice {}.",
73
+ "a photo of the small {}.",
74
+ "a photo of the weird {}.",
75
+ "the cartoon {}.",
76
+ "art of the {}.",
77
+ "a drawing of the {}.",
78
+ "a photo of the large {}.",
79
+ "a black and white photo of a {}.",
80
+ "the plushie {}.",
81
+ "a dark photo of a {}.",
82
+ "itap of a {}.",
83
+ "graffiti of the {}.",
84
+ "a toy {}.",
85
+ "itap of my {}.",
86
+ "a photo of a cool {}.",
87
+ "a photo of a small {}.",
88
+ "a tattoo of the {}.",
89
+ ]
90
+
91
+ maskclip_templates = [
92
+ "there is a {} in the scene.",
93
+ "there is the {} in the scene.",
94
+ "this is a {} in the scene.",
95
+ "this is the {} in the scene.",
96
+ "this is one {} in the scene.", # maskclip
97
+ ]
98
+
99
+ sub_imagenet_template = [
100
+ "itap of a {}.",
101
+ "a bad photo of a {}.",
102
+ "a origami {}.",
103
+ "a photo of the large {}.",
104
+ "a {} in a video game.",
105
+ "art of the {}.",
106
+ "a photo of the small {}.",
107
+ ]
108
+
109
+ simple_imagenet_template = [
110
+ "a photo of a {}.",
111
+ ]
112
+
113
+ plural_template = [
114
+ "a photo of {}s.",
115
+ ]
116
+
117
+ identity_template = [
118
+ "{}",
119
+ ]
120
+
121
+ template_meta = {
122
+ "full": full_imagenet_templates,
123
+ "full+maskclip": full_imagenet_templates + maskclip_templates, # templates used in maskclip paper
124
+ "subset": sub_imagenet_template,
125
+ "subset+maskclip": sub_imagenet_template + maskclip_templates,
126
+ "maskclip": maskclip_templates,
127
+ "simple": simple_imagenet_template,
128
+ "plural": plural_template,
129
+ "identity": identity_template,
130
+ }
131
+
132
+
133
+ def get_template(key):
134
+ if key in template_meta:
135
+ return template_meta[key]
136
+
137
+ gdic = globals()
138
+ if key in gdic:
139
+ return gdic[key]
140
+
141
+ raise ValueError(key)
142
+
143
+
144
+ # custom template boosts performance a little.
145
+ custom = sub_imagenet_template + [
146
+ "a photo of many {}.",
147
+ "a photo of {}s.",
148
+ ]
hf_model/us.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # FreeDA
3
+ # ------------------------------------------------------------------------------
4
+ from typing import Dict, List, Any
5
+ from datetime import datetime
6
+ from itertools import chain
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.distributed as dist
12
+ import numpy as np
13
+
14
+ # ImageNet mean/std (from timm)
15
+
16
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
17
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
18
+
19
+ DEFAULT_MEAN = IMAGENET_DEFAULT_MEAN
20
+ DEFAULT_STD = IMAGENET_DEFAULT_STD
21
+
22
+ # NOTE Originally CLIP statistics should be used, but the legacy of ImageNet statistics
23
+ # from GroupViT is applied. Fortunately, CLIP is quite robust to slightly different
24
+ # normalization constants (https://github.com/openai/CLIP/issues/20#issuecomment-764985771).
25
+
26
+
27
+ def unnorm(x):
28
+ mean = torch.as_tensor(DEFAULT_MEAN, device=x.device)[None, ..., None, None]
29
+ std = torch.as_tensor(DEFAULT_STD, device=x.device)[None, ..., None, None]
30
+ return x.mul(std).add(mean)
31
+
32
+
33
+ # DEBUG NaN
34
+ def check_nonfinite(x, name=""):
35
+ rank = dist.get_rank()
36
+ n_nan = x.isnan().sum()
37
+ n_inf = x.isinf().sum()
38
+ if n_nan or n_inf:
39
+ print(f"[RANK {rank}] {name} is not finite: #nan={n_nan}, #inf={n_inf}")
40
+ return True
41
+
42
+ print(f"[RANK {rank}] {name} is OK ...")
43
+ return False
44
+
45
+
46
+ def normalize(t, dim, eps=1e-6):
47
+ """Large default eps for fp16"""
48
+ return F.normalize(t, dim=dim, eps=eps)
49
+
50
+
51
+ def timestamp(fmt="%y%m%d-%H%M%S"):
52
+ return datetime.now().strftime(fmt)
53
+
54
+
55
+ def merge_dicts_by_key(dics: List[Dict]) -> Dict[Any, List]:
56
+ """Merge dictionaries by key. All of dicts must have same keys."""
57
+ ret = {key: [] for key in dics[0].keys()}
58
+ for dic in dics:
59
+ for key, value in dic.items():
60
+ ret[key].append(value)
61
+
62
+ return ret
63
+
64
+
65
+ def flatten_2d_list(list2d):
66
+ return list(chain.from_iterable(list2d))
67
+
68
+
69
+ def num_params(module):
70
+ return sum(p.numel() for p in module.parameters())
71
+
72
+
73
+ def param_trace(name, module, depth=0, max_depth=999, threshold=0, printf=print):
74
+ if depth > max_depth:
75
+ return
76
+ prefix = " " * depth
77
+ n_params = num_params(module)
78
+ if n_params > threshold:
79
+ printf("{:60s}\t{:10.3f}M".format(prefix + name, n_params / 1024 / 1024))
80
+ for n, m in module.named_children():
81
+ if depth == 0:
82
+ child_name = n
83
+ else:
84
+ child_name = "{}.{}".format(name, n)
85
+ param_trace(child_name, m, depth + 1, max_depth, threshold, printf)
86
+
87
+
88
+ @torch.no_grad()
89
+ def hash_bn(module):
90
+ summary = []
91
+ for m in module.modules():
92
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
93
+ w = m.weight.detach().mean().item()
94
+ b = m.bias.detach().mean().item()
95
+ rm = m.running_mean.detach().mean().item()
96
+ rv = m.running_var.detach().mean().item()
97
+ summary.append((w, b, rm, rv))
98
+
99
+ if not summary:
100
+ return 0.0, 0.0
101
+
102
+ w, b, rm, rv = [np.mean(col) for col in zip(*summary)]
103
+ p = np.mean([w, b])
104
+ s = np.mean([rm, rv])
105
+
106
+ return p, s
107
+
108
+
109
+ @torch.no_grad()
110
+ def hash_params(module):
111
+ return torch.as_tensor([p.mean() for p in module.parameters()]).mean().item()
112
+
113
+
114
+ @torch.no_grad()
115
+ def hashm(module):
116
+ p = hash_params(module)
117
+ _, s = hash_bn(module)
118
+
119
+ return p, s