|
|
from flax import config |
|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.checkpoint as cp |
|
|
from transformers import PreTrainedModel, AutoTokenizer, AutoModel, AutoProcessor |
|
|
from typing import Dict, List, Tuple, Optional, Any, Union |
|
|
import numpy as np |
|
|
import os |
|
|
import cv2 |
|
|
from collections import defaultdict |
|
|
import builtins |
|
|
import sys |
|
|
from laser.models import llava_clip_model_v3 |
|
|
sys.modules["llava_clip_model_v3"] = llava_clip_model_v3 |
|
|
import inspect |
|
|
from transformers.models.clip import modeling_clip |
|
|
import transformers |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .vine_config import VineConfig |
|
|
from laser.models.model_utils import ( |
|
|
extract_single_object, |
|
|
extract_object_subject, |
|
|
crop_image_contain_bboxes, |
|
|
segment_list |
|
|
) |
|
|
from .flattening import ( |
|
|
extract_valid_object_pairs, |
|
|
flatten_segments_for_batch, |
|
|
) |
|
|
|
|
|
from .vis_utils import save_mask_one_image |
|
|
|
|
|
class VineModel(PreTrainedModel): |
|
|
""" |
|
|
VINE (Video Understanding with Natural Language) Model |
|
|
|
|
|
This model processes videos along with categorical, unary, and binary keywords |
|
|
to return probability distributions over those keywords for detected objects |
|
|
and their relationships in the video. |
|
|
""" |
|
|
|
|
|
config_class = VineConfig |
|
|
|
|
|
def __init__(self, config: VineConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.config = config |
|
|
self.visualize = getattr(config, "visualize", False) |
|
|
self.visualization_dir = getattr(config, "visualization_dir", None) |
|
|
self.debug_visualizations = getattr(config, "debug_visualizations", False) |
|
|
self._device = getattr(config, "_device") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name) |
|
|
if self.clip_tokenizer.pad_token is None: |
|
|
self.clip_tokenizer.pad_token = ( |
|
|
self.clip_tokenizer.unk_token |
|
|
if self.clip_tokenizer.unk_token |
|
|
else self.clip_tokenizer.eos_token |
|
|
) |
|
|
self.clip_processor = AutoProcessor.from_pretrained(config.model_name) |
|
|
self.clip_cate_model = AutoModel.from_pretrained(config.model_name) |
|
|
self.clip_unary_model = AutoModel.from_pretrained(config.model_name) |
|
|
self.clip_binary_model = AutoModel.from_pretrained(config.model_name) |
|
|
|
|
|
|
|
|
|
|
|
if config.use_hf_repo: |
|
|
self._load_huggingface_vine_weights(config.model_repo, config.model_file) |
|
|
else: |
|
|
self._load_local_pretrained_vine_weights(config.local_dir, config.local_filename) |
|
|
|
|
|
|
|
|
self.to(self._device) |
|
|
|
|
|
def _load_huggingface_vine_weights(self, model_repo: str, model_file: Optional[str] = None): |
|
|
""" |
|
|
Load pretrained VINE weights from HuggingFace Hub. |
|
|
""" |
|
|
try: |
|
|
print(f"Loading VINE weights from HuggingFace repo: {model_repo}") |
|
|
vine_model = AutoModel.from_pretrained( |
|
|
model_repo, |
|
|
trust_remote_code=True, |
|
|
revision=model_file if model_file else "main" |
|
|
) |
|
|
self.clip_cate_model = vine_model.clip_cate_model |
|
|
self.clip_unary_model = vine_model.clip_unary_model |
|
|
self.clip_binary_model = vine_model.clip_binary_model |
|
|
print("✓ Successfully loaded VINE weights from HuggingFace Hub") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"✗ Error loading VINE weights from HuggingFace Hub: {e}") |
|
|
print("Using base CLIP models instead") |
|
|
return False |
|
|
|
|
|
def _load_local_pretrained_vine_weights(self, local_dir: str, local_filename: Optional[str] = None, epoch: int = 0): |
|
|
""" |
|
|
Load pretrained VINE weights from a saved .pt file or ensemble format. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_path = os.path.join(local_dir, local_filename) if local_filename else local_dir |
|
|
|
|
|
if full_path.endswith(".pkl"): |
|
|
print(f"Loading VINE weights from: {full_path}") |
|
|
loaded_vine_model = torch.load(full_path, map_location=self._device, weights_only=False) |
|
|
|
|
|
print(f"Loaded state type: {type(loaded_vine_model)}") |
|
|
if not isinstance(loaded_vine_model, dict): |
|
|
if hasattr(loaded_vine_model, 'clip_cate_model'): |
|
|
self.clip_cate_model.load_state_dict(loaded_vine_model.clip_cate_model.state_dict()) |
|
|
if hasattr(loaded_vine_model, 'clip_unary_model'): |
|
|
self.clip_unary_model.load_state_dict(loaded_vine_model.clip_unary_model.state_dict()) |
|
|
if hasattr(loaded_vine_model, 'clip_binary_model'): |
|
|
self.clip_binary_model.load_state_dict(loaded_vine_model.clip_binary_model.state_dict()) |
|
|
return True |
|
|
|
|
|
elif full_path.endswith(".pt") or full_path.endswith(".pth"): |
|
|
state = torch.load(full_path, map_location=self._device, weights_only=True) |
|
|
print(f"Loaded state type: {type(state)}") |
|
|
self.load_state_dict(state) |
|
|
return True |
|
|
|
|
|
|
|
|
if os.path.isdir(full_path): |
|
|
model_files = [f for f in os.listdir(full_path) if f.endswith(f'.{epoch}.model')] |
|
|
if model_files: |
|
|
model_file = os.path.join(full_path, model_files[0]) |
|
|
print(f"Loading VINE weights from: {model_file}") |
|
|
pretrained_model = torch.load(model_file, map_location="cpu") |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(pretrained_model, 'clip_cate_model'): |
|
|
self.clip_cate_model.load_state_dict(pretrained_model.clip_cate_model.state_dict()) |
|
|
if hasattr(pretrained_model, 'clip_unary_model'): |
|
|
self.clip_unary_model.load_state_dict(pretrained_model.clip_unary_model.state_dict()) |
|
|
if hasattr(pretrained_model, 'clip_binary_model'): |
|
|
self.clip_binary_model.load_state_dict(pretrained_model.clip_binary_model.state_dict()) |
|
|
print("✓ Loaded all sub-model weights from ensemble format") |
|
|
return True |
|
|
else: |
|
|
print(f"No model file found for epoch {epoch} in {full_path}") |
|
|
return False |
|
|
|
|
|
print("Unsupported format for pretrained_vine_path") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_pretrained_vine( |
|
|
cls, |
|
|
model_path: str, |
|
|
config: Optional[VineConfig] = None, |
|
|
epoch: int = 0, |
|
|
**kwargs |
|
|
): |
|
|
""" |
|
|
Create VineModel from pretrained VINE weights. |
|
|
|
|
|
Args: |
|
|
model_path: Path to pretrained VINE model |
|
|
config: Optional config, will create default if None |
|
|
epoch: Epoch number to load |
|
|
**kwargs: Additional arguments |
|
|
|
|
|
Returns: |
|
|
VineModel instance with loaded weights |
|
|
""" |
|
|
|
|
|
if config is None: |
|
|
|
|
|
|
|
|
if model_path and ("/" in model_path and not os.path.exists(model_path)): |
|
|
config = VineConfig(use_hf_repo=True, model_repo=model_path) |
|
|
else: |
|
|
|
|
|
if os.path.isdir(model_path): |
|
|
config = VineConfig(use_hf_repo=False, local_dir=model_path) |
|
|
else: |
|
|
config = VineConfig( |
|
|
use_hf_repo=False, |
|
|
local_dir=os.path.dirname(model_path) or None, |
|
|
local_filename=os.path.basename(model_path) or None, |
|
|
) |
|
|
else: |
|
|
|
|
|
if model_path and ("/" in model_path and not os.path.exists(model_path)): |
|
|
config.use_hf_repo = True |
|
|
config.model_repo = model_path |
|
|
config.model_file = None |
|
|
config.local_dir = None |
|
|
config.local_filename = None |
|
|
else: |
|
|
config.use_hf_repo = False |
|
|
if os.path.isdir(model_path): |
|
|
config.local_dir = model_path |
|
|
config.local_filename = None |
|
|
else: |
|
|
config.local_dir = os.path.dirname(model_path) or None |
|
|
config.local_filename = os.path.basename(model_path) or None |
|
|
|
|
|
|
|
|
model = cls(config, **kwargs) |
|
|
|
|
|
return model |
|
|
|
|
|
def _text_features_checkpoint(self, model, tokens): |
|
|
"""Extract text features with gradient checkpointing.""" |
|
|
token_keys = list(tokens.keys()) |
|
|
|
|
|
def get_text_features_wrapped(*inputs): |
|
|
kwargs = {key: value for key, value in zip(token_keys, inputs)} |
|
|
return model.get_text_features(**kwargs) |
|
|
|
|
|
token_values = [tokens[key] for key in token_keys] |
|
|
return cp.checkpoint(get_text_features_wrapped, *token_values, use_reentrant=False) |
|
|
|
|
|
def _image_features_checkpoint(self, model, images): |
|
|
"""Extract image features with gradient checkpointing.""" |
|
|
return cp.checkpoint(model.get_image_features, images, use_reentrant=False) |
|
|
|
|
|
def clip_sim(self, model, nl_feat, img_feat): |
|
|
img_feat = img_feat / img_feat.norm(p=2, dim=-1, keepdim=True) |
|
|
nl_feat = nl_feat / nl_feat.norm(p=2, dim=-1, keepdim=True) |
|
|
logits = torch.matmul(img_feat, nl_feat.T) |
|
|
if hasattr(model, "logit_scale"): |
|
|
logits = logits * model.logit_scale.exp() |
|
|
return logits |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
video_frames: torch.Tensor, |
|
|
masks: Dict[int, Dict[int, torch.Tensor]], |
|
|
bboxes: Dict[int, Dict[int, List]], |
|
|
categorical_keywords: List[str], |
|
|
unary_keywords: Optional[List[str]] = None, |
|
|
binary_keywords: Optional[List[str]] = None, |
|
|
object_pairs: Optional[List[Tuple[int, int]]] = None, |
|
|
return_flattened_segments: Optional[bool] = None, |
|
|
return_valid_pairs: Optional[bool] = None, |
|
|
interested_object_pairs: Optional[List[Tuple[int, int]]] = None, |
|
|
debug_visualizations: Optional[bool] = None, |
|
|
**kwargs |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Forward pass of the VINE model. |
|
|
|
|
|
Args: |
|
|
video_frames: Tensor of shape (num_frames, height, width, 3) |
|
|
masks: Dict mapping frame_id -> object_id -> mask tensor |
|
|
bboxes: Dict mapping frame_id -> object_id -> [x1, y1, x2, y2] |
|
|
categorical_keywords: List of category names to classify objects |
|
|
unary_keywords: Optional list of unary predicates (actions on single objects) |
|
|
binary_keywords: Optional list of binary predicates (relations between objects) |
|
|
object_pairs: Optional list of (obj1_id, obj2_id) pairs for binary classification |
|
|
|
|
|
Returns: |
|
|
Dict containing probability distributions for categorical, unary, and binary predictions |
|
|
""" |
|
|
if unary_keywords is None: |
|
|
unary_keywords = [] |
|
|
if binary_keywords is None: |
|
|
binary_keywords = [] |
|
|
if object_pairs is None: |
|
|
object_pairs = [] |
|
|
if return_flattened_segments is None: |
|
|
return_flattened_segments = self.config.return_flattened_segments |
|
|
if return_valid_pairs is None: |
|
|
return_valid_pairs = self.config.return_valid_pairs |
|
|
if interested_object_pairs is None or len(interested_object_pairs) == 0: |
|
|
interested_object_pairs = getattr(self.config, "interested_object_pairs", []) or [] |
|
|
if debug_visualizations is None: |
|
|
debug_visualizations = self.debug_visualizations |
|
|
|
|
|
|
|
|
dummy_str = "" |
|
|
|
|
|
|
|
|
if len(categorical_keywords) == 0: |
|
|
categorical_keywords = [dummy_str] |
|
|
if len(unary_keywords) == 0: |
|
|
unary_keywords = [dummy_str] |
|
|
if len(binary_keywords) == 0: |
|
|
binary_keywords = [dummy_str] |
|
|
|
|
|
|
|
|
categorical_features = self._extract_text_features( |
|
|
self.clip_cate_model, categorical_keywords |
|
|
) |
|
|
unary_features = self._extract_text_features( |
|
|
self.clip_unary_model, unary_keywords |
|
|
) |
|
|
binary_features = self._extract_text_features( |
|
|
self.clip_binary_model, binary_keywords |
|
|
) |
|
|
|
|
|
|
|
|
categorical_probs = {} |
|
|
unary_probs = {} |
|
|
binary_probs = {} |
|
|
|
|
|
|
|
|
for frame_id, frame_masks in masks.items(): |
|
|
if frame_id >= len(video_frames): |
|
|
continue |
|
|
|
|
|
frame = self._frame_to_numpy(video_frames[frame_id]) |
|
|
frame_bboxes = bboxes.get(frame_id, {}) |
|
|
|
|
|
|
|
|
for obj_id, mask in frame_masks.items(): |
|
|
if obj_id not in frame_bboxes: |
|
|
continue |
|
|
|
|
|
bbox = frame_bboxes[obj_id] |
|
|
|
|
|
|
|
|
mask_np = self._mask_to_numpy(mask) |
|
|
|
|
|
obj_image = extract_single_object( |
|
|
frame, mask_np, alpha=self.config.alpha |
|
|
) |
|
|
|
|
|
|
|
|
obj_features = self._extract_image_features( |
|
|
self.clip_cate_model, obj_image |
|
|
) |
|
|
|
|
|
|
|
|
cat_similarities = self.clip_sim( |
|
|
self.clip_cate_model, categorical_features, obj_features |
|
|
) |
|
|
cat_probs = F.softmax(cat_similarities, dim=-1) |
|
|
|
|
|
|
|
|
for i, keyword in enumerate(categorical_keywords): |
|
|
if keyword != dummy_str: |
|
|
categorical_probs[(obj_id, keyword)] = cat_probs[0, i].item() |
|
|
|
|
|
|
|
|
if len(unary_keywords) > 0 and unary_keywords[0] != dummy_str: |
|
|
unary_similarities = self.clip_sim( |
|
|
self.clip_unary_model, unary_features, obj_features |
|
|
) |
|
|
unary_probs_tensor = F.softmax(unary_similarities, dim=-1) |
|
|
|
|
|
for i, keyword in enumerate(unary_keywords): |
|
|
if keyword != dummy_str: |
|
|
unary_probs[(frame_id, obj_id, keyword)] = unary_probs_tensor[0, i].item() |
|
|
|
|
|
|
|
|
if len(binary_keywords) > 0 and binary_keywords[0] != dummy_str and len(object_pairs) > 0: |
|
|
for obj1_id, obj2_id in object_pairs: |
|
|
for frame_id, frame_masks in masks.items(): |
|
|
if frame_id >= len(video_frames): |
|
|
continue |
|
|
if (obj1_id in frame_masks and obj2_id in frame_masks and |
|
|
obj1_id in bboxes.get(frame_id, {}) and obj2_id in bboxes.get(frame_id, {})): |
|
|
|
|
|
frame = self._frame_to_numpy(video_frames[frame_id]) |
|
|
mask1 = frame_masks[obj1_id] |
|
|
mask2 = frame_masks[obj2_id] |
|
|
|
|
|
mask1_np = self._mask_to_numpy(mask1) |
|
|
mask2_np = self._mask_to_numpy(mask2) |
|
|
|
|
|
|
|
|
pair_image = extract_object_subject( |
|
|
frame, mask1_np[..., None], mask2_np[..., None], |
|
|
alpha=self.config.alpha, |
|
|
white_alpha=self.config.white_alpha |
|
|
) |
|
|
|
|
|
|
|
|
bbox1 = bboxes[frame_id][obj1_id] |
|
|
bbox2 = bboxes[frame_id][obj2_id] |
|
|
|
|
|
|
|
|
if bbox1[0] >= bbox2[2] or bbox2[1] >= bbox1[3] or \ |
|
|
bbox2[0] >= bbox1[2] or bbox1[1] >= bbox2[3]: |
|
|
continue |
|
|
|
|
|
cropped_image = crop_image_contain_bboxes( |
|
|
pair_image, [bbox1, bbox2], f"frame_{frame_id}" |
|
|
) |
|
|
|
|
|
|
|
|
pair_features = self._extract_image_features( |
|
|
self.clip_binary_model, cropped_image |
|
|
) |
|
|
|
|
|
|
|
|
binary_similarities = self.clip_sim( |
|
|
self.clip_binary_model, binary_features, pair_features |
|
|
) |
|
|
binary_probs_tensor = F.softmax(binary_similarities, dim=-1) |
|
|
|
|
|
for i, keyword in enumerate(binary_keywords): |
|
|
if keyword != dummy_str: |
|
|
binary_probs[(frame_id, (obj1_id, obj2_id), keyword)] = binary_probs_tensor[0, i].item() |
|
|
|
|
|
|
|
|
dummy_prob = 1.0 / max(len(categorical_keywords), len(unary_keywords), len(binary_keywords)) |
|
|
|
|
|
result: Dict[str, Any] = { |
|
|
"categorical_probs": {0: categorical_probs}, |
|
|
"unary_probs": {0: unary_probs}, |
|
|
"binary_probs": [binary_probs], |
|
|
"dummy_prob": dummy_prob |
|
|
} |
|
|
|
|
|
if return_flattened_segments or return_valid_pairs: |
|
|
flattened = flatten_segments_for_batch( |
|
|
video_id=0, |
|
|
segments=masks, |
|
|
bbox_min_dim=self.config.bbox_min_dim, |
|
|
) |
|
|
if return_flattened_segments: |
|
|
result["flattened_segments"] = flattened |
|
|
if return_valid_pairs: |
|
|
interested_pairs = interested_object_pairs if interested_object_pairs else None |
|
|
result["valid_pairs"] = extract_valid_object_pairs( |
|
|
flattened["object_ids"], |
|
|
interested_pairs, |
|
|
) |
|
|
if interested_pairs is None: |
|
|
|
|
|
result["valid_pairs_metadata"] = {"pair_source": "all_pairs"} |
|
|
else: |
|
|
result["valid_pairs_metadata"] = {"pair_source": "filtered", "requested_pairs": interested_pairs} |
|
|
|
|
|
return result |
|
|
|
|
|
def _frame_to_numpy(self, frame: Union[torch.Tensor, np.ndarray]) -> np.ndarray: |
|
|
"""Convert a frame tensor/array to a contiguous numpy array.""" |
|
|
if torch.is_tensor(frame): |
|
|
frame_np = frame.detach().cpu().numpy() |
|
|
else: |
|
|
frame_np = np.asarray(frame) |
|
|
return np.ascontiguousarray(frame_np) |
|
|
|
|
|
def _mask_to_numpy(self, mask: Union[torch.Tensor, np.ndarray]) -> np.ndarray: |
|
|
"""Convert a mask tensor/array to a 2D boolean numpy array.""" |
|
|
if torch.is_tensor(mask): |
|
|
mask_np = mask.detach().cpu().numpy() |
|
|
else: |
|
|
mask_np = np.asarray(mask) |
|
|
|
|
|
if mask_np.ndim == 3: |
|
|
if mask_np.shape[0] == 1: |
|
|
mask_np = mask_np.squeeze(0) |
|
|
elif mask_np.shape[2] == 1: |
|
|
mask_np = mask_np.squeeze(2) |
|
|
|
|
|
if mask_np.ndim != 2: |
|
|
raise ValueError(f"Mask must be 2D after squeezing, got shape {mask_np.shape}") |
|
|
|
|
|
return mask_np.astype(bool, copy=False) |
|
|
|
|
|
def _extract_text_features(self, model, keywords): |
|
|
"""Extract text features for given keywords.""" |
|
|
tokens = self.clip_tokenizer( |
|
|
keywords, |
|
|
return_tensors="pt", |
|
|
max_length=75, |
|
|
truncation=True, |
|
|
padding='max_length' |
|
|
).to(self._device) |
|
|
|
|
|
return self._text_features_checkpoint(model, tokens) |
|
|
|
|
|
def _extract_image_features(self, model, image): |
|
|
"""Extract image features for given image.""" |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
if image.dtype != np.uint8: |
|
|
image = image.astype(np.uint8) |
|
|
|
|
|
if len(image.shape) == 3 and image.shape[2] == 3: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
inputs = self.clip_processor( |
|
|
images=image, |
|
|
return_tensors="pt" |
|
|
).to(self._device) |
|
|
|
|
|
return self._image_features_checkpoint(model, inputs['pixel_values']) |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
video_frames: torch.Tensor, |
|
|
masks: Dict[int, Dict[int, torch.Tensor]], |
|
|
bboxes: Dict[int, Dict[int, List]], |
|
|
categorical_keywords: List[str], |
|
|
unary_keywords: Optional[List[str]] = None, |
|
|
binary_keywords: Optional[List[str]] = None, |
|
|
object_pairs: Optional[List[Tuple[int, int]]] = None, |
|
|
return_top_k: int = 3, |
|
|
return_flattened_segments: Optional[bool] = None, |
|
|
return_valid_pairs: Optional[bool] = None, |
|
|
interested_object_pairs: Optional[List[Tuple[int, int]]] = None, |
|
|
debug_visualizations: Optional[bool] = None, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
High-level prediction method that returns formatted results. |
|
|
|
|
|
Args: |
|
|
video_frames: Tensor of shape (num_frames, height, width, 3) |
|
|
masks: Dict mapping frame_id -> object_id -> mask tensor |
|
|
bboxes: Dict mapping frame_id -> object_id -> [x1, y1, x2, y2] |
|
|
categorical_keywords: List of category names |
|
|
unary_keywords: Optional list of unary predicates |
|
|
binary_keywords: Optional list of binary predicates |
|
|
object_pairs: Optional list of object pairs for binary relations |
|
|
return_top_k: Number of top predictions to return |
|
|
return_flattened_segments: Whether to include flattened mask/bbox tensors |
|
|
return_valid_pairs: Whether to compute valid object pairs per frame |
|
|
interested_object_pairs: Optional subset of object pairs to track |
|
|
|
|
|
Returns: |
|
|
Formatted prediction results |
|
|
""" |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.forward( |
|
|
video_frames=video_frames, |
|
|
masks=masks, |
|
|
bboxes=bboxes, |
|
|
categorical_keywords=categorical_keywords, |
|
|
unary_keywords=unary_keywords, |
|
|
binary_keywords=binary_keywords, |
|
|
object_pairs=object_pairs, |
|
|
return_flattened_segments=return_flattened_segments, |
|
|
return_valid_pairs=return_valid_pairs, |
|
|
interested_object_pairs=interested_object_pairs, |
|
|
debug_visualizations=debug_visualizations, |
|
|
) |
|
|
|
|
|
|
|
|
formatted_categorical = {} |
|
|
for (obj_id, category), prob in outputs["categorical_probs"][0].items(): |
|
|
if obj_id not in formatted_categorical: |
|
|
formatted_categorical[obj_id] = [] |
|
|
formatted_categorical[obj_id].append((prob, category)) |
|
|
|
|
|
|
|
|
for obj_id in formatted_categorical: |
|
|
formatted_categorical[obj_id] = sorted( |
|
|
formatted_categorical[obj_id], reverse=True |
|
|
)[:return_top_k] |
|
|
|
|
|
|
|
|
formatted_unary = {} |
|
|
for (frame_id, obj_id, predicate), prob in outputs["unary_probs"][0].items(): |
|
|
key = (frame_id, obj_id) |
|
|
if key not in formatted_unary: |
|
|
formatted_unary[key] = [] |
|
|
formatted_unary[key].append((prob, predicate)) |
|
|
|
|
|
|
|
|
for key in formatted_unary: |
|
|
formatted_unary[key] = sorted( |
|
|
formatted_unary[key], reverse=True |
|
|
)[:return_top_k] |
|
|
|
|
|
|
|
|
formatted_binary = {} |
|
|
if len(outputs["binary_probs"]) > 0: |
|
|
for (frame_id, obj_pair, predicate), prob in outputs["binary_probs"][0].items(): |
|
|
key = (frame_id, obj_pair) |
|
|
if key not in formatted_binary: |
|
|
formatted_binary[key] = [] |
|
|
formatted_binary[key].append((prob, predicate)) |
|
|
|
|
|
|
|
|
for key in formatted_binary: |
|
|
formatted_binary[key] = sorted( |
|
|
formatted_binary[key], reverse=True |
|
|
)[:return_top_k] |
|
|
|
|
|
result: Dict[str, Any] = { |
|
|
"categorical_predictions": formatted_categorical, |
|
|
"unary_predictions": formatted_unary, |
|
|
"binary_predictions": formatted_binary, |
|
|
"confidence_scores": { |
|
|
"categorical": max([max([p for p, _ in preds], default=0.0) |
|
|
for preds in formatted_categorical.values()], default=0.0), |
|
|
"unary": max([max([p for p, _ in preds], default=0.0) |
|
|
for preds in formatted_unary.values()], default=0.0), |
|
|
"binary": max([max([p for p, _ in preds], default=0.0) |
|
|
for preds in formatted_binary.values()], default=0.0) |
|
|
} |
|
|
} |
|
|
|
|
|
if "flattened_segments" in outputs: |
|
|
result["flattened_segments"] = outputs["flattened_segments"] |
|
|
if "valid_pairs" in outputs: |
|
|
result["valid_pairs"] = outputs["valid_pairs"] |
|
|
if "valid_pairs_metadata" in outputs: |
|
|
result["valid_pairs_metadata"] = outputs["valid_pairs_metadata"] |
|
|
|
|
|
return result |
|
|
|