narinzar commited on
Commit
6d91736
·
verified ·
1 Parent(s): cd83e63

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # 08_app.py
3
+ # Purpose: Gradio app for X-ray analysis with Gemma 3
4
+
5
+ import os
6
+ import time
7
+ import gradio as gr
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from dotenv import load_dotenv
11
+ from PIL import Image
12
+ import traceback
13
+
14
+ # Load environment variables
15
+ load_dotenv()
16
+ HF_USERNAME = os.getenv("HF_USERNAME", "your_username")
17
+ HF_MODEL_NAME = os.getenv("HF_MODEL_NAME", "GemmaXRayAnalyzer_Finetune_Gemma_3_4b")
18
+
19
+ # Model repository ID
20
+ MODEL_ID = f"{HF_USERNAME}/{HF_MODEL_NAME}"
21
+
22
+ # Demo instruction/prompt
23
+ INSTRUCTION = "You are an expert radiologist. Analyze this X-ray image and describe what you see in detail."
24
+
25
+ # Function to load model and tokenizer
26
+ def load_model():
27
+ print(f"Loading model from {MODEL_ID}...")
28
+
29
+ # Get the device upfront to ensure model loads on the right device
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ print(f"Using device: {device}")
32
+
33
+ # First try loading from user's HF repository
34
+ try:
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ MODEL_ID,
37
+ device_map="auto", # Let transformers decide the device mapping
38
+ torch_dtype="auto" # Let transformers decide the dtype
39
+ )
40
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
41
+ print("Model loaded successfully from Hugging Face Hub")
42
+ except Exception as e:
43
+ print(f"Error loading from {MODEL_ID}: {e}")
44
+ print("Falling back to base Gemma model")
45
+
46
+ # Fall back to base Gemma model
47
+ try:
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ "unsloth/gemma-3-4b-it",
50
+ device_map="auto", # Let transformers decide the device mapping
51
+ torch_dtype="auto" # Let transformers decide the dtype
52
+ )
53
+ tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-4b-it")
54
+ print("Base Gemma model loaded successfully as fallback")
55
+ except Exception as e:
56
+ print(f"Error loading fallback model: {e}")
57
+ raise
58
+
59
+ return model, tokenizer
60
+
61
+ # Load model at startup
62
+ print("Initializing model...")
63
+ model, tokenizer = load_model()
64
+
65
+ # Function to analyze X-ray image and text
66
+ def analyze_xray(image, prompt, max_tokens=256, temperature=0.7, top_p=0.9):
67
+ try:
68
+ if not prompt:
69
+ prompt = INSTRUCTION
70
+
71
+ # Handle the image if provided
72
+ image_description = ""
73
+ if image is not None:
74
+ # Save the image temporarily for display
75
+ temp_img_path = "temp_xray.jpg"
76
+ if isinstance(image, Image.Image):
77
+ image.save(temp_img_path)
78
+ else:
79
+ # If it's already a path
80
+ temp_img_path = image
81
+
82
+ image_description = f"\n\nImage uploaded: X-ray image received for analysis."
83
+
84
+ # Combine prompt with image notification
85
+ full_text_prompt = prompt + image_description
86
+
87
+ # Format the prompt using Gemma's format
88
+ full_prompt = f"<start_of_turn>user\n{full_text_prompt}<end_of_turn>\n<start_of_turn>model\n"
89
+
90
+ # Tokenize the prompt
91
+ inputs = tokenizer(full_prompt, return_tensors="pt")
92
+
93
+ # Move inputs to the correct device - the model should already be on the correct device
94
+ try:
95
+ # Try to get the model's device directly
96
+ device = next(model.parameters()).device
97
+ except:
98
+ # If that fails, default to CUDA if available
99
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100
+
101
+ # Move inputs to the device
102
+ inputs = {k: v.to(device) for k, v in inputs.items()}
103
+
104
+ # Start timer
105
+ start_time = time.time()
106
+
107
+ # Generate response
108
+ with torch.no_grad():
109
+ outputs = model.generate(
110
+ **inputs,
111
+ max_new_tokens=max_tokens,
112
+ temperature=temperature,
113
+ top_p=top_p,
114
+ )
115
+
116
+ # Compute generation time
117
+ gen_time = time.time() - start_time
118
+
119
+ # Decode the response
120
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
121
+
122
+ # Extract just the model's response
123
+ if "<start_of_turn>model\n" in response:
124
+ response = response.split("<start_of_turn>model\n")[-1].strip()
125
+
126
+ # Return the image if it was provided, along with the response
127
+ result = ""
128
+ if image is not None:
129
+ # Create the response with the image
130
+ result = f"**X-ray Analysis:**\n\n{response}\n\n_Generated in {gen_time:.2f} seconds_"
131
+ else:
132
+ # Just return the text response
133
+ result = f"{response}\n\n_Generated in {gen_time:.2f} seconds_"
134
+
135
+ return result
136
+ except Exception as e:
137
+ print(f"Error in analyze_xray: {e}")
138
+ traceback.print_exc()
139
+ return f"Error generating response: {str(e)}\n\nPlease try a different prompt or check the console for detailed error information."
140
+
141
+ # Create the Gradio interface with image upload
142
+ demo = gr.Interface(
143
+ fn=analyze_xray,
144
+ inputs=[
145
+ gr.Image(type="pil", label="Upload X-ray Image (Optional)"),
146
+ gr.Textbox(
147
+ label="Prompt",
148
+ placeholder="Analyze this chest X-ray showing...",
149
+ value=INSTRUCTION,
150
+ lines=4
151
+ ),
152
+ gr.Slider(
153
+ minimum=50, maximum=512, value=256, step=1,
154
+ label="Maximum Tokens"
155
+ ),
156
+ gr.Slider(
157
+ minimum=0.1, maximum=1.5, value=0.7, step=0.1,
158
+ label="Temperature"
159
+ ),
160
+ gr.Slider(
161
+ minimum=0.1, maximum=1.0, value=0.9, step=0.1,
162
+ label="Top-p"
163
+ )
164
+ ],
165
+ outputs=gr.Markdown(),
166
+ title="🩻 X-ray Analysis with Gemma 3",
167
+ description="This demo showcases the Gemma 3 model for medical X-ray analysis. Upload an X-ray image and enter your prompt describing what you'd like to analyze.",
168
+ examples=[
169
+ [None, "Analyze this chest X-ray showing opacity in the lower right lung"],
170
+ [None, "Describe the findings in this X-ray of a patient with suspected pneumonia"],
171
+ [None, "What can you tell me about this X-ray showing a possible fracture in the wrist?"],
172
+ [None, "Generate a detailed report for this abdominal X-ray showing bowel obstruction"],
173
+ ],
174
+ article="""
175
+ ## How to Use
176
+
177
+ 1. (Optional) Upload an X-ray image using the image upload area
178
+ 2. Enter a prompt describing what you want the model to analyze
179
+ 3. Adjust generation parameters if desired
180
+ 4. Click "Submit" to generate the analysis
181
+
182
+ ## Example Prompts
183
+
184
+ - "Analyze this chest X-ray and describe any abnormalities"
185
+ - "What pathologies are visible in this X-ray image?"
186
+ - "Is there evidence of pneumonia in this chest X-ray?"
187
+ - "Generate a radiological report for this X-ray"
188
+
189
+ ## Disclaimer
190
+
191
+ This is a demonstration tool and should not be used for actual medical diagnosis.
192
+ Always consult a qualified healthcare professional for medical advice.
193
+
194
+ Note: The model has been fine-tuned on radiological text data but may not directly
195
+ analyze the uploaded image. The image upload feature is provided for reference and context.
196
+ """
197
+ )
198
+
199
+ # Launch the app
200
+ if __name__ == "__main__":
201
+ demo.launch(share=True) # Set share=True to create a public link