import torch import gradio as gr from torchvision.transforms import v2 as transforms from PIL import Image import numpy as np import cv2 from torchvision.transforms.v2 import functional # Constants RESIZE_DIM = 224 NORMALIZE_MEAN = [0.485, 0.456, 0.406] NORMALIZE_STD = [0.229, 0.224, 0.225] # BreakHis tumor type labels (classes: ["TA", "MC", "F", "DC"]) BREAKHIS_LABELS = { 0: "Tubular Adenoma (TA) - Benign", 1: "Mucinous Carcinoma (MC) - Malignant", 2: "Fibroadenoma (F) - Benign", 3: "Ductal Carcinoma (DC) - Malignant" } GLEASON_LABELS = { 0: "Benign", 1: "Gleason 3", 2: "Gleason 4", 3: "Gleason 5" } BACH_LABELS = {0: "Benign", 1: "InSitu", 2:"Invasive", 3: "Normal"} CRC_LABELS = { 0: "ADI", 1: "BACK", 2: "DEB", 3: "LYM", 4: "MUC", 5: "MUS", 6: "NORM", 7: "STR", 8: "TUM", } print("Loading DinoV2 base model...") dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg') print("Loading custom pathology checkpoint...") #ours = torch.load("/data/linears/teacher_checkpoint.pth") #checkpoint = torch.load("./teacher_checkpoint_load.pt") checkpoint = torch.hub.load_state_dict_from_url("https://huggingface.co/SophontAI/OpenMidnight/resolve/main/teacher_checkpoint_load.pt") new_shape = checkpoint["pos_embed"] dinov2.pos_embed = torch.nn.parameter.Parameter(new_shape) dinov2.load_state_dict(checkpoint) dinov2.eval() #torch.save(dinov2.state_dict(), "teacher_checkpoint_load.pt") def setup_linear(path): print(f"Loading {path} linear classifier...") # Load the best checkpoint from the latest run linear_checkpoint = torch.load(path) linear_weights = linear_checkpoint["state_dict"]["head.weight"] linear_bias = linear_checkpoint["state_dict"]["head.bias"] # Create linear layer linear = torch.nn.Linear(1536, 4) linear.weight.data = linear_weights linear.bias.data = linear_bias linear.eval() return linear def setup_linear_crc(path): print(f"Loading {path} linear classifier...") # Load the best checkpoint from the latest run linear_checkpoint = torch.load(path) linear_weights = linear_checkpoint["state_dict"]["head.weight"] linear_bias = linear_checkpoint["state_dict"]["head.bias"] # Create linear layer linear = torch.nn.Linear(1536, 9) linear.weight.data = linear_weights linear.bias.data = linear_bias linear.eval() return linear # Move models to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dinov2 = dinov2.to(device) breakhis_path = "./breakhis_best.ckpt" breakhis_linear = setup_linear(breakhis_path).to(device) gleason_path = "./gleason_best.ckpt" gleason_linear = setup_linear(gleason_path).to(device) bach_path = "./bach_best.ckpt" bach_linear = setup_linear(bach_path).to(device) crc_path = "./crc_best.ckpt" crc_linear = setup_linear_crc(crc_path).to(device) print(f"Models loaded on {device}") model_transforms = transforms.Compose([ transforms.Resize(RESIZE_DIM), transforms.CenterCrop(RESIZE_DIM), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD) ]) def cv_path(path): image = cv2.imread(path, flags=cv2.IMREAD_COLOR) if image.ndim == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if image.ndim == 2 and flags == cv2.IMREAD_COLOR: image = image[:, :, np.newaxis] image = np.asarray(image, dtype=np.uint8) image = functional.to_image(image) return image def predict_breakhis(image): return predict_class(image, breakhis_linear, "breakhis") def predict_gleason(image): return predict_class(image, gleason_linear, "gleason") def predict_bach(image): return predict_class(image, bach_linear, "bach") def predict_crc(image): return predict_class(image, crc_linear, "crc") def predict_class(image, linear, dataset): """ Predict breast tumor type from a histopathology image Args: image: PIL Image or numpy array Returns: dict: Probability distribution over tumor types """ image = cv_path(image) # Preprocess image image_tensor = model_transforms(image).unsqueeze(0).to(device) # Get embedding from DinoV2 with torch.no_grad(): embedding = dinov2(image_tensor) # Get logits from linear classifier logits = linear(embedding) print(logits) # Convert to probabilities probs = torch.nn.functional.softmax(logits, dim=1) print(probs) # Create output dictionary probs_dict = {} for idx, prob in enumerate(probs[0].cpu().numpy()): if dataset == "breakhis": probs_dict[BREAKHIS_LABELS[idx]] = float(prob) elif dataset == "gleason": probs_dict[GLEASON_LABELS[idx]] = float(prob) elif dataset == "bach": probs_dict[BACH_LABELS[idx]] = float(prob) elif dataset == "crc": probs_dict[CRC_LABELS[idx]] = float(prob) return probs_dict # Create Gradio interface breakhis = gr.Interface( fn=predict_breakhis, inputs=gr.Image(type="filepath", label="Upload Breast Histopathology Image"), outputs=gr.Label(num_top_classes=4, label="Tumor Type Prediction"), title="BreakHis Breast Tumor Classification", description=""" Upload a breast histopathology image to predict the tumor type. Your image must be at 40X magnification, and ideally between 224x224 and 700x460 resolution. Do not otherwise modify your image. This model uses a custom-trained DinoV2 foundation model for pathology images with a linear classifier for BreakHis tumor classification. **Tumor Types:** - **Benign tumors:** Tubular Adenoma (TA), Fibroadenoma (F) - **Malignant tumors:** Mucinous Carcinoma (MC), Ductal Carcinoma (DC) These 4 classes were selected from the full BreakHis dataset as they have sufficient patient counts (≥7 patients) for robust evaluation. For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. """, examples=["./SOB_B_TA-14-13200-40-001.png", "./SOB_M_MC-14-10147-40-001.png", "./SOB_B_F-14-14134-40-001.png", ], # You can add example image paths here theme=gr.themes.Soft() ) gleason = gr.Interface( fn=predict_gleason, inputs=gr.Image(type="filepath", label="Upload Prostate Cancer Image"), outputs=gr.Label(num_top_classes=4, label="Gleason Tumor Type Prediction"), title="Gleason Prostate Tumor Classification", description=""" Upload a prostate cancer image to predict the tumor type. Your image must be at 40X magnification, and ideally between 224x224 and 750x750 resolution. Do not otherwise modify your image. This model uses a custom-trained DinoV2 foundation model for pathology images with a linear classifier for gleason tumor classification. Images are classified as benign, Gleason pattern 3, 4 or 5. For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. """, examples=["./ZT111_4_A_1_12_patch_13_class_2.jpg", "./ZT204_6_A_1_10_patch_10_class_3.jpg", #"", ], # You can add example image paths here theme=gr.themes.Soft() ) crc = gr.Interface( fn=predict_crc, inputs=gr.Image(type="filepath", label="Upload Colorectal Cancer Image"), outputs=gr.Label(num_top_classes=9, label="CRC Tumor Type Prediction"), title="Colorectal Tumor Classification", description=""" Upload a colorectal cancer image to predict the tumor type. Your image must be at 20X magnification, and ideally at 224x224. Do not otherwise modify your image. This model uses a custom-trained DinoV2 foundation model for pathology images with a linear classifier for colorectal tumor classification. The tissue classes are: Adipose (ADI), background (BACK), debris (DEB), lymphocytes (LYM), mucus (MUC), smooth muscle (MUS), normal colon mucosa (NORM), cancer-associated stroma (STR) and colorectal adenocarcinoma epithelium (TUM) For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. """, examples=["./ADI-TCGA-AAICEQFN.png", "./BACK-TCGA-AARRNSTS.png", "./DEB-TCGA-AANNAWLE.png", ], # You can add example image paths here theme=gr.themes.Soft() ) bach = gr.Interface( fn=predict_bach, inputs=gr.Image(type="filepath", label="Upload Cancer Image"), outputs=gr.Label(num_top_classes=4, label="Bach Tumor Type Prediction"), title="Tumor Classification", description=""" Upload a prostate cancer image to predict the tumor type. Your image must be at 20X magnification, and ideally between 224x224 and 1536x2048 resolution. Do not otherwise modify your image. This model uses a custom-trained DinoV2 foundation model for pathology images with a linear classifier for tumor classification. Images are classified as benign, normal, invasive, inSitu For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. """, examples=["./b001.png", "./n001.png", "./is001.png", "./iv001.png" ], # You can add example image paths here theme=gr.themes.Soft() ) demo = gr.TabbedInterface([breakhis, gleason, crc, bach],["BreakHis", "Gleason", "CRC", "Bach"]) if __name__ == "__main__": demo.launch()