Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PIL import Image | |
| import torch | |
| import cohere | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| import io | |
| import os | |
| import base64 | |
| from fpdf import FPDF | |
| from sqlalchemy import create_engine | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, VectorParams | |
| from sentence_transformers import SentenceTransformer | |
| #from langchain_community.vectorstores.pgvector import PGVector | |
| #from langchain_postgres import PGVector | |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
| from langchain_community.vectorstores import Qdrant | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.embeddings import SentenceTransformerEmbeddings | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM | |
| import nest_asyncio | |
| torch.cuda.empty_cache() | |
| nest_asyncio.apply() | |
| co = cohere.Client(st.secrets["COHERE_API_KEY"]) | |
| st.set_page_config(page_title="DermBOT", page_icon="π§¬", layout="centered") | |
| # === Model Selection === | |
| available_models = ["GPT-4o", "LLaMA 4 Maverick", "Gemini 2.5 Pro","All"] | |
| st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models) | |
| # === Qdrant DB Setup === | |
| qdrant_client = QdrantClient( | |
| url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io", | |
| api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q" | |
| ) | |
| collection_name = "ks_collection_1.5BE" | |
| #embedding_model = SentenceTransformer("D:\DR\RAG\gte-Qwen2-1.5B-instruct", trust_remote_code=True) | |
| #embedding_model.max_seq_length = 8192 | |
| #local_embedding = SentenceTransformerEmbeddings(model=embedding_model) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def get_safe_embedding_model(): | |
| model_name = "Alibaba-NLP/gte-Qwen2-1.5B-instruct" | |
| try: | |
| print("Trying to load embedding model on CUDA...") | |
| embedding = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs={ | |
| "trust_remote_code": True, | |
| "device": "cuda" | |
| } | |
| ) | |
| print("Loaded embedding model on GPU.") | |
| return embedding | |
| except RuntimeError as e: | |
| if "CUDA out of memory" in str(e): | |
| print("CUDA OOM. Falling back to CPU.") | |
| else: | |
| print(" Error loading model on CUDA:", str(e)) | |
| print("Loading embedding model on CPU...") | |
| return HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs={ | |
| "trust_remote_code": True, | |
| "device": "cpu" | |
| } | |
| ) | |
| # Replace your old local_embedding line with this | |
| local_embedding = get_safe_embedding_model() | |
| print(" Qwen2-1.5B local embedding model loaded.") | |
| vector_store = Qdrant( | |
| client=qdrant_client, | |
| collection_name=collection_name, | |
| embeddings=local_embedding | |
| ) | |
| retriever = vector_store.as_retriever() | |
| pair_ranker = pipeline( | |
| "text-classification", | |
| model="llm-blender/PairRM", | |
| tokenizer="llm-blender/PairRM", | |
| return_all_scores=True | |
| ) | |
| gen_fuser = pipeline( | |
| "text-generation", | |
| model="llm-blender/gen_fuser_3b", | |
| tokenizer="llm-blender/gen_fuser_3b", | |
| max_length=2048, | |
| do_sample=False | |
| ) | |
| #selected_model = st.session_state["selected_model"] | |
| if "OpenAI" in selected_model: | |
| from langchain_openai import ChatOpenAI | |
| llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]) | |
| elif "LLaMA" in selected_model: | |
| from groq import Groq | |
| client = Groq(api_key=st.secrets["GROQ_API_KEY"]) # Store in `.streamlit/secrets.toml` | |
| def get_llama_response(prompt): | |
| completion = client.chat.completions.create( | |
| model="meta-llama/llama-4-maverick-17b-128e-instruct", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=1, | |
| max_completion_tokens=1024, | |
| top_p=1, | |
| stream=False | |
| ) | |
| return completion.choices[0].message.content | |
| llm = get_llama_response # use this in place of llm.invoke() | |
| elif "Gemini" in selected_model: | |
| import google.generativeai as genai | |
| genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) # Store in `.streamlit/secrets.toml` | |
| gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25") | |
| def get_gemini_response(prompt): | |
| response = gemini_model.generate_content(prompt) | |
| return response.text | |
| llm = get_gemini_response | |
| elif "All" in selected_model: | |
| from groq import Groq | |
| import google.generativeai as genai | |
| genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) | |
| def get_all_model_responses(prompt): | |
| openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]).invoke( | |
| [{"role": "system", "content": prompt}]).content | |
| gemini = genai.GenerativeModel("gemini-2.5-pro-exp-03-25") | |
| gemini_resp = gemini.generate_content(prompt).text | |
| llama = Groq(api_key=st.secrets["GROQ_API_KEY"]) | |
| llama_resp = llama.chat.completions.create( | |
| model="meta-llama/llama-4-maverick-17b-128e-instruct", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=1, max_completion_tokens=1024, top_p=1, stream=False | |
| ).choices[0].message.content | |
| return [openai_resp, gemini_resp, llama_resp] | |
| def rank_and_fuse(prompt, responses): | |
| ranked = [(resp, pair_ranker(f"{prompt}\n\n{resp}")[0][1]['score']) for resp in responses] | |
| ranked.sort(key=lambda x: x[1], reverse=True) | |
| fusion_input = "\n\n".join([f"[Answer {i+1}]: {ans}" for i, (ans, _) in enumerate(ranked)]) | |
| return gen_fuser(f"Fuse these responses:\n{fusion_input}", return_full_text=False)[0]['generated_text'] | |
| else: | |
| st.error("Unsupported model selected.") | |
| st.stop() | |
| #retriever = vector_store.as_retriever() | |
| AI_PROMPT_TEMPLATE = """ | |
| You are DermBOT, a compassionate and knowledgeable AI Dermatology Assistant designed to educate users about skin-related health concerns with clarity, empathy, and precision. | |
| Your goal is to respond like a well-informed human expertβbalancing professionalism with warmth and reassurance. | |
| When crafting responses: | |
| - Begin with a clear, engaging summary of the condition or concern. | |
| - Use short paragraphs for readability. | |
| - Include bullet points or numbered lists where appropriate. | |
| - Avoid overly technical terms unless explained simply. | |
| - End with a helpful next step, such as lifestyle advice or when to see a doctor. | |
| π©Ί Response Structure: | |
| 1. **Overview** β Briefly introduce the condition or concern. | |
| 2. **Common Symptoms** β Describe noticeable signs in simple terms. | |
| 3. **Causes & Risk Factors** β Include genetic, lifestyle, and environmental aspects. | |
| 4. **Treatment Options** β Outline common OTC and prescription treatments. | |
| 5. **When to Seek Help** β Warn about symptoms that require urgent care. | |
| Always encourage consulting a licensed dermatologist for personal diagnosis and treatment. For any breathing difficulties, serious infections, or rapid symptom worsening, advise calling emergency services immediately. | |
| --- | |
| Query: {question} | |
| Relevant Context: {context} | |
| Your Response: | |
| """ | |
| prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"]) | |
| #rag_chain = RetrievalQA.from_chain_type( | |
| # llm=llm, | |
| # retriever=retriever, | |
| # chain_type="stuff", | |
| # chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"} | |
| #) | |
| # === Class Names === | |
| multilabel_class_names = [ | |
| "Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch", | |
| "Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae", | |
| "Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis", | |
| "Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped", | |
| "Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow", | |
| "Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma", | |
| "Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst" | |
| ] | |
| multiclass_class_names = [ | |
| "systemic", "hair", "drug_reactions", "uriticaria", "acne", "light", | |
| "autoimmune", "papulosquamous", "eczema", "skincancer", | |
| "benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections" | |
| ] | |
| # === Load Models === | |
| class SkinViT(nn.Module): | |
| def __init__(self, num_classes): | |
| super(SkinViT, self).__init__() | |
| self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT) | |
| in_features = self.model.heads.head.in_features | |
| self.model.heads.head = nn.Linear(in_features, num_classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class DermNetViT(nn.Module): | |
| def __init__(self, num_classes): | |
| super(DermNetViT, self).__init__() | |
| self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT) | |
| in_features = self.model.heads[0].in_features | |
| self.model.heads[0] = nn.Sequential( | |
| nn.Dropout(0.3), | |
| nn.Linear(in_features, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| #multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu') | |
| #multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu') | |
| # === Load Model State Dicts === | |
| multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth") | |
| multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth") | |
| def load_model_with_fallback(model_class, weight_path, num_classes, model_name): | |
| try: | |
| print(f"π Loading {model_name} on GPU...") | |
| model = model_class(num_classes) | |
| model.load_state_dict(torch.load(weight_path, map_location="cuda")) | |
| model.to("cuda") | |
| print(f"β {model_name} loaded on GPU.") | |
| return model | |
| except RuntimeError as e: | |
| if "CUDA out of memory" in str(e): | |
| print(f"β οΈ {model_name} OOM. Falling back to CPU.") | |
| else: | |
| print(f"β Error loading {model_name} on CUDA: {e}") | |
| print(f"π Loading {model_name} on CPU...") | |
| model = model_class(num_classes) | |
| model.load_state_dict(torch.load(weight_path, map_location="cpu")) | |
| model.to("cpu") | |
| return model | |
| # Load both models with fallback | |
| multilabel_model = load_model_with_fallback(SkinViT, multilabel_model_path, len(multilabel_class_names), "SkinViT") | |
| multiclass_model = load_model_with_fallback(DermNetViT, multiclass_model_path, len(multiclass_class_names), "DermNetViT") | |
| multilabel_model.eval() | |
| multiclass_model.eval() | |
| # === Session Init === | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # === Image Processing Function === | |
| def run_inference(image): | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]) | |
| ]) | |
| input_tensor = transform(image).unsqueeze(0) | |
| # Automatically match model device (GPU or CPU) | |
| model_device = next(multilabel_model.parameters()).device | |
| input_tensor = input_tensor.to(model_device) | |
| with torch.no_grad(): | |
| probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().cpu().numpy() | |
| pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item() | |
| predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5] | |
| predicted_single = multiclass_class_names[pred_idx] | |
| return predicted_multi, predicted_single | |
| # === PDF Export === | |
| def export_chat_to_pdf(messages): | |
| pdf = FPDF() | |
| pdf.add_page() | |
| pdf.set_font("Arial", size=12) | |
| for msg in messages: | |
| role = "You" if msg["role"] == "user" else "AI" | |
| pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n") | |
| buf = io.BytesIO() | |
| pdf.output(buf) | |
| buf.seek(0) | |
| return buf | |
| #Reranker utility | |
| def rerank_with_cohere(query, documents, top_n=5): | |
| if not documents: | |
| return [] | |
| raw_texts = [doc.page_content for doc in documents] | |
| results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5") | |
| return [documents[result.index] for result in results] | |
| # Final answer generation using reranked context | |
| def get_reranked_response(query): | |
| docs = retriever.get_relevant_documents(query) | |
| reranked_docs = rerank_with_cohere(query, docs) | |
| context = "\n\n".join([doc.page_content for doc in reranked_docs]) | |
| prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context) | |
| if selected_model == "All": | |
| responses = get_all_model_responses(prompt) | |
| fused = rank_and_fuse(prompt, responses) | |
| return type("Obj", (), {"content": fused}) | |
| if callable(llm): | |
| return type("Obj", (), {"content": llm(prompt)}) | |
| else: | |
| return llm.invoke([{"role": "system", "content": prompt}]) | |
| # === App UI === | |
| st.title("𧬠DermBOT β Skin AI Assistant") | |
| st.caption(f"π§ Using model: {selected_model}") | |
| uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| st.image(uploaded_file, caption="Uploaded image", use_container_width=True) | |
| image = Image.open(uploaded_file).convert("RGB") | |
| predicted_multi, predicted_single = run_inference(image) | |
| # Show predictions clearly to the user | |
| st.markdown(f"π§Ύ **Skin Issues**: {', '.join(predicted_multi)}") | |
| st.markdown(f"π **Most Likely Diagnosis**: {predicted_single}") | |
| query = f"What are my treatment options for {predicted_multi} and {predicted_single}?" | |
| st.session_state.messages.append({"role": "user", "content": query}) | |
| with st.spinner("π Analyzing and retrieving context..."): | |
| response = get_reranked_response(query) | |
| st.session_state.messages.append({"role": "assistant", "content": response.content}) | |
| with st.chat_message("assistant"): | |
| st.markdown(response.content) | |
| # === Chat Interface === | |
| if prompt := st.chat_input("Ask a follow-up..."): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| response = get_reranked_response(prompt) | |
| st.session_state.messages.append({"role": "assistant", "content": response.content}) | |
| with st.chat_message("assistant"): | |
| st.markdown(response.content) | |
| # === PDF Button === | |
| if st.button("π Download Chat as PDF"): | |
| pdf_file = export_chat_to_pdf(st.session_state.messages) | |
| st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf") |