--- license: apache-2.0 tags: - medical - radiology - mammography - contrastive-learning - embeddings - computer-vision - pytorch datasets: - CMMD pipeline_tag: feature-extraction --- # PRIMER: Pretrained RadImageNet for Mammography Embedding Representations PRIMER is a specialized deep learning model for mammography analysis, finetuned from RadImageNet using contrastive learning on the CMMD (Chinese Mammography Mass Database) dataset. The model generates discriminative embedding vectors specifically optimized for mammogram images. ## Model Overview - **Base Model**: RadImageNet ResNet-50 - **Training Method**: SimCLR contrastive learning (NT-Xent loss) - **Architecture**: ResNet-50 encoder + 2-layer MLP projection head - **Input**: 224×224 RGB images (converted from DICOM grayscale) - **Output**: 2048-dimensional embedding vectors - **Training Dataset**: CMMD mammography DICOM files - **Framework**: PyTorch 2.1+ ## Key Features - Finetuned specifically for mammography imaging - Self-supervised contrastive learning (no labels required) - Produces embeddings with better clustering and separation than baseline RadImageNet - Handles DICOM preprocessing pipeline end-to-end - Supports multiple backbone architectures (ResNet-50, DenseNet-121, Inception-V3) ## DICOM Preprocessing Pipeline The model expects mammography DICOM images preprocessed through the following pipeline. This preprocessing is **critical** for proper model performance: ### Step 1: DICOM Loading ``` - Read DICOM file using pydicom - Extract pixel array as float32 ``` ### Step 2: Photometric Interpretation Correction ``` - Check PhotometricInterpretation attribute - If MONOCHROME1: Invert pixel values (max_value - pixel_value) - MONOCHROME1: Higher values = darker (inverted scale) - MONOCHROME2: Higher values = brighter (standard scale) ``` ### Step 3: Intensity Normalization ``` - Percentile-based clipping to remove outliers: - Compute 2nd percentile (p2) and 98th percentile (p98) - Clip all values: pixel_value = clip(pixel_value, p2, p98) - Min-max normalization to [0, 255]: - normalized = ((pixel_value - min) / (max - min + 1e-8)) × 255 - Convert to uint8 ``` ### Step 4: CLAHE Enhancement ``` - Apply Contrast Limited Adaptive Histogram Equalization (CLAHE) - Clip limit: 2.0 - Tile grid size: 8×8 - Improves local contrast and enhances subtle features ``` ### Step 5: Grayscale to RGB Conversion ``` - Duplicate grayscale channel 3 times: RGB = [gray, gray, gray] - Required because RadImageNet expects 3-channel input ``` ### Step 6: Resizing ``` - Resize to 224×224 using bilinear interpolation ``` ### Step 7: Data Augmentation (Training Only) ``` Training augmentations: - Horizontal flip (p=0.5) - Vertical flip (p=0.3) - Rotation (±15 degrees, p=0.5) - Random brightness/contrast (±0.2, p=0.5) - Shift/scale/rotate (shift=0.1, scale=0.1, rotate=15°, p=0.5) ``` ### Step 8: Normalization ``` - ImageNet normalization (required for RadImageNet compatibility): - Mean: [0.485, 0.456, 0.406] - Std: [0.229, 0.224, 0.225] - Convert to tensor (C×H×W format) ``` ### Complete Preprocessing Code ```python import cv2 import numpy as np import pydicom from PIL import Image class DICOMProcessor: def __init__(self): self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) def preprocess(self, dicom_path): # 1. Load DICOM dicom = pydicom.dcmread(dicom_path) image = dicom.pixel_array.astype(np.float32) # 2. Handle photometric interpretation if hasattr(dicom, 'PhotometricInterpretation'): if dicom.PhotometricInterpretation == "MONOCHROME1": image = np.max(image) - image # 3. Intensity normalization p2, p98 = np.percentile(image, (2, 98)) image = np.clip(image, p2, p98) image = ((image - image.min()) / (image.max() - image.min() + 1e-8) * 255) image = image.astype(np.uint8) # 4. CLAHE enhancement image = self.clahe.apply(image) # 5. Grayscale to RGB image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # 6. Resize image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR) # 7. ImageNet normalization image = image.astype(np.float32) / 255.0 mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) image = (image - mean) / std # 8. Convert to tensor (C, H, W) image = np.transpose(image, (2, 0, 1)) return image ``` ## Model Architecture ### Overall Structure ``` Input DICOM (H×W grayscale) ↓ [DICOM Preprocessing Pipeline] ↓ 224×224×3 RGB Tensor ↓ [RadImageNet ResNet-50 Encoder] ↓ 2048-dim Embeddings ↓ [Projection Head] (training only) ↓ 128-dim Projections ``` ### Components **1. Encoder (RadImageNet ResNet-50)** - Pretrained on RadImageNet dataset - Modified final layer: removed classification head - Output: 2048-dimensional feature vectors - Finetuned on mammography data during contrastive learning **2. Projection Head (used during training, discarded for inference)** - 2-layer MLP: 2048 → 512 → 128 - Batch normalization + ReLU activation - Used only for contrastive learning - Discarded during embedding extraction **3. Loss Function: NT-Xent (Normalized Temperature-scaled Cross Entropy)** - Contrastive loss from SimCLR framework - Temperature parameter: τ = 0.07 - Cosine similarity with L2 normalization - Positive pairs: Two augmented views of same image - Negative pairs: All other images in batch ### Training Details **Contrastive Learning Framework (SimCLR)** ``` For each mammogram: 1. Create two different augmented views (image1, image2) 2. Pass both through encoder → projection head 3. Compute NT-Xent loss between the two projections 4. Maximize agreement between views of same image 5. Minimize similarity with other images in batch ``` **Hyperparameters** - Batch size: 128 - Epochs: 50 - Learning rate: 1e-4 (AdamW optimizer) - Weight decay: 1e-5 - Temperature: 0.07 - LR scheduler: Cosine annealing with 10-epoch warmup - Mixed precision training: Enabled (AMP) - Gradient clipping: 1.0 - Early stopping patience: 15 epochs **Training Data** - Dataset: CMMD (Chinese Mammography Mass Database) - Training split: 70% - Validation split: 15% - Test split: 15% - Total training images: ~13,000 mammograms ## Model Specifications | Property | Value | |----------|-------| | Model Type | Feature Extraction / Embedding Model | | Architecture | ResNet-50 (RadImageNet pretrained) | | Input Shape | (3, 224, 224) | | Output Shape | (2048,) | | Parameters | ~23.5M trainable | | Model Size | 283 MB | | Precision | FP32 | | Framework | PyTorch 2.1+ | ## Usage ### Loading the Model ```python import torch import torch.nn as nn import timm # Define the encoder architecture class RadImageNetEncoder(nn.Module): def __init__(self): super().__init__() self.encoder = timm.create_model('resnet50', pretrained=False, num_classes=0) self.feature_dim = 2048 def forward(self, x): return self.encoder(x) # Load the checkpoint checkpoint = torch.load('pytorch_model.bin', map_location='cpu') # Extract encoder weights model = RadImageNetEncoder() encoder_state_dict = { k.replace('encoder.encoder.', ''): v for k, v in checkpoint['model_state_dict'].items() if k.startswith('encoder.encoder.') } model.encoder.load_state_dict(encoder_state_dict) model.eval() ``` ### Extracting Embeddings ```python # Preprocess DICOM (see preprocessing code above) processor = DICOMProcessor() image = processor.preprocess('path/to/mammogram.dcm') # Convert to tensor and add batch dimension image_tensor = torch.from_numpy(image).unsqueeze(0) # Shape: (1, 3, 224, 224) # Extract embeddings with torch.no_grad(): embeddings = model(image_tensor) # Shape: (1, 2048) # L2 normalize (recommended) embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) ``` ## Performance: PRIMER vs RadImageNet Baseline PRIMER demonstrates significant improvements over baseline RadImageNet embeddings on mammography-specific evaluation metrics: | Metric | RadImageNet (Baseline) | PRIMER (Finetuned) | Improvement | |--------|------------------------|-------------------|-------------| | Silhouette Score | 0.127 | 0.289 | +127% | | Davies-Bouldin Score | 2.847 | 1.653 | -42% (lower is better) | | Calinski-Harabasz Score | 1,834 | 3,621 | +97% | | Embedding Variance | 0.012 | 0.024 | +100% | | Intra-cluster Distance | 1.92 | 1.34 | -30% | | Inter-cluster Distance | 2.15 | 2.87 | +33% | **Key Improvements:** - **Better Clustering**: Silhouette score increased from 0.127 to 0.289, indicating much tighter and more separated clusters - **Enhanced Discrimination**: Davies-Bouldin score decreased by 42%, showing better cluster separation - **Richer Representations**: Embedding variance doubled, indicating more diverse and informative features - **Mammography-Specific**: Features learned are specialized for mammographic patterns (masses, calcifications, tissue density) ### Visualization Improvements Dimensionality reduction visualizations (t-SNE, UMAP, PCA) show: - PRIMER embeddings form distinct, well-separated clusters - RadImageNet embeddings show more overlap and diffuse boundaries - PRIMER captures mammography-specific visual patterns more effectively ## Requirements ``` torch>=2.1.0 torchvision>=0.16.0 pydicom>=2.4.4 opencv-python>=4.8.1.78 numpy>=1.26.0 timm>=0.9.12 albumentations>=1.3.1 scikit-learn>=1.3.2 ``` ## Dataset **CMMD (Chinese Mammography Mass Database)** - Modality: Full-field digital mammography (FFDM) - Format: DICOM files - Views: CC (craniocaudal), MLO (mediolateral oblique) - Resolution: Variable (typically 2048×3328 or similar) ## Limitations 1. **Domain Specificity**: Model is trained on CMMD dataset (Chinese population). Performance may vary on other populations or imaging protocols. 2. **DICOM Format**: Requires proper DICOM preprocessing. Standard images (PNG/JPG) must follow the same preprocessing pipeline for best results. 3. **Image Quality**: Performance depends on proper CLAHE enhancement and normalization. Poor quality or corrupted DICOM files may produce suboptimal embeddings. 4. **Resolution**: Model expects 224×224 input. Very high-resolution details may be lost during resizing. 5. **Self-Supervised**: Model uses contrastive learning without labels. Does not perform classification directly - embeddings must be used with downstream tasks (clustering, retrieval, classification). 6. **Photometric Interpretation**: Critical to handle MONOCHROME1 vs MONOCHROME2 correctly. Failure to invert MONOCHROME1 images will result in poor embeddings. ## Intended Use ### Primary Use Cases - **Feature Extraction**: Generate embeddings for mammography images - **Similarity Search**: Find similar mammograms based on visual features - **Clustering**: Group mammograms by visual characteristics - **Transfer Learning**: Use as pretrained backbone for downstream tasks (classification, segmentation) - **Retrieval Systems**: Content-based mammography image retrieval - **Quality Control**: Identify outlier or anomalous mammograms ### Out-of-Scope Use Cases - **Direct Diagnosis**: Model does not provide diagnostic predictions - **Standalone Clinical Use**: Requires integration with clinical workflows and expert interpretation - **Non-Mammography Images**: Optimized for mammography; may not generalize to other modalities - **Real-time Processing**: Model size (283MB) and preprocessing may not be suitable for real-time applications without optimization ## Model Card Contact For questions or issues, please open an issue on the [GitHub repository](https://github.com/Lab-Rasool/PRIMER) or contact via HuggingFace discussions.