Aakash-Tripathi commited on
Commit
ed863e8
·
verified ·
1 Parent(s): f320dc9

Upload 5 files

Browse files
Files changed (5) hide show
  1. .gitattributes +6 -28
  2. README.md +368 -0
  3. config.json +110 -0
  4. model_card.json +185 -0
  5. pytorch_model.bin +3 -0
.gitattributes CHANGED
@@ -1,35 +1,13 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
 
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
1
  *.bin filter=lfs diff=lfs merge=lfs -text
2
+ *.pth filter=lfs diff=lfs merge=lfs -text
3
+ *.pt filter=lfs diff=lfs merge=lfs -text
4
  *.ckpt filter=lfs diff=lfs merge=lfs -text
5
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
 
6
  *.h5 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
7
  *.onnx filter=lfs diff=lfs merge=lfs -text
 
 
8
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
9
  *.tflite filter=lfs diff=lfs merge=lfs -text
10
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
11
+ *.arrow filter=lfs diff=lfs merge=lfs -text
 
12
  *.zip filter=lfs diff=lfs merge=lfs -text
13
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
 
README.md ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - medical
5
+ - radiology
6
+ - mammography
7
+ - contrastive-learning
8
+ - embeddings
9
+ - computer-vision
10
+ - pytorch
11
+ datasets:
12
+ - CMMD
13
+ pipeline_tag: feature-extraction
14
+ ---
15
+
16
+ # PRIMER: Pretrained RadImageNet for Mammography Embedding Representations
17
+
18
+ PRIMER is a specialized deep learning model for mammography analysis, finetuned from RadImageNet using contrastive learning on the CMMD (Chinese Mammography Mass Database) dataset. The model generates discriminative embedding vectors specifically optimized for mammogram images.
19
+
20
+ ## Model Overview
21
+
22
+ - **Base Model**: RadImageNet ResNet-50
23
+ - **Training Method**: SimCLR contrastive learning (NT-Xent loss)
24
+ - **Architecture**: ResNet-50 encoder + 2-layer MLP projection head
25
+ - **Input**: 224×224 RGB images (converted from DICOM grayscale)
26
+ - **Output**: 2048-dimensional embedding vectors
27
+ - **Training Dataset**: CMMD mammography DICOM files
28
+ - **Framework**: PyTorch 2.1+
29
+
30
+ ## Key Features
31
+
32
+ - Finetuned specifically for mammography imaging
33
+ - Self-supervised contrastive learning (no labels required)
34
+ - Produces embeddings with better clustering and separation than baseline RadImageNet
35
+ - Handles DICOM preprocessing pipeline end-to-end
36
+ - Supports multiple backbone architectures (ResNet-50, DenseNet-121, Inception-V3)
37
+
38
+ ## DICOM Preprocessing Pipeline
39
+
40
+ The model expects mammography DICOM images preprocessed through the following pipeline. This preprocessing is **critical** for proper model performance:
41
+
42
+ ### Step 1: DICOM Loading
43
+ ```
44
+ - Read DICOM file using pydicom
45
+ - Extract pixel array as float32
46
+ ```
47
+
48
+ ### Step 2: Photometric Interpretation Correction
49
+ ```
50
+ - Check PhotometricInterpretation attribute
51
+ - If MONOCHROME1: Invert pixel values (max_value - pixel_value)
52
+ - MONOCHROME1: Higher values = darker (inverted scale)
53
+ - MONOCHROME2: Higher values = brighter (standard scale)
54
+ ```
55
+
56
+ ### Step 3: Intensity Normalization
57
+ ```
58
+ - Percentile-based clipping to remove outliers:
59
+ - Compute 2nd percentile (p2) and 98th percentile (p98)
60
+ - Clip all values: pixel_value = clip(pixel_value, p2, p98)
61
+
62
+ - Min-max normalization to [0, 255]:
63
+ - normalized = ((pixel_value - min) / (max - min + 1e-8)) × 255
64
+ - Convert to uint8
65
+ ```
66
+
67
+ ### Step 4: CLAHE Enhancement
68
+ ```
69
+ - Apply Contrast Limited Adaptive Histogram Equalization (CLAHE)
70
+ - Clip limit: 2.0
71
+ - Tile grid size: 8×8
72
+ - Improves local contrast and enhances subtle features
73
+ ```
74
+
75
+ ### Step 5: Grayscale to RGB Conversion
76
+ ```
77
+ - Duplicate grayscale channel 3 times: RGB = [gray, gray, gray]
78
+ - Required because RadImageNet expects 3-channel input
79
+ ```
80
+
81
+ ### Step 6: Resizing
82
+ ```
83
+ - Resize to 224×224 using bilinear interpolation
84
+ ```
85
+
86
+ ### Step 7: Data Augmentation (Training Only)
87
+ ```
88
+ Training augmentations:
89
+ - Horizontal flip (p=0.5)
90
+ - Vertical flip (p=0.3)
91
+ - Rotation (±15 degrees, p=0.5)
92
+ - Random brightness/contrast (±0.2, p=0.5)
93
+ - Shift/scale/rotate (shift=0.1, scale=0.1, rotate=15°, p=0.5)
94
+ ```
95
+
96
+ ### Step 8: Normalization
97
+ ```
98
+ - ImageNet normalization (required for RadImageNet compatibility):
99
+ - Mean: [0.485, 0.456, 0.406]
100
+ - Std: [0.229, 0.224, 0.225]
101
+ - Convert to tensor (C×H×W format)
102
+ ```
103
+
104
+ ### Complete Preprocessing Code
105
+
106
+ ```python
107
+ import cv2
108
+ import numpy as np
109
+ import pydicom
110
+ from PIL import Image
111
+
112
+ class DICOMProcessor:
113
+ def __init__(self):
114
+ self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
115
+
116
+ def preprocess(self, dicom_path):
117
+ # 1. Load DICOM
118
+ dicom = pydicom.dcmread(dicom_path)
119
+ image = dicom.pixel_array.astype(np.float32)
120
+
121
+ # 2. Handle photometric interpretation
122
+ if hasattr(dicom, 'PhotometricInterpretation'):
123
+ if dicom.PhotometricInterpretation == "MONOCHROME1":
124
+ image = np.max(image) - image
125
+
126
+ # 3. Intensity normalization
127
+ p2, p98 = np.percentile(image, (2, 98))
128
+ image = np.clip(image, p2, p98)
129
+ image = ((image - image.min()) / (image.max() - image.min() + 1e-8) * 255)
130
+ image = image.astype(np.uint8)
131
+
132
+ # 4. CLAHE enhancement
133
+ image = self.clahe.apply(image)
134
+
135
+ # 5. Grayscale to RGB
136
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
137
+
138
+ # 6. Resize
139
+ image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
140
+
141
+ # 7. ImageNet normalization
142
+ image = image.astype(np.float32) / 255.0
143
+ mean = np.array([0.485, 0.456, 0.406])
144
+ std = np.array([0.229, 0.224, 0.225])
145
+ image = (image - mean) / std
146
+
147
+ # 8. Convert to tensor (C, H, W)
148
+ image = np.transpose(image, (2, 0, 1))
149
+
150
+ return image
151
+ ```
152
+
153
+ ## Model Architecture
154
+
155
+ ### Overall Structure
156
+ ```
157
+ Input DICOM (H×W grayscale)
158
+
159
+ [DICOM Preprocessing Pipeline]
160
+
161
+ 224×224×3 RGB Tensor
162
+
163
+ [RadImageNet ResNet-50 Encoder]
164
+
165
+ 2048-dim Embeddings
166
+
167
+ [Projection Head] (training only)
168
+
169
+ 128-dim Projections
170
+ ```
171
+
172
+ ### Components
173
+
174
+ **1. Encoder (RadImageNet ResNet-50)**
175
+ - Pretrained on RadImageNet dataset
176
+ - Modified final layer: removed classification head
177
+ - Output: 2048-dimensional feature vectors
178
+ - Finetuned on mammography data during contrastive learning
179
+
180
+ **2. Projection Head (used during training, discarded for inference)**
181
+ - 2-layer MLP: 2048 → 512 → 128
182
+ - Batch normalization + ReLU activation
183
+ - Used only for contrastive learning
184
+ - Discarded during embedding extraction
185
+
186
+ **3. Loss Function: NT-Xent (Normalized Temperature-scaled Cross Entropy)**
187
+ - Contrastive loss from SimCLR framework
188
+ - Temperature parameter: τ = 0.07
189
+ - Cosine similarity with L2 normalization
190
+ - Positive pairs: Two augmented views of same image
191
+ - Negative pairs: All other images in batch
192
+
193
+ ### Training Details
194
+
195
+ **Contrastive Learning Framework (SimCLR)**
196
+ ```
197
+ For each mammogram:
198
+ 1. Create two different augmented views (image1, image2)
199
+ 2. Pass both through encoder → projection head
200
+ 3. Compute NT-Xent loss between the two projections
201
+ 4. Maximize agreement between views of same image
202
+ 5. Minimize similarity with other images in batch
203
+ ```
204
+
205
+ **Hyperparameters**
206
+ - Batch size: 128
207
+ - Epochs: 50
208
+ - Learning rate: 1e-4 (AdamW optimizer)
209
+ - Weight decay: 1e-5
210
+ - Temperature: 0.07
211
+ - LR scheduler: Cosine annealing with 10-epoch warmup
212
+ - Mixed precision training: Enabled (AMP)
213
+ - Gradient clipping: 1.0
214
+ - Early stopping patience: 15 epochs
215
+
216
+ **Training Data**
217
+ - Dataset: CMMD (Chinese Mammography Mass Database)
218
+ - Training split: 70%
219
+ - Validation split: 15%
220
+ - Test split: 15%
221
+ - Total training images: ~13,000 mammograms
222
+
223
+ ## Model Specifications
224
+
225
+ | Property | Value |
226
+ |----------|-------|
227
+ | Model Type | Feature Extraction / Embedding Model |
228
+ | Architecture | ResNet-50 (RadImageNet pretrained) |
229
+ | Input Shape | (3, 224, 224) |
230
+ | Output Shape | (2048,) |
231
+ | Parameters | ~23.5M trainable |
232
+ | Model Size | 283 MB |
233
+ | Precision | FP32 |
234
+ | Framework | PyTorch 2.1+ |
235
+
236
+ ## Usage
237
+
238
+ ### Loading the Model
239
+
240
+ ```python
241
+ import torch
242
+ import torch.nn as nn
243
+ import timm
244
+
245
+ # Define the encoder architecture
246
+ class RadImageNetEncoder(nn.Module):
247
+ def __init__(self):
248
+ super().__init__()
249
+ self.encoder = timm.create_model('resnet50', pretrained=False, num_classes=0)
250
+ self.feature_dim = 2048
251
+
252
+ def forward(self, x):
253
+ return self.encoder(x)
254
+
255
+ # Load the checkpoint
256
+ checkpoint = torch.load('pytorch_model.bin', map_location='cpu')
257
+
258
+ # Extract encoder weights
259
+ model = RadImageNetEncoder()
260
+ encoder_state_dict = {
261
+ k.replace('encoder.encoder.', ''): v
262
+ for k, v in checkpoint['model_state_dict'].items()
263
+ if k.startswith('encoder.encoder.')
264
+ }
265
+ model.encoder.load_state_dict(encoder_state_dict)
266
+
267
+ model.eval()
268
+ ```
269
+
270
+ ### Extracting Embeddings
271
+
272
+ ```python
273
+ # Preprocess DICOM (see preprocessing code above)
274
+ processor = DICOMProcessor()
275
+ image = processor.preprocess('path/to/mammogram.dcm')
276
+
277
+ # Convert to tensor and add batch dimension
278
+ image_tensor = torch.from_numpy(image).unsqueeze(0) # Shape: (1, 3, 224, 224)
279
+
280
+ # Extract embeddings
281
+ with torch.no_grad():
282
+ embeddings = model(image_tensor) # Shape: (1, 2048)
283
+
284
+ # L2 normalize (recommended)
285
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
286
+ ```
287
+
288
+ ## Performance: PRIMER vs RadImageNet Baseline
289
+
290
+ PRIMER demonstrates significant improvements over baseline RadImageNet embeddings on mammography-specific evaluation metrics:
291
+
292
+ | Metric | RadImageNet (Baseline) | PRIMER (Finetuned) | Improvement |
293
+ |--------|------------------------|-------------------|-------------|
294
+ | Silhouette Score | 0.127 | 0.289 | +127% |
295
+ | Davies-Bouldin Score | 2.847 | 1.653 | -42% (lower is better) |
296
+ | Calinski-Harabasz Score | 1,834 | 3,621 | +97% |
297
+ | Embedding Variance | 0.012 | 0.024 | +100% |
298
+ | Intra-cluster Distance | 1.92 | 1.34 | -30% |
299
+ | Inter-cluster Distance | 2.15 | 2.87 | +33% |
300
+
301
+ **Key Improvements:**
302
+ - **Better Clustering**: Silhouette score increased from 0.127 to 0.289, indicating much tighter and more separated clusters
303
+ - **Enhanced Discrimination**: Davies-Bouldin score decreased by 42%, showing better cluster separation
304
+ - **Richer Representations**: Embedding variance doubled, indicating more diverse and informative features
305
+ - **Mammography-Specific**: Features learned are specialized for mammographic patterns (masses, calcifications, tissue density)
306
+
307
+ ### Visualization Improvements
308
+
309
+ Dimensionality reduction visualizations (t-SNE, UMAP, PCA) show:
310
+ - PRIMER embeddings form distinct, well-separated clusters
311
+ - RadImageNet embeddings show more overlap and diffuse boundaries
312
+ - PRIMER captures mammography-specific visual patterns more effectively
313
+
314
+ ## Requirements
315
+
316
+ ```
317
+ torch>=2.1.0
318
+ torchvision>=0.16.0
319
+ pydicom>=2.4.4
320
+ opencv-python>=4.8.1.78
321
+ numpy>=1.26.0
322
+ timm>=0.9.12
323
+ albumentations>=1.3.1
324
+ scikit-learn>=1.3.2
325
+ ```
326
+
327
+ ## Dataset
328
+
329
+ **CMMD (Chinese Mammography Mass Database)**
330
+ - Modality: Full-field digital mammography (FFDM)
331
+ - Format: DICOM files
332
+ - Views: CC (craniocaudal), MLO (mediolateral oblique)
333
+ - Resolution: Variable (typically 2048×3328 or similar)
334
+
335
+ ## Limitations
336
+
337
+ 1. **Domain Specificity**: Model is trained on CMMD dataset (Chinese population). Performance may vary on other populations or imaging protocols.
338
+
339
+ 2. **DICOM Format**: Requires proper DICOM preprocessing. Standard images (PNG/JPG) must follow the same preprocessing pipeline for best results.
340
+
341
+ 3. **Image Quality**: Performance depends on proper CLAHE enhancement and normalization. Poor quality or corrupted DICOM files may produce suboptimal embeddings.
342
+
343
+ 4. **Resolution**: Model expects 224×224 input. Very high-resolution details may be lost during resizing.
344
+
345
+ 5. **Self-Supervised**: Model uses contrastive learning without labels. Does not perform classification directly - embeddings must be used with downstream tasks (clustering, retrieval, classification).
346
+
347
+ 6. **Photometric Interpretation**: Critical to handle MONOCHROME1 vs MONOCHROME2 correctly. Failure to invert MONOCHROME1 images will result in poor embeddings.
348
+
349
+ ## Intended Use
350
+
351
+ ### Primary Use Cases
352
+ - **Feature Extraction**: Generate embeddings for mammography images
353
+ - **Similarity Search**: Find similar mammograms based on visual features
354
+ - **Clustering**: Group mammograms by visual characteristics
355
+ - **Transfer Learning**: Use as pretrained backbone for downstream tasks (classification, segmentation)
356
+ - **Retrieval Systems**: Content-based mammography image retrieval
357
+ - **Quality Control**: Identify outlier or anomalous mammograms
358
+
359
+ ### Out-of-Scope Use Cases
360
+ - **Direct Diagnosis**: Model does not provide diagnostic predictions
361
+ - **Standalone Clinical Use**: Requires integration with clinical workflows and expert interpretation
362
+ - **Non-Mammography Images**: Optimized for mammography; may not generalize to other modalities
363
+ - **Real-time Processing**: Model size (283MB) and preprocessing may not be suitable for real-time applications without optimization
364
+
365
+ ## Model Card Contact
366
+
367
+ For questions or issues, please open an issue on the [GitHub repository](https://github.com/Lab-Rasool/PRIMER) or contact via HuggingFace discussions.
368
+
config.json ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "primer",
3
+ "architecture": "resnet50",
4
+ "task": "feature-extraction",
5
+ "framework": "pytorch",
6
+
7
+ "model_config": {
8
+ "architecture": "resnet50",
9
+ "pretrained_radimagenet": true,
10
+ "embedding_dim": 2048,
11
+ "projection_dim": 128,
12
+ "dropout": 0.2,
13
+ "freeze_backbone": false,
14
+ "use_projection_head": true
15
+ },
16
+
17
+ "input_config": {
18
+ "image_size": 224,
19
+ "num_channels": 3,
20
+ "input_shape": [3, 224, 224],
21
+ "data_format": "channels_first",
22
+ "color_mode": "rgb"
23
+ },
24
+
25
+ "preprocessing_config": {
26
+ "use_clahe": true,
27
+ "clahe_clip_limit": 2.0,
28
+ "clahe_tile_grid_size": [8, 8],
29
+ "normalize_hu": true,
30
+ "percentile_clip": [2, 98],
31
+ "normalization": {
32
+ "mean": [0.485, 0.456, 0.406],
33
+ "std": [0.229, 0.224, 0.225],
34
+ "description": "ImageNet normalization (required for RadImageNet compatibility)"
35
+ }
36
+ },
37
+
38
+ "training_config": {
39
+ "method": "contrastive",
40
+ "framework": "simclr",
41
+ "batch_size": 128,
42
+ "num_epochs": 50,
43
+ "learning_rate": 0.0001,
44
+ "weight_decay": 0.00001,
45
+ "warmup_epochs": 10,
46
+ "patience": 15,
47
+ "gradient_clip": 1.0,
48
+ "optimizer": {
49
+ "name": "adamw",
50
+ "betas": [0.9, 0.999]
51
+ },
52
+ "scheduler": {
53
+ "name": "cosine",
54
+ "min_lr": 0.000001
55
+ }
56
+ },
57
+
58
+ "contrastive_learning": {
59
+ "loss": "nt_xent",
60
+ "temperature": 0.07,
61
+ "use_cosine_similarity": true,
62
+ "negative_samples": "all",
63
+ "description": "NT-Xent (Normalized Temperature-scaled Cross Entropy) loss from SimCLR"
64
+ },
65
+
66
+ "augmentation_config": {
67
+ "horizontal_flip": 0.5,
68
+ "vertical_flip": 0.3,
69
+ "rotate_limit": 15,
70
+ "brightness_limit": 0.2,
71
+ "contrast_limit": 0.2,
72
+ "shift_scale_rotate": true,
73
+ "elastic_transform": false,
74
+ "grid_distortion": false
75
+ },
76
+
77
+ "data_config": {
78
+ "dataset": "CMMD",
79
+ "train_split": 0.7,
80
+ "val_split": 0.15,
81
+ "test_split": 0.15,
82
+ "random_seed": 42,
83
+ "num_training_samples": 13000,
84
+ "modality": "mammography",
85
+ "format": "dicom"
86
+ },
87
+
88
+ "output_config": {
89
+ "embedding_dim": 2048,
90
+ "normalize_embeddings": true,
91
+ "normalization_type": "l2"
92
+ },
93
+
94
+ "hardware_config": {
95
+ "mixed_precision": true,
96
+ "gpu_memory_required": "12GB",
97
+ "recommended_batch_size": 128
98
+ },
99
+
100
+ "metrics": {
101
+ "silhouette_score": 0.289,
102
+ "davies_bouldin_score": 1.653,
103
+ "calinski_harabasz_score": 3621,
104
+ "embedding_variance": 0.024
105
+ },
106
+
107
+ "version": "1.0.0",
108
+ "pytorch_version": "2.1.0",
109
+ "timm_version": "0.9.12"
110
+ }
model_card.json ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "PRIMER",
3
+ "full_name": "Pretrained RadImageNet for Mammography Embedding Representations",
4
+ "version": "1.0.0",
5
+ "release_date": "2024-10-17",
6
+
7
+ "model_details": {
8
+ "organization": "Lab-Rasool",
9
+ "architecture": "ResNet-50",
10
+ "base_model": "RadImageNet ResNet-50",
11
+ "training_method": "SimCLR Contrastive Learning",
12
+ "model_type": "Feature Extraction / Embedding Model",
13
+ "modality": "Medical Imaging - Mammography",
14
+ "parameters": "23.5M",
15
+ "model_size_mb": 283,
16
+ "license": "Apache-2.0"
17
+ },
18
+
19
+ "intended_use": {
20
+ "primary_uses": [
21
+ "Feature extraction for mammography images",
22
+ "Similarity search and retrieval",
23
+ "Clustering and grouping mammograms",
24
+ "Transfer learning backbone for downstream tasks",
25
+ "Content-based image retrieval systems",
26
+ "Quality control and anomaly detection"
27
+ ],
28
+ "out_of_scope": [
29
+ "Direct clinical diagnosis",
30
+ "Standalone diagnostic tool",
31
+ "Non-mammography medical images",
32
+ "Real-time processing without optimization"
33
+ ]
34
+ },
35
+
36
+ "training_data": {
37
+ "dataset": "CMMD (Chinese Mammography Mass Database)",
38
+ "dataset_url": "https://doi.org/10.7937/tcia.eqde-3b16",
39
+ "num_training_samples": 13000,
40
+ "data_splits": {
41
+ "train": 0.7,
42
+ "validation": 0.15,
43
+ "test": 0.15
44
+ },
45
+ "image_format": "DICOM",
46
+ "views": ["CC (craniocaudal)", "MLO (mediolateral oblique)"],
47
+ "population": "Chinese population"
48
+ },
49
+
50
+ "training_procedure": {
51
+ "method": "Self-supervised contrastive learning (SimCLR)",
52
+ "loss_function": "NT-Xent (Normalized Temperature-scaled Cross Entropy)",
53
+ "epochs": 50,
54
+ "batch_size": 128,
55
+ "optimizer": "AdamW",
56
+ "learning_rate": 0.0001,
57
+ "scheduler": "Cosine annealing with warmup",
58
+ "temperature": 0.07,
59
+ "mixed_precision": true,
60
+ "hardware": "NVIDIA RTX 3090 (24GB VRAM)"
61
+ },
62
+
63
+ "performance_metrics": {
64
+ "embedding_quality": {
65
+ "silhouette_score": {
66
+ "radimagenet_baseline": 0.127,
67
+ "primer_finetuned": 0.289,
68
+ "improvement_percent": 127
69
+ },
70
+ "davies_bouldin_score": {
71
+ "radimagenet_baseline": 2.847,
72
+ "primer_finetuned": 1.653,
73
+ "improvement_percent": -42,
74
+ "note": "Lower is better"
75
+ },
76
+ "calinski_harabasz_score": {
77
+ "radimagenet_baseline": 1834,
78
+ "primer_finetuned": 3621,
79
+ "improvement_percent": 97
80
+ },
81
+ "embedding_variance": {
82
+ "radimagenet_baseline": 0.012,
83
+ "primer_finetuned": 0.024,
84
+ "improvement_percent": 100
85
+ }
86
+ }
87
+ },
88
+
89
+ "input_output": {
90
+ "input": {
91
+ "format": "DICOM or preprocessed image tensor",
92
+ "shape": [3, 224, 224],
93
+ "dtype": "float32",
94
+ "color_space": "RGB",
95
+ "normalization": "ImageNet (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])"
96
+ },
97
+ "output": {
98
+ "format": "Embedding vector",
99
+ "shape": [2048],
100
+ "dtype": "float32",
101
+ "normalization": "L2 normalization recommended"
102
+ }
103
+ },
104
+
105
+ "preprocessing_requirements": {
106
+ "critical_steps": [
107
+ "Photometric interpretation correction (MONOCHROME1 inversion)",
108
+ "Percentile-based intensity clipping (2nd-98th percentile)",
109
+ "Min-max normalization to [0, 255]",
110
+ "CLAHE enhancement (clipLimit=2.0, tileGridSize=8x8)",
111
+ "Grayscale to RGB conversion",
112
+ "Resize to 224x224",
113
+ "ImageNet normalization"
114
+ ],
115
+ "dependencies": [
116
+ "pydicom>=2.4.4",
117
+ "opencv-python>=4.8.1.78",
118
+ "numpy>=1.26.0"
119
+ ]
120
+ },
121
+
122
+ "limitations": {
123
+ "domain_specificity": "Trained on CMMD dataset (Chinese population); performance may vary on other populations",
124
+ "dicom_dependency": "Requires proper DICOM preprocessing for optimal results",
125
+ "resolution_loss": "High-resolution details may be lost at 224x224 input size",
126
+ "self_supervised": "No direct classification output; requires downstream task integration",
127
+ "photometric_interpretation": "Critical to handle MONOCHROME1/MONOCHROME2 correctly"
128
+ },
129
+
130
+ "ethical_considerations": {
131
+ "bias": "Model trained on Chinese population data; may not generalize equally to all demographics",
132
+ "clinical_use": "Not FDA approved; requires clinical validation before medical use",
133
+ "privacy": "DICOM files may contain PHI; ensure proper de-identification",
134
+ "interpretability": "Embeddings are learned representations; clinical interpretation required"
135
+ },
136
+
137
+ "citations": {
138
+ "primer": {
139
+ "title": "PRIMER: Pretrained RadImageNet for Mammography Embedding Representations",
140
+ "authors": "Lab-Rasool",
141
+ "year": 2024,
142
+ "url": "https://huggingface.co/Lab-Rasool/PRIMER"
143
+ },
144
+ "radimagenet": {
145
+ "title": "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning",
146
+ "authors": "Mei et al.",
147
+ "journal": "Radiology: Artificial Intelligence",
148
+ "year": 2022,
149
+ "doi": "10.1148/ryai.210315"
150
+ },
151
+ "simclr": {
152
+ "title": "A Simple Framework for Contrastive Learning of Visual Representations",
153
+ "authors": "Chen et al.",
154
+ "conference": "ICML",
155
+ "year": 2020,
156
+ "arxiv": "2002.05709"
157
+ },
158
+ "cmmd": {
159
+ "title": "Chinese Mammography Database (CMMD)",
160
+ "source": "The Cancer Imaging Archive",
161
+ "doi": "10.7937/tcia.eqde-3b16"
162
+ }
163
+ },
164
+
165
+ "contact": {
166
+ "organization": "Lab-Rasool",
167
+ "huggingface": "https://huggingface.co/Lab-Rasool",
168
+ "model_repository": "https://huggingface.co/Lab-Rasool/PRIMER",
169
+ "issues": "https://huggingface.co/Lab-Rasool/PRIMER/discussions"
170
+ },
171
+
172
+ "technical_specifications": {
173
+ "framework": "PyTorch 2.1+",
174
+ "required_libraries": [
175
+ "torch>=2.1.0",
176
+ "torchvision>=0.16.0",
177
+ "timm>=0.9.12",
178
+ "pydicom>=2.4.4",
179
+ "opencv-python>=4.8.1.78",
180
+ "albumentations>=1.3.1"
181
+ ],
182
+ "gpu_requirements": "12GB+ VRAM recommended for inference",
183
+ "inference_speed": "~50ms per image on RTX 3090"
184
+ }
185
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1560f1bb53cd03c7891bba8689e73dd4b8f4c4aeaac3df3fa931db8680a664f
3
+ size 295966599