--- license: apache-2.0 tags: - medical-imaging - glaucoma - federated-learning - semantic-segmentation - ophthalmology - computer-vision - healthcare - mask2former - swin-transformer datasets: - chaksu - refuge - g1020 - rim-one-dl - messidor - origa library_name: transformers pipeline_tag: image-segmentation --- # Federated Learning for Glaucoma Segmentation: Model Checkpoints ## Overview This repository contains trained model checkpoints from the research project: **"A Federated Learning-based Optic Disc and Cup Segmentation Model for Glaucoma Monitoring in Color Fundus Photographs"** ### Key Information - **Task**: Automated optic disc and cup segmentation for glaucoma assessment - **Architecture**: Mask2Former with Swin Transformer backbone - **Pre-training**: ADE20K semantic segmentation dataset - **Training Data**: 5,550 color fundus photographs from 9 datasets across 7 countries - **Approach**: Privacy-preserving federated learning with site-specific fine-tuning ## Clinical Context Glaucoma is a leading cause of irreversible blindness worldwide, affecting 3.54% of the population aged 40-80 and projected to impact 111.8 million people by 2040. A key indicator of glaucoma severity is the vertical cup-to-disc ratio (CDR), with ratios ≥0.6 suggestive of glaucoma. This work addresses the need for accurate automated segmentation while preserving patient data privacy across multiple clinical sites, enabling HIPAA/GDPR-compliant multi-institutional collaboration. ## Models Included This repository contains **22 trained models** organized into four categories: ### Baseline Models - **Central Model** (1 model): Trained on pooled multi-site data, representing upper bound performance - **Local Models** (9 models): Site-specific models trained on individual datasets, representing lower bound performance ### Federated Learning Models - **Pipeline 1** (1 model): Global Validation - **Pipeline 2** (1 model): Weighted Global Validation - **Pipeline 3** (1 model): Onsite Validation - **Pipeline 4** (9 models): Fine-Tuned Onsite Validation ## Usage ### Download Specific Model from huggingface_hub import hf_hub_download # Download central model model_path = hf_hub_download( repo_id="sud11111/Federated-Learning-Glaucoma", filename="models/central/best_model.pt" ) # Download fine-tuned model for specific dataset model_path = hf_hub_download( repo_id="sud11111/Federated-Learning-Glaucoma", filename="models/pipeline4/chaksu/best_model.pt" ) ### Download All Models from huggingface_hub import snapshot_download # Download entire models directory local_dir = snapshot_download( repo_id="sud11111/Federated-Learning-Glaucoma", allow_patterns="models/**" ) print(f"Models downloaded to: {local_dir}") ### Load and Perform Inference import torch from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor from PIL import Image # Load preprocessor processor = Mask2FormerImageProcessor.from_pretrained( "facebook/mask2former-swin-base-ade-semantic" ) # Load model architecture model = Mask2FormerForUniversalSegmentation.from_pretrained( "facebook/mask2former-swin-base-ade-semantic", num_labels=4 # background, unlabeled, optic disc, optic cup ) # Load trained weights model.load_state_dict(torch.load(model_path)) model.eval() # Perform inference on fundus image image = Image.open("fundus_image.jpg") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) # Post-process segmentation predicted_segmentation = processor.post_process_semantic_segmentation( outputs, target_sizes=[image.size[::-1]] )[0] ## Datasets Training was performed across 9 public datasets spanning 7 countries, comprising a total of 5,550 color fundus photographs from at least 917 patients: | Dataset | Total Images | Test Images | Country | Characteristics | |---------|-------------|-------------|---------|-----------------| | Chaksu | 1,345 | 135 | India | Multi-center research dataset | | REFUGE | 1,200 | 120 | China | Glaucoma challenge dataset | | G1020 | 1,020 | 102 | Germany | Benchmark retinal fundus dataset | | RIM-ONE DL | 485 | 49 | Spain | Glaucoma assessment dataset | | MESSIDOR | 460 | 46 | France | Diabetic retinopathy screening | | ORIGA | 650 | 65 | Singapore | Multi-ethnic Asian population | | Bin Rushed | 195 | 20 | Saudi Arabia | RIGA dataset collection | | DRISHTI-GS | 101 | 11 | India | Optic nerve head segmentation | | Magrabi | 94 | 10 | Saudi Arabia | RIGA dataset collection | **Data Split**: Each dataset was divided into training (80%), validation (10%), and testing (10%) subsets. For datasets with multiple expert annotations, the STAPLE (Simultaneous Truth and Performance Level Estimation) method was used to generate consensus segmentation labels. ## Model Architecture - **Base Model**: Mask2Former - **Backbone**: Swin Transformer (Swin-Base) - **Pre-training**: ADE20K semantic segmentation dataset - **Input Resolution**: 512×512 pixels - **Output Classes**: 4 (background, unlabeled, optic disc, optic cup) - **Optimizer**: AdamW (learning rate: 2×10⁻⁵) - **Loss Function**: Multi-class cross-entropy - **Early Stopping**: Patience of 7 epochs/rounds ## Training Configuration ### Common Hyperparameters - Batch size: 8 - Learning rate: 2×10⁻⁵ - Optimizer: AdamW with weight decay - Maximum epochs: 100 (with early stopping) - Early stopping patience: 7 epochs/rounds - Input size: 512×512 pixels (normalized) ---