tmabraham's picture
Update app.py
641bdbe verified
raw
history blame
11.7 kB
import torch
import gradio as gr
from torchvision.transforms import v2 as transforms
from PIL import Image
import numpy as np
import cv2
from torchvision.transforms.v2 import functional
# Constants
RESIZE_DIM = 224
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]
# BreakHis tumor type labels (classes: ["TA", "MC", "F", "DC"])
BREAKHIS_LABELS = {
0: "Tubular Adenoma (TA) - Benign",
1: "Mucinous Carcinoma (MC) - Malignant",
2: "Fibroadenoma (F) - Benign",
3: "Ductal Carcinoma (DC) - Malignant"
}
GLEASON_LABELS = {
0: "Benign",
1: "Gleason 3",
2: "Gleason 4",
3: "Gleason 5"
}
BACH_LABELS = {0: "Benign",
1: "In Situ",
2:"Invasive",
3: "Normal"}
CRC_LABELS = {
0: "ADI",
1: "BACK",
2: "DEB",
3: "LYM",
4: "MUC",
5: "MUS",
6: "NORM",
7: "STR",
8: "TUM",
}
BRACS_LABELS = {
0: "Normal",
1: "Pathological Benign",
2: "Usual Ductal Hyperplasia",
3: "Flat Epithelial Atypia",
4: "Atypical Ductal Hyperplasia",
5: "Ductal Carcinoma In Situ",
6: "Invasive Carcinoma",
}
import torch
from huggingface_hub import hf_hub_download
#Downloads to hf cache location
download_location = hf_hub_download(repo_id="SophontAI/OpenMidnight", filename="teacher_checkpoint_load.pt")
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg', weights = None)
#Load OpenMidnight weights
checkpoint = torch.load(download_location)
#Required because dinov2 is baseline 392 and we are baseline 224 resolution
pos_embed = checkpoint["pos_embed"]
model.pos_embed = torch.nn.parameter.Parameter(pos_embed)
model.load_state_dict(checkpoint)
model.eval()
print(f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters")
def setup_linear(path, classes = 4):
print(f"Loading {path} linear classifier...")
# Load the best checkpoint from the latest run
linear_checkpoint = torch.load(path)
linear_weights = linear_checkpoint["state_dict"]["head.weight"]
linear_bias = linear_checkpoint["state_dict"]["head.bias"]
# Create linear layer
linear = torch.nn.Linear(1536, classes)
linear.weight.data = linear_weights
linear.bias.data = linear_bias
linear.eval()
return linear
# Move models to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dinov2 = model.to(device)
breakhis_path = "./breakhis_best.ckpt"
breakhis_linear = setup_linear(breakhis_path).to(device)
gleason_path = "./gleason_best.ckpt"
gleason_linear = setup_linear(gleason_path).to(device)
bach_path = "./bach_best.ckpt"
bach_linear = setup_linear(bach_path).to(device)
crc_path = "./crc_best.ckpt"
crc_linear = setup_linear(crc_path, 9).to(device)
bracs_path = "./bracs_best.ckpt"
bracs_linear = setup_linear(bracs_path, 7).to(device)
print(f"Models loaded on {device}")
model_transforms = transforms.Compose([
transforms.Resize(RESIZE_DIM),
transforms.CenterCrop(RESIZE_DIM),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
])
def cv_path(path):
image = cv2.imread(path, flags=cv2.IMREAD_COLOR)
if image.ndim == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if image.ndim == 2 and flags == cv2.IMREAD_COLOR:
image = image[:, :, np.newaxis]
image = np.asarray(image, dtype=np.uint8)
image = functional.to_image(image)
return image
def predict_breakhis(image):
return predict_class(image, breakhis_linear, "breakhis")
def predict_gleason(image):
return predict_class(image, gleason_linear, "gleason")
def predict_bach(image):
return predict_class(image, bach_linear, "bach")
def predict_crc(image):
return predict_class(image, crc_linear, "crc")
def predict_bracs(image):
return predict_class(image, bracs_linear, "bracs")
def predict_class(image, linear, dataset):
"""
Predict breast tumor type from a histopathology image
Args:
image: PIL Image or numpy array
Returns:
dict: Probability distribution over tumor types
"""
image = cv_path(image)
# Preprocess image
image_tensor = model_transforms(image).unsqueeze(0).to(device)
# Get embedding from DinoV2
with torch.no_grad():
embedding = dinov2(image_tensor)
# Get logits from linear classifier
logits = linear(embedding)
print(logits)
# Convert to probabilities
probs = torch.nn.functional.softmax(logits, dim=1)
print(probs)
# Create output dictionary
probs_dict = {}
for idx, prob in enumerate(probs[0].cpu().numpy()):
if dataset == "breakhis":
probs_dict[BREAKHIS_LABELS[idx]] = float(prob)
elif dataset == "gleason":
probs_dict[GLEASON_LABELS[idx]] = float(prob)
elif dataset == "bach":
probs_dict[BACH_LABELS[idx]] = float(prob)
elif dataset == "crc":
probs_dict[CRC_LABELS[idx]] = float(prob)
elif dataset == "bracs":
probs_dict[BRACS_LABELS[idx]] = float(prob)
return probs_dict
# Create Gradio interface
breakhis = gr.Interface(
fn=predict_breakhis,
inputs=gr.Image(type="filepath", label="Upload Breast Histopathology Image"),
outputs=gr.Label(num_top_classes=4, label="BreakHis Breast Cancer Classification"),
title="BreakHis Breast Cancer Classification",
description="""
Upload a breast histopathology image to predict the breast cancer subtype. Your image must be at 40X magnification, and ideally between 224x224 and 700x460 resolution. Do not otherwise modify your image.
This demo uses a custom-trained DINOv2 foundation model for pathology images called [OpenMidnight](https://sophont.med/blog/openmidnight)
with a linear classifier for BreakHis breast cancer classification.
**Tumor Types:**
- **Benign tumors:** Tubular Adenoma (TA), Fibroadenoma (F)
- **Malignant tumors:** Mucinous Carcinoma (MC), Ductal Carcinoma (DC)
These 4 classes were selected from the full BreakHis dataset as they have sufficient patient counts (≥7 patients) for robust evaluation.
For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
This demonstration is for illustrative purposes only and should not be used for diagnostic/clinical purposes.
""",
examples=["./SOB_B_TA-14-13200-40-001.png",
"./SOB_M_MC-14-10147-40-001.png",
"./SOB_B_F-14-14134-40-001.png",
], # You can add example image paths here
theme=gr.themes.Soft()
)
gleason = gr.Interface(
fn=predict_gleason,
inputs=gr.Image(type="filepath", label="Upload Prostate Cancer Image"),
outputs=gr.Label(num_top_classes=4, label="Gleason Grading"),
title="Gleason Grading",
description="""
Upload a prostate cancer image to predict the tumor type. Your image must be at 40X magnification, and ideally between 224x224 and 750x750 resolution. Do not otherwise modify your image.
This demo uses a custom-trained DINOv2 foundation model for pathology images called [OpenMidnight](https://sophont.med/blog/openmidnight)
with a linear classifier for Gleason grading.
Images are classified as benign, Gleason pattern 3, 4 or 5.
For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
This demonstration is for illustrative purposes only and should not be used for diagnostic/clinical purposes.
""",
examples=["./ZT111_4_A_1_12_patch_13_class_2.jpg",
"./ZT204_6_A_1_10_patch_10_class_3.jpg",
#"",
], # You can add example image paths here
theme=gr.themes.Soft()
)
crc = gr.Interface(
fn=predict_crc,
inputs=gr.Image(type="filepath", label="Upload Colorectal Cancer Image"),
outputs=gr.Label(num_top_classes=9, label="CRC Tumor Type Prediction"),
title="Colorectal Cancer Tissue Classification",
description="""
Upload a colorectal cancer image to predict the tissue class. Your image must be at 20X magnification, and ideally at 224x224. Do not otherwise modify your image.
This demo uses a custom-trained DINOv2 foundation model for pathology images called [OpenMidnight](https://sophont.med/blog/openmidnight)
with a linear classifier for colorectal cancer tissue classification.
The tissue classes are: Adipose (ADI), background (BACK), debris (DEB), lymphocytes (LYM), mucus (MUC), smooth muscle (MUS), normal colon mucosa (NORM), cancer-associated stroma (STR) and colorectal adenocarcinoma epithelium (TUM)
For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
This demonstration is for illustrative purposes only and should not be used for diagnostic/clinical purposes.
""",
examples=["./ADI-TCGA-AAICEQFN.png",
"./BACK-TCGA-AARRNSTS.png",
"./DEB-TCGA-AANNAWLE.png",
], # You can add example image paths here
theme=gr.themes.Soft()
)
bach = gr.Interface(
fn=predict_bach,
inputs=gr.Image(type="filepath", label="Upload Cancer Image"),
outputs=gr.Label(num_top_classes=4, label="BACH Breast Cancer Classification"),
title="BACH Breast Cancer Classification",
description="""
Upload a breast cancer image to predict the tumor type. Your image must be at 20X magnification, and ideally between 224x224 and 1536x2048 resolution. Do not otherwise modify your image.
This demo uses a custom-trained DINOv2 foundation model for pathology images called [OpenMidnight](https://sophont.med/blog/openmidnight)
with a linear classifier for tumor classification.
Images are classified as benign, normal, invasive, in-situ.
For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
This demonstration is for illustrative purposes only and should not be used for diagnostic/clinical purposes.
""",
examples=["./b001.png",
"./n001.png",
"./is001.png",
"./iv001.png"
], # You can add example image paths here
theme=gr.themes.Soft()
)
bracs = gr.Interface(
fn=predict_bracs,
inputs=gr.Image(type="filepath", label="Upload Cancer Image"),
outputs=gr.Label(num_top_classes=7, label="BRACS Tumor Subtyping"),
title="BRACS Tumor Subtyping",
description="""
Upload a breast cancer image to predict the tumor type. Your image must be at 40X magnification. Do not otherwise modify your image.
This demo uses a custom-trained DINOv2 foundation model for pathology images called [OpenMidnight](https://sophont.med/blog/openmidnight)
with a linear classifier for tumor classification.
Images are classified as Normal, Pathological Benign, Usual Ductal Hyperplasia, Flat Epithelial Atypia,
Atypical Ductal Hyperplasia, Ductal Carcinoma In Situ, Invasive Carcinoma
For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
This demonstration is for illustrative purposes only and should not be used for diagnostic/clinical purposes.
""",
examples=[
], # You can add example image paths here
theme=gr.themes.Soft()
)
demo = gr.TabbedInterface([breakhis, gleason, crc, bach, bracs],["BreakHis", "Gleason", "CRC", "Bach", "BRACS"])
if __name__ == "__main__":
demo.launch()