Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| import torchvision | |
| from typing import Any, Dict, List, Tuple | |
| from .sam.image_encoder import ImageEncoderViT | |
| from .sam.mask_decoder import MaskDecoder | |
| from .sam.prompt_encoder import PromptEncoder | |
| from .detr import box_ops | |
| from .detr.segmentation import dice_loss, sigmoid_focal_loss | |
| from .detr.misc import nested_tensor_from_tensor_list, interpolate | |
| from . import axis_ops, ilnr_loss #, pwnp_loss | |
| from .vnl_loss import VNL_Loss | |
| from .midas_loss import MidasLoss | |
| class SamTransformer(nn.Module): | |
| mask_threshold: float = 0.0 | |
| image_format: str = "RGB" | |
| def __init__( | |
| self, | |
| image_encoder: ImageEncoderViT, | |
| prompt_encoder: PromptEncoder, | |
| mask_decoder: MaskDecoder, | |
| affordance_decoder: MaskDecoder, | |
| depth_decoder: MaskDecoder, | |
| transformer_hidden_dim: int, | |
| backbone_name: str, | |
| pixel_mean: List[float] = [123.675, 116.28, 103.53], | |
| pixel_std: List[float] = [58.395, 57.12, 57.375], | |
| ) -> None: | |
| """ | |
| SAM predicts object masks from an image and input prompts. | |
| Arguments: | |
| image_encoder (ImageEncoderViT): The backbone used to encode the | |
| image into image embeddings that allow for efficient mask prediction. | |
| prompt_encoder (PromptEncoder): Encodes various types of input prompts. | |
| mask_decoder (MaskDecoder): Predicts masks from the image embeddings | |
| and encoded prompts. | |
| pixel_mean (list(float)): Mean values for normalizing pixels in the input image. | |
| pixel_std (list(float)): Std values for normalizing pixels in the input image. | |
| """ | |
| super().__init__() | |
| self.image_encoder = image_encoder | |
| self.prompt_encoder = prompt_encoder | |
| self.mask_decoder = mask_decoder | |
| self.affordance_decoder = affordance_decoder | |
| # depth head | |
| self.depth_decoder = depth_decoder | |
| self.depth_query = nn.Embedding(2, transformer_hidden_dim) | |
| fov = torch.tensor(1.0) | |
| image_size = (768, 1024) | |
| focal_length = (image_size[1] / 2 / torch.tan(fov / 2)).item() | |
| self.vnl_loss = VNL_Loss(focal_length, focal_length, image_size) | |
| self.midas_loss = MidasLoss(alpha=0.1) | |
| self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) | |
| self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) | |
| # if backbone_name == 'vit_h': | |
| # checkpoint_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'checkpoints', 'sam_vit_h_4b8939.pth') | |
| # elif backbone_name == 'vit_l': | |
| # checkpoint_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'checkpoints', 'sam_vit_l_0b3195.pth') | |
| # elif backbone_name == 'vit_b': | |
| # checkpoint_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'checkpoints', 'sam_vit_b_01ec64.pth') | |
| # else: | |
| # raise ValueError | |
| # with open(checkpoint_path, "rb") as f: | |
| # state_dict = torch.load(f) | |
| # self.load_state_dict(state_dict, strict=False) | |
| # self.affordance_decoder.load_state_dict(self.mask_decoder.state_dict(), strict=False) | |
| # self.depth_decoder.load_state_dict(self.mask_decoder.state_dict(), strict=False) | |
| self.num_queries = 15 | |
| self._affordance_focal_alpha = 0.95 | |
| self._ignore_index = -100 | |
| def device(self) -> Any: | |
| return self.pixel_mean.device | |
| def freeze_layers(self, names): | |
| """ | |
| Freeze layers in 'names'. | |
| """ | |
| for name, param in self.named_parameters(): | |
| for freeze_name in names: | |
| if freeze_name in name: | |
| param.requires_grad = False | |
| def forward( | |
| self, | |
| image: torch.Tensor, | |
| valid: torch.Tensor, | |
| keypoints: torch.Tensor, | |
| bbox: torch.Tensor, | |
| masks: torch.Tensor, | |
| movable: torch.Tensor, | |
| rigid: torch.Tensor, | |
| kinematic: torch.Tensor, | |
| action: torch.Tensor, | |
| affordance: torch.Tensor, | |
| affordance_map: torch.FloatTensor, | |
| depth: torch.Tensor, | |
| axis: torch.Tensor, | |
| fov: torch.Tensor, | |
| backward: bool = True, | |
| **kwargs, | |
| ): | |
| device = image.device | |
| multimask_output = False | |
| # image encoder | |
| # pad image to square | |
| h, w = image.shape[-2:] | |
| padh = self.image_encoder.img_size - h | |
| padw = self.image_encoder.img_size - w | |
| x = F.pad(image, (0, padw, 0, padh)) | |
| image_embeddings = self.image_encoder(x) | |
| outputs_seg_masks = [] | |
| outputs_movable = [] | |
| outputs_rigid = [] | |
| outputs_kinematic = [] | |
| outputs_action = [] | |
| outputs_axis = [] | |
| outputs_boxes = [] | |
| outputs_aff_masks = [] | |
| outputs_depth = [] | |
| for idx, curr_embedding in enumerate(image_embeddings): | |
| point_coords = keypoints[idx].unsqueeze(1) | |
| point_labels = torch.ones_like(point_coords[:, :, 0]) | |
| points = (point_coords, point_labels) | |
| sparse_embeddings, dense_embeddings = self.prompt_encoder( | |
| points=points, | |
| boxes=None, | |
| masks=None, | |
| ) | |
| # mask decoder | |
| low_res_masks, iou_predictions, output_movable, output_rigid, output_kinematic, output_action, output_axis = self.mask_decoder( | |
| image_embeddings=curr_embedding.unsqueeze(0), | |
| image_pe=self.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=multimask_output, | |
| ) | |
| output_mask = self.postprocess_masks( | |
| low_res_masks, | |
| input_size=image.shape[-2:], | |
| original_size=(768, 1024), | |
| ) | |
| outputs_seg_masks.append(output_mask[:, 0]) | |
| outputs_movable.append(output_movable[:, 0]) | |
| outputs_rigid.append(output_rigid[:, 0]) | |
| outputs_kinematic.append(output_kinematic[:, 0]) | |
| outputs_action.append(output_action[:, 0]) | |
| outputs_axis.append(output_axis[:, 0]) | |
| # convert masks to boxes for evaluation | |
| pred_mask_bbox = (output_mask[:, 0].clone() > 0.0).long() | |
| empty_mask = pred_mask_bbox.sum(dim=-1).sum(dim=-1) | |
| pred_mask_bbox[empty_mask == 0] += 1 | |
| pred_boxes = torchvision.ops.masks_to_boxes(pred_mask_bbox) | |
| #pred_boxes = box_ops.rescale_bboxes(pred_boxes, [1 / self._image_size[1], 1 / self._image_size[0]]) | |
| pred_boxes = box_ops.rescale_bboxes(pred_boxes, [1 / 768, 1 / 1024]) | |
| outputs_boxes.append(pred_boxes) | |
| # affordance decoder | |
| low_res_masks, iou_predictions = self.affordance_decoder( | |
| image_embeddings=curr_embedding.unsqueeze(0), | |
| image_pe=self.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=multimask_output, | |
| ) | |
| output_aff_masks = self.postprocess_masks( | |
| low_res_masks, | |
| input_size=image.shape[-2:], | |
| original_size=(192, 256), | |
| ) | |
| outputs_aff_masks.append(output_aff_masks[:, 0]) | |
| # depth decoder | |
| bs = keypoints.shape[0] | |
| #depth_sparse_embeddings = self.depth_query.weight.unsqueeze(0).repeat(bs, 1, 1) | |
| depth_sparse_embeddings = self.depth_query.weight.unsqueeze(0) | |
| #depth_dense_embeddings = torch.zeros((bs, 256, 64, 64)).to(dense_embeddings.device) | |
| depth_dense_embeddings = torch.zeros((1, 256, 64, 64)).to(dense_embeddings.device) | |
| low_res_masks, iou_predictions = self.depth_decoder( | |
| image_embeddings=curr_embedding.unsqueeze(0), | |
| image_pe=self.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=depth_sparse_embeddings, | |
| dense_prompt_embeddings=depth_dense_embeddings, | |
| multimask_output=multimask_output, | |
| ) | |
| output_depth = self.postprocess_masks( | |
| low_res_masks, | |
| input_size=image.shape[-2:], | |
| original_size=(768, 1024), | |
| ) | |
| outputs_depth.append(output_depth[:, 0]) | |
| outputs_seg_masks = torch.stack(outputs_seg_masks) | |
| outputs_movable = torch.stack(outputs_movable) | |
| outputs_rigid = torch.stack(outputs_rigid) | |
| outputs_kinematic = torch.stack(outputs_kinematic) | |
| outputs_action = torch.stack(outputs_action) | |
| outputs_axis = torch.stack(outputs_axis) | |
| outputs_boxes = torch.stack(outputs_boxes) | |
| outputs_aff_masks = torch.stack(outputs_aff_masks) | |
| outputs_depth = torch.stack(outputs_depth) | |
| out = { | |
| 'pred_boxes': outputs_boxes, | |
| 'pred_movable': outputs_movable, | |
| 'pred_rigid': outputs_rigid, | |
| 'pred_kinematic': outputs_kinematic, | |
| 'pred_action': outputs_action, | |
| 'pred_masks': outputs_seg_masks, | |
| 'pred_axis': outputs_axis, | |
| 'pred_depth': outputs_depth, | |
| # 'pred_depth': outputs_seg_masks[:, :1].sigmoid(), | |
| 'pred_affordance': outputs_aff_masks, | |
| } | |
| if not backward: | |
| return out | |
| # backward | |
| src_boxes = outputs_boxes | |
| target_boxes = bbox | |
| target_boxes = box_ops.box_xyxy_to_cxcywh(target_boxes) | |
| bbox_valid = bbox[:, :, 0] > -0.5 | |
| num_boxes = bbox_valid.sum() | |
| out['loss_bbox'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| out['loss_giou'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| # affordance | |
| # out['loss_affordance'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| affordance_valid = affordance[:, :, 0] > -0.5 | |
| if affordance_valid.sum() == 0: | |
| out['loss_affordance'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| else: | |
| src_aff_masks = outputs_aff_masks[affordance_valid] | |
| tgt_aff_masks = affordance_map[affordance_valid] | |
| src_aff_masks = src_aff_masks.flatten(1) | |
| tgt_aff_masks = tgt_aff_masks.flatten(1) | |
| loss_aff = sigmoid_focal_loss( | |
| src_aff_masks, | |
| tgt_aff_masks, | |
| affordance_valid.sum(), | |
| alpha=self._affordance_focal_alpha, | |
| ) | |
| out['loss_affordance'] = loss_aff | |
| # axis | |
| axis_valid = axis[:, :, 0] > 0.0 | |
| num_axis = axis_valid.sum() | |
| if num_axis == 0: | |
| out['loss_axis_angle'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| out['loss_axis_offset'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| out['loss_eascore'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| else: | |
| # regress angle | |
| src_axis_angle = outputs_axis[axis_valid] | |
| src_axis_angle_norm = F.normalize(src_axis_angle[:, :2]) | |
| src_axis_angle = torch.cat((src_axis_angle_norm, src_axis_angle[:, 2:]), dim=-1) | |
| target_axis_xyxy = axis[axis_valid] | |
| axis_center = target_boxes[axis_valid].clone() | |
| axis_center[:, 2:] = axis_center[:, :2] | |
| target_axis_angle = axis_ops.line_xyxy_to_angle(target_axis_xyxy, center=axis_center) | |
| loss_axis_angle = F.l1_loss(src_axis_angle[:, :2], target_axis_angle[:, :2], reduction='sum') / num_axis | |
| loss_axis_offset = F.l1_loss(src_axis_angle[:, 2:], target_axis_angle[:, 2:], reduction='sum') / num_axis | |
| out['loss_axis_angle'] = loss_axis_angle | |
| out['loss_axis_offset'] = loss_axis_offset | |
| src_axis_xyxy = axis_ops.line_angle_to_xyxy(src_axis_angle, center=axis_center) | |
| target_axis_xyxy = axis_ops.line_angle_to_xyxy(target_axis_angle, center=axis_center) | |
| axis_eascore, _, _ = axis_ops.ea_score(src_axis_xyxy, target_axis_xyxy) | |
| loss_eascore = 1 - axis_eascore | |
| out['loss_eascore'] = loss_eascore.mean() | |
| loss_movable = F.cross_entropy(outputs_movable.permute(0, 2, 1), movable, ignore_index=self._ignore_index) | |
| if torch.isnan(loss_movable): | |
| loss_movable = torch.tensor(0.0, requires_grad=True).to(device) | |
| out['loss_movable'] = loss_movable | |
| loss_rigid = F.cross_entropy(outputs_rigid.permute(0, 2, 1), rigid, ignore_index=self._ignore_index) | |
| if torch.isnan(loss_rigid): | |
| loss_rigid = torch.tensor(0.0, requires_grad=True).to(device) | |
| out['loss_rigid'] = loss_rigid | |
| loss_kinematic = F.cross_entropy(outputs_kinematic.permute(0, 2, 1), kinematic, ignore_index=self._ignore_index) | |
| if torch.isnan(loss_kinematic): | |
| loss_kinematic = torch.tensor(0.0, requires_grad=True).to(device) | |
| out['loss_kinematic'] = loss_kinematic | |
| loss_action = F.cross_entropy(outputs_action.permute(0, 2, 1), action, ignore_index=self._ignore_index) | |
| if torch.isnan(loss_action): | |
| loss_action = torch.tensor(0.0, requires_grad=True).to(device) | |
| out['loss_action'] = loss_action | |
| # depth backward | |
| out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| # (bs, 1, H, W) | |
| src_depths = interpolate(outputs_depth, size=depth.shape[-2:], mode='bilinear', align_corners=False) | |
| src_depths = src_depths.clamp(min=0.0, max=1.0) | |
| tgt_depths = depth.unsqueeze(1) # (bs, H, W) | |
| valid_depth = depth[:, 0, 0] > 0 | |
| if valid_depth.any(): | |
| src_depths = src_depths[valid_depth] | |
| tgt_depths = tgt_depths[valid_depth] | |
| depth_mask = tgt_depths > 1e-8 | |
| midas_loss, ssi_loss, reg_loss = self.midas_loss(src_depths, tgt_depths, depth_mask) | |
| loss_vnl = self.vnl_loss(tgt_depths, src_depths) | |
| out['loss_depth'] = midas_loss | |
| out['loss_vnl'] = loss_vnl | |
| else: | |
| out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| # mask backward | |
| tgt_masks = masks | |
| src_masks = interpolate(outputs_seg_masks, size=tgt_masks.shape[-2:], mode='bilinear', align_corners=False) | |
| valid_mask = tgt_masks.sum(dim=-1).sum(dim=-1) > 10 | |
| if valid_mask.sum() == 0: | |
| out['loss_mask'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| out['loss_dice'] = torch.tensor(0.0, requires_grad=True).to(device) | |
| else: | |
| num_masks = valid_mask.sum() | |
| src_masks = src_masks[valid_mask] | |
| tgt_masks = tgt_masks[valid_mask] | |
| src_masks = src_masks.flatten(1) | |
| tgt_masks = tgt_masks.flatten(1) | |
| tgt_masks = tgt_masks.view(src_masks.shape) | |
| out['loss_mask'] = sigmoid_focal_loss(src_masks, tgt_masks.float(), num_masks) | |
| out['loss_dice'] = dice_loss(src_masks, tgt_masks, num_masks) | |
| return out | |
| def postprocess_masks( | |
| self, | |
| masks: torch.Tensor, | |
| input_size: Tuple[int, ...], | |
| original_size: Tuple[int, ...], | |
| ) -> torch.Tensor: | |
| """ | |
| Remove padding and upscale masks to the original image size. | |
| Arguments: | |
| masks (torch.Tensor): Batched masks from the mask_decoder, | |
| in BxCxHxW format. | |
| input_size (tuple(int, int)): The size of the image input to the | |
| model, in (H, W) format. Used to remove padding. | |
| original_size (tuple(int, int)): The original size of the image | |
| before resizing for input to the model, in (H, W) format. | |
| Returns: | |
| (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) | |
| is given by original_size. | |
| """ | |
| masks = F.interpolate( | |
| masks, | |
| (self.image_encoder.img_size, self.image_encoder.img_size), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| masks = masks[..., : input_size[0], : input_size[1]] | |
| masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) | |
| return masks | |
| def preprocess(self, x: torch.Tensor) -> torch.Tensor: | |
| """Normalize pixel values and pad to a square input.""" | |
| # Normalize colors | |
| x = (x - self.pixel_mean) / self.pixel_std | |
| # Pad | |
| h, w = x.shape[-2:] | |
| padh = self.image_encoder.img_size - h | |
| padw = self.image_encoder.img_size - w | |
| x = F.pad(x, (0, padw, 0, padh)) | |
| return x | |