Mueris commited on
Commit
3662ec7
·
verified ·
1 Parent(s): 206db56

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +19 -70
  2. config.json +8 -0
  3. inference.py +46 -0
  4. model.py +91 -0
app.py CHANGED
@@ -1,70 +1,19 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
-
68
-
69
- if __name__ == "__main__":
70
- demo.launch()
 
1
+ import gradio as gr
2
+ from inference import load_for_inference, predict
3
+
4
+ REPO_ID = "MUERIS/TurkishVLMTAMGA"
5
+
6
+ model, tokenizer, device = load_for_inference(REPO_ID)
7
+
8
+ def answer(image, question):
9
+ return predict(model, tokenizer, device, image, question)
10
+
11
+ gr.Interface(
12
+ fn=answer,
13
+ inputs=[
14
+ gr.Image(type="pil"),
15
+ gr.Textbox(label="Question")
16
+ ],
17
+ outputs="text",
18
+ title="CLIP2MT5 Visual Question Answering"
19
+ ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "clip2mt5-crossattention",
3
+ "library": "pytorch",
4
+ "architectures": ["CLIP2MT5_CrossAttention"],
5
+ "pipeline_tag": "image-text-to-text",
6
+ "description": "CLIP + mT5 VQA Model using cross-attention.",
7
+ "author": "MUERIS"
8
+ }
inference.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+
5
+ from model import load_model
6
+
7
+
8
+ # Preprocessing
9
+ _transform = transforms.Compose([
10
+ transforms.Resize((224, 224)),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize(
13
+ mean=[0.4815, 0.4578, 0.4082],
14
+ std=[0.2686, 0.2613, 0.2758]
15
+ )
16
+ ])
17
+
18
+
19
+ def load_for_inference(repo_id, filename="model.pt"):
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model = load_model(repo_id=repo_id, filename=filename, device=device)
22
+ tokenizer = model.tokenizer
23
+ return model, tokenizer, device
24
+
25
+
26
+ def predict(model, tokenizer, device, image: Image.Image, question: str):
27
+ image_tensor = _transform(image).unsqueeze(0).to(device)
28
+
29
+ q = tokenizer(
30
+ question,
31
+ return_tensors='pt',
32
+ padding=True,
33
+ truncation=True,
34
+ max_length=64
35
+ ).to(device)
36
+
37
+ with torch.no_grad():
38
+ output_ids = model.generate(
39
+ images=image_tensor,
40
+ input_ids=q.input_ids,
41
+ attention_mask=q.attention_mask,
42
+ max_length=64,
43
+ num_beams=4
44
+ )
45
+
46
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
model.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPModel, AutoTokenizer, AutoModelForSeq2SeqLM
4
+ from huggingface_hub import hf_hub_download
5
+
6
+
7
+ class CLIP2MT5_CrossAttention(nn.Module):
8
+ def __init__(self, clip_name='openai/clip-vit-base-patch32',
9
+ t5_name='mukayese/mt5-base-turkish-summarization'):
10
+ super().__init__()
11
+
12
+ self.clip = CLIPModel.from_pretrained(clip_name)
13
+ self.tokenizer = AutoTokenizer.from_pretrained(t5_name)
14
+ self.t5 = AutoModelForSeq2SeqLM.from_pretrained(t5_name)
15
+
16
+ self.vis_proj = nn.Linear(
17
+ self.clip.config.vision_config.hidden_size,
18
+ self.t5.config.d_model
19
+ )
20
+
21
+ def forward(self, images, input_ids, attention_mask, labels=None):
22
+ vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state
23
+ vision_embeds = self.vis_proj(vision_outputs)
24
+
25
+ text_embeds = self.t5.encoder.embed_tokens(input_ids)
26
+
27
+ extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1)
28
+
29
+ extended_attention_mask = torch.cat([
30
+ torch.ones(vision_embeds.size(0), vision_embeds.size(1),
31
+ dtype=attention_mask.dtype, device=attention_mask.device),
32
+ attention_mask
33
+ ], dim=1)
34
+
35
+ if labels is not None:
36
+ labels = labels.clone()
37
+ labels[labels == self.tokenizer.pad_token_id] = -100
38
+
39
+ return self.t5(
40
+ inputs_embeds=extended_input_embeds,
41
+ attention_mask=extended_attention_mask,
42
+ labels=labels,
43
+ return_dict=True
44
+ )
45
+
46
+ @torch.no_grad()
47
+ def generate(self, images, input_ids, attention_mask, **gen_kwargs):
48
+ vision_outputs = self.clip.vision_model(pixel_values=images).last_hidden_state
49
+ vision_embeds = self.vis_proj(vision_outputs)
50
+
51
+ text_embeds = self.t5.encoder.embed_tokens(input_ids)
52
+
53
+ extended_input_embeds = torch.cat([vision_embeds, text_embeds], dim=1)
54
+
55
+ extended_attention_mask = torch.cat([
56
+ torch.ones(vision_embeds.size(0), vision_embeds.size(1),
57
+ dtype=attention_mask.dtype, device=attention_mask.device),
58
+ attention_mask
59
+ ], dim=1)
60
+
61
+ return self.t5.generate(
62
+ inputs_embeds=extended_input_embeds,
63
+ attention_mask=extended_attention_mask,
64
+ **gen_kwargs
65
+ )
66
+
67
+
68
+
69
+ # HF Loader for STATE_DICT
70
+
71
+
72
+ def load_model(
73
+ repo_id: str,
74
+ filename: str = "model.pt",
75
+ clip_name="openai/clip-vit-base-patch32",
76
+ t5_name="mukayese/mt5-base-turkish-summarization",
77
+ device=None
78
+ ):
79
+ if device is None:
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+
82
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
83
+
84
+ model = CLIP2MT5_CrossAttention(clip_name=clip_name, t5_name=t5_name)
85
+
86
+ state = torch.load(model_path, map_location=device)
87
+ model.load_state_dict(state)
88
+
89
+ model.to(device)
90
+ model.eval()
91
+ return model