| 
							 | 
						from dataclasses import dataclass | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn as nn | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F | 
					
					
						
						| 
							 | 
						from timm import create_model | 
					
					
						
						| 
							 | 
						from transformers import ( | 
					
					
						
						| 
							 | 
						    AutoConfig, | 
					
					
						
						| 
							 | 
						    AutoModel, | 
					
					
						
						| 
							 | 
						    AutoTokenizer, | 
					
					
						
						| 
							 | 
						    PretrainedConfig, | 
					
					
						
						| 
							 | 
						    PreTrainedModel, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						from transformers.utils import ModelOutput | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from .location_encoder import LocationEncoder | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class CLOSPConfig(PretrainedConfig): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Configuration class for CLOSPModel. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    This class stores the configuration of a CLOSPModel, which is used to instantiate the model | 
					
					
						
						| 
							 | 
						    according to the specified parameters. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    model_type = "closp" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        vision_model_key: str = "vit-s", | 
					
					
						
						| 
							 | 
						        s1_embedding_dim: int = 384, | 
					
					
						
						| 
							 | 
						        s2_embedding_dim: int = 384, | 
					
					
						
						| 
							 | 
						        s1_head_dim: int = 0, | 
					
					
						
						| 
							 | 
						        s2_head_dim: int = 0, | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        text_model_name_or_path: str = "distilbert-base-uncased", | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        use_location_encoder: bool = True, | 
					
					
						
						| 
							 | 
						        location_embedding_dim: int = 512, | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        projection_dim: int = 768, | 
					
					
						
						| 
							 | 
						        **kwargs, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        super().__init__(**kwargs) | 
					
					
						
						| 
							 | 
						        self.vision_model_key = vision_model_key | 
					
					
						
						| 
							 | 
						        self.s1_embedding_dim = s1_embedding_dim | 
					
					
						
						| 
							 | 
						        self.s2_embedding_dim = s2_embedding_dim | 
					
					
						
						| 
							 | 
						        self.text_model_name_or_path = text_model_name_or_path | 
					
					
						
						| 
							 | 
						        self.use_location_encoder = use_location_encoder | 
					
					
						
						| 
							 | 
						        self.location_embedding_dim = location_embedding_dim | 
					
					
						
						| 
							 | 
						        self.projection_dim = projection_dim | 
					
					
						
						| 
							 | 
						        self.s1_head_dim = s1_head_dim | 
					
					
						
						| 
							 | 
						        self.s2_head_dim = s2_head_dim | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@dataclass | 
					
					
						
						| 
							 | 
						class CLOSPOutput(ModelOutput): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Base class for CLOSP model's outputs. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    loss: torch.FloatTensor = None | 
					
					
						
						| 
							 | 
						    logits_per_image: torch.FloatTensor = None | 
					
					
						
						| 
							 | 
						    logits_per_text: torch.FloatTensor = None | 
					
					
						
						| 
							 | 
						    logits_per_loc_img: torch.FloatTensor = None | 
					
					
						
						| 
							 | 
						    logits_per_img_loc: torch.FloatTensor = None | 
					
					
						
						| 
							 | 
						    image_embeds: torch.FloatTensor = None | 
					
					
						
						| 
							 | 
						    text_embeds: torch.FloatTensor = None | 
					
					
						
						| 
							 | 
						    location_embeds: torch.FloatTensor = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class CLOSPModel(PreTrainedModel): | 
					
					
						
						| 
							 | 
						    config_class = CLOSPConfig | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, config: CLOSPConfig): | 
					
					
						
						| 
							 | 
						        super().__init__(config) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.s1_encoder = create_model( | 
					
					
						
						| 
							 | 
						            config.vision_model_key, | 
					
					
						
						| 
							 | 
						            in_chans=2, | 
					
					
						
						| 
							 | 
						            num_classes=config.s1_head_dim, | 
					
					
						
						| 
							 | 
						            pretrained=False, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.s2_encoder = create_model( | 
					
					
						
						| 
							 | 
						            config.vision_model_key, | 
					
					
						
						| 
							 | 
						            in_chans=13, | 
					
					
						
						| 
							 | 
						            num_classes=config.s2_head_dim, | 
					
					
						
						| 
							 | 
						            pretrained=False, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.s1_projection = nn.Linear(config.s1_embedding_dim, config.projection_dim) | 
					
					
						
						| 
							 | 
						        self.s2_projection = nn.Linear(config.s2_embedding_dim, config.projection_dim) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.text_model = AutoModel.from_config( | 
					
					
						
						| 
							 | 
						            AutoConfig.from_pretrained(config.text_model_name_or_path) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if config.use_location_encoder: | 
					
					
						
						| 
							 | 
						            self.location_encoder = LocationEncoder(512, 2, 256, 10) | 
					
					
						
						| 
							 | 
						            self.location_projection = nn.Linear( | 
					
					
						
						| 
							 | 
						                config.location_embedding_dim, config.projection_dim | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def tokenize_text(self, text: str): | 
					
					
						
						| 
							 | 
						        """Tokenizes input text using the model's tokenizer.""" | 
					
					
						
						| 
							 | 
						        return self.tokenizer( | 
					
					
						
						| 
							 | 
						            text, | 
					
					
						
						| 
							 | 
						            padding="max_length", | 
					
					
						
						| 
							 | 
						            truncation=True, | 
					
					
						
						| 
							 | 
						            max_length=self.tokenizer.model_max_length, | 
					
					
						
						| 
							 | 
						            return_tensors="pt", | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_image_features(self, image: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        """Encodes an image tensor into features.""" | 
					
					
						
						| 
							 | 
						        image = image.float() | 
					
					
						
						| 
							 | 
						        if image.shape[1] == 2:   | 
					
					
						
						| 
							 | 
						            image_features = self.s1_projection(self.s1_encoder(image)) | 
					
					
						
						| 
							 | 
						        else:   | 
					
					
						
						| 
							 | 
						            image_features = self.s2_projection(self.s2_encoder(image)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return F.normalize(image_features, p=2, dim=-1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_text_features( | 
					
					
						
						| 
							 | 
						        self, input_ids: torch.Tensor, attention_mask: torch.Tensor | 
					
					
						
						| 
							 | 
						    ) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        """Encodes text tokens into features.""" | 
					
					
						
						| 
							 | 
						        text_outputs = self.text_model( | 
					
					
						
						| 
							 | 
						            input_ids=input_ids, | 
					
					
						
						| 
							 | 
						            attention_mask=attention_mask, | 
					
					
						
						| 
							 | 
						            output_hidden_states=True, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        text_features = text_outputs.last_hidden_state[:, 0, :] | 
					
					
						
						| 
							 | 
						        return F.normalize(text_features, p=2, dim=-1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_location_features(self, coords: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        """Encodes coordinates into features.""" | 
					
					
						
						| 
							 | 
						        if not self.config.use_location_encoder: | 
					
					
						
						| 
							 | 
						            raise ValueError( | 
					
					
						
						| 
							 | 
						                "Location encoder is not enabled for this model. Set `use_location_encoder=True` in config." | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        location_features = self.location_encoder(coords) | 
					
					
						
						| 
							 | 
						        location_features = self.location_projection(location_features) | 
					
					
						
						| 
							 | 
						        return F.normalize(location_features, p=2, dim=-1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        image: torch.Tensor, | 
					
					
						
						| 
							 | 
						        input_ids: torch.Tensor, | 
					
					
						
						| 
							 | 
						        attention_mask: torch.Tensor, | 
					
					
						
						| 
							 | 
						        coords: torch.Tensor = None, | 
					
					
						
						| 
							 | 
						        return_loss: bool = False, | 
					
					
						
						| 
							 | 
						    ) -> CLOSPOutput: | 
					
					
						
						| 
							 | 
						        image_embeds = self.get_image_features(image) | 
					
					
						
						| 
							 | 
						        text_embeds = self.get_text_features(input_ids, attention_mask) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logits_per_image = image_embeds @ text_embeds.T | 
					
					
						
						| 
							 | 
						        logits_per_text = logits_per_image.T | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        location_embeds = None | 
					
					
						
						| 
							 | 
						        logits_per_loc_img = None | 
					
					
						
						| 
							 | 
						        logits_per_img_loc = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.config.use_location_encoder: | 
					
					
						
						| 
							 | 
						            if coords is None: | 
					
					
						
						| 
							 | 
						                raise ValueError( | 
					
					
						
						| 
							 | 
						                    "Coordinates must be provided when use_location_encoder is True." | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            location_embeds = self.get_location_features(coords) | 
					
					
						
						| 
							 | 
						            logits_per_loc_img = location_embeds @ image_embeds.T | 
					
					
						
						| 
							 | 
						            logits_per_img_loc = image_embeds @ location_embeds.T | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        loss = None | 
					
					
						
						| 
							 | 
						        if return_loss: | 
					
					
						
						| 
							 | 
						            outputs = [ | 
					
					
						
						| 
							 | 
						                logits_per_image, | 
					
					
						
						| 
							 | 
						                logits_per_text, | 
					
					
						
						| 
							 | 
						                logits_per_loc_img, | 
					
					
						
						| 
							 | 
						                logits_per_img_loc, | 
					
					
						
						| 
							 | 
						            ] | 
					
					
						
						| 
							 | 
						            ground_truth = torch.arange(len(input_ids)).to(self.device) | 
					
					
						
						| 
							 | 
						            loss = [F.cross_entropy(o, ground_truth) for o in outputs if o is not None] | 
					
					
						
						| 
							 | 
						            loss = sum(loss) / len(loss) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return CLOSPOutput( | 
					
					
						
						| 
							 | 
						            loss=loss, | 
					
					
						
						| 
							 | 
						            logits_per_image=logits_per_image, | 
					
					
						
						| 
							 | 
						            logits_per_text=logits_per_text, | 
					
					
						
						| 
							 | 
						            logits_per_loc_img=logits_per_loc_img, | 
					
					
						
						| 
							 | 
						            logits_per_img_loc=logits_per_img_loc, | 
					
					
						
						| 
							 | 
						            image_embeds=image_embeds, | 
					
					
						
						| 
							 | 
						            text_embeds=text_embeds, | 
					
					
						
						| 
							 | 
						            location_embeds=location_embeds, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 |