Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from transformers import CLIPModel, AutoTokenizer, AutoModelForSeq2SeqLM | |
| from huggingface_hub import hf_hub_download | |
| class CLIP2MT5_CrossAttention(nn.Module): | |
| def __init__(self, clip_name='openai/clip-vit-base-patch32', | |
| t5_name='mukayese/mt5-base-turkish-summarization'): | |
| super().__init__() | |
| self.clip = CLIPModel.from_pretrained(clip_name) | |
| self.tokenizer = AutoTokenizer.from_pretrained(t5_name) | |
| self.t5 = AutoModelForSeq2SeqLM.from_pretrained(t5_name) | |
| self.vis_proj = nn.Linear( | |
| self.clip.config.vision_config.hidden_size, | |
| self.t5.config.d_model | |
| ) | |
| def forward(self, images, input_ids, attention_mask, labels=None): | |
| vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state | |
| vision_embeds = self.vis_proj(vision_outputs) | |
| text_embeds = self.t5.encoder.embed_tokens(input_ids) | |
| extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1) | |
| extended_attention_mask = torch.cat([ | |
| torch.ones(vision_embeds.size(0), vision_embeds.size(1), | |
| dtype=attention_mask.dtype, device=attention_mask.device), | |
| attention_mask | |
| ], dim=1) | |
| if labels is not None: | |
| labels = labels.clone() | |
| labels[labels == self.tokenizer.pad_token_id] = -100 | |
| return self.t5( | |
| inputs_embeds=extended_input_embeds, | |
| attention_mask=extended_attention_mask, | |
| labels=labels, | |
| return_dict=True | |
| ) | |
| def generate(self, images, input_ids, attention_mask, **gen_kwargs): | |
| vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state | |
| vision_embeds = self.vis_proj(vision_outputs) | |
| text_embeds = self.t5.encoder.embed_tokens(input_ids) | |
| extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1) | |
| extended_attention_mask = torch.cat([ | |
| torch.ones(vision_embeds.size(0), vision_embeds.size(1), | |
| dtype=attention_mask.dtype, device=attention_mask.device), | |
| attention_mask | |
| ], dim=1) | |
| return self.t5.generate( | |
| inputs_embeds=extended_input_embeds, | |
| attention_mask=extended_attention_mask, | |
| **gen_kwargs | |
| ) | |
| # HF Loader for STATE_DICT | |
| def load_model( | |
| repo_id: str, | |
| filename: str = "model.pt", | |
| clip_name="openai/clip-vit-base-patch32", | |
| t5_name="mukayese/mt5-base-turkish-summarization", | |
| device=None | |
| ): | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| model = CLIP2MT5_CrossAttention(clip_name=clip_name, t5_name=t5_name) | |
| state = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state) | |
| model.to(device) | |
| model.eval() | |
| return model | |