|
|
--- |
|
|
license: mit |
|
|
datasets: |
|
|
- WinterSchool/MedificsDataset |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- accuracy |
|
|
base_model: |
|
|
- microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224 |
|
|
tags: |
|
|
- medical |
|
|
- clip |
|
|
- fine-tuned |
|
|
- zero-shot |
|
|
--- |
|
|
|
|
|
|
|
|
|
|
|
This repository contains a fine-tuned version of BiomedCLIP (specifically the PubMedBERT_256-vit_base_patch16_224 variant) using OpenCLIP. The model is trained to recognize and classify various medical images (e.g., chest X-rays, histopathology slides) in a zero-shot manner. It was further adapted on a subset of medical data (e.g., from the WinterSchool/MedificsDataset) to enhance performance on specific image classes. |
|
|
|
|
|
Model Details |
|
|
Architecture: Vision Transformer (ViT-B/16) + PubMedBERT-based text encoder, loaded through open_clip. |
|
|
Training Objective: CLIP-style contrastive learning to align medical text prompts with images. |
|
|
Fine-Tuned On: Selected medical images and text pairs, including X-rays, histopathology images, etc. |
|
|
Intended Use: |
|
|
Zero-shot classification of medical images (e.g., “This is a photo of a chest X-ray”). |
|
|
Exploratory research or educational demos showcasing multi-modal (image-text) alignment in the medical domain. |
|
|
Usage |
|
|
Below is a minimal Python snippet using OpenCLIP. Adjust the labels and text prompts as needed: |
|
|
|
|
|
python |
|
|
Copy |
|
|
import torch |
|
|
import open_clip |
|
|
from PIL import Image |
|
|
|
|
|
# 1) Load the fine-tuned model |
|
|
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms( |
|
|
"hf-hub:mgbam/OpenCLIP-BiomedCLIP-Finetuned", |
|
|
pretrained=None |
|
|
) |
|
|
tokenizer = open_clip.get_tokenizer("hf-hub:mgbam/OpenCLIP-BiomedCLIP-Finetuned") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
# 2) Example labels |
|
|
labels = [ |
|
|
"chest X-ray", |
|
|
"brain MRI", |
|
|
"bone X-ray", |
|
|
"squamous cell carcinoma histopathology", |
|
|
"adenocarcinoma histopathology", |
|
|
"immunohistochemistry histopathology" |
|
|
] |
|
|
|
|
|
# 3) Load and preprocess an image |
|
|
image_path = "path/to/your_image.jpg" |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
image_tensor = preprocess_val(image).unsqueeze(0).to(device) |
|
|
|
|
|
# 4) Create text prompts & tokenize |
|
|
text_prompts = [f"This is a photo of a {label}" for label in labels] |
|
|
tokens = tokenizer(text_prompts).to(device) |
|
|
|
|
|
# 5) Forward pass |
|
|
with torch.no_grad(): |
|
|
image_features = model.encode_image(image_tensor) |
|
|
text_features = model.encode_text(tokens) |
|
|
logit_scale = model.logit_scale.exp() |
|
|
logits = (logit_scale * image_features @ text_features.t()).softmax(dim=-1) |
|
|
|
|
|
# 6) Get predictions |
|
|
probs = logits[0].cpu().tolist() |
|
|
for label, prob in zip(labels, probs): |
|
|
print(f"{label}: {prob:.4f}") |
|
|
Example Gradio App |
|
|
You can also deploy a simple Gradio demo: |
|
|
|
|
|
python |
|
|
Copy |
|
|
import gradio as gr |
|
|
import torch |
|
|
import open_clip |
|
|
from PIL import Image |
|
|
|
|
|
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms( |
|
|
"hf-hub:mgbam/OpenCLIP-BiomedCLIP-Finetuned", |
|
|
pretrained=None |
|
|
) |
|
|
tokenizer = open_clip.get_tokenizer("hf-hub:your-username/OpenCLIP-BiomedCLIP-Finetuned") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
labels = ["chest X-ray", "brain MRI", "histopathology", "etc."] |
|
|
|
|
|
def classify_image(img): |
|
|
if img is None: |
|
|
return {} |
|
|
image_tensor = preprocess_val(img).unsqueeze(0).to(device) |
|
|
prompts = [f"This is a photo of a {label}" for label in labels] |
|
|
tokens = tokenizer(prompts).to(device) |
|
|
with torch.no_grad(): |
|
|
image_feats = model.encode_image(image_tensor) |
|
|
text_feats = model.encode_text(tokens) |
|
|
logit_scale = model.logit_scale.exp() |
|
|
logits = (logit_scale * image_feats @ text_feats.T).softmax(dim=-1) |
|
|
probs = logits.squeeze().cpu().numpy().tolist() |
|
|
return {label: float(prob) for label, prob in zip(labels, probs)} |
|
|
|
|
|
demo = gr.Interface(fn=classify_image, inputs=gr.Image(type="pil"), outputs="label") |
|
|
demo.launch() |
|
|
Performance |
|
|
Accuracy: Varies based on your specific dataset. This model can effectively classify medical images like chest X-rays or histopathology slides, but performance depends heavily on fine-tuning data coverage. |
|
|
Potential Limitations: |
|
|
Ultrasound, CT, MRI or other modalities might not be recognized if not included in training data. |
|
|
The model may incorrectly label images that fall outside its known categories. |
|
|
Limitations & Caveats |
|
|
Not a Medical Device: This model is not FDA-approved or clinically validated. It’s intended for research and educational purposes only. |
|
|
Data Bias: If the training dataset lacked certain pathologies or modalities, the model may systematically misclassify them. |
|
|
Security: This model uses standard PyTorch and open_clip. Be mindful of potential vulnerabilities when loading models or code from untrusted sources. |
|
|
Privacy: If you use patient data, comply with local regulations (HIPAA, GDPR, etc.). |
|
|
Citation & Acknowledgements |
|
|
Base Model: BiomedCLIP by Microsoft |
|
|
OpenCLIP: GitHub – open_clip |
|
|
Fine-tuning dataset: WinterSchool/MedificsDataset |
|
|
If you use this model in your research or demos, please cite the above works accordingly. |
|
|
|
|
|
License |
|
|
[Specify your license here—e.g., MIT, Apache 2.0, or a custom license.] |
|
|
|
|
|
Note: Always include disclaimers that this model is not a substitute for professional medical advice and that it may not generalize to all imaging modalities or patient populations. |