import torch import gradio as gr from torchvision.transforms import v2 as transforms from PIL import Image import numpy as np import cv2 from torchvision.transforms.v2 import functional # Constants RESIZE_DIM = 224 NORMALIZE_MEAN = [0.485, 0.456, 0.406] NORMALIZE_STD = [0.229, 0.224, 0.225] # BreakHis tumor type labels (classes: ["TA", "MC", "F", "DC"]) BREAKHIS_LABELS = { 0: "Tubular Adenoma (TA) - Benign", 1: "Mucinous Carcinoma (MC) - Malignant", 2: "Fibroadenoma (F) - Benign", 3: "Ductal Carcinoma (DC) - Malignant" } GLEASON_LABELS = { 0: "Benign", 1: "Gleason 3", 2: "Gleason 4", 3: "Gleason 5" } BACH_LABELS = {0: "Benign", 1: "In Situ", 2:"Invasive", 3: "Normal"} CRC_LABELS = { 0: "ADI", 1: "BACK", 2: "DEB", 3: "LYM", 4: "MUC", 5: "MUS", 6: "NORM", 7: "STR", 8: "TUM", } BRACS_LABELS = { 0: "Normal", 1: "Pathological Benign", 2: "Usual Ductal Hyperplasia", 3: "Flat Epithelial Atypia", 4: "Atypical Ductal Hyperplasia", 5: "Ductal Carcinoma In Situ", 6: "Invasive Carcinoma", } import torch from huggingface_hub import hf_hub_download #Downloads to hf cache location download_location = hf_hub_download(repo_id="SophontAI/OpenMidnight", filename="teacher_checkpoint_load.pt") model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg', weights = None) #Load OpenMidnight weights checkpoint = torch.load(download_location) #Required because dinov2 is baseline 392 and we are baseline 224 resolution pos_embed = checkpoint["pos_embed"] model.pos_embed = torch.nn.parameter.Parameter(pos_embed) model.load_state_dict(checkpoint) model.eval() print(f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters") def setup_linear(path, classes = 4): print(f"Loading {path} linear classifier...") # Load the best checkpoint from the latest run linear_checkpoint = torch.load(path) linear_weights = linear_checkpoint["state_dict"]["head.weight"] linear_bias = linear_checkpoint["state_dict"]["head.bias"] # Create linear layer linear = torch.nn.Linear(1536, classes) linear.weight.data = linear_weights linear.bias.data = linear_bias linear.eval() return linear # Move models to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dinov2 = model.to(device) breakhis_path = "./breakhis_best.ckpt" breakhis_linear = setup_linear(breakhis_path).to(device) gleason_path = "./gleason_best.ckpt" gleason_linear = setup_linear(gleason_path).to(device) bach_path = "./bach_best.ckpt" bach_linear = setup_linear(bach_path).to(device) crc_path = "./crc_best.ckpt" crc_linear = setup_linear(crc_path, 9).to(device) bracs_path = "./bracs_best.ckpt" bracs_linear = setup_linear(bracs_path, 7).to(device) print(f"Models loaded on {device}") model_transforms = transforms.Compose([ transforms.Resize(RESIZE_DIM), transforms.CenterCrop(RESIZE_DIM), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD) ]) def cv_path(path): image = cv2.imread(path, flags=cv2.IMREAD_COLOR) if image.ndim == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if image.ndim == 2 and flags == cv2.IMREAD_COLOR: image = image[:, :, np.newaxis] image = np.asarray(image, dtype=np.uint8) image = functional.to_image(image) return image def predict_breakhis(image): return predict_class(image, breakhis_linear, "breakhis") def predict_gleason(image): return predict_class(image, gleason_linear, "gleason") def predict_bach(image): return predict_class(image, bach_linear, "bach") def predict_crc(image): return predict_class(image, crc_linear, "crc") def predict_bracs(image): return predict_class(image, bracs_linear, "bracs") def predict_class(image, linear, dataset): """ Predict breast tumor type from a histopathology image Args: image: PIL Image or numpy array Returns: dict: Probability distribution over tumor types """ image = cv_path(image) # Preprocess image image_tensor = model_transforms(image).unsqueeze(0).to(device) # Get embedding from DinoV2 with torch.no_grad(): embedding = dinov2(image_tensor) # Get logits from linear classifier logits = linear(embedding) print(logits) # Convert to probabilities probs = torch.nn.functional.softmax(logits, dim=1) print(probs) # Create output dictionary probs_dict = {} for idx, prob in enumerate(probs[0].cpu().numpy()): if dataset == "breakhis": probs_dict[BREAKHIS_LABELS[idx]] = float(prob) elif dataset == "gleason": probs_dict[GLEASON_LABELS[idx]] = float(prob) elif dataset == "bach": probs_dict[BACH_LABELS[idx]] = float(prob) elif dataset == "crc": probs_dict[CRC_LABELS[idx]] = float(prob) elif dataset == "bracs": probs_dict[BRACS_LABELS[idx]] = float(prob) return probs_dict # Create Gradio interface breakhis = gr.Interface( fn=predict_breakhis, inputs=gr.Image(type="filepath", label="Upload Breast Histopathology Image"), outputs=gr.Label(num_top_classes=4, label="BreakHis Breast Cancer Classification"), title="BreakHis Breast Cancer Classification", description=""" Upload a breast histopathology image to predict the breast cancer subtype. Your image must be at 40X magnification, and ideally between 224x224 and 700x460 resolution. Do not otherwise modify your image. This demo uses a custom-trained DINOv2 foundation model for pathology images called [OpenMidnight](https://sophont.med/blog/openmidnight) with a linear classifier for BreakHis breast cancer classification. **Tumor Types:** - **Benign tumors:** Tubular Adenoma (TA), Fibroadenoma (F) - **Malignant tumors:** Mucinous Carcinoma (MC), Ductal Carcinoma (DC) These 4 classes were selected from the full BreakHis dataset as they have sufficient patient counts (≥7 patients) for robust evaluation. For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. This demonstration is for illustrative purposes only and should not be used for diagnostic/clinical purposes. """, examples=["./SOB_B_TA-14-13200-40-001.png", "./SOB_M_MC-14-10147-40-001.png", "./SOB_B_F-14-14134-40-001.png", ], # You can add example image paths here theme=gr.themes.Soft() ) gleason = gr.Interface( fn=predict_gleason, inputs=gr.Image(type="filepath", label="Upload Prostate Cancer Image"), outputs=gr.Label(num_top_classes=4, label="Gleason Grading"), title="Gleason Grading", description=""" Upload a prostate cancer image to predict the tumor type. Your image must be at 40X magnification, and ideally between 224x224 and 750x750 resolution. Do not otherwise modify your image. This demo uses a custom-trained DINOv2 foundation model for pathology images called [OpenMidnight](https://sophont.med/blog/openmidnight) with a linear classifier for Gleason grading. Images are classified as benign, Gleason pattern 3, 4 or 5. For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. This demonstration is for illustrative purposes only and should not be used for diagnostic/clinical purposes. """, examples=["./ZT111_4_A_1_12_patch_13_class_2.jpg", "./ZT204_6_A_1_10_patch_10_class_3.jpg", #"", ], # You can add example image paths here theme=gr.themes.Soft() ) crc = gr.Interface( fn=predict_crc, inputs=gr.Image(type="filepath", label="Upload Colorectal Cancer Image"), outputs=gr.Label(num_top_classes=9, label="CRC Tumor Type Prediction"), title="Colorectal Cancer Tissue Classification", description=""" Upload a colorectal cancer image to predict the tissue class. Your image must be at 20X magnification, and ideally at 224x224. Do not otherwise modify your image. This demo uses a custom-trained DINOv2 foundation model for pathology images called [OpenMidnight](https://sophont.med/blog/openmidnight) with a linear classifier for colorectal cancer tissue classification. The tissue classes are: Adipose (ADI), background (BACK), debris (DEB), lymphocytes (LYM), mucus (MUC), smooth muscle (MUS), normal colon mucosa (NORM), cancer-associated stroma (STR) and colorectal adenocarcinoma epithelium (TUM) For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. This demonstration is for illustrative purposes only and should not be used for diagnostic/clinical purposes. """, examples=["./ADI-TCGA-AAICEQFN.png", "./BACK-TCGA-AARRNSTS.png", "./DEB-TCGA-AANNAWLE.png", ], # You can add example image paths here theme=gr.themes.Soft() ) bach = gr.Interface( fn=predict_bach, inputs=gr.Image(type="filepath", label="Upload Cancer Image"), outputs=gr.Label(num_top_classes=4, label="BACH Breast Cancer Classification"), title="BACH Breast Cancer Classification", description=""" Upload a breast cancer image to predict the tumor type. Your image must be at 20X magnification, and ideally between 224x224 and 1536x2048 resolution. Do not otherwise modify your image. This demo uses a custom-trained DINOv2 foundation model for pathology images called [OpenMidnight](https://sophont.med/blog/openmidnight) with a linear classifier for tumor classification. Images are classified as benign, normal, invasive, in-situ. For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. This demonstration is for illustrative purposes only and should not be used for diagnostic/clinical purposes. """, examples=["./b001.png", "./n001.png", "./is001.png", "./iv001.png" ], # You can add example image paths here theme=gr.themes.Soft() ) bracs = gr.Interface( fn=predict_bracs, inputs=gr.Image(type="filepath", label="Upload Cancer Image"), outputs=gr.Label(num_top_classes=7, label="BRACS Tumor Subtyping"), title="BRACS Tumor Subtyping", description=""" Upload a breast cancer image to predict the tumor type. Your image must be at 40X magnification. Do not otherwise modify your image. This demo uses a custom-trained DINOv2 foundation model for pathology images called [OpenMidnight](https://sophont.med/blog/openmidnight) with a linear classifier for tumor classification. Images are classified as Normal, Pathological Benign, Usual Ductal Hyperplasia, Flat Epithelial Atypia, Atypical Ductal Hyperplasia, Ductal Carcinoma In Situ, Invasive Carcinoma For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. This demonstration is for illustrative purposes only and should not be used for diagnostic/clinical purposes. """, examples=[ ], # You can add example image paths here theme=gr.themes.Soft() ) demo = gr.TabbedInterface([breakhis, gleason, crc, bach, bracs],["BreakHis", "Gleason", "CRC", "Bach", "BRACS"]) if __name__ == "__main__": demo.launch()