import torch import torch.nn as nn import numpy as np from typing import Dict, List, Optional, Tuple, Union from transformers.models.mask2former.modeling_mask2former import ( Mask2FormerMaskedAttentionDecoderOutput, Mask2FormerModelOutput, Mask2FormerForUniversalSegmentationOutput, Mask2FormerMLPPredictionHead, sample_point, pair_wise_sigmoid_cross_entropy_loss, pair_wise_dice_loss, sigmoid_cross_entropy_loss, dice_loss) from torch import Tensor import torch.nn.functional as F from transformers.file_utils import is_scipy_available if is_scipy_available(): from scipy.optimize import linear_sum_assignment def get_classification_logits(x, text_classifier, logit_scale): # x in shape of [B, *, C] # text_classifier in shape of [num_classes, C] # logit_scale is a learnable scalar https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/model.py#L201 # return: [B, *, num_classes] x = F.normalize(x, dim=-1) text_classifier = F.normalize(text_classifier, dim=-1) logit_scale = torch.clamp(logit_scale.exp(), max=100) pred_logits = logit_scale * x @ text_classifier.T # B, *, N + 1 return pred_logits def _post_init(self): self.class_embed = Mask2FormerMLPPredictionHead(self.config.hidden_dim, self.config.hidden_dim, self.config.hidden_dim, 3) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def ov_class_predictor(self, x, text_classifier): x = self.class_embed(x) all_pred_logits = [] for per_x, per_text_classifier in zip(x, text_classifier): per_pred_logits = get_classification_logits(per_x.unsqueeze(0), per_text_classifier, self.logit_scale) all_pred_logits.append(per_pred_logits.squeeze(0)) return all_pred_logits def Mask2FormerLoss_loss_labels( self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array] ) -> Dict[str, Tensor]: batch_size = len(class_queries_logits) num_queries = class_queries_logits[0].shape[0] all_ce_loss = [] for i in range(batch_size): num_labels_plus1 = class_queries_logits[i].shape[-1] empty_weight = torch.ones(num_labels_plus1) empty_weight[-1] = self.eos_coef empty_weight = empty_weight.to(class_queries_logits[i].device).to(class_queries_logits[i].dtype) criterion = nn.CrossEntropyLoss(weight=empty_weight, reduction='none') target_classes_o = class_labels[i][indices[i][1]] target_classes = torch.full( (num_queries, ), fill_value=num_labels_plus1-1, dtype=torch.int64, device=class_queries_logits[i].device) target_classes[indices[i][0]] = target_classes_o.to(class_queries_logits[i].device) target_classes = target_classes.unsqueeze(0) pred_logits = class_queries_logits[i].unsqueeze(0).transpose(1, 2) loss_ce = criterion(pred_logits, target_classes) all_ce_loss.append(loss_ce) losses = {"loss_cross_entropy": torch.cat(all_ce_loss, dim=-1).mean()} return losses def Mask2FormerLoss_loss_masks( self, masks_queries_logits: torch.Tensor, mask_labels: List[torch.Tensor], indices: Tuple[np.array], num_masks: int ) -> Dict[str, torch.Tensor]: src_idx = self._get_predictions_permutation_indices(indices) tgt_idx = self._get_targets_permutation_indices(indices) # shape (batch_size * num_queries, height, width) pred_masks = masks_queries_logits[src_idx] # shape (batch_size, num_queries, height, width) # pad all and stack the targets to the num_labels dimension target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) target_masks = target_masks[tgt_idx] # No need to upsample predictions as we are using normalized coordinates pred_masks = pred_masks[:, None] target_masks = target_masks[:, None] # Sample point coordinates with torch.no_grad(): point_coordinates = self.sample_points_using_uncertainty( pred_masks, lambda logits: self.calculate_uncertainty(logits), self.num_points, self.oversample_ratio, self.importance_sample_ratio, ) point_labels = sample_point(target_masks.to(torch.bfloat16), point_coordinates.to(torch.bfloat16), align_corners=False).squeeze(1) point_logits = sample_point(pred_masks, point_coordinates.to(pred_masks.dtype), align_corners=False).squeeze(1) losses = { "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks), "loss_dice": dice_loss(point_logits, point_labels, num_masks), } del pred_masks del target_masks return losses def Mask2FormerLoss_sample_points_using_uncertainty( self, logits: torch.Tensor, uncertainty_function, num_points: int, oversample_ratio: int, importance_sample_ratio: float, ) -> torch.Tensor: num_boxes = logits.shape[0] num_points_sampled = int(num_points * oversample_ratio) # Get random point coordinates point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device) # Get sampled prediction value for the point coordinates point_logits = sample_point(logits, point_coordinates.to(logits.dtype), align_corners=False) # Calculate the uncertainties based on the sampled prediction values of the points point_uncertainties = uncertainty_function(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) idx += shift[:, None] point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) if num_random_points > 0: point_coordinates = torch.cat( [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)], dim=1, ) return point_coordinates @torch.no_grad() def Mask2FormerHungarianMatcher_forward( self, masks_queries_logits: torch.Tensor, class_queries_logits: torch.Tensor, mask_labels: torch.Tensor, class_labels: torch.Tensor, ) -> List[Tuple[Tensor]]: indices: List[Tuple[np.array]] = [] # iterate through batch size batch_size = masks_queries_logits.shape[0] for i in range(batch_size): pred_probs = class_queries_logits[i].softmax(-1) pred_mask = masks_queries_logits[i] # Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be ommitted. cost_class = -pred_probs[:, class_labels[i]] target_mask = mask_labels[i].to(pred_mask) target_mask = target_mask[:, None] pred_mask = pred_mask[:, None] # Sample ground truth and predicted masks point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device) target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1).to(target_mask.dtype) target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1) pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1).to(pred_mask.dtype) pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1) # compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels) cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask) # Compute the dice loss betwen each mask pairs -> shape (num_queries, num_labels) cost_dice = pair_wise_dice_loss(pred_mask, target_mask) # final cost matrix cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice # eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible`` cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10)) cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10)) cost_matrix = torch.nan_to_num(cost_matrix, 0) # do the assigmented using the hungarian algorithm in scipy assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.to(torch.float32).cpu()) indices.append(assigned_indices) # It could be stacked in one tensor matched_indices = [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] return matched_indices def Mask2FormerMaskedAttentionDecoder_forward_first3layers( self, inputs_embeds: torch.Tensor = None, multi_stage_positional_embeddings: torch.Tensor = None, pixel_embeddings: torch.Tensor = None, encoder_hidden_states: torch.Tensor = None, query_position_embeddings: torch.Tensor = None, feature_size_list: List = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): The query embeddings that are passed into the decoder. multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`): Position embeddings that are added to the keys in each cross(masked)-attention layer. pixel_embeddings (`torch.FloatTensor`): Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder. query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross(masked)-attention of the decoder. feature_size_list (`List[torch.Size]`): This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is not None: hidden_states = inputs_embeds # intermediate hidden states with layernorm applied - required for predicting class logits intermediate = () # decoder layers all_hidden_states = () if output_hidden_states else None attentions = () if output_attentions else None # intermediate mask predictions from transformer decoder layers intermediate_mask_predictions = () intermediate_hidden_states = self.layernorm(inputs_embeds) intermediate += (intermediate_hidden_states,) predicted_mask, attention_mask = self.mask_predictor( intermediate_hidden_states, pixel_embeddings, feature_size_list[0] ) intermediate_mask_predictions += (predicted_mask,) for idx, decoder_layer in enumerate(self.layers[:3]): if output_hidden_states: all_hidden_states += (hidden_states,) dropout_probability = torch.rand([]) if self.training and (dropout_probability < self.layerdrop): continue if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, None, None, output_attentions, ) else: level_index = idx % self.num_feature_levels where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) # Multiply the attention mask instead of indexing to avoid issue in torch.export. attention_mask = attention_mask * where.unsqueeze(-1) layer_outputs = decoder_layer( hidden_states, level_index=level_index, position_embeddings=multi_stage_positional_embeddings, query_position_embeddings=query_position_embeddings, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, output_attentions=output_attentions, ) intermediate_hidden_states = self.layernorm(layer_outputs[0]) predicted_mask, attention_mask = self.mask_predictor( intermediate_hidden_states, pixel_embeddings, feature_size_list[(idx + 1) % self.num_feature_levels], ) intermediate_mask_predictions += (predicted_mask,) # add intermediate hidden states with layer norm applied which will be used for predicting class logits intermediate += (intermediate_hidden_states,) hidden_states = layer_outputs[0] if output_attentions: attentions += (layer_outputs[1],) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) hidden_states = hidden_states.transpose(1, 0) if not return_dict: outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions] return tuple(v for v in outputs if v is not None) return Mask2FormerMaskedAttentionDecoderOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=attentions, intermediate_hidden_states=intermediate, masks_queries_logits=intermediate_mask_predictions, ) def Mask2FormerMaskedAttentionDecoder_forward_last3layers( self, inputs_embeds: torch.Tensor = None, multi_stage_positional_embeddings: torch.Tensor = None, pixel_embeddings: torch.Tensor = None, encoder_hidden_states: torch.Tensor = None, query_position_embeddings: torch.Tensor = None, feature_size_list: List = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): The query embeddings that are passed into the decoder. multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`): Position embeddings that are added to the keys in each cross(masked)-attention layer. pixel_embeddings (`torch.FloatTensor`): Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder. query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross(masked)-attention of the decoder. feature_size_list (`List[torch.Size]`): This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is not None: hidden_states = inputs_embeds # intermediate hidden states with layernorm applied - required for predicting class logits intermediate = () # decoder layers all_hidden_states = () if output_hidden_states else None attentions = () if output_attentions else None # intermediate mask predictions from transformer decoder layers intermediate_mask_predictions = () intermediate_hidden_states = self.layernorm(inputs_embeds) intermediate += (intermediate_hidden_states,) predicted_mask, attention_mask = self.mask_predictor( intermediate_hidden_states, pixel_embeddings, feature_size_list[0] ) intermediate_mask_predictions += (predicted_mask,) for _idx, decoder_layer in enumerate(self.layers[3:]): idx = _idx + 3 if output_hidden_states: all_hidden_states += (hidden_states,) dropout_probability = torch.rand([]) if self.training and (dropout_probability < self.layerdrop): continue if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, None, None, output_attentions, ) else: level_index = idx % self.num_feature_levels where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) # Multiply the attention mask instead of indexing to avoid issue in torch.export. attention_mask = attention_mask * where.unsqueeze(-1) layer_outputs = decoder_layer( hidden_states, level_index=level_index, position_embeddings=multi_stage_positional_embeddings, query_position_embeddings=query_position_embeddings, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, output_attentions=output_attentions, ) intermediate_hidden_states = self.layernorm(layer_outputs[0]) predicted_mask, attention_mask = self.mask_predictor( intermediate_hidden_states, pixel_embeddings, feature_size_list[(idx + 1) % self.num_feature_levels], ) intermediate_mask_predictions += (predicted_mask,) # add intermediate hidden states with layer norm applied which will be used for predicting class logits intermediate += (intermediate_hidden_states,) hidden_states = layer_outputs[0] if output_attentions: attentions += (layer_outputs[1],) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) hidden_states = hidden_states.transpose(1, 0) if not return_dict: outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions] return tuple(v for v in outputs if v is not None) return Mask2FormerMaskedAttentionDecoderOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=attentions, intermediate_hidden_states=intermediate, masks_queries_logits=intermediate_mask_predictions, ) def Mask2FormerTransformerModule_forward_first_part( self, multi_scale_features: List[Tensor], mask_features: Tensor, output_hidden_states: bool = False, output_attentions: bool = False, ) -> Mask2FormerMaskedAttentionDecoderOutput: multi_stage_features = [] multi_stage_positional_embeddings = [] size_list = [] for i in range(self.num_feature_levels): size_list.append(multi_scale_features[i].shape[-2:]) multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2)) multi_stage_features.append( self.input_projections[i](multi_scale_features[i]).flatten(2) + self.level_embed.weight[i][None, :, None] ) # Flatten (batch_size, num_channels, height, width) -> (height*width, batch_size, num_channels) multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) _, batch_size, _ = multi_stage_features[0].shape # [num_queries, batch_size, num_channels] query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1) query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1) decoder_output = self.decoder.Mask2FormerMaskedAttentionDecoder_forward_first3layers( inputs_embeds=query_features, multi_stage_positional_embeddings=multi_stage_positional_embeddings, pixel_embeddings=mask_features, encoder_hidden_states=multi_stage_features, query_position_embeddings=query_embeddings, feature_size_list=size_list, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=True, ) return decoder_output def Mask2FormerTransformerModule_forward_second_part( self, query_features: Tensor, query_embeddings: Tensor, multi_scale_features: List[Tensor], mask_features: Tensor, output_hidden_states: bool = False, output_attentions: bool = False, ) -> Mask2FormerMaskedAttentionDecoderOutput: multi_stage_features = [] multi_stage_positional_embeddings = [] size_list = [] for i in range(self.num_feature_levels): size_list.append(multi_scale_features[i].shape[-2:]) multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2)) multi_stage_features.append( self.input_projections[i](multi_scale_features[i]).flatten(2) + self.level_embed.weight[i][None, :, None] ) # Flatten (batch_size, num_channels, height, width) -> (height*width, batch_size, num_channels) multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) _, batch_size, _ = multi_stage_features[0].shape # [num_queries, batch_size, num_channels] # query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1) # query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1) decoder_output = self.decoder.Mask2FormerMaskedAttentionDecoder_forward_last3layers( inputs_embeds=query_features, multi_stage_positional_embeddings=multi_stage_positional_embeddings, pixel_embeddings=mask_features, encoder_hidden_states=multi_stage_features, query_position_embeddings=query_embeddings, feature_size_list=size_list, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=True, ) return decoder_output def Mask2FormerModel_forward_first_part( self, pixel_values: Tensor, pixel_mask: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Mask2FormerModelOutput: r""" Returns: `Mask2FormerModelOutput` Examples: ```python >>> import torch >>> from PIL import Image >>> import requests >>> from transformers import AutoImageProcessor, Mask2FormerModel >>> # load image >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> # load image preprocessor and Mask2FormerModel trained on COCO instance segmentation dataset >>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance") >>> model = Mask2FormerModel.from_pretrained("facebook/mask2former-swin-small-coco-instance") >>> inputs = image_processor(image, return_tensors="pt") >>> # forward pass >>> with torch.no_grad(): ... outputs = model(**inputs) >>> # model outputs last hidden states of shape (batch_size, num_queries, hidden_size) >>> print(outputs.transformer_decoder_last_hidden_state.shape) torch.Size([1, 100, 256]) ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, _, height, width = pixel_values.shape if pixel_mask is None: pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) pixel_level_module_output = self.pixel_level_module( pixel_values=pixel_values, output_hidden_states=output_hidden_states ) transformer_module_output = self.transformer_module.Mask2FormerTransformerModule_forward_first_part( multi_scale_features=pixel_level_module_output.decoder_hidden_states, mask_features=pixel_level_module_output.decoder_last_hidden_state, output_hidden_states=True, output_attentions=output_attentions, ) query_features = transformer_module_output.last_hidden_state return query_features, pixel_level_module_output def Mask2FormerModel_forward_second_part( self, query_features: Tensor, query_embeddings: Tensor, pixel_level_module_output, pixel_values: Tensor, pixel_mask: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Mask2FormerModelOutput: r""" Returns: `Mask2FormerModelOutput` Examples: ```python >>> import torch >>> from PIL import Image >>> import requests >>> from transformers import AutoImageProcessor, Mask2FormerModel >>> # load image >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> # load image preprocessor and Mask2FormerModel trained on COCO instance segmentation dataset >>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance") >>> model = Mask2FormerModel.from_pretrained("facebook/mask2former-swin-small-coco-instance") >>> inputs = image_processor(image, return_tensors="pt") >>> # forward pass >>> with torch.no_grad(): ... outputs = model(**inputs) >>> # model outputs last hidden states of shape (batch_size, num_queries, hidden_size) >>> print(outputs.transformer_decoder_last_hidden_state.shape) torch.Size([1, 100, 256]) ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, _, height, width = pixel_values.shape if pixel_mask is None: pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) transformer_module_output = self.transformer_module.Mask2FormerTransformerModule_forward_second_part( query_features=query_features, query_embeddings=query_embeddings, multi_scale_features=pixel_level_module_output.decoder_hidden_states, mask_features=pixel_level_module_output.decoder_last_hidden_state, output_hidden_states=True, output_attentions=output_attentions, ) encoder_hidden_states = None pixel_decoder_hidden_states = None transformer_decoder_hidden_states = None transformer_decoder_intermediate_states = None if output_hidden_states: encoder_hidden_states = pixel_level_module_output.encoder_hidden_states pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states transformer_decoder_hidden_states = transformer_module_output.hidden_states transformer_decoder_intermediate_states = transformer_module_output.intermediate_hidden_states output = Mask2FormerModelOutput( encoder_last_hidden_state=pixel_level_module_output.encoder_last_hidden_state, pixel_decoder_last_hidden_state=pixel_level_module_output.decoder_last_hidden_state, transformer_decoder_last_hidden_state=transformer_module_output.last_hidden_state, encoder_hidden_states=encoder_hidden_states, pixel_decoder_hidden_states=pixel_decoder_hidden_states, transformer_decoder_hidden_states=transformer_decoder_hidden_states, transformer_decoder_intermediate_states=transformer_decoder_intermediate_states, attentions=transformer_module_output.attentions, masks_queries_logits=transformer_module_output.masks_queries_logits, ) if not return_dict: output = tuple(v for v in output.values() if v is not None) return output def Mask2FormerForUniversalSegmentation_forward_first_part( self, pixel_values: Tensor, mask_labels: Optional[List[Tensor]] = None, class_labels: Optional[List[Tensor]] = None, pixel_mask: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_auxiliary_logits: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Mask2FormerForUniversalSegmentationOutput: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict query_features, pixel_level_module_output = self.model.Mask2FormerModel_forward_first_part( pixel_values=pixel_values, pixel_mask=pixel_mask, output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, output_attentions=output_attentions, return_dict=True, ) return query_features, pixel_level_module_output def Mask2FormerForUniversalSegmentation_forward_second_part( self, query_features, query_embeddings, pixel_level_module_output, text_classifier, pixel_values: Tensor, mask_labels: Optional[List[Tensor]] = None, class_labels: Optional[List[Tensor]] = None, pixel_mask: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_auxiliary_logits: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Mask2FormerForUniversalSegmentationOutput: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model.Mask2FormerModel_forward_second_part( query_features=query_features, query_embeddings=query_embeddings, pixel_level_module_output=pixel_level_module_output, pixel_values=pixel_values, pixel_mask=pixel_mask, output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, output_attentions=output_attentions, return_dict=True, ) loss, loss_dict, auxiliary_logits = None, None, None class_queries_logits = () for decoder_output in outputs.transformer_decoder_intermediate_states: class_prediction = self.ov_class_predictor(decoder_output.transpose(0, 1), text_classifier) # class_prediction = self.class_predictor(decoder_output.transpose(0, 1)) class_queries_logits += (class_prediction,) masks_queries_logits = outputs.masks_queries_logits auxiliary_logits = self.get_auxiliary_logits(class_queries_logits, masks_queries_logits) if mask_labels is not None and class_labels is not None: loss_dict = self.get_loss_dict( masks_queries_logits=masks_queries_logits[-1], class_queries_logits=class_queries_logits[-1], mask_labels=mask_labels, class_labels=class_labels, auxiliary_predictions=auxiliary_logits, ) loss = self.get_loss(loss_dict) encoder_hidden_states = None pixel_decoder_hidden_states = None transformer_decoder_hidden_states = None if output_hidden_states: encoder_hidden_states = outputs.encoder_hidden_states pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states output_auxiliary_logits = ( self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits ) if not output_auxiliary_logits: auxiliary_logits = None output = Mask2FormerForUniversalSegmentationOutput( loss=loss, class_queries_logits=class_queries_logits[-1], masks_queries_logits=masks_queries_logits[-1], auxiliary_logits=auxiliary_logits, encoder_last_hidden_state=outputs.encoder_last_hidden_state, pixel_decoder_last_hidden_state=outputs.pixel_decoder_last_hidden_state, transformer_decoder_last_hidden_state=outputs.transformer_decoder_last_hidden_state, encoder_hidden_states=encoder_hidden_states, pixel_decoder_hidden_states=pixel_decoder_hidden_states, transformer_decoder_hidden_states=transformer_decoder_hidden_states, attentions=outputs.attentions, ) if not return_dict: output = tuple(v for v in output.values() if v is not None) if loss is not None: output = (loss) + output return output