import torch import torch.nn as nn from transformers import CLIPModel, AutoTokenizer, AutoModelForSeq2SeqLM from huggingface_hub import hf_hub_download class CLIP2MT5_CrossAttention(nn.Module): def __init__(self, clip_name='openai/clip-vit-base-patch32', t5_name='mukayese/mt5-base-turkish-summarization'): super().__init__() self.clip = CLIPModel.from_pretrained(clip_name) self.tokenizer = AutoTokenizer.from_pretrained(t5_name) self.t5 = AutoModelForSeq2SeqLM.from_pretrained(t5_name) self.vis_proj = nn.Linear( self.clip.config.vision_config.hidden_size, self.t5.config.d_model ) def forward(self, images, input_ids, attention_mask, labels=None): vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state vision_embeds = self.vis_proj(vision_outputs) text_embeds = self.t5.encoder.embed_tokens(input_ids) extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1) extended_attention_mask = torch.cat([ torch.ones(vision_embeds.size(0), vision_embeds.size(1), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask ], dim=1) if labels is not None: labels = labels.clone() labels[labels == self.tokenizer.pad_token_id] = -100 return self.t5( inputs_embeds=extended_input_embeds, attention_mask=extended_attention_mask, labels=labels, return_dict=True ) @torch.no_grad() def generate(self, images, input_ids, attention_mask, **gen_kwargs): vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state vision_embeds = self.vis_proj(vision_outputs) text_embeds = self.t5.encoder.embed_tokens(input_ids) extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1) extended_attention_mask = torch.cat([ torch.ones(vision_embeds.size(0), vision_embeds.size(1), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask ], dim=1) return self.t5.generate( inputs_embeds=extended_input_embeds, attention_mask=extended_attention_mask, **gen_kwargs ) # HF Loader for STATE_DICT def load_model( repo_id: str, filename: str = "model.pt", clip_name="openai/clip-vit-base-patch32", t5_name="mukayese/mt5-base-turkish-summarization", device=None ): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_path = hf_hub_download(repo_id=repo_id, filename=filename) model = CLIP2MT5_CrossAttention(clip_name=clip_name, t5_name=t5_name) state = torch.load(model_path, map_location=device) model.load_state_dict(state) model.to(device) model.eval() return model