Spaces:
Running
on
L4
Running
on
L4
Commit
·
5e2bf3b
1
Parent(s):
9d0b2aa
load from HF
Browse files- hugging_face/app.py +8 -3
- matanyone/__init__.py +0 -0
- matanyone/inference/inference_core.py +1 -2
- matanyone/inference/memory_manager.py +1 -5
- matanyone/model/big_modules.py +13 -6
- matanyone/model/matanyone.py +18 -8
- matanyone/model/modules.py +3 -24
- matanyone/model/transformer/object_summarizer.py +1 -1
- matanyone/model/transformer/object_transformer.py +1 -1
- matanyone/model/utils/resnet.py +1 -1
- matanyone/utils/get_default_model.py +8 -4
hugging_face/app.py
CHANGED
|
@@ -416,9 +416,14 @@ sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type]
|
|
| 416 |
model = MaskGenerator(sam_checkpoint, args)
|
| 417 |
|
| 418 |
# initialize matanyone
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
matanyone_model = matanyone_model.to(args.device).eval()
|
| 423 |
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
|
| 424 |
|
|
|
|
| 416 |
model = MaskGenerator(sam_checkpoint, args)
|
| 417 |
|
| 418 |
# initialize matanyone
|
| 419 |
+
# load from ckpt
|
| 420 |
+
# pretrain_model_url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0"
|
| 421 |
+
# ckpt_path = load_file_from_url(os.path.join(pretrain_model_url, 'matanyone.pth'), checkpoint_folder)
|
| 422 |
+
# matanyone_model = get_matanyone_model(ckpt_path, args.device)
|
| 423 |
+
# load from Hugging Face
|
| 424 |
+
from matanyone.model.matanyone import MatAnyone
|
| 425 |
+
matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
|
| 426 |
+
|
| 427 |
matanyone_model = matanyone_model.to(args.device).eval()
|
| 428 |
matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
|
| 429 |
|
matanyone/__init__.py
ADDED
|
File without changes
|
matanyone/inference/inference_core.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import List, Optional, Iterable
|
| 2 |
import logging
|
| 3 |
from omegaconf import DictConfig
|
| 4 |
|
|
@@ -302,7 +302,6 @@ class InferenceCore:
|
|
| 302 |
|
| 303 |
mask, _ = pad_divide_by(mask, 16)
|
| 304 |
if need_segment:
|
| 305 |
-
print("HERE!!!!!!!!!!!")
|
| 306 |
# merge predicted mask with the incomplete input mask
|
| 307 |
pred_prob_no_bg = pred_prob_with_bg[1:]
|
| 308 |
# use the mutual exclusivity of segmentation
|
|
|
|
| 1 |
+
from typing import List, Optional, Iterable
|
| 2 |
import logging
|
| 3 |
from omegaconf import DictConfig
|
| 4 |
|
|
|
|
| 302 |
|
| 303 |
mask, _ = pad_divide_by(mask, 16)
|
| 304 |
if need_segment:
|
|
|
|
| 305 |
# merge predicted mask with the incomplete input mask
|
| 306 |
pred_prob_no_bg = pred_prob_with_bg[1:]
|
| 307 |
# use the mutual exclusivity of segmentation
|
matanyone/inference/memory_manager.py
CHANGED
|
@@ -2,12 +2,11 @@ import logging
|
|
| 2 |
from omegaconf import DictConfig
|
| 3 |
from typing import List, Dict
|
| 4 |
import torch
|
| 5 |
-
import cv2
|
| 6 |
|
| 7 |
from matanyone.inference.object_manager import ObjectManager
|
| 8 |
from matanyone.inference.kv_memory_store import KeyValueMemoryStore
|
| 9 |
from matanyone.model.matanyone import MatAnyone
|
| 10 |
-
from matanyone.model.utils.memory_utils import
|
| 11 |
|
| 12 |
log = logging.getLogger()
|
| 13 |
|
|
@@ -128,8 +127,6 @@ class MemoryManager:
|
|
| 128 |
bs = pix_feat.shape[0]
|
| 129 |
assert last_mask.shape[0] == bs
|
| 130 |
|
| 131 |
-
uncert_mask = uncert_output["mask"] if uncert_output is not None else None
|
| 132 |
-
|
| 133 |
"""
|
| 134 |
Compute affinity and perform readout
|
| 135 |
"""
|
|
@@ -374,7 +371,6 @@ class MemoryManager:
|
|
| 374 |
self.engaged = False
|
| 375 |
|
| 376 |
def compress_features(self, bucket_id: int) -> None:
|
| 377 |
-
HW = self.HW
|
| 378 |
|
| 379 |
# perform memory consolidation
|
| 380 |
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
|
|
|
| 2 |
from omegaconf import DictConfig
|
| 3 |
from typing import List, Dict
|
| 4 |
import torch
|
|
|
|
| 5 |
|
| 6 |
from matanyone.inference.object_manager import ObjectManager
|
| 7 |
from matanyone.inference.kv_memory_store import KeyValueMemoryStore
|
| 8 |
from matanyone.model.matanyone import MatAnyone
|
| 9 |
+
from matanyone.model.utils.memory_utils import get_similarity, do_softmax
|
| 10 |
|
| 11 |
log = logging.getLogger()
|
| 12 |
|
|
|
|
| 127 |
bs = pix_feat.shape[0]
|
| 128 |
assert last_mask.shape[0] == bs
|
| 129 |
|
|
|
|
|
|
|
| 130 |
"""
|
| 131 |
Compute affinity and perform readout
|
| 132 |
"""
|
|
|
|
| 371 |
self.engaged = False
|
| 372 |
|
| 373 |
def compress_features(self, bucket_id: int) -> None:
|
|
|
|
| 374 |
|
| 375 |
# perform memory consolidation
|
| 376 |
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
matanyone/model/big_modules.py
CHANGED
|
@@ -8,14 +8,15 @@ g - usually denotes features that are not shared between objects
|
|
| 8 |
The trailing number of a variable usually denotes the stride
|
| 9 |
"""
|
| 10 |
|
|
|
|
| 11 |
from omegaconf import DictConfig
|
| 12 |
import torch
|
| 13 |
import torch.nn as nn
|
| 14 |
import torch.nn.functional as F
|
| 15 |
|
| 16 |
-
from matanyone.model.group_modules import
|
| 17 |
from matanyone.model.utils import resnet
|
| 18 |
-
from matanyone.model.modules import
|
| 19 |
|
| 20 |
class UncertPred(nn.Module):
|
| 21 |
def __init__(self, model_cfg: DictConfig):
|
|
@@ -51,11 +52,14 @@ class PixelEncoder(nn.Module):
|
|
| 51 |
super().__init__()
|
| 52 |
|
| 53 |
self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
|
|
|
|
|
|
|
|
|
|
| 54 |
if self.is_resnet:
|
| 55 |
if model_cfg.pixel_encoder.type == 'resnet18':
|
| 56 |
-
network = resnet.resnet18(pretrained=
|
| 57 |
elif model_cfg.pixel_encoder.type == 'resnet50':
|
| 58 |
-
network = resnet.resnet50(pretrained=
|
| 59 |
else:
|
| 60 |
raise NotImplementedError
|
| 61 |
self.conv1 = network.conv1
|
|
@@ -127,10 +131,13 @@ class MaskEncoder(nn.Module):
|
|
| 127 |
self.single_object = single_object
|
| 128 |
extra_dim = 1 if single_object else 2
|
| 129 |
|
|
|
|
|
|
|
|
|
|
| 130 |
if model_cfg.mask_encoder.type == 'resnet18':
|
| 131 |
-
network = resnet.resnet18(pretrained=
|
| 132 |
elif model_cfg.mask_encoder.type == 'resnet50':
|
| 133 |
-
network = resnet.resnet50(pretrained=
|
| 134 |
else:
|
| 135 |
raise NotImplementedError
|
| 136 |
self.conv1 = network.conv1
|
|
|
|
| 8 |
The trailing number of a variable usually denotes the stride
|
| 9 |
"""
|
| 10 |
|
| 11 |
+
from typing import Iterable
|
| 12 |
from omegaconf import DictConfig
|
| 13 |
import torch
|
| 14 |
import torch.nn as nn
|
| 15 |
import torch.nn.functional as F
|
| 16 |
|
| 17 |
+
from matanyone.model.group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d
|
| 18 |
from matanyone.model.utils import resnet
|
| 19 |
+
from matanyone.model.modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock
|
| 20 |
|
| 21 |
class UncertPred(nn.Module):
|
| 22 |
def __init__(self, model_cfg: DictConfig):
|
|
|
|
| 52 |
super().__init__()
|
| 53 |
|
| 54 |
self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
|
| 55 |
+
# if model_cfg.pretrained_resnet is set in the model_cfg we get the value
|
| 56 |
+
# else default to True
|
| 57 |
+
is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
|
| 58 |
if self.is_resnet:
|
| 59 |
if model_cfg.pixel_encoder.type == 'resnet18':
|
| 60 |
+
network = resnet.resnet18(pretrained=is_pretrained_resnet)
|
| 61 |
elif model_cfg.pixel_encoder.type == 'resnet50':
|
| 62 |
+
network = resnet.resnet50(pretrained=is_pretrained_resnet)
|
| 63 |
else:
|
| 64 |
raise NotImplementedError
|
| 65 |
self.conv1 = network.conv1
|
|
|
|
| 131 |
self.single_object = single_object
|
| 132 |
extra_dim = 1 if single_object else 2
|
| 133 |
|
| 134 |
+
# if model_cfg.pretrained_resnet is set in the model_cfg we get the value
|
| 135 |
+
# else default to True
|
| 136 |
+
is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
|
| 137 |
if model_cfg.mask_encoder.type == 'resnet18':
|
| 138 |
+
network = resnet.resnet18(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
|
| 139 |
elif model_cfg.mask_encoder.type == 'resnet50':
|
| 140 |
+
network = resnet.resnet50(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
|
| 141 |
else:
|
| 142 |
raise NotImplementedError
|
| 143 |
self.conv1 = network.conv1
|
matanyone/model/matanyone.py
CHANGED
|
@@ -1,21 +1,31 @@
|
|
| 1 |
-
from typing import List, Dict
|
| 2 |
import logging
|
| 3 |
from omegaconf import DictConfig
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
from matanyone.model.
|
| 8 |
-
from matanyone.model.big_modules import *
|
| 9 |
from matanyone.model.aux_modules import AuxComputer
|
| 10 |
-
from matanyone.model.utils.memory_utils import
|
| 11 |
from matanyone.model.transformer.object_transformer import QueryTransformer
|
| 12 |
from matanyone.model.transformer.object_summarizer import ObjectSummarizer
|
| 13 |
from matanyone.utils.tensor_utils import aggregate
|
| 14 |
|
| 15 |
log = logging.getLogger()
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def __init__(self, cfg: DictConfig, *, single_object=False):
|
| 21 |
super().__init__()
|
|
@@ -304,7 +314,7 @@ class MatAnyone(nn.Module):
|
|
| 304 |
finetune a trained model with single object datasets.
|
| 305 |
"""
|
| 306 |
if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
|
| 307 |
-
log.warning(
|
| 308 |
'This is not supposed to happen in standard training.')
|
| 309 |
src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
|
| 310 |
src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
|
|
|
|
| 1 |
+
from typing import List, Dict, Iterable
|
| 2 |
import logging
|
| 3 |
from omegaconf import DictConfig
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 9 |
|
| 10 |
+
from matanyone.model.big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder
|
|
|
|
| 11 |
from matanyone.model.aux_modules import AuxComputer
|
| 12 |
+
from matanyone.model.utils.memory_utils import get_affinity, readout
|
| 13 |
from matanyone.model.transformer.object_transformer import QueryTransformer
|
| 14 |
from matanyone.model.transformer.object_summarizer import ObjectSummarizer
|
| 15 |
from matanyone.utils.tensor_utils import aggregate
|
| 16 |
|
| 17 |
log = logging.getLogger()
|
| 18 |
+
class MatAnyone(nn.Module,
|
| 19 |
+
PyTorchModelHubMixin,
|
| 20 |
+
library_name="matanyone",
|
| 21 |
+
repo_url="https://github.com/pq-yang/MatAnyone",
|
| 22 |
+
coders={
|
| 23 |
+
DictConfig: (
|
| 24 |
+
lambda x: OmegaConf.to_container(x),
|
| 25 |
+
lambda data: OmegaConf.create(data),
|
| 26 |
+
)
|
| 27 |
+
},
|
| 28 |
+
):
|
| 29 |
|
| 30 |
def __init__(self, cfg: DictConfig, *, single_object=False):
|
| 31 |
super().__init__()
|
|
|
|
| 314 |
finetune a trained model with single object datasets.
|
| 315 |
"""
|
| 316 |
if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
|
| 317 |
+
log.warning('Converting mask_encoder.conv1.weight from multiple objects to single object.'
|
| 318 |
'This is not supposed to happen in standard training.')
|
| 319 |
src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
|
| 320 |
src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
|
matanyone/model/modules.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
from typing import List, Iterable
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
|
|
|
| 4 |
|
| 5 |
-
from matanyone.model.group_modules import
|
| 6 |
|
| 7 |
|
| 8 |
class UpsampleBlock(nn.Module):
|
|
@@ -145,26 +146,4 @@ class ResBlock(nn.Module):
|
|
| 145 |
|
| 146 |
g = self.downsample(g)
|
| 147 |
|
| 148 |
-
return out_g + g
|
| 149 |
-
|
| 150 |
-
def __init__(self, in_dim, reduction_dim, bins):
|
| 151 |
-
super(PPM, self).__init__()
|
| 152 |
-
self.features = []
|
| 153 |
-
for bin in bins:
|
| 154 |
-
self.features.append(nn.Sequential(
|
| 155 |
-
nn.AdaptiveAvgPool2d(bin),
|
| 156 |
-
nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
|
| 157 |
-
nn.PReLU()
|
| 158 |
-
))
|
| 159 |
-
self.features = nn.ModuleList(self.features)
|
| 160 |
-
self.fuse = nn.Sequential(
|
| 161 |
-
nn.Conv2d(in_dim+reduction_dim*4, in_dim, kernel_size=3, padding=1, bias=False),
|
| 162 |
-
nn.PReLU())
|
| 163 |
-
|
| 164 |
-
def forward(self, x):
|
| 165 |
-
x_size = x.size()
|
| 166 |
-
out = [x]
|
| 167 |
-
for f in self.features:
|
| 168 |
-
out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
|
| 169 |
-
out_feat = self.fuse(torch.cat(out, 1))
|
| 170 |
-
return out_feat
|
|
|
|
| 1 |
from typing import List, Iterable
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
|
| 6 |
+
from matanyone.model.group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups
|
| 7 |
|
| 8 |
|
| 9 |
class UpsampleBlock(nn.Module):
|
|
|
|
| 146 |
|
| 147 |
g = self.downsample(g)
|
| 148 |
|
| 149 |
+
return out_g + g
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
matanyone/model/transformer/object_summarizer.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import
|
| 2 |
from omegaconf import DictConfig
|
| 3 |
|
| 4 |
import torch
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
from omegaconf import DictConfig
|
| 3 |
|
| 4 |
import torch
|
matanyone/model/transformer/object_transformer.py
CHANGED
|
@@ -6,7 +6,7 @@ import torch.nn as nn
|
|
| 6 |
from matanyone.model.group_modules import GConv2d
|
| 7 |
from matanyone.utils.tensor_utils import aggregate
|
| 8 |
from matanyone.model.transformer.positional_encoding import PositionalEncoding
|
| 9 |
-
from matanyone.model.transformer.transformer_layers import
|
| 10 |
|
| 11 |
|
| 12 |
class QueryTransformerBlock(nn.Module):
|
|
|
|
| 6 |
from matanyone.model.group_modules import GConv2d
|
| 7 |
from matanyone.utils.tensor_utils import aggregate
|
| 8 |
from matanyone.model.transformer.positional_encoding import PositionalEncoding
|
| 9 |
+
from matanyone.model.transformer.transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN
|
| 10 |
|
| 11 |
|
| 12 |
class QueryTransformerBlock(nn.Module):
|
matanyone/model/utils/resnet.py
CHANGED
|
@@ -15,7 +15,7 @@ def load_weights_add_extra_dim(target, source_state, extra_dim=1):
|
|
| 15 |
new_dict = OrderedDict()
|
| 16 |
|
| 17 |
for k1, v1 in target.state_dict().items():
|
| 18 |
-
if
|
| 19 |
if k1 in source_state:
|
| 20 |
tar_v = source_state[k1]
|
| 21 |
|
|
|
|
| 15 |
new_dict = OrderedDict()
|
| 16 |
|
| 17 |
for k1, v1 in target.state_dict().items():
|
| 18 |
+
if 'num_batches_tracked' not in k1:
|
| 19 |
if k1 in source_state:
|
| 20 |
tar_v = source_state[k1]
|
| 21 |
|
matanyone/utils/get_default_model.py
CHANGED
|
@@ -6,9 +6,8 @@ from hydra import compose, initialize
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
from matanyone.model.matanyone import MatAnyone
|
| 9 |
-
from matanyone.inference.utils.args_utils import get_dataset_cfg
|
| 10 |
|
| 11 |
-
def get_matanyone_model(ckpt_path, device) -> MatAnyone:
|
| 12 |
initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
|
| 13 |
cfg = compose(config_name="eval_matanyone_config")
|
| 14 |
|
|
@@ -16,8 +15,13 @@ def get_matanyone_model(ckpt_path, device) -> MatAnyone:
|
|
| 16 |
cfg['weights'] = ckpt_path
|
| 17 |
|
| 18 |
# Load the network weights
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
matanyone.load_weights(model_weights)
|
| 22 |
|
| 23 |
return matanyone
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
from matanyone.model.matanyone import MatAnyone
|
|
|
|
| 9 |
|
| 10 |
+
def get_matanyone_model(ckpt_path, device=None) -> MatAnyone:
|
| 11 |
initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
|
| 12 |
cfg = compose(config_name="eval_matanyone_config")
|
| 13 |
|
|
|
|
| 15 |
cfg['weights'] = ckpt_path
|
| 16 |
|
| 17 |
# Load the network weights
|
| 18 |
+
if device is not None:
|
| 19 |
+
matanyone = MatAnyone(cfg, single_object=True).to(device).eval()
|
| 20 |
+
model_weights = torch.load(cfg.weights, map_location=device)
|
| 21 |
+
else: # if device is not specified, `.cuda()` by default
|
| 22 |
+
matanyone = MatAnyone(cfg, single_object=True).cuda().eval()
|
| 23 |
+
model_weights = torch.load(cfg.weights)
|
| 24 |
+
|
| 25 |
matanyone.load_weights(model_weights)
|
| 26 |
|
| 27 |
return matanyone
|