mgbam's picture
Update README.md
3287ef8 verified
|
raw
history blame
5.32 kB
---
license: mit
datasets:
- WinterSchool/MedificsDataset
language:
- en
metrics:
- accuracy
base_model:
- microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224
tags:
- medical
- clip
- fine-tuned
- zero-shot
---
This repository contains a fine-tuned version of BiomedCLIP (specifically the PubMedBERT_256-vit_base_patch16_224 variant) using OpenCLIP. The model is trained to recognize and classify various medical images (e.g., chest X-rays, histopathology slides) in a zero-shot manner. It was further adapted on a subset of medical data (e.g., from the WinterSchool/MedificsDataset) to enhance performance on specific image classes.
Model Details
Architecture: Vision Transformer (ViT-B/16) + PubMedBERT-based text encoder, loaded through open_clip.
Training Objective: CLIP-style contrastive learning to align medical text prompts with images.
Fine-Tuned On: Selected medical images and text pairs, including X-rays, histopathology images, etc.
Intended Use:
Zero-shot classification of medical images (e.g., “This is a photo of a chest X-ray”).
Exploratory research or educational demos showcasing multi-modal (image-text) alignment in the medical domain.
Usage
Below is a minimal Python snippet using OpenCLIP. Adjust the labels and text prompts as needed:
python
Copy
import torch
import open_clip
from PIL import Image
# 1) Load the fine-tuned model
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
"hf-hub:mgbam/OpenCLIP-BiomedCLIP-Finetuned",
pretrained=None
)
tokenizer = open_clip.get_tokenizer("hf-hub:mgbam/OpenCLIP-BiomedCLIP-Finetuned")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# 2) Example labels
labels = [
"chest X-ray",
"brain MRI",
"bone X-ray",
"squamous cell carcinoma histopathology",
"adenocarcinoma histopathology",
"immunohistochemistry histopathology"
]
# 3) Load and preprocess an image
image_path = "path/to/your_image.jpg"
image = Image.open(image_path).convert("RGB")
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
# 4) Create text prompts & tokenize
text_prompts = [f"This is a photo of a {label}" for label in labels]
tokens = tokenizer(text_prompts).to(device)
# 5) Forward pass
with torch.no_grad():
image_features = model.encode_image(image_tensor)
text_features = model.encode_text(tokens)
logit_scale = model.logit_scale.exp()
logits = (logit_scale * image_features @ text_features.t()).softmax(dim=-1)
# 6) Get predictions
probs = logits[0].cpu().tolist()
for label, prob in zip(labels, probs):
print(f"{label}: {prob:.4f}")
Example Gradio App
You can also deploy a simple Gradio demo:
python
Copy
import gradio as gr
import torch
import open_clip
from PIL import Image
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
"hf-hub:mgbam/OpenCLIP-BiomedCLIP-Finetuned",
pretrained=None
)
tokenizer = open_clip.get_tokenizer("hf-hub:your-username/OpenCLIP-BiomedCLIP-Finetuned")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
labels = ["chest X-ray", "brain MRI", "histopathology", "etc."]
def classify_image(img):
if img is None:
return {}
image_tensor = preprocess_val(img).unsqueeze(0).to(device)
prompts = [f"This is a photo of a {label}" for label in labels]
tokens = tokenizer(prompts).to(device)
with torch.no_grad():
image_feats = model.encode_image(image_tensor)
text_feats = model.encode_text(tokens)
logit_scale = model.logit_scale.exp()
logits = (logit_scale * image_feats @ text_feats.T).softmax(dim=-1)
probs = logits.squeeze().cpu().numpy().tolist()
return {label: float(prob) for label, prob in zip(labels, probs)}
demo = gr.Interface(fn=classify_image, inputs=gr.Image(type="pil"), outputs="label")
demo.launch()
Performance
Accuracy: Varies based on your specific dataset. This model can effectively classify medical images like chest X-rays or histopathology slides, but performance depends heavily on fine-tuning data coverage.
Potential Limitations:
Ultrasound, CT, MRI or other modalities might not be recognized if not included in training data.
The model may incorrectly label images that fall outside its known categories.
Limitations & Caveats
Not a Medical Device: This model is not FDA-approved or clinically validated. It’s intended for research and educational purposes only.
Data Bias: If the training dataset lacked certain pathologies or modalities, the model may systematically misclassify them.
Security: This model uses standard PyTorch and open_clip. Be mindful of potential vulnerabilities when loading models or code from untrusted sources.
Privacy: If you use patient data, comply with local regulations (HIPAA, GDPR, etc.).
Citation & Acknowledgements
Base Model: BiomedCLIP by Microsoft
OpenCLIP: GitHub – open_clip
Fine-tuning dataset: WinterSchool/MedificsDataset
If you use this model in your research or demos, please cite the above works accordingly.
License
[Specify your license here—e.g., MIT, Apache 2.0, or a custom license.]
Note: Always include disclaimers that this model is not a substitute for professional medical advice and that it may not generalize to all imaging modalities or patient populations.