Spaces:
Sleeping
Sleeping
check
Browse files- app.py +25 -12
- clip_dinoiser.yaml +39 -0
- models/__init__.py +2 -0
- models/builder.py +16 -0
- models/clip_dinoiser/__init__.py +1 -0
- models/clip_dinoiser/clip_dinoiser.py +120 -0
- models/maskclip/__init__.py +1 -0
- models/maskclip/maskclip.py +221 -0
- models/maskclip/utils/__init__.py +4 -0
- models/maskclip/utils/embed.py +334 -0
- models/maskclip/utils/prompt_templates.py +82 -0
- models/maskclip/vit.py +470 -0
- segmentation/configs/_base_/custom_import.py +12 -0
- segmentation/configs/_base_/datasets/ade20k.py +58 -0
- segmentation/configs/_base_/datasets/cityscapes.py +37 -0
- segmentation/configs/_base_/datasets/coco.py +39 -0
- segmentation/configs/_base_/datasets/pascal_context.py +38 -0
- segmentation/configs/_base_/datasets/pascal_context59.py +38 -0
- segmentation/configs/_base_/datasets/pascal_voc12.py +40 -0
- segmentation/configs/_base_/datasets/pascal_voc12_20.py +40 -0
- segmentation/configs/_base_/datasets/stuff.py +39 -0
- segmentation/datasets/__init__.py +5 -0
- segmentation/datasets/coco_object.py +42 -0
- segmentation/datasets/coco_stuff.py +97 -0
- segmentation/datasets/pascal_context.py +108 -0
- segmentation/datasets/pascal_voc.py +40 -0
- segmentation/datasets/pascal_voc20.py +40 -0
- segmentation/evaluation/__init__.py +1 -0
- segmentation/evaluation/builder.py +66 -0
- segmentation/evaluation/clip_dinoiser_eval.py +34 -0
- visualization.py +7 -0
app.py
CHANGED
|
@@ -1,18 +1,8 @@
|
|
| 1 |
-
import git
|
| 2 |
-
|
| 3 |
-
git_url = "https://github.com/ariG23498/clip_dinoiser.git"
|
| 4 |
-
repo_dir = "clip_dinoiser"
|
| 5 |
-
git.Repo.clone_from(git_url, repo_dir)
|
| 6 |
-
|
| 7 |
-
import os
|
| 8 |
-
|
| 9 |
-
print(os.getcwd())
|
| 10 |
-
os.chdir("clip_dinoiser/")
|
| 11 |
-
|
| 12 |
from models.builder import build_model
|
| 13 |
-
from
|
| 14 |
from segmentation.datasets import PascalVOCDataset
|
| 15 |
|
|
|
|
| 16 |
from hydra import compose, initialize
|
| 17 |
from PIL import Image
|
| 18 |
import matplotlib.pyplot as plt
|
|
@@ -23,6 +13,29 @@ from operator import itemgetter
|
|
| 23 |
import torch
|
| 24 |
import warnings
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
import gradio as gr
|
| 27 |
|
| 28 |
def greet(name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from models.builder import build_model
|
| 2 |
+
from visualization import mask2rgb
|
| 3 |
from segmentation.datasets import PascalVOCDataset
|
| 4 |
|
| 5 |
+
import os
|
| 6 |
from hydra import compose, initialize
|
| 7 |
from PIL import Image
|
| 8 |
import matplotlib.pyplot as plt
|
|
|
|
| 13 |
import torch
|
| 14 |
import warnings
|
| 15 |
|
| 16 |
+
warnings.filterwarnings("ignore")
|
| 17 |
+
initialize(config_path="configs", version_base=None)
|
| 18 |
+
|
| 19 |
+
from huggingface_hub import Repository
|
| 20 |
+
|
| 21 |
+
repo = Repository(
|
| 22 |
+
local_dir="models",
|
| 23 |
+
clone_from="ariG23498/clip-dinoiser",
|
| 24 |
+
use_auth_token=os.environ.get("token")
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
check_path = 'models/checkpoints/last.pt'
|
| 28 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
+
|
| 30 |
+
check = torch.load(check_path, map_location=device)
|
| 31 |
+
dinoclip_cfg = "clip_dinoiser.yaml"
|
| 32 |
+
cfg = compose(config_name=dinoclip_cfg)
|
| 33 |
+
|
| 34 |
+
model = build_model(cfg.model, class_names=PascalVOCDataset.CLASSES).to(device)
|
| 35 |
+
model.clip_backbone.decode_head.use_templates=False # switching off the imagenet templates for fast inference
|
| 36 |
+
model.load_state_dict(check['model_state_dict'], strict=False)
|
| 37 |
+
model = model.eval()
|
| 38 |
+
|
| 39 |
import gradio as gr
|
| 40 |
|
| 41 |
def greet(name):
|
clip_dinoiser.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_base_: "default.yml"
|
| 2 |
+
defaults:
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
seed: 0
|
| 6 |
+
model_name: clip_dinoiser
|
| 7 |
+
model:
|
| 8 |
+
type: CLIP_DINOiser
|
| 9 |
+
clip_backbone: maskclip
|
| 10 |
+
mask_th: 0.2
|
| 11 |
+
in_dim: 256
|
| 12 |
+
certainty_th: 0.9
|
| 13 |
+
found_th: 0.5
|
| 14 |
+
feats_idx: -3
|
| 15 |
+
|
| 16 |
+
checkpoint_path: "checkpoints/last.pt"
|
| 17 |
+
output: logs
|
| 18 |
+
|
| 19 |
+
evaluate:
|
| 20 |
+
eval_only: true
|
| 21 |
+
task:
|
| 22 |
+
- voc
|
| 23 |
+
- voc20
|
| 24 |
+
- context
|
| 25 |
+
- context59
|
| 26 |
+
- coco_stuff
|
| 27 |
+
- coco_object
|
| 28 |
+
- cityscapes
|
| 29 |
+
- ade20k
|
| 30 |
+
|
| 31 |
+
# evaluation
|
| 32 |
+
voc: segmentation/configs/_base_/datasets/pascal_voc12.py
|
| 33 |
+
voc20: segmentation/configs/_base_/datasets/pascal_voc12_20.py
|
| 34 |
+
context: segmentation/configs/_base_/datasets/pascal_context.py
|
| 35 |
+
context59: segmentation/configs/_base_/datasets/pascal_context59.py
|
| 36 |
+
coco_stuff: segmentation/configs/_base_/datasets/stuff.py
|
| 37 |
+
coco_object: segmentation/configs/_base_/datasets/coco.py
|
| 38 |
+
cityscapes: segmentation/configs/_base_/datasets/cityscapes.py
|
| 39 |
+
ade20k: segmentation/configs/_base_/datasets/ade20k.py
|
models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .maskclip import *
|
| 2 |
+
from .clip_dinoiser import *
|
models/builder.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# author: Monika Wysoczanska
|
| 4 |
+
# ------------------------------------------------------------------------------
|
| 5 |
+
# Modified from GroupViT (https://github.com/NVlabs/GroupViT)
|
| 6 |
+
# Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
from mmcv.utils import Registry
|
| 9 |
+
MODELS = Registry('models')
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_model(config, class_names):
|
| 14 |
+
model = MODELS.build(OmegaConf.to_container(config, resolve=True),
|
| 15 |
+
default_args={'class_names': class_names})
|
| 16 |
+
return model
|
models/clip_dinoiser/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .clip_dinoiser import *
|
models/clip_dinoiser/clip_dinoiser.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology & Oriane Simeoni, valeo.ai
|
| 4 |
+
# ---------------------------------------------------------------------------------------------------
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from models.builder import MODELS
|
| 7 |
+
from models.builder import build_model
|
| 8 |
+
import torch
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@MODELS.register_module()
|
| 17 |
+
class CLIP_DINOiser(nn.Module):
|
| 18 |
+
def __init__(self, clip_backbone, class_names, mask_th=None, found_th=0.5, certainty_th=0.9, apply_found=False,
|
| 19 |
+
in_dim=256, conv_kernel=3, feats_idx=-3):
|
| 20 |
+
|
| 21 |
+
super(CLIP_DINOiser, self).__init__()
|
| 22 |
+
self.mask_th = mask_th
|
| 23 |
+
self.apply_found = apply_found
|
| 24 |
+
self.found_th = found_th
|
| 25 |
+
self.certainty_th = certainty_th
|
| 26 |
+
self.sigmoid = nn.Sigmoid()
|
| 27 |
+
maskclip_cfg = OmegaConf.load(f"configs/{clip_backbone}.yaml")
|
| 28 |
+
self.clip_backbone = build_model(maskclip_cfg["model"], class_names=class_names)
|
| 29 |
+
self.vit_patch_size = self.clip_backbone.patch_size
|
| 30 |
+
self.feats_idx = feats_idx
|
| 31 |
+
self.in_dim = [in_dim]
|
| 32 |
+
in_size = 768 if self.feats_idx != 'final' else 512
|
| 33 |
+
self.bkg_decoder = nn.Conv2d(in_size, 1, (1, 1))
|
| 34 |
+
self.obj_proj = nn.Conv2d(in_size, in_dim, (conv_kernel, conv_kernel),
|
| 35 |
+
padding=conv_kernel // 2, padding_mode='replicate')
|
| 36 |
+
|
| 37 |
+
# setup clip feature for training
|
| 38 |
+
if feats_idx != 'final':
|
| 39 |
+
train_feats = {}
|
| 40 |
+
def get_activation(name):
|
| 41 |
+
def hook(model, input, output):
|
| 42 |
+
train_feats[name] = output.detach()
|
| 43 |
+
return hook
|
| 44 |
+
self.clip_backbone.backbone.layers[feats_idx].ln2.register_forward_hook(get_activation('clip_inter'))
|
| 45 |
+
self.train_feats = train_feats
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def forward_pass(self, x):
|
| 49 |
+
clip_feats = self.get_clip_map(x)[0]
|
| 50 |
+
B, c_dim, h, w = clip_feats.shape
|
| 51 |
+
_, _, H, W = x.shape
|
| 52 |
+
if self.feats_idx != 'final':
|
| 53 |
+
clip_feats = self.train_feats['clip_inter']
|
| 54 |
+
c_dim = clip_feats.shape[-1]
|
| 55 |
+
clip_feats = clip_feats[:, 1:, ].permute(0, 2, 1).reshape(B, c_dim, h, w)
|
| 56 |
+
|
| 57 |
+
proj_feats = self.obj_proj(clip_feats).reshape(B, self.in_dim[-1], -1)
|
| 58 |
+
proj_feats = proj_feats / proj_feats.norm(dim=1, keepdim=True)
|
| 59 |
+
|
| 60 |
+
corrs = torch.matmul(proj_feats.permute(0, 2, 1), proj_feats).reshape(B,h*w, h, w)
|
| 61 |
+
output = clip_feats / clip_feats.norm(dim=1, keepdim=True)
|
| 62 |
+
output = self.bkg_decoder(output)
|
| 63 |
+
|
| 64 |
+
return output, corrs
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
preds, corrs = self.forward_pass(x)
|
| 68 |
+
output, _, _ = self.get_clip_map(x)
|
| 69 |
+
B, C, hf, wf = output.shape
|
| 70 |
+
preds = F.interpolate(preds, (hf, wf), mode="bilinear", align_corners=False )
|
| 71 |
+
|
| 72 |
+
# Compute weighted pooling
|
| 73 |
+
if self.mask_th:
|
| 74 |
+
corrs[corrs < self.mask_th] = 0.0
|
| 75 |
+
output = self.compute_weighted_pool(output, corrs)
|
| 76 |
+
output = output.reshape(B, C, hf, wf)
|
| 77 |
+
output = self.clip_backbone.decode_head.cls_seg(output)
|
| 78 |
+
|
| 79 |
+
if self.apply_found:
|
| 80 |
+
# Compute FOUND --------------------------------------------------
|
| 81 |
+
soft_found = self.sigmoid(preds.detach())
|
| 82 |
+
r_soft_found = soft_found.reshape(-1)
|
| 83 |
+
nb_cls = output.shape[1]
|
| 84 |
+
r_hard_found = (r_soft_found > self.found_th).float()
|
| 85 |
+
|
| 86 |
+
# TODO: make it work for Batch Size != 1
|
| 87 |
+
uncertain = (output.max(dim=1)[0] < self.certainty_th).reshape(-1)
|
| 88 |
+
output.reshape(1, nb_cls, -1)[:, 0, uncertain & (~r_hard_found.bool())] = 1.0 # background class
|
| 89 |
+
|
| 90 |
+
return output
|
| 91 |
+
|
| 92 |
+
def predict(self, x):
|
| 93 |
+
return self(x)
|
| 94 |
+
|
| 95 |
+
@torch.no_grad()
|
| 96 |
+
def get_clip_map(self, img):
|
| 97 |
+
maskclip_map, feat, k = self.clip_backbone(img, return_feat=True)
|
| 98 |
+
|
| 99 |
+
return feat, k, maskclip_map
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def compute_weighted_pool(self, clipmap, corrs):
|
| 103 |
+
# upsampling
|
| 104 |
+
B = clipmap.shape[0]
|
| 105 |
+
h_m, w_m = clipmap.shape[-2:]
|
| 106 |
+
h_w, w_w = corrs.shape[-2:]
|
| 107 |
+
|
| 108 |
+
if (h_m != h_w) or (w_m != w_w):
|
| 109 |
+
clipmap = F.interpolate(clipmap, (h_w, w_w), mode="bilinear", align_corners=False )
|
| 110 |
+
h_m, w_m = h_w, w_w
|
| 111 |
+
|
| 112 |
+
corrs[corrs < 0.0] = 0.0 # B HW H W
|
| 113 |
+
clipmap_refined = torch.einsum("bnij, bcij -> bcn", corrs, clipmap) # B C HW
|
| 114 |
+
norm_factor = corrs.flatten(-2, -1).sum(dim=-1)[:, None] # B 1 HW
|
| 115 |
+
clipmap_refined = clipmap_refined / (norm_factor + 1e-6)
|
| 116 |
+
|
| 117 |
+
# RESHAPE back to 2d
|
| 118 |
+
clipmap_refined = clipmap_refined.reshape(B, -1, h_m, w_m)
|
| 119 |
+
|
| 120 |
+
return clipmap_refined
|
models/maskclip/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .maskclip import *
|
models/maskclip/maskclip.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# author: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ------------------------------------------------------------------------------
|
| 5 |
+
# Modified from OpenMMLab https://github.com/chongzhou96/MaskCLIP
|
| 6 |
+
# ------------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from mmseg.ops import resize
|
| 12 |
+
from typing import Any, List
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
from mmcv.utils import print_log
|
| 15 |
+
from mmseg.utils import get_root_logger
|
| 16 |
+
from open_clip import get_tokenizer, create_model_from_pretrained
|
| 17 |
+
from models.builder import MODELS
|
| 18 |
+
from .vit import VisionTransformer
|
| 19 |
+
import torchvision.transforms as T
|
| 20 |
+
from .utils.embed import AdaptivePadding
|
| 21 |
+
from .utils.prompt_templates import imagenet_templates
|
| 22 |
+
|
| 23 |
+
OPENAI_NORMALIZE = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def make_vision_transformer(backbone_cfg):
|
| 27 |
+
model = VisionTransformer(**backbone_cfg)
|
| 28 |
+
model.init_weights()
|
| 29 |
+
return model
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@MODELS.register_module()
|
| 33 |
+
class MaskClip(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
backbone,
|
| 37 |
+
decode_head,
|
| 38 |
+
clip_model,
|
| 39 |
+
class_names
|
| 40 |
+
):
|
| 41 |
+
super(MaskClip, self).__init__()
|
| 42 |
+
|
| 43 |
+
self.decode_head = eval(decode_head.get('type'))(clip_model, class_names, **decode_head)
|
| 44 |
+
self.backbone = make_vision_transformer(backbone)
|
| 45 |
+
self.clip_T = OPENAI_NORMALIZE
|
| 46 |
+
|
| 47 |
+
self.to_PIL = T.ToPILImage()
|
| 48 |
+
self.patch_size = backbone.get('patch_size')
|
| 49 |
+
self.padding = AdaptivePadding(self.patch_size, self.patch_size)
|
| 50 |
+
|
| 51 |
+
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
|
| 52 |
+
"""Extract features from images."""
|
| 53 |
+
x = self.backbone(inputs)
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
def forward(self, inputs: Tensor, return_feat=False) -> Tensor:
|
| 57 |
+
"""Encode images with backbone and decode into a semantic segmentation
|
| 58 |
+
map of the same size as input."""
|
| 59 |
+
inputs = self.clip_T(inputs)
|
| 60 |
+
x = self.extract_feat(inputs)
|
| 61 |
+
|
| 62 |
+
seg_logits, feats, k = self.decode_head(x, return_feat)
|
| 63 |
+
|
| 64 |
+
if return_feat:
|
| 65 |
+
return seg_logits, feats, k
|
| 66 |
+
return seg_logits
|
| 67 |
+
|
| 68 |
+
class MaskClipHead(nn.Module):
|
| 69 |
+
def __init__(self, clip_model, class_names, visual_projs_path=None, in_index=-1, in_channels=3, norm_cfg=None, channels=0,
|
| 70 |
+
text_channels=512, attn_pooling=False, align_corners=False, model_prefix='hf-hub:laion', use_templates=False, **kwargs):
|
| 71 |
+
super(MaskClipHead, self).__init__()
|
| 72 |
+
|
| 73 |
+
self.text_channels = text_channels
|
| 74 |
+
self.visual_projs_path = visual_projs_path
|
| 75 |
+
self.clip_model = clip_model
|
| 76 |
+
self.class_names = class_names
|
| 77 |
+
self.in_channels = in_channels
|
| 78 |
+
self.in_index = in_index # from base decode head default
|
| 79 |
+
self._init_inputs(in_channels, in_index, None)
|
| 80 |
+
self.channels = channels
|
| 81 |
+
self.norm_cfg = norm_cfg
|
| 82 |
+
self.align_corners = align_corners
|
| 83 |
+
self.use_templates = use_templates
|
| 84 |
+
|
| 85 |
+
self.proj = nn.Conv2d(self.in_channels, text_channels, 1, bias=False)
|
| 86 |
+
self.load_visual_projs()
|
| 87 |
+
|
| 88 |
+
self.attn_pooling = attn_pooling
|
| 89 |
+
self.tokenizer = get_tokenizer(f'{model_prefix}/{clip_model}')
|
| 90 |
+
self.hf_modelname = f'{model_prefix}/{clip_model}'
|
| 91 |
+
model, _ = create_model_from_pretrained(f'{model_prefix}/{clip_model}')
|
| 92 |
+
model.eval()
|
| 93 |
+
self.register_buffer("class_embeddings", self._get_class_embeddings(model, class_names))
|
| 94 |
+
|
| 95 |
+
@torch.no_grad()
|
| 96 |
+
def update_vocab(self, class_names):
|
| 97 |
+
model, _ = create_model_from_pretrained(self.hf_modelname)
|
| 98 |
+
model.eval()
|
| 99 |
+
self.class_embeddings = self._get_class_embeddings(model, class_names)
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def _embed_label(self, text_model: torch.nn.Module, label: str) -> torch.Tensor:
|
| 103 |
+
"""
|
| 104 |
+
Encode label name into a single vector
|
| 105 |
+
"""
|
| 106 |
+
if self.use_templates:
|
| 107 |
+
templates = imagenet_templates
|
| 108 |
+
else:
|
| 109 |
+
templates = ['a photo of an {}' if label.startswith('aeiou') else 'a photo of a {}']
|
| 110 |
+
|
| 111 |
+
all_prompts = [self.tokenizer(template.format(label)) for template in templates]
|
| 112 |
+
out = text_model.encode_text(torch.cat(all_prompts))
|
| 113 |
+
out /= out.norm(dim=-1, keepdim=True)
|
| 114 |
+
out = out.mean(dim=0)
|
| 115 |
+
return out
|
| 116 |
+
|
| 117 |
+
def _get_class_embeddings(self, text_model: torch.nn.Module, class_names: List[str]):
|
| 118 |
+
aug_embeddings = torch.stack([self._embed_label(text_model, label) for label in class_names])
|
| 119 |
+
# normalize vector
|
| 120 |
+
aug_embeddings = aug_embeddings / aug_embeddings.norm(dim=-1, keepdim=True)
|
| 121 |
+
return aug_embeddings.squeeze(1)
|
| 122 |
+
|
| 123 |
+
def load_visual_projs(self):
|
| 124 |
+
loaded = torch.load(self.visual_projs_path, map_location='cuda')
|
| 125 |
+
attrs = ['proj']
|
| 126 |
+
for attr in attrs:
|
| 127 |
+
current_attr = getattr(self, attr)
|
| 128 |
+
state_dict = loaded[attr]
|
| 129 |
+
for key in state_dict:
|
| 130 |
+
if 'weight' in key:
|
| 131 |
+
state_dict[key] = state_dict[key][:, :, None, None]
|
| 132 |
+
current_attr.load_state_dict(state_dict)
|
| 133 |
+
print_log(f'Loaded proj weights from {self.visual_projs_path}', logger=get_root_logger())
|
| 134 |
+
|
| 135 |
+
def forward(self, inputs, return_feat=False):
|
| 136 |
+
x = self._transform_inputs(inputs)
|
| 137 |
+
q, k, v, cls_token = None, None, None, None
|
| 138 |
+
if isinstance(x, list) and len(x) == 4:
|
| 139 |
+
x, q, k, v = x
|
| 140 |
+
if isinstance(x, list) and len(x) == 2:
|
| 141 |
+
x, cls_token = x
|
| 142 |
+
if v is not None:
|
| 143 |
+
feat = self.proj(v)
|
| 144 |
+
else:
|
| 145 |
+
feat = self.proj(x)
|
| 146 |
+
output = self.cls_seg(feat)
|
| 147 |
+
if return_feat:
|
| 148 |
+
return output, feat, k
|
| 149 |
+
|
| 150 |
+
return output
|
| 151 |
+
|
| 152 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
| 153 |
+
"""Check and initialize input transforms.
|
| 154 |
+
|
| 155 |
+
The in_channels, in_index and input_transform must match.
|
| 156 |
+
Specifically, when input_transform is None, only single feature map
|
| 157 |
+
will be selected. So in_channels and in_index must be of type int.
|
| 158 |
+
When input_transform
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 162 |
+
in_index (int|Sequence[int]): Input feature index.
|
| 163 |
+
input_transform (str|None): Transformation type of input features.
|
| 164 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 165 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 166 |
+
same size as first one and than concat together.
|
| 167 |
+
Usually used in FCN head of HRNet.
|
| 168 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 169 |
+
a list and passed into decode head.
|
| 170 |
+
None: Only one select feature map is allowed.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
if input_transform is not None:
|
| 174 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
| 175 |
+
self.input_transform = input_transform
|
| 176 |
+
self.in_index = in_index
|
| 177 |
+
if input_transform is not None:
|
| 178 |
+
assert isinstance(in_channels, (list, tuple))
|
| 179 |
+
assert isinstance(in_index, (list, tuple))
|
| 180 |
+
assert len(in_channels) == len(in_index)
|
| 181 |
+
if input_transform == 'resize_concat':
|
| 182 |
+
self.in_channels = sum(in_channels)
|
| 183 |
+
else:
|
| 184 |
+
self.in_channels = in_channels
|
| 185 |
+
else:
|
| 186 |
+
assert isinstance(in_channels, int)
|
| 187 |
+
assert isinstance(in_index, int)
|
| 188 |
+
self.in_channels = in_channels
|
| 189 |
+
|
| 190 |
+
def cls_seg(self, feat):
|
| 191 |
+
feat = feat / feat.norm(dim=1, keepdim=True)
|
| 192 |
+
output = F.conv2d(feat, self.class_embeddings[:, :, None, None])
|
| 193 |
+
output = F.softmax(output * 100, dim=1) # softmax of similarities with temp scaling
|
| 194 |
+
|
| 195 |
+
return output
|
| 196 |
+
|
| 197 |
+
def _transform_inputs(self, inputs):
|
| 198 |
+
"""Transform inputs for decoder.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Tensor: The transformed inputs
|
| 205 |
+
"""
|
| 206 |
+
if self.input_transform == 'resize_concat':
|
| 207 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 208 |
+
upsampled_inputs = [
|
| 209 |
+
resize(
|
| 210 |
+
input=x,
|
| 211 |
+
size=inputs[0].shape[2:],
|
| 212 |
+
mode='bilinear',
|
| 213 |
+
align_corners=self.align_corners) for x in inputs
|
| 214 |
+
]
|
| 215 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
| 216 |
+
elif self.input_transform == 'multiple_select':
|
| 217 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 218 |
+
else:
|
| 219 |
+
inputs = inputs[self.in_index]
|
| 220 |
+
|
| 221 |
+
return inputs
|
models/maskclip/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .embed import PatchEmbed
|
| 2 |
+
from .prompt_templates import imagenet_templates
|
| 3 |
+
|
| 4 |
+
__all__ = ['PatchEmbed', 'imagenet_templates']
|
models/maskclip/utils/embed.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# From OpenMMLab https://github.com/chongzhou96/MaskCLIP
|
| 2 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# ------------------------------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Sequence
|
| 8 |
+
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from mmcv.cnn import build_conv_layer, build_norm_layer
|
| 12 |
+
from mmcv.runner.base_module import BaseModule
|
| 13 |
+
from mmcv.utils import to_2tuple
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AdaptivePadding(nn.Module):
|
| 17 |
+
"""Applies padding to input (if needed) so that input can get fully covered
|
| 18 |
+
by filter you specified. It support two modes "same" and "corner". The
|
| 19 |
+
"same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
|
| 20 |
+
input. The "corner" mode would pad zero to bottom right.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
kernel_size (int | tuple): Size of the kernel:
|
| 24 |
+
stride (int | tuple): Stride of the filter. Default: 1:
|
| 25 |
+
dilation (int | tuple): Spacing between kernel elements.
|
| 26 |
+
Default: 1.
|
| 27 |
+
padding (str): Support "same" and "corner", "corner" mode
|
| 28 |
+
would pad zero to bottom right, and "same" mode would
|
| 29 |
+
pad zero around input. Default: "corner".
|
| 30 |
+
Example:
|
| 31 |
+
>>> kernel_size = 16
|
| 32 |
+
>>> stride = 16
|
| 33 |
+
>>> dilation = 1
|
| 34 |
+
>>> input = torch.rand(1, 1, 15, 17)
|
| 35 |
+
>>> adap_pad = AdaptivePadding(
|
| 36 |
+
>>> kernel_size=kernel_size,
|
| 37 |
+
>>> stride=stride,
|
| 38 |
+
>>> dilation=dilation,
|
| 39 |
+
>>> padding="corner")
|
| 40 |
+
>>> out = adap_pad(input)
|
| 41 |
+
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
|
| 42 |
+
>>> input = torch.rand(1, 1, 16, 17)
|
| 43 |
+
>>> out = adap_pad(input)
|
| 44 |
+
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
|
| 48 |
+
|
| 49 |
+
super(AdaptivePadding, self).__init__()
|
| 50 |
+
|
| 51 |
+
assert padding in ('same', 'corner')
|
| 52 |
+
|
| 53 |
+
kernel_size = to_2tuple(kernel_size)
|
| 54 |
+
stride = to_2tuple(stride)
|
| 55 |
+
dilation = to_2tuple(dilation)
|
| 56 |
+
|
| 57 |
+
self.padding = padding
|
| 58 |
+
self.kernel_size = kernel_size
|
| 59 |
+
self.stride = stride
|
| 60 |
+
self.dilation = dilation
|
| 61 |
+
|
| 62 |
+
def get_pad_shape(self, input_shape):
|
| 63 |
+
input_h, input_w = input_shape
|
| 64 |
+
kernel_h, kernel_w = self.kernel_size
|
| 65 |
+
stride_h, stride_w = self.stride
|
| 66 |
+
output_h = math.ceil(input_h / stride_h)
|
| 67 |
+
output_w = math.ceil(input_w / stride_w)
|
| 68 |
+
pad_h = max((output_h - 1) * stride_h +
|
| 69 |
+
(kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
|
| 70 |
+
pad_w = max((output_w - 1) * stride_w +
|
| 71 |
+
(kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
|
| 72 |
+
return pad_h, pad_w
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
|
| 76 |
+
if pad_h > 0 or pad_w > 0:
|
| 77 |
+
if self.padding == 'corner':
|
| 78 |
+
x = F.pad(x, [0, pad_w, 0, pad_h])
|
| 79 |
+
elif self.padding == 'same':
|
| 80 |
+
x = F.pad(x, [
|
| 81 |
+
pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
|
| 82 |
+
pad_h - pad_h // 2
|
| 83 |
+
])
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class PatchEmbed(BaseModule):
|
| 88 |
+
"""Image to Patch Embedding.
|
| 89 |
+
|
| 90 |
+
We use a conv layer to implement PatchEmbed.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
in_channels (int): The num of input channels. Default: 3
|
| 94 |
+
embed_dims (int): The dimensions of embedding. Default: 768
|
| 95 |
+
conv_type (str): The config dict for embedding
|
| 96 |
+
conv layer type selection. Default: "Conv2d".
|
| 97 |
+
kernel_size (int): The kernel_size of embedding conv. Default: 16.
|
| 98 |
+
stride (int, optional): The slide stride of embedding conv.
|
| 99 |
+
Default: None (Would be set as `kernel_size`).
|
| 100 |
+
padding (int | tuple | string ): The padding length of
|
| 101 |
+
embedding conv. When it is a string, it means the mode
|
| 102 |
+
of adaptive padding, support "same" and "corner" now.
|
| 103 |
+
Default: "corner".
|
| 104 |
+
dilation (int): The dilation rate of embedding conv. Default: 1.
|
| 105 |
+
bias (bool): Bias of embed conv. Default: True.
|
| 106 |
+
norm_cfg (dict, optional): Config dict for normalization layer.
|
| 107 |
+
Default: None.
|
| 108 |
+
input_size (int | tuple | None): The size of input, which will be
|
| 109 |
+
used to calculate the out size. Only work when `dynamic_size`
|
| 110 |
+
is False. Default: None.
|
| 111 |
+
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
| 112 |
+
Default: None.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(self,
|
| 116 |
+
in_channels=3,
|
| 117 |
+
embed_dims=768,
|
| 118 |
+
conv_type='Conv2d',
|
| 119 |
+
kernel_size=16,
|
| 120 |
+
stride=None,
|
| 121 |
+
padding='corner',
|
| 122 |
+
dilation=1,
|
| 123 |
+
bias=True,
|
| 124 |
+
norm_cfg=None,
|
| 125 |
+
input_size=None,
|
| 126 |
+
init_cfg=None):
|
| 127 |
+
super(PatchEmbed, self).__init__(init_cfg=init_cfg)
|
| 128 |
+
|
| 129 |
+
self.embed_dims = embed_dims
|
| 130 |
+
if stride is None:
|
| 131 |
+
stride = kernel_size
|
| 132 |
+
|
| 133 |
+
kernel_size = to_2tuple(kernel_size)
|
| 134 |
+
stride = to_2tuple(stride)
|
| 135 |
+
dilation = to_2tuple(dilation)
|
| 136 |
+
|
| 137 |
+
if isinstance(padding, str):
|
| 138 |
+
self.adap_padding = AdaptivePadding(
|
| 139 |
+
kernel_size=kernel_size,
|
| 140 |
+
stride=stride,
|
| 141 |
+
dilation=dilation,
|
| 142 |
+
padding=padding)
|
| 143 |
+
# disable the padding of conv
|
| 144 |
+
padding = 0
|
| 145 |
+
else:
|
| 146 |
+
self.adap_padding = None
|
| 147 |
+
padding = to_2tuple(padding)
|
| 148 |
+
|
| 149 |
+
self.projection = build_conv_layer(
|
| 150 |
+
dict(type=conv_type),
|
| 151 |
+
in_channels=in_channels,
|
| 152 |
+
out_channels=embed_dims,
|
| 153 |
+
kernel_size=kernel_size,
|
| 154 |
+
stride=stride,
|
| 155 |
+
padding=padding,
|
| 156 |
+
dilation=dilation,
|
| 157 |
+
bias=bias)
|
| 158 |
+
|
| 159 |
+
if norm_cfg is not None:
|
| 160 |
+
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
| 161 |
+
else:
|
| 162 |
+
self.norm = None
|
| 163 |
+
|
| 164 |
+
if input_size:
|
| 165 |
+
input_size = to_2tuple(input_size)
|
| 166 |
+
# `init_out_size` would be used outside to
|
| 167 |
+
# calculate the num_patches
|
| 168 |
+
# when `use_abs_pos_embed` outside
|
| 169 |
+
self.init_input_size = input_size
|
| 170 |
+
if self.adap_padding:
|
| 171 |
+
pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
|
| 172 |
+
input_h, input_w = input_size
|
| 173 |
+
input_h = input_h + pad_h
|
| 174 |
+
input_w = input_w + pad_w
|
| 175 |
+
input_size = (input_h, input_w)
|
| 176 |
+
|
| 177 |
+
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
| 178 |
+
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
|
| 179 |
+
(kernel_size[0] - 1) - 1) // stride[0] + 1
|
| 180 |
+
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
|
| 181 |
+
(kernel_size[1] - 1) - 1) // stride[1] + 1
|
| 182 |
+
self.init_out_size = (h_out, w_out)
|
| 183 |
+
else:
|
| 184 |
+
self.init_input_size = None
|
| 185 |
+
self.init_out_size = None
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
"""
|
| 189 |
+
Args:
|
| 190 |
+
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
tuple: Contains merged results and its spatial shape.
|
| 194 |
+
|
| 195 |
+
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
|
| 196 |
+
- out_size (tuple[int]): Spatial shape of x, arrange as
|
| 197 |
+
(out_h, out_w).
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
if self.adap_padding:
|
| 201 |
+
x = self.adap_padding(x)
|
| 202 |
+
|
| 203 |
+
x = self.projection(x)
|
| 204 |
+
out_size = (x.shape[2], x.shape[3])
|
| 205 |
+
x = x.flatten(2).transpose(1, 2)
|
| 206 |
+
if self.norm is not None:
|
| 207 |
+
x = self.norm(x)
|
| 208 |
+
return x, out_size
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class PatchMerging(BaseModule):
|
| 212 |
+
"""Merge patch feature map.
|
| 213 |
+
|
| 214 |
+
This layer groups feature map by kernel_size, and applies norm and linear
|
| 215 |
+
layers to the grouped feature map. Our implementation uses `nn.Unfold` to
|
| 216 |
+
merge patch, which is about 25% faster than original implementation.
|
| 217 |
+
Instead, we need to modify pretrained models for compatibility.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
in_channels (int): The num of input channels.
|
| 221 |
+
out_channels (int): The num of output channels.
|
| 222 |
+
kernel_size (int | tuple, optional): the kernel size in the unfold
|
| 223 |
+
layer. Defaults to 2.
|
| 224 |
+
stride (int | tuple, optional): the stride of the sliding blocks in the
|
| 225 |
+
unfold layer. Default: None. (Would be set as `kernel_size`)
|
| 226 |
+
padding (int | tuple | string ): The padding length of
|
| 227 |
+
embedding conv. When it is a string, it means the mode
|
| 228 |
+
of adaptive padding, support "same" and "corner" now.
|
| 229 |
+
Default: "corner".
|
| 230 |
+
dilation (int | tuple, optional): dilation parameter in the unfold
|
| 231 |
+
layer. Default: 1.
|
| 232 |
+
bias (bool, optional): Whether to add bias in linear layer or not.
|
| 233 |
+
Defaults: False.
|
| 234 |
+
norm_cfg (dict, optional): Config dict for normalization layer.
|
| 235 |
+
Default: dict(type='LN').
|
| 236 |
+
init_cfg (dict, optional): The extra config for initialization.
|
| 237 |
+
Default: None.
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
def __init__(self,
|
| 241 |
+
in_channels,
|
| 242 |
+
out_channels,
|
| 243 |
+
kernel_size=2,
|
| 244 |
+
stride=None,
|
| 245 |
+
padding='corner',
|
| 246 |
+
dilation=1,
|
| 247 |
+
bias=False,
|
| 248 |
+
norm_cfg=dict(type='LN'),
|
| 249 |
+
init_cfg=None):
|
| 250 |
+
super().__init__(init_cfg=init_cfg)
|
| 251 |
+
self.in_channels = in_channels
|
| 252 |
+
self.out_channels = out_channels
|
| 253 |
+
if stride:
|
| 254 |
+
stride = stride
|
| 255 |
+
else:
|
| 256 |
+
stride = kernel_size
|
| 257 |
+
|
| 258 |
+
kernel_size = to_2tuple(kernel_size)
|
| 259 |
+
stride = to_2tuple(stride)
|
| 260 |
+
dilation = to_2tuple(dilation)
|
| 261 |
+
|
| 262 |
+
if isinstance(padding, str):
|
| 263 |
+
self.adap_padding = AdaptivePadding(
|
| 264 |
+
kernel_size=kernel_size,
|
| 265 |
+
stride=stride,
|
| 266 |
+
dilation=dilation,
|
| 267 |
+
padding=padding)
|
| 268 |
+
# disable the padding of unfold
|
| 269 |
+
padding = 0
|
| 270 |
+
else:
|
| 271 |
+
self.adap_padding = None
|
| 272 |
+
|
| 273 |
+
padding = to_2tuple(padding)
|
| 274 |
+
self.sampler = nn.Unfold(
|
| 275 |
+
kernel_size=kernel_size,
|
| 276 |
+
dilation=dilation,
|
| 277 |
+
padding=padding,
|
| 278 |
+
stride=stride)
|
| 279 |
+
|
| 280 |
+
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
|
| 281 |
+
|
| 282 |
+
if norm_cfg is not None:
|
| 283 |
+
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
|
| 284 |
+
else:
|
| 285 |
+
self.norm = None
|
| 286 |
+
|
| 287 |
+
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
|
| 288 |
+
|
| 289 |
+
def forward(self, x, input_size):
|
| 290 |
+
"""
|
| 291 |
+
Args:
|
| 292 |
+
x (Tensor): Has shape (B, H*W, C_in).
|
| 293 |
+
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
|
| 294 |
+
Default: None.
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
tuple: Contains merged results and its spatial shape.
|
| 298 |
+
|
| 299 |
+
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
|
| 300 |
+
- out_size (tuple[int]): Spatial shape of x, arrange as
|
| 301 |
+
(Merged_H, Merged_W).
|
| 302 |
+
"""
|
| 303 |
+
B, L, C = x.shape
|
| 304 |
+
assert isinstance(input_size, Sequence), f'Expect ' \
|
| 305 |
+
f'input_size is ' \
|
| 306 |
+
f'`Sequence` ' \
|
| 307 |
+
f'but get {input_size}'
|
| 308 |
+
|
| 309 |
+
H, W = input_size
|
| 310 |
+
assert L == H * W, 'input feature has wrong size'
|
| 311 |
+
|
| 312 |
+
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
|
| 313 |
+
# Use nn.Unfold to merge patch. About 25% faster than original method,
|
| 314 |
+
# but need to modify pretrained model for compatibility
|
| 315 |
+
|
| 316 |
+
if self.adap_padding:
|
| 317 |
+
x = self.adap_padding(x)
|
| 318 |
+
H, W = x.shape[-2:]
|
| 319 |
+
|
| 320 |
+
x = self.sampler(x)
|
| 321 |
+
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
|
| 322 |
+
|
| 323 |
+
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
|
| 324 |
+
(self.sampler.kernel_size[0] - 1) -
|
| 325 |
+
1) // self.sampler.stride[0] + 1
|
| 326 |
+
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
|
| 327 |
+
(self.sampler.kernel_size[1] - 1) -
|
| 328 |
+
1) // self.sampler.stride[1] + 1
|
| 329 |
+
|
| 330 |
+
output_size = (out_h, out_w)
|
| 331 |
+
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
|
| 332 |
+
x = self.norm(x) if self.norm else x
|
| 333 |
+
x = self.reduction(x)
|
| 334 |
+
return x, output_size
|
models/maskclip/utils/prompt_templates.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
imagenet_templates = [
|
| 2 |
+
'a bad photo of a {}.',
|
| 3 |
+
'a photo of many {}.',
|
| 4 |
+
'a sculpture of a {}.',
|
| 5 |
+
'a photo of the hard to see {}.',
|
| 6 |
+
'a low resolution photo of the {}.',
|
| 7 |
+
'a rendering of a {}.',
|
| 8 |
+
'graffiti of a {}.',
|
| 9 |
+
'a bad photo of the {}.',
|
| 10 |
+
'a cropped photo of the {}.',
|
| 11 |
+
'a tattoo of a {}.',
|
| 12 |
+
'the embroidered {}.',
|
| 13 |
+
'a photo of a hard to see {}.',
|
| 14 |
+
'a bright photo of a {}.',
|
| 15 |
+
'a photo of a clean {}.',
|
| 16 |
+
'a photo of a dirty {}.',
|
| 17 |
+
'a dark photo of the {}.',
|
| 18 |
+
'a drawing of a {}.',
|
| 19 |
+
'a photo of my {}.',
|
| 20 |
+
'the plastic {}.',
|
| 21 |
+
'a photo of the cool {}.',
|
| 22 |
+
'a close-up photo of a {}.',
|
| 23 |
+
'a black and white photo of the {}.',
|
| 24 |
+
'a painting of the {}.',
|
| 25 |
+
'a painting of a {}.',
|
| 26 |
+
'a pixelated photo of the {}.',
|
| 27 |
+
'a sculpture of the {}.',
|
| 28 |
+
'a bright photo of the {}.',
|
| 29 |
+
'a cropped photo of a {}.',
|
| 30 |
+
'a plastic {}.',
|
| 31 |
+
'a photo of the dirty {}.',
|
| 32 |
+
'a jpeg corrupted photo of a {}.',
|
| 33 |
+
'a blurry photo of the {}.',
|
| 34 |
+
'a photo of the {}.',
|
| 35 |
+
'a good photo of the {}.',
|
| 36 |
+
'a rendering of the {}.',
|
| 37 |
+
'a {} in a video game.',
|
| 38 |
+
'a photo of one {}.',
|
| 39 |
+
'a doodle of a {}.',
|
| 40 |
+
'a close-up photo of the {}.',
|
| 41 |
+
'a photo of a {}.',
|
| 42 |
+
'the origami {}.',
|
| 43 |
+
'the {} in a video game.',
|
| 44 |
+
'a sketch of a {}.',
|
| 45 |
+
'a doodle of the {}.',
|
| 46 |
+
'a origami {}.',
|
| 47 |
+
'a low resolution photo of a {}.',
|
| 48 |
+
'the toy {}.',
|
| 49 |
+
'a rendition of the {}.',
|
| 50 |
+
'a photo of the clean {}.',
|
| 51 |
+
'a photo of a large {}.',
|
| 52 |
+
'a rendition of a {}.',
|
| 53 |
+
'a photo of a nice {}.',
|
| 54 |
+
'a photo of a weird {}.',
|
| 55 |
+
'a blurry photo of a {}.',
|
| 56 |
+
'a cartoon {}.',
|
| 57 |
+
'art of a {}.',
|
| 58 |
+
'a sketch of the {}.',
|
| 59 |
+
'a embroidered {}.',
|
| 60 |
+
'a pixelated photo of a {}.',
|
| 61 |
+
'itap of the {}.',
|
| 62 |
+
'a jpeg corrupted photo of the {}.',
|
| 63 |
+
'a good photo of a {}.',
|
| 64 |
+
'a plushie {}.',
|
| 65 |
+
'a photo of the nice {}.',
|
| 66 |
+
'a photo of the small {}.',
|
| 67 |
+
'a photo of the weird {}.',
|
| 68 |
+
'the cartoon {}.',
|
| 69 |
+
'art of the {}.',
|
| 70 |
+
'a drawing of the {}.',
|
| 71 |
+
'a photo of the large {}.',
|
| 72 |
+
'a black and white photo of a {}.',
|
| 73 |
+
'the plushie {}.',
|
| 74 |
+
'a dark photo of a {}.',
|
| 75 |
+
'itap of a {}.',
|
| 76 |
+
'graffiti of the {}.',
|
| 77 |
+
'a toy {}.',
|
| 78 |
+
'itap of my {}.',
|
| 79 |
+
'a photo of a cool {}.',
|
| 80 |
+
'a photo of a small {}.',
|
| 81 |
+
'a tattoo of the {}.',
|
| 82 |
+
]
|
models/maskclip/vit.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import math
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from mmcv.cnn import build_norm_layer
|
| 8 |
+
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
| 9 |
+
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
| 10 |
+
trunc_normal_)
|
| 11 |
+
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
| 12 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 13 |
+
from torch.nn.modules.utils import _pair as to_2tuple
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
from mmseg.ops import resize
|
| 17 |
+
from mmseg.utils import get_root_logger
|
| 18 |
+
|
| 19 |
+
from models.maskclip.utils import PatchEmbed
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TransformerEncoderLayer(BaseModule):
|
| 23 |
+
"""Implements one encoder layer in Vision Transformer.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
embed_dims (int): The feature dimension.
|
| 27 |
+
num_heads (int): Parallel attention heads.
|
| 28 |
+
feedforward_channels (int): The hidden dimension for FFNs.
|
| 29 |
+
drop_rate (float): Probability of an element to be zeroed
|
| 30 |
+
after the feed forward layer. Default: 0.0.
|
| 31 |
+
attn_drop_rate (float): The drop out rate for attention layer.
|
| 32 |
+
Default: 0.0.
|
| 33 |
+
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
| 34 |
+
num_fcs (int): The number of fully-connected layers for FFNs.
|
| 35 |
+
Default: 2.
|
| 36 |
+
qkv_bias (bool): enable bias for qkv if True. Default: True
|
| 37 |
+
act_cfg (dict): The activation config for FFNs.
|
| 38 |
+
Default: dict(type='GELU').
|
| 39 |
+
norm_cfg (dict): Config dict for normalization layer.
|
| 40 |
+
Default: dict(type='LN').
|
| 41 |
+
batch_first (bool): Key, Query and Value are shape of
|
| 42 |
+
(batch, n, embed_dim)
|
| 43 |
+
or (n, batch, embed_dim). Default: True.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self,
|
| 47 |
+
embed_dims,
|
| 48 |
+
num_heads,
|
| 49 |
+
feedforward_channels,
|
| 50 |
+
drop_rate=0.,
|
| 51 |
+
attn_drop_rate=0.,
|
| 52 |
+
drop_path_rate=0.,
|
| 53 |
+
num_fcs=2,
|
| 54 |
+
qkv_bias=True,
|
| 55 |
+
act_cfg=dict(type='GELU'),
|
| 56 |
+
norm_cfg=dict(type='LN'),
|
| 57 |
+
batch_first=True):
|
| 58 |
+
super(TransformerEncoderLayer, self).__init__()
|
| 59 |
+
|
| 60 |
+
self.norm1_name, norm1 = build_norm_layer(
|
| 61 |
+
norm_cfg, embed_dims, postfix=1)
|
| 62 |
+
self.add_module(self.norm1_name, norm1)
|
| 63 |
+
|
| 64 |
+
self.attn = MultiheadAttention(
|
| 65 |
+
embed_dims=embed_dims,
|
| 66 |
+
num_heads=num_heads,
|
| 67 |
+
attn_drop=attn_drop_rate,
|
| 68 |
+
proj_drop=drop_rate,
|
| 69 |
+
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
| 70 |
+
batch_first=batch_first,
|
| 71 |
+
bias=qkv_bias)
|
| 72 |
+
|
| 73 |
+
self.norm2_name, norm2 = build_norm_layer(
|
| 74 |
+
norm_cfg, embed_dims, postfix=2)
|
| 75 |
+
self.add_module(self.norm2_name, norm2)
|
| 76 |
+
|
| 77 |
+
self.ffn = FFN(
|
| 78 |
+
embed_dims=embed_dims,
|
| 79 |
+
feedforward_channels=feedforward_channels,
|
| 80 |
+
num_fcs=num_fcs,
|
| 81 |
+
ffn_drop=drop_rate,
|
| 82 |
+
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
| 83 |
+
act_cfg=act_cfg)
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def norm1(self):
|
| 87 |
+
return getattr(self, self.norm1_name)
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def norm2(self):
|
| 91 |
+
return getattr(self, self.norm2_name)
|
| 92 |
+
|
| 93 |
+
def forward(self, x, return_qkv=False):
|
| 94 |
+
q, k, v = None, None, None
|
| 95 |
+
if return_qkv:
|
| 96 |
+
y = self.norm1(x)
|
| 97 |
+
y = F.linear(y, self.attn.attn.in_proj_weight, self.attn.attn.in_proj_bias)
|
| 98 |
+
N, L, C = y.shape
|
| 99 |
+
y = y.view(N, L, 3, C // 3).permute(2, 0, 1, 3).reshape(3 * N, L, C // 3)
|
| 100 |
+
y = F.linear(y, self.attn.attn.out_proj.weight, self.attn.attn.out_proj.bias)
|
| 101 |
+
q, k, v = y.tensor_split(3, dim=0)
|
| 102 |
+
v += x
|
| 103 |
+
v = self.ffn(self.norm2(v), identity=v)
|
| 104 |
+
|
| 105 |
+
x = self.attn(self.norm1(x), identity=x)
|
| 106 |
+
x = self.ffn(self.norm2(x), identity=x)
|
| 107 |
+
return x, q, k, v
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class VisionTransformer(BaseModule):
|
| 111 |
+
"""Vision Transformer.
|
| 112 |
+
|
| 113 |
+
This backbone is the implementation of `An Image is Worth 16x16 Words:
|
| 114 |
+
Transformers for Image Recognition at
|
| 115 |
+
Scale <https://arxiv.org/abs/2010.11929>`_.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
img_size (int | tuple): Input image size. Default: 224.
|
| 119 |
+
patch_size (int): The patch size. Default: 16.
|
| 120 |
+
in_channels (int): Number of input channels. Default: 3.
|
| 121 |
+
embed_dims (int): embedding dimension. Default: 768.
|
| 122 |
+
num_layers (int): depth of transformer. Default: 12.
|
| 123 |
+
num_heads (int): number of attention heads. Default: 12.
|
| 124 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
| 125 |
+
Default: 4.
|
| 126 |
+
out_indices (list | tuple | int): Output from which stages.
|
| 127 |
+
Default: -1.
|
| 128 |
+
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
| 129 |
+
drop_rate (float): Probability of an element to be zeroed.
|
| 130 |
+
Default 0.0
|
| 131 |
+
attn_drop_rate (float): The drop out rate for attention layer.
|
| 132 |
+
Default 0.0
|
| 133 |
+
drop_path_rate (float): stochastic depth rate. Default 0.0
|
| 134 |
+
with_cls_token (bool): Whether concatenating class token into image
|
| 135 |
+
tokens as transformer input. Default: True.
|
| 136 |
+
output_cls_token (bool): Whether output the cls_token. If set True,
|
| 137 |
+
`with_cls_token` must be True. Default: False.
|
| 138 |
+
norm_cfg (dict): Config dict for normalization layer.
|
| 139 |
+
Default: dict(type='LN')
|
| 140 |
+
act_cfg (dict): The activation config for FFNs.
|
| 141 |
+
Default: dict(type='GELU').
|
| 142 |
+
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
| 143 |
+
Default: False.
|
| 144 |
+
final_norm (bool): Whether to add a additional layer to normalize
|
| 145 |
+
final feature map. Default: False.
|
| 146 |
+
interpolate_mode (str): Select the interpolate mode for position
|
| 147 |
+
embeding vector resize. Default: bicubic.
|
| 148 |
+
num_fcs (int): The number of fully-connected layers for FFNs.
|
| 149 |
+
Default: 2.
|
| 150 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
| 151 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
| 152 |
+
and its variants only. Default: False.
|
| 153 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
| 154 |
+
some memory while slowing down the training speed. Default: False.
|
| 155 |
+
pretrained (str, optional): model pretrained path. Default: None.
|
| 156 |
+
init_cfg (dict or list[dict], optional): Initialization config dict.
|
| 157 |
+
Default: None.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __init__(self,
|
| 161 |
+
img_size=224,
|
| 162 |
+
patch_size=16,
|
| 163 |
+
patch_bias=True,
|
| 164 |
+
in_channels=3,
|
| 165 |
+
embed_dims=768,
|
| 166 |
+
num_layers=12,
|
| 167 |
+
num_heads=12,
|
| 168 |
+
mlp_ratio=4,
|
| 169 |
+
out_indices=-1,
|
| 170 |
+
qkv_bias=True,
|
| 171 |
+
drop_rate=0.,
|
| 172 |
+
attn_drop_rate=0.,
|
| 173 |
+
drop_path_rate=0.,
|
| 174 |
+
with_cls_token=True,
|
| 175 |
+
output_cls_token=False,
|
| 176 |
+
norm_cfg=dict(type='LN'),
|
| 177 |
+
act_cfg=dict(type='GELU'),
|
| 178 |
+
patch_norm=False,
|
| 179 |
+
pre_norm=False,
|
| 180 |
+
final_norm=False,
|
| 181 |
+
return_qkv=False,
|
| 182 |
+
skip_last_attn=False,
|
| 183 |
+
interpolate_mode='bicubic',
|
| 184 |
+
num_fcs=2,
|
| 185 |
+
norm_eval=False,
|
| 186 |
+
with_cp=False,
|
| 187 |
+
pretrained=None,
|
| 188 |
+
init_cfg=None):
|
| 189 |
+
super(VisionTransformer, self).__init__(init_cfg=init_cfg)
|
| 190 |
+
|
| 191 |
+
if isinstance(img_size, int):
|
| 192 |
+
img_size = to_2tuple(img_size)
|
| 193 |
+
elif isinstance(img_size, tuple):
|
| 194 |
+
if len(img_size) == 1:
|
| 195 |
+
img_size = to_2tuple(img_size[0])
|
| 196 |
+
assert len(img_size) == 2, \
|
| 197 |
+
f'The size of image should have length 1 or 2, ' \
|
| 198 |
+
f'but got {len(img_size)}'
|
| 199 |
+
|
| 200 |
+
if output_cls_token:
|
| 201 |
+
assert with_cls_token is True, f'with_cls_token must be True if' \
|
| 202 |
+
f'set output_cls_token to True, but got {with_cls_token}'
|
| 203 |
+
|
| 204 |
+
assert not (init_cfg and pretrained), \
|
| 205 |
+
'init_cfg and pretrained cannot be set at the same time'
|
| 206 |
+
if isinstance(pretrained, str):
|
| 207 |
+
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
| 208 |
+
'please use "init_cfg" instead')
|
| 209 |
+
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
| 210 |
+
elif pretrained is not None:
|
| 211 |
+
raise TypeError('pretrained must be a str or None')
|
| 212 |
+
|
| 213 |
+
self.img_size = img_size
|
| 214 |
+
self.patch_size = patch_size
|
| 215 |
+
self.interpolate_mode = interpolate_mode
|
| 216 |
+
self.norm_eval = norm_eval
|
| 217 |
+
self.with_cp = with_cp
|
| 218 |
+
self.pretrained = pretrained
|
| 219 |
+
|
| 220 |
+
self.patch_embed = PatchEmbed(
|
| 221 |
+
in_channels=in_channels,
|
| 222 |
+
embed_dims=embed_dims,
|
| 223 |
+
conv_type='Conv2d',
|
| 224 |
+
kernel_size=patch_size,
|
| 225 |
+
stride=patch_size,
|
| 226 |
+
padding='corner',
|
| 227 |
+
bias=patch_bias,
|
| 228 |
+
norm_cfg=norm_cfg if patch_norm else None,
|
| 229 |
+
init_cfg=None,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
num_patches = (img_size[0] // patch_size) * \
|
| 233 |
+
(img_size[1] // patch_size)
|
| 234 |
+
|
| 235 |
+
self.with_cls_token = with_cls_token
|
| 236 |
+
self.output_cls_token = output_cls_token
|
| 237 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
| 238 |
+
self.pos_embed = nn.Parameter(
|
| 239 |
+
torch.zeros(1, num_patches + 1, embed_dims))
|
| 240 |
+
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
| 241 |
+
|
| 242 |
+
if isinstance(out_indices, int):
|
| 243 |
+
if out_indices == -1:
|
| 244 |
+
out_indices = num_layers - 1
|
| 245 |
+
self.out_indices = [out_indices]
|
| 246 |
+
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
|
| 247 |
+
self.out_indices = out_indices
|
| 248 |
+
else:
|
| 249 |
+
raise TypeError('out_indices must be type of int, list or tuple')
|
| 250 |
+
|
| 251 |
+
dpr = [
|
| 252 |
+
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
|
| 253 |
+
] # stochastic depth decay rule
|
| 254 |
+
|
| 255 |
+
self.layers = ModuleList()
|
| 256 |
+
for i in range(num_layers):
|
| 257 |
+
self.layers.append(
|
| 258 |
+
TransformerEncoderLayer(
|
| 259 |
+
embed_dims=embed_dims,
|
| 260 |
+
num_heads=num_heads,
|
| 261 |
+
feedforward_channels=mlp_ratio * embed_dims,
|
| 262 |
+
attn_drop_rate=attn_drop_rate,
|
| 263 |
+
drop_rate=drop_rate,
|
| 264 |
+
drop_path_rate=dpr[i],
|
| 265 |
+
num_fcs=num_fcs,
|
| 266 |
+
qkv_bias=qkv_bias,
|
| 267 |
+
act_cfg=act_cfg,
|
| 268 |
+
norm_cfg=norm_cfg,
|
| 269 |
+
batch_first=True))
|
| 270 |
+
|
| 271 |
+
self.pre_norm = pre_norm
|
| 272 |
+
if pre_norm:
|
| 273 |
+
self.norm0_name, norm0 = build_norm_layer(
|
| 274 |
+
norm_cfg, embed_dims, postfix=0)
|
| 275 |
+
self.add_module(self.norm0_name, norm0)
|
| 276 |
+
|
| 277 |
+
self.final_norm = final_norm
|
| 278 |
+
if final_norm:
|
| 279 |
+
self.norm1_name, norm1 = build_norm_layer(
|
| 280 |
+
norm_cfg, embed_dims, postfix=1)
|
| 281 |
+
self.add_module(self.norm1_name, norm1)
|
| 282 |
+
|
| 283 |
+
self.return_qkv = [False] * num_layers
|
| 284 |
+
if isinstance(return_qkv, bool):
|
| 285 |
+
for out_i in self.out_indices:
|
| 286 |
+
self.return_qkv[out_i] = return_qkv
|
| 287 |
+
elif isinstance(return_qkv, list) or isinstance(return_qkv, tuple):
|
| 288 |
+
for i, out_i in enumerate(self.out_indices):
|
| 289 |
+
self.return_qkv[out_i] = return_qkv[i]
|
| 290 |
+
else:
|
| 291 |
+
raise TypeError('return_qkv must be type of bool, list or tuple')
|
| 292 |
+
|
| 293 |
+
self.skip_last_attn = skip_last_attn
|
| 294 |
+
|
| 295 |
+
@property
|
| 296 |
+
def norm0(self):
|
| 297 |
+
return getattr(self, self.norm0_name)
|
| 298 |
+
|
| 299 |
+
@property
|
| 300 |
+
def norm1(self):
|
| 301 |
+
return getattr(self, self.norm1_name)
|
| 302 |
+
|
| 303 |
+
def init_weights(self):
|
| 304 |
+
if (isinstance(self.init_cfg, dict)
|
| 305 |
+
and self.init_cfg.get('type') == 'Pretrained'):
|
| 306 |
+
logger = get_root_logger()
|
| 307 |
+
|
| 308 |
+
checkpoint = _load_checkpoint(
|
| 309 |
+
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
| 310 |
+
|
| 311 |
+
if 'state_dict' in checkpoint:
|
| 312 |
+
state_dict = checkpoint['state_dict']
|
| 313 |
+
else:
|
| 314 |
+
state_dict = checkpoint
|
| 315 |
+
|
| 316 |
+
if 'pos_embed' in state_dict.keys():
|
| 317 |
+
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
| 318 |
+
logger.info(msg=f'Resize the pos_embed shape from '
|
| 319 |
+
f'{state_dict["pos_embed"].shape} to '
|
| 320 |
+
f'{self.pos_embed.shape}')
|
| 321 |
+
h, w = self.img_size
|
| 322 |
+
pos_size = int(
|
| 323 |
+
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
| 324 |
+
state_dict['pos_embed'] = self.resize_pos_embed(
|
| 325 |
+
state_dict['pos_embed'],
|
| 326 |
+
(h // self.patch_size, w // self.patch_size),
|
| 327 |
+
(pos_size, pos_size), self.interpolate_mode)
|
| 328 |
+
|
| 329 |
+
print(self.load_state_dict(state_dict, False))
|
| 330 |
+
elif self.init_cfg is not None:
|
| 331 |
+
super(VisionTransformer, self).init_weights()
|
| 332 |
+
else:
|
| 333 |
+
# We only implement the 'jax_impl' initialization implemented at
|
| 334 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
| 335 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 336 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 337 |
+
for n, m in self.named_modules():
|
| 338 |
+
if isinstance(m, nn.Linear):
|
| 339 |
+
trunc_normal_(m.weight, std=.02)
|
| 340 |
+
if m.bias is not None:
|
| 341 |
+
if 'ffn' in n:
|
| 342 |
+
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
| 343 |
+
else:
|
| 344 |
+
nn.init.constant_(m.bias, 0)
|
| 345 |
+
elif isinstance(m, nn.Conv2d):
|
| 346 |
+
kaiming_init(m, mode='fan_in', bias=0.)
|
| 347 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
| 348 |
+
constant_init(m, val=1.0, bias=0.)
|
| 349 |
+
|
| 350 |
+
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
|
| 351 |
+
"""Positiong embeding method.
|
| 352 |
+
|
| 353 |
+
Resize the pos_embed, if the input image size doesn't match
|
| 354 |
+
the training size.
|
| 355 |
+
Args:
|
| 356 |
+
patched_img (torch.Tensor): The patched image, it should be
|
| 357 |
+
shape of [B, L1, C].
|
| 358 |
+
hw_shape (tuple): The downsampled image resolution.
|
| 359 |
+
pos_embed (torch.Tensor): The pos_embed weighs, it should be
|
| 360 |
+
shape of [B, L2, c].
|
| 361 |
+
Return:
|
| 362 |
+
torch.Tensor: The pos encoded image feature.
|
| 363 |
+
"""
|
| 364 |
+
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
|
| 365 |
+
'the shapes of patched_img and pos_embed must be [B, L, C]'
|
| 366 |
+
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
|
| 367 |
+
if x_len != pos_len:
|
| 368 |
+
if pos_len == (self.img_size[0] // self.patch_size) * (
|
| 369 |
+
self.img_size[1] // self.patch_size) + 1:
|
| 370 |
+
pos_h = self.img_size[0] // self.patch_size
|
| 371 |
+
pos_w = self.img_size[1] // self.patch_size
|
| 372 |
+
else:
|
| 373 |
+
raise ValueError(
|
| 374 |
+
'Unexpected shape of pos_embed, got {}.'.format(
|
| 375 |
+
pos_embed.shape))
|
| 376 |
+
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
|
| 377 |
+
(pos_h, pos_w),
|
| 378 |
+
self.interpolate_mode)
|
| 379 |
+
return self.drop_after_pos(patched_img + pos_embed)
|
| 380 |
+
|
| 381 |
+
@staticmethod
|
| 382 |
+
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
|
| 383 |
+
"""Resize pos_embed weights.
|
| 384 |
+
|
| 385 |
+
Resize pos_embed using bicubic interpolate method.
|
| 386 |
+
Args:
|
| 387 |
+
pos_embed (torch.Tensor): Position embedding weights.
|
| 388 |
+
input_shpae (tuple): Tuple for (downsampled input image height,
|
| 389 |
+
downsampled input image width).
|
| 390 |
+
pos_shape (tuple): The resolution of downsampled origin training
|
| 391 |
+
image.
|
| 392 |
+
mode (str): Algorithm used for upsampling:
|
| 393 |
+
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
| 394 |
+
``'trilinear'``. Default: ``'nearest'``
|
| 395 |
+
Return:
|
| 396 |
+
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
| 397 |
+
"""
|
| 398 |
+
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
| 399 |
+
pos_h, pos_w = pos_shape
|
| 400 |
+
cls_token_weight = pos_embed[:, 0]
|
| 401 |
+
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
| 402 |
+
pos_embed_weight = pos_embed_weight.reshape(
|
| 403 |
+
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
| 404 |
+
pos_embed_weight = resize(
|
| 405 |
+
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
| 406 |
+
cls_token_weight = cls_token_weight.unsqueeze(1)
|
| 407 |
+
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
| 408 |
+
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
| 409 |
+
return pos_embed
|
| 410 |
+
|
| 411 |
+
def forward(self, inputs):
|
| 412 |
+
B = inputs.shape[0]
|
| 413 |
+
|
| 414 |
+
x, hw_shape = self.patch_embed(inputs)
|
| 415 |
+
|
| 416 |
+
# stole cls_tokens impl from Phil Wang, thanks
|
| 417 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 418 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 419 |
+
x = self._pos_embeding(x, hw_shape, self.pos_embed)
|
| 420 |
+
|
| 421 |
+
if not self.with_cls_token:
|
| 422 |
+
# Remove class token for transformer encoder input
|
| 423 |
+
x = x[:, 1:]
|
| 424 |
+
|
| 425 |
+
if self.pre_norm:
|
| 426 |
+
x = self.norm0(x)
|
| 427 |
+
|
| 428 |
+
outs = []
|
| 429 |
+
for i, layer in enumerate(self.layers):
|
| 430 |
+
x, q, k, v = layer(x, self.return_qkv[i] \
|
| 431 |
+
or (i == len(self.layers) - 1 and self.skip_last_attn))
|
| 432 |
+
if i == len(self.layers) - 1:
|
| 433 |
+
if self.final_norm:
|
| 434 |
+
x = self.norm1(x)
|
| 435 |
+
if self.return_qkv[i]:
|
| 436 |
+
v = self.norm1(v)
|
| 437 |
+
if self.skip_last_attn:
|
| 438 |
+
if self.with_cls_token:
|
| 439 |
+
x[:, 1:] = v[:, 1:]
|
| 440 |
+
else:
|
| 441 |
+
x = v
|
| 442 |
+
if i in self.out_indices:
|
| 443 |
+
if self.with_cls_token:
|
| 444 |
+
# Remove class token and reshape token for decoder head
|
| 445 |
+
out = x[:, 1:]
|
| 446 |
+
else:
|
| 447 |
+
out = x
|
| 448 |
+
B, _, C = out.shape
|
| 449 |
+
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
| 450 |
+
C).permute(0, 3, 1, 2).contiguous()
|
| 451 |
+
if self.output_cls_token:
|
| 452 |
+
out = [out, x[:, 0]]
|
| 453 |
+
if self.return_qkv[i]:
|
| 454 |
+
if self.with_cls_token:
|
| 455 |
+
q = q[:, 1:]
|
| 456 |
+
k = k[:, 1:]
|
| 457 |
+
v = v[:, 1:]
|
| 458 |
+
v = v.reshape(B, hw_shape[0], hw_shape[1],
|
| 459 |
+
C).permute(0, 3, 1, 2).contiguous()
|
| 460 |
+
out = [out, q, k, v]
|
| 461 |
+
outs.append(out)
|
| 462 |
+
|
| 463 |
+
return tuple(outs)
|
| 464 |
+
|
| 465 |
+
def train(self, mode=True):
|
| 466 |
+
super(VisionTransformer, self).train(mode)
|
| 467 |
+
if mode and self.norm_eval:
|
| 468 |
+
for m in self.modules():
|
| 469 |
+
if isinstance(m, nn.LayerNorm):
|
| 470 |
+
m.eval()
|
segmentation/configs/_base_/custom_import.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ----------------------------------------------------------------------------------------------------
|
| 5 |
+
# Modified from TCL
|
| 6 |
+
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
custom_imports = dict(
|
| 10 |
+
imports=["segmentation.datasets.coco_object", "segmentation.datasets.pascal_voc", "datasets.transforms", "segmentation.datasets.pascal_voc20"],
|
| 11 |
+
allow_failed_imports=False,
|
| 12 |
+
)
|
segmentation/configs/_base_/datasets/ade20k.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ----------------------------------------------------------------------------------------------------
|
| 5 |
+
# Modified from TCL
|
| 6 |
+
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
_base_ = ["../custom_import.py"]
|
| 9 |
+
# dataset settings
|
| 10 |
+
dataset_type = "ADE20KDataset"
|
| 11 |
+
data_root = "./data"
|
| 12 |
+
|
| 13 |
+
train_pipeline = [
|
| 14 |
+
dict(type="LoadImageFromFile"),
|
| 15 |
+
dict(type='ToRGB'),
|
| 16 |
+
dict(
|
| 17 |
+
type="MultiScaleFlipAug",
|
| 18 |
+
img_scale=(2048, 448),
|
| 19 |
+
flip=True,
|
| 20 |
+
transforms=[
|
| 21 |
+
dict(type='LoadImageFromFile'),
|
| 22 |
+
dict(type='ToRGB'),
|
| 23 |
+
dict(type='Resize', img_scale=(2048, 448)),
|
| 24 |
+
dict(type='RandomCrop', crop_size=(448, 448)),
|
| 25 |
+
dict(type='RandomFlip', prob=0.5),
|
| 26 |
+
dict(type='PhotoMetricDistortion'),
|
| 27 |
+
dict(type="ImageToTensorV2", keys=["img"]),
|
| 28 |
+
dict(type='Collect', keys=['img'], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
|
| 29 |
+
],
|
| 30 |
+
),
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
test_pipeline = [
|
| 34 |
+
dict(type="LoadImageFromFile"),
|
| 35 |
+
dict(type='ToRGB'),
|
| 36 |
+
dict(
|
| 37 |
+
type="MultiScaleFlipAug",
|
| 38 |
+
img_scale=(2048, 448),
|
| 39 |
+
flip=False,
|
| 40 |
+
transforms=[
|
| 41 |
+
dict(type="Resize", keep_ratio=True),
|
| 42 |
+
dict(type="RandomFlip"),
|
| 43 |
+
dict(type="ImageToTensorV2", keys=["img"]),
|
| 44 |
+
dict(type="Collect", keys=["img"], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
|
| 45 |
+
],
|
| 46 |
+
),
|
| 47 |
+
]
|
| 48 |
+
data = dict(
|
| 49 |
+
test=dict(
|
| 50 |
+
type=dataset_type,
|
| 51 |
+
data_root=data_root,
|
| 52 |
+
img_dir="ADEChallengeData2016/images/validation",
|
| 53 |
+
ann_dir="ADEChallengeData2016/annotations/validation",
|
| 54 |
+
pipeline=test_pipeline,
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
|
segmentation/configs/_base_/datasets/cityscapes.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ----------------------------------------------------------------------------------------------------
|
| 5 |
+
# Modified from TCL
|
| 6 |
+
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
_base_ = ["../custom_import.py"]
|
| 9 |
+
# dataset settings
|
| 10 |
+
dataset_type = "CityscapesDataset"
|
| 11 |
+
data_root = "./data/cityscapes"
|
| 12 |
+
test_pipeline = [
|
| 13 |
+
dict(type="LoadImageFromFile"),
|
| 14 |
+
dict(type='ToRGB'),
|
| 15 |
+
dict(
|
| 16 |
+
type="MultiScaleFlipAug",
|
| 17 |
+
img_scale=(2048, 448),
|
| 18 |
+
flip=False,
|
| 19 |
+
transforms=[
|
| 20 |
+
dict(type="Resize", keep_ratio=True),
|
| 21 |
+
dict(type="RandomFlip"),
|
| 22 |
+
dict(type="ImageToTensorV2", keys=["img"]),
|
| 23 |
+
dict(type="Collect", keys=["img"], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
|
| 24 |
+
],
|
| 25 |
+
),
|
| 26 |
+
]
|
| 27 |
+
data = dict(
|
| 28 |
+
test=dict(
|
| 29 |
+
type=dataset_type,
|
| 30 |
+
data_root=data_root,
|
| 31 |
+
img_dir="leftImg8bit/val",
|
| 32 |
+
ann_dir="gtFine/val",
|
| 33 |
+
pipeline=test_pipeline,
|
| 34 |
+
)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
|
segmentation/configs/_base_/datasets/coco.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ----------------------------------------------------------------------------------------------------
|
| 5 |
+
# Modified from TCL
|
| 6 |
+
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
_base_ = ["../custom_import.py"]
|
| 9 |
+
# dataset settings
|
| 10 |
+
dataset_type = "COCOObjectDataset"
|
| 11 |
+
data_root = "./data/coco_stuff164k"
|
| 12 |
+
|
| 13 |
+
test_pipeline = [
|
| 14 |
+
dict(type="LoadImageFromFile"),
|
| 15 |
+
dict(type='ToRGB'),
|
| 16 |
+
dict(
|
| 17 |
+
type="MultiScaleFlipAug",
|
| 18 |
+
img_scale=(2048, 448),
|
| 19 |
+
flip=False,
|
| 20 |
+
transforms=[
|
| 21 |
+
dict(type="Resize", keep_ratio=True),
|
| 22 |
+
dict(type="RandomFlip"),
|
| 23 |
+
dict(type="ImageToTensorV2", keys=["img"]),
|
| 24 |
+
dict(type="Collect", keys=["img"], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
|
| 25 |
+
],
|
| 26 |
+
),
|
| 27 |
+
]
|
| 28 |
+
data = dict(
|
| 29 |
+
|
| 30 |
+
test=dict(
|
| 31 |
+
type=dataset_type,
|
| 32 |
+
data_root=data_root,
|
| 33 |
+
img_dir="images/val2017",
|
| 34 |
+
ann_dir="annotations/val2017",
|
| 35 |
+
pipeline=test_pipeline,
|
| 36 |
+
)
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
|
segmentation/configs/_base_/datasets/pascal_context.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ----------------------------------------------------------------------------------------------------
|
| 5 |
+
# Modified from TCL
|
| 6 |
+
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
_base_ = ["../custom_import.py"]
|
| 9 |
+
# dataset settings
|
| 10 |
+
dataset_type = "PascalContextDataset"
|
| 11 |
+
data_root = "./data/VOCdevkit/VOC2010"
|
| 12 |
+
test_pipeline = [
|
| 13 |
+
dict(type="LoadImageFromFile"),
|
| 14 |
+
dict(type='ToRGB'),
|
| 15 |
+
dict(
|
| 16 |
+
type="MultiScaleFlipAug",
|
| 17 |
+
img_scale=(2048, 448),
|
| 18 |
+
flip=False,
|
| 19 |
+
transforms=[
|
| 20 |
+
dict(type="Resize", keep_ratio=True),
|
| 21 |
+
dict(type="RandomFlip"),
|
| 22 |
+
dict(type="ImageToTensorV2", keys=["img"]),
|
| 23 |
+
dict(type="Collect", keys=["img"], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
|
| 24 |
+
],
|
| 25 |
+
),
|
| 26 |
+
]
|
| 27 |
+
data = dict(
|
| 28 |
+
test=dict(
|
| 29 |
+
type=dataset_type,
|
| 30 |
+
data_root=data_root,
|
| 31 |
+
img_dir="JPEGImages",
|
| 32 |
+
ann_dir="SegmentationClassContext",
|
| 33 |
+
split="ImageSets/SegmentationContext/val.txt",
|
| 34 |
+
pipeline=test_pipeline,
|
| 35 |
+
)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
|
segmentation/configs/_base_/datasets/pascal_context59.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ----------------------------------------------------------------------------------------------------
|
| 5 |
+
# Modified from TCL
|
| 6 |
+
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
# dataset settings
|
| 9 |
+
dataset_type = "PascalContextDataset59"
|
| 10 |
+
data_root = "./data/VOCdevkit/VOC2010"
|
| 11 |
+
test_pipeline = [
|
| 12 |
+
dict(type="LoadImageFromFile"),
|
| 13 |
+
dict(type='ToRGB'),
|
| 14 |
+
dict(
|
| 15 |
+
type="MultiScaleFlipAug",
|
| 16 |
+
img_scale=(2048, 448),
|
| 17 |
+
flip=False,
|
| 18 |
+
transforms=[
|
| 19 |
+
dict(type="Resize", keep_ratio=True),
|
| 20 |
+
dict(type="RandomFlip"),
|
| 21 |
+
dict(type="ImageToTensorV2", keys=["img"]),
|
| 22 |
+
dict(type="Collect", keys=["img"],
|
| 23 |
+
meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
|
| 24 |
+
],
|
| 25 |
+
),
|
| 26 |
+
]
|
| 27 |
+
data = dict(
|
| 28 |
+
test=dict(
|
| 29 |
+
type=dataset_type,
|
| 30 |
+
data_root=data_root,
|
| 31 |
+
img_dir="JPEGImages",
|
| 32 |
+
ann_dir="SegmentationClassContext",
|
| 33 |
+
split="ImageSets/SegmentationContext/val.txt",
|
| 34 |
+
pipeline=test_pipeline,
|
| 35 |
+
)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
|
segmentation/configs/_base_/datasets/pascal_voc12.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ----------------------------------------------------------------------------------------------------
|
| 5 |
+
# Modified from TCL
|
| 6 |
+
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
_base_ = ["../custom_import.py"]
|
| 9 |
+
# dataset settings
|
| 10 |
+
dataset_type = "PascalVOCDataset"
|
| 11 |
+
data_root = "./data/VOCdevkit/VOC2012"
|
| 12 |
+
|
| 13 |
+
test_pipeline = [
|
| 14 |
+
dict(type="LoadImageFromFile"),
|
| 15 |
+
dict(type='ToRGB'),
|
| 16 |
+
dict(
|
| 17 |
+
type="MultiScaleFlipAug",
|
| 18 |
+
img_scale=(2048, 448),
|
| 19 |
+
flip=False,
|
| 20 |
+
transforms=[
|
| 21 |
+
dict(type="Resize", keep_ratio=True),
|
| 22 |
+
dict(type="RandomFlip"),
|
| 23 |
+
dict(type="ImageToTensorV2", keys=["img"]),
|
| 24 |
+
dict(type="Collect", keys=["img"], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
|
| 25 |
+
],
|
| 26 |
+
),
|
| 27 |
+
]
|
| 28 |
+
data = dict(
|
| 29 |
+
|
| 30 |
+
test=dict(
|
| 31 |
+
type=dataset_type,
|
| 32 |
+
data_root=data_root,
|
| 33 |
+
img_dir="JPEGImages",
|
| 34 |
+
ann_dir="SegmentationClass",
|
| 35 |
+
split="ImageSets/Segmentation/val.txt",
|
| 36 |
+
pipeline=test_pipeline,
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
|
segmentation/configs/_base_/datasets/pascal_voc12_20.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ----------------------------------------------------------------------------------------------------
|
| 5 |
+
# Modified from GroupViT (https://github.com/NVlabs/GroupViT)
|
| 6 |
+
# Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
_base_ = ["../custom_import.py"]
|
| 9 |
+
# dataset settings
|
| 10 |
+
dataset_type = "PascalVOCDataset20"
|
| 11 |
+
data_root = "./data/VOCdevkit/VOC2012"
|
| 12 |
+
test_pipeline = [
|
| 13 |
+
dict(type="LoadImageFromFile"),
|
| 14 |
+
dict(type='ToRGB'),
|
| 15 |
+
dict(
|
| 16 |
+
type="MultiScaleFlipAug",
|
| 17 |
+
img_scale=(2048, 448),
|
| 18 |
+
flip=False,
|
| 19 |
+
transforms=[
|
| 20 |
+
dict(type="Resize", keep_ratio=True),
|
| 21 |
+
dict(type="RandomFlip"),
|
| 22 |
+
dict(type="ImageToTensorV2", keys=["img"]),
|
| 23 |
+
dict(type="Collect", keys=["img"],
|
| 24 |
+
meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
|
| 25 |
+
],
|
| 26 |
+
),
|
| 27 |
+
]
|
| 28 |
+
data = dict(
|
| 29 |
+
test=dict(
|
| 30 |
+
type=dataset_type,
|
| 31 |
+
data_root=data_root,
|
| 32 |
+
img_dir="JPEGImages",
|
| 33 |
+
ann_dir="SegmentationClass",
|
| 34 |
+
split="ImageSets/Segmentation/val.txt",
|
| 35 |
+
pipeline=test_pipeline,
|
| 36 |
+
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
|
segmentation/configs/_base_/datasets/stuff.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ----------------------------------------------------------------------------------------------------
|
| 5 |
+
# Modified from TCL
|
| 6 |
+
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
_base_ = ["../custom_import.py"]
|
| 9 |
+
# dataset settings
|
| 10 |
+
dataset_type = "COCOStuffDataset"
|
| 11 |
+
data_root = "./data/coco_stuff164k"
|
| 12 |
+
|
| 13 |
+
test_pipeline = [
|
| 14 |
+
dict(type="LoadImageFromFile"),
|
| 15 |
+
dict(type='ToRGB'),
|
| 16 |
+
dict(
|
| 17 |
+
type="MultiScaleFlipAug",
|
| 18 |
+
img_scale=(2048, 448),
|
| 19 |
+
flip=False,
|
| 20 |
+
transforms=[
|
| 21 |
+
dict(type="Resize", keep_ratio=True),
|
| 22 |
+
dict(type="RandomFlip"),
|
| 23 |
+
dict(type="ImageToTensorV2", keys=["img"]),
|
| 24 |
+
dict(type="Collect", keys=["img"], meta_keys=['ori_shape', 'img_shape', 'pad_shape', 'flip', 'img_info']),
|
| 25 |
+
],
|
| 26 |
+
),
|
| 27 |
+
]
|
| 28 |
+
data = dict(
|
| 29 |
+
|
| 30 |
+
test=dict(
|
| 31 |
+
type=dataset_type,
|
| 32 |
+
data_root=data_root,
|
| 33 |
+
img_dir="images/val2017",
|
| 34 |
+
ann_dir="annotations/val2017",
|
| 35 |
+
pipeline=test_pipeline,
|
| 36 |
+
)
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
test_cfg = dict(mode="slide", stride=(224, 224), crop_size=(448, 448))
|
segmentation/datasets/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .coco_object import *
|
| 2 |
+
from .pascal_voc import *
|
| 3 |
+
from .pascal_voc20 import *
|
| 4 |
+
from .pascal_context import *
|
| 5 |
+
from .coco_stuff import *
|
segmentation/datasets/coco_object.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ----------------------------------------------------------------------------------------------------
|
| 5 |
+
# Modified from TCL
|
| 6 |
+
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
from mmseg.datasets import DATASETS, CustomDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@DATASETS.register_module()
|
| 12 |
+
class COCOObjectDataset(CustomDataset):
|
| 13 |
+
"""COCO-Object dataset.
|
| 14 |
+
|
| 15 |
+
1 bg class + first 80 classes from the COCO-Stuff dataset.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'aeroplane', 'bus', 'train', 'truck', 'boat',
|
| 19 |
+
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
|
| 20 |
+
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
| 21 |
+
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
|
| 22 |
+
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
|
| 23 |
+
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
|
| 24 |
+
'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
|
| 25 |
+
'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
| 26 |
+
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
|
| 27 |
+
|
| 28 |
+
PALETTE = [[0, 0, 0], [0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224],
|
| 29 |
+
[0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64],
|
| 30 |
+
[0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
|
| 31 |
+
[128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0],
|
| 32 |
+
[192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32],
|
| 33 |
+
[128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
|
| 34 |
+
[128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32],
|
| 35 |
+
[128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
|
| 36 |
+
[0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
|
| 37 |
+
[0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160],
|
| 38 |
+
[192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0],
|
| 39 |
+
[0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0]]
|
| 40 |
+
|
| 41 |
+
def __init__(self, **kwargs):
|
| 42 |
+
super(COCOObjectDataset, self).__init__(img_suffix='.jpg', seg_map_suffix='_instanceTrainIds.png', **kwargs)
|
segmentation/datasets/coco_stuff.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from mmseg.datasets import DATASETS, CustomDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@DATASETS.register_module(force=True)
|
| 10 |
+
class COCOStuffDataset(CustomDataset):
|
| 11 |
+
"""COCO-Stuff dataset.
|
| 12 |
+
|
| 13 |
+
In segmentation map annotation for COCO-Stuff, Train-IDs of the 10k version
|
| 14 |
+
are from 1 to 171, where 0 is the ignore index, and Train-ID of COCO Stuff
|
| 15 |
+
164k is from 0 to 170, where 255 is the ignore index. So, they are all 171
|
| 16 |
+
semantic categories. ``reduce_zero_label`` is set to True and False for the
|
| 17 |
+
10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg',
|
| 18 |
+
and ``seg_map_suffix`` is fixed to '.png'.
|
| 19 |
+
"""
|
| 20 |
+
CLASSES = (
|
| 21 |
+
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
| 22 |
+
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
| 23 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
|
| 24 |
+
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
|
| 25 |
+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
| 26 |
+
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
| 27 |
+
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
|
| 28 |
+
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
| 29 |
+
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
| 30 |
+
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
|
| 31 |
+
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
| 32 |
+
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
| 33 |
+
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
|
| 34 |
+
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
|
| 35 |
+
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
|
| 36 |
+
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
|
| 37 |
+
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
|
| 38 |
+
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood',
|
| 39 |
+
'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass',
|
| 40 |
+
'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat',
|
| 41 |
+
'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',
|
| 42 |
+
'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
|
| 43 |
+
'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
|
| 44 |
+
'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
|
| 45 |
+
'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
|
| 46 |
+
'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
|
| 47 |
+
'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
|
| 48 |
+
'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
|
| 49 |
+
'window-blind', 'window-other', 'wood')
|
| 50 |
+
|
| 51 |
+
PALETTE = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
|
| 52 |
+
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
|
| 53 |
+
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
|
| 54 |
+
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
|
| 55 |
+
[0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
|
| 56 |
+
[128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
|
| 57 |
+
[64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160],
|
| 58 |
+
[0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0],
|
| 59 |
+
[0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128],
|
| 60 |
+
[64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160],
|
| 61 |
+
[0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128],
|
| 62 |
+
[128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192],
|
| 63 |
+
[0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160],
|
| 64 |
+
[64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0],
|
| 65 |
+
[0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
|
| 66 |
+
[0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160],
|
| 67 |
+
[64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128],
|
| 68 |
+
[128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128],
|
| 69 |
+
[64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224],
|
| 70 |
+
[64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0],
|
| 71 |
+
[0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128],
|
| 72 |
+
[64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224],
|
| 73 |
+
[64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128],
|
| 74 |
+
[128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192],
|
| 75 |
+
[0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224],
|
| 76 |
+
[0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0],
|
| 77 |
+
[64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192],
|
| 78 |
+
[0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224],
|
| 79 |
+
[0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128],
|
| 80 |
+
[192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128],
|
| 81 |
+
[64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160],
|
| 82 |
+
[0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64],
|
| 83 |
+
[64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128],
|
| 84 |
+
[64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160],
|
| 85 |
+
[0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192],
|
| 86 |
+
[192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192],
|
| 87 |
+
[0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160],
|
| 88 |
+
[64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64],
|
| 89 |
+
[64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192],
|
| 90 |
+
[0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
|
| 91 |
+
[64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
|
| 92 |
+
[192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
|
| 93 |
+
[64, 192, 96], [64, 160, 64], [64, 64, 0]]
|
| 94 |
+
|
| 95 |
+
def __init__(self, **kwargs):
|
| 96 |
+
super(COCOStuffDataset, self).__init__(
|
| 97 |
+
img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs)
|
segmentation/datasets/pascal_context.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ----------------------------------------------------------------------------------------------------
|
| 5 |
+
# from MaskCLIP
|
| 6 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
from mmseg.datasets import DATASETS, CustomDataset
|
| 9 |
+
import os.path as osp
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@DATASETS.register_module(force=True)
|
| 13 |
+
class PascalContextDataset(CustomDataset):
|
| 14 |
+
"""PascalContext dataset.
|
| 15 |
+
|
| 16 |
+
In segmentation map annotation for PascalContext, 0 stands for background,
|
| 17 |
+
which is included in 60 categories. ``reduce_zero_label`` is fixed to
|
| 18 |
+
False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
| 19 |
+
fixed to '.png'.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
split (str): Split txt file for PascalContext.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
|
| 26 |
+
'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
|
| 27 |
+
'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
|
| 28 |
+
'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
|
| 29 |
+
'floor', 'flower', 'food', 'grass', 'ground', 'horse',
|
| 30 |
+
'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
|
| 31 |
+
'plate', 'platform', 'potted plant', 'road', 'rock', 'sheep',
|
| 32 |
+
'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
|
| 33 |
+
'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
|
| 34 |
+
'window', 'wood')
|
| 35 |
+
|
| 36 |
+
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
| 37 |
+
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
| 38 |
+
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
| 39 |
+
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
| 40 |
+
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
| 41 |
+
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
| 42 |
+
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
| 43 |
+
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
| 44 |
+
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
| 45 |
+
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
| 46 |
+
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
| 47 |
+
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
| 48 |
+
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
| 49 |
+
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
| 50 |
+
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
|
| 51 |
+
|
| 52 |
+
def __init__(self, split, **kwargs):
|
| 53 |
+
super(PascalContextDataset, self).__init__(
|
| 54 |
+
img_suffix='.jpg',
|
| 55 |
+
seg_map_suffix='.png',
|
| 56 |
+
split=split,
|
| 57 |
+
reduce_zero_label=False,
|
| 58 |
+
**kwargs)
|
| 59 |
+
assert osp.exists(self.img_dir) and self.split is not None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@DATASETS.register_module(force=True)
|
| 63 |
+
class PascalContextDataset59(CustomDataset):
|
| 64 |
+
"""PascalContext dataset.
|
| 65 |
+
|
| 66 |
+
In segmentation map annotation for PascalContext, 0 stands for background,
|
| 67 |
+
which is included in 60 categories. ``reduce_zero_label`` is fixed to
|
| 68 |
+
False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
| 69 |
+
fixed to '.png'.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
split (str): Split txt file for PascalContext.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
|
| 76 |
+
'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet',
|
| 77 |
+
'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
|
| 78 |
+
'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
|
| 79 |
+
'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
|
| 80 |
+
'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
|
| 81 |
+
'potted plant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
|
| 82 |
+
'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
|
| 83 |
+
'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
|
| 84 |
+
|
| 85 |
+
PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
|
| 86 |
+
[120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
|
| 87 |
+
[4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
|
| 88 |
+
[120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
|
| 89 |
+
[204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
|
| 90 |
+
[61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
|
| 91 |
+
[255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
|
| 92 |
+
[112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
|
| 93 |
+
[10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
|
| 94 |
+
[102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
|
| 95 |
+
[0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
|
| 96 |
+
[235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
|
| 97 |
+
[250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
|
| 98 |
+
[255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
|
| 99 |
+
[0, 235, 255], [0, 173, 255], [31, 0, 255]]
|
| 100 |
+
|
| 101 |
+
def __init__(self, split, **kwargs):
|
| 102 |
+
super(PascalContextDataset59, self).__init__(
|
| 103 |
+
img_suffix='.jpg',
|
| 104 |
+
seg_map_suffix='.png',
|
| 105 |
+
split=split,
|
| 106 |
+
reduce_zero_label=True,
|
| 107 |
+
**kwargs)
|
| 108 |
+
assert osp.exists(self.img_dir) and self.split is not None
|
segmentation/datasets/pascal_voc.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# TCL
|
| 3 |
+
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
|
| 4 |
+
# ------------------------------------------------------------------------------
|
| 5 |
+
# Modified from GroupViT (https://github.com/NVlabs/GroupViT)
|
| 6 |
+
# Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
import os
|
| 9 |
+
from mmseg.datasets import DATASETS
|
| 10 |
+
from mmseg.datasets import CustomDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@DATASETS.register_module(force=True)
|
| 14 |
+
class PascalVOCDataset(CustomDataset):
|
| 15 |
+
"""Pascal VOC dataset (the background class is ignored).
|
| 16 |
+
Burrowed from MaskCLIP
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
split (str): Split txt file for Pascal VOC.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
|
| 23 |
+
'bus', 'car', 'cat', 'chair', 'cow', 'dining table', 'dog',
|
| 24 |
+
'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa',
|
| 25 |
+
'train', 'tvmonitor')
|
| 26 |
+
|
| 27 |
+
PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
|
| 28 |
+
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
|
| 29 |
+
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
|
| 30 |
+
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
|
| 31 |
+
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
|
| 32 |
+
|
| 33 |
+
def __init__(self, split, **kwargs):
|
| 34 |
+
super(PascalVOCDataset, self).__init__(
|
| 35 |
+
img_suffix='.jpg',
|
| 36 |
+
seg_map_suffix='.png',
|
| 37 |
+
split=split,
|
| 38 |
+
reduce_zero_label=False,
|
| 39 |
+
**kwargs)
|
| 40 |
+
assert os.path.exists(self.img_dir) and self.split is not None
|
segmentation/datasets/pascal_voc20.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ----------------------------------------------------------------------------------------------------
|
| 5 |
+
# Modified from TCL
|
| 6 |
+
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import os.path as osp
|
| 10 |
+
from mmseg.datasets import DATASETS
|
| 11 |
+
from mmseg.datasets import CustomDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@DATASETS.register_module()
|
| 15 |
+
class PascalVOCDataset20(CustomDataset):
|
| 16 |
+
"""Pascal VOC dataset (the background class is ignored).
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
split (str): Split txt file for Pascal VOC.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
|
| 23 |
+
'bus', 'car', 'cat', 'chair', 'cow', 'dining table', 'dog',
|
| 24 |
+
'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa',
|
| 25 |
+
'train', 'tvmonitor')
|
| 26 |
+
|
| 27 |
+
PALETTE = [[128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
|
| 28 |
+
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
|
| 29 |
+
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
|
| 30 |
+
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
|
| 31 |
+
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
|
| 32 |
+
|
| 33 |
+
def __init__(self, split, **kwargs):
|
| 34 |
+
super(PascalVOCDataset20, self).__init__(
|
| 35 |
+
img_suffix='.jpg',
|
| 36 |
+
seg_map_suffix='.png',
|
| 37 |
+
split=split,
|
| 38 |
+
reduce_zero_label=True,
|
| 39 |
+
**kwargs)
|
| 40 |
+
assert osp.exists(self.img_dir) and self.split is not None
|
segmentation/evaluation/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .builder import build_seg_dataloader, build_seg_dataset, build_seg_inference
|
segmentation/evaluation/builder.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------------------------------
|
| 2 |
+
# CLIP-DINOiser
|
| 3 |
+
# authors: Monika Wysoczanska, Warsaw University of Technology
|
| 4 |
+
# ---------------------------------------------------------------------------------------------------
|
| 5 |
+
# modified from TCL
|
| 6 |
+
# Copyright (c) 2023 Kakao Brain. All Rights Reserved.
|
| 7 |
+
# ---------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import mmcv
|
| 10 |
+
from mmseg.datasets import build_dataloader, build_dataset
|
| 11 |
+
from mmcv.utils import Registry
|
| 12 |
+
from mmcv.cnn import MODELS as MMCV_MODELS
|
| 13 |
+
MODELS = Registry('models', parent=MMCV_MODELS)
|
| 14 |
+
SEGMENTORS = MODELS
|
| 15 |
+
from .clip_dinoiser_eval import DinoCLIP_Infrencer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def build_seg_dataset(config):
|
| 19 |
+
"""Build a dataset from config."""
|
| 20 |
+
cfg = mmcv.Config.fromfile(config)
|
| 21 |
+
dataset = build_dataset(cfg.data.test)
|
| 22 |
+
return dataset
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_seg_dataloader(dataset, dist=True):
|
| 26 |
+
# batch size is set to 1 to handle varying image size (due to different aspect ratio)
|
| 27 |
+
if dist:
|
| 28 |
+
data_loader = build_dataloader(
|
| 29 |
+
dataset,
|
| 30 |
+
samples_per_gpu=1,
|
| 31 |
+
workers_per_gpu=2,
|
| 32 |
+
dist=dist,
|
| 33 |
+
shuffle=False,
|
| 34 |
+
persistent_workers=True,
|
| 35 |
+
pin_memory=False,
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
data_loader = build_dataloader(
|
| 39 |
+
dataset=dataset,
|
| 40 |
+
samples_per_gpu=1,
|
| 41 |
+
workers_per_gpu=2,
|
| 42 |
+
dist=dist,
|
| 43 |
+
shuffle=False,
|
| 44 |
+
persistent_workers=True,
|
| 45 |
+
pin_memory=False,
|
| 46 |
+
)
|
| 47 |
+
return data_loader
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def build_seg_inference(
|
| 51 |
+
model,
|
| 52 |
+
dataset,
|
| 53 |
+
config,
|
| 54 |
+
seg_config,
|
| 55 |
+
):
|
| 56 |
+
dset_cfg = mmcv.Config.fromfile(seg_config) # dataset config
|
| 57 |
+
classnames = dataset.CLASSES
|
| 58 |
+
kwargs = dict()
|
| 59 |
+
if hasattr(dset_cfg, "test_cfg"):
|
| 60 |
+
kwargs["test_cfg"] = dset_cfg.test_cfg
|
| 61 |
+
|
| 62 |
+
seg_model = DinoCLIP_Infrencer(model, num_classes=len(classnames), **kwargs, **config.evaluate)
|
| 63 |
+
seg_model.CLASSES = dataset.CLASSES
|
| 64 |
+
seg_model.PALETTE = dataset.PALETTE
|
| 65 |
+
|
| 66 |
+
return seg_model
|
segmentation/evaluation/clip_dinoiser_eval.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
log = logging.getLogger(__name__)
|
| 5 |
+
from mmseg.ops import resize
|
| 6 |
+
from mmseg.models import EncoderDecoder
|
| 7 |
+
|
| 8 |
+
class DinoCLIP_Infrencer(EncoderDecoder):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
model,
|
| 12 |
+
num_classes,
|
| 13 |
+
test_cfg=dict(),
|
| 14 |
+
**kwargs,
|
| 15 |
+
):
|
| 16 |
+
super(EncoderDecoder, self).__init__()
|
| 17 |
+
self.mode = test_cfg['mode']
|
| 18 |
+
self.num_classes = num_classes
|
| 19 |
+
self.model = model
|
| 20 |
+
self.test_cfg = test_cfg
|
| 21 |
+
self.align_corners = False
|
| 22 |
+
|
| 23 |
+
@torch.no_grad()
|
| 24 |
+
def encode_decode(self, img, meta_data):
|
| 25 |
+
"""
|
| 26 |
+
"""
|
| 27 |
+
masks = self.model(img)
|
| 28 |
+
masks = resize(
|
| 29 |
+
input=masks,
|
| 30 |
+
size=img.shape[-2:],
|
| 31 |
+
mode='bilinear',
|
| 32 |
+
align_corners=self.align_corners)
|
| 33 |
+
return masks
|
| 34 |
+
|
visualization.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
def mask2rgb(mask, palette):
|
| 4 |
+
img = np.zeros((mask.shape[0], mask.shape[1], 3))
|
| 5 |
+
for l in np.unique(mask):
|
| 6 |
+
img[mask == int(l)] = palette[int(l)]
|
| 7 |
+
return img.astype(int)
|