Commit
·
e5cd79c
0
Parent(s):
Update model: capsnet-4class-lung-disease-classifier
Browse files- .gitattributes +7 -0
- README.md +95 -0
- capsnet_training_metrics_all_runs.csv +3 -0
- image.png +3 -0
- model.keras +3 -0
- test_on_external_dataset_capsnet_lung_disease_classifier_krnl9.csv +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.keras filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
metrics:
|
| 4 |
+
- accuracy
|
| 5 |
+
---
|
| 6 |
+
# Model Card — CapsNet 4-Class Lung-Disease Classifier
|
| 7 |
+
|
| 8 |
+
**Model name:** `capsnet_4class_lung_disease_classifier`
|
| 9 |
+
**Version:** 1.0
|
| 10 |
+
**Date:** 2025-09-17
|
| 11 |
+
|
| 12 |
+
## Overview
|
| 13 |
+
A Capsule Network (CapsNet) implemented in TensorFlow/Keras to classify **four lung-disease categories** from masked chest X-ray images. The model uses routing-by-agreement and **margin loss**, and was trained with MLflow tracking. Input images are resized to **256×256×3**.
|
| 14 |
+
|
| 15 |
+
> ⚠️ **Not a medical device.** Outputs are for research/education. Clinician review is required before any clinical use.
|
| 16 |
+
|
| 17 |
+
## Intended Use
|
| 18 |
+
- **Primary use:** Educational/research experiments on lung-disease image classification.
|
| 19 |
+
- **Users:** ML practitioners and students familiar with Python/TensorFlow.
|
| 20 |
+
- **Out-of-scope:** Direct clinical decision-making; deployment on patient data without formal validation and regulatory clearance.
|
| 21 |
+
|
| 22 |
+
## Model Details
|
| 23 |
+
- **Architecture:** CapsNet with `PrimaryCaps` and `DigitCaps`; **routing iterations:** 3; **first Conv2D kernel size:** tuned over [5, 7, 9, 10, 11] after an exploratory sweep (3-epoch runs over [3,5,7,9,10,11]).
|
| 24 |
+
- **Loss:** `margin_loss` (capsule margin).
|
| 25 |
+
- **Optimizer:** `Adam` with a learning-rate scheduler (`lr_scheduler`).
|
| 26 |
+
- **Metrics:** accuracy.
|
| 27 |
+
- **Input shape:** `(256, 256, 3)`; **#classes:** 4; **batch size:** 32.
|
| 28 |
+
- **Training schedule:** up to **50 epochs** with callbacks: `EarlyStopping` (mode='min'), `ReduceLROnPlateau`, custom `StopAtValAccuracy(target=0.95)`, and `ModelCheckpoint(save_best_only=True)`.
|
| 29 |
+
- **Framework:** TensorFlow/Keras.
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
## Data
|
| 34 |
+
- **Training/validation:** Balanced masked CXR (loaded as `train_dataset` and `val_dataset`). Long-run training used `validation_split = 0.2`; kernel exploration used `validation_split = 0.5`.
|
| 35 |
+
- **Preprocessing:** resize to 256×256; masking performed upstream. Any additional normalization should match the notebook pipeline.
|
| 36 |
+
- **Label schema:** 4 disease classes (variable `disease_labels = ['COVID', 'Lung_Opacity', 'Normal', 'Viral Pneumonia']`).
|
| 37 |
+
|
| 38 |
+
> **Licenses & provenance:**
|
| 39 |
+
|
| 40 |
+
**Training/Validation dataset:** https://www.kaggle.com/tawsifurrahman/covid19-radiography-database (License: Data files © Original Authors)
|
| 41 |
+
(Data size: 1.15 Gb)
|
| 42 |
+
|
| 43 |
+
**Dataset for external tests:** https://www.kaggle.com/datasets/pranavraikokte/covid19-image-dataset (License: https://creativecommons.org/licenses/by-sa/4.0/)
|
| 44 |
+
and
|
| 45 |
+
https://www.kaggle.com/datasets/omkarmanohardalvi/lungs-disease-dataset-4-types (License: Unknown)
|
| 46 |
+
|
| 47 |
+
## Evaluation
|
| 48 |
+
- **Protocol:** train on the balanced training set, validate with a held-out split; select first-layer kernel size from [5,7,9,10,11] based on validation performance; evaluate best checkpoint on the unseen test set.
|
| 49 |
+
- **Reported metrics:** Metrics for each model being trained are in `capsnet_training_metrics_all_runs.csv`
|
| 50 |
+
|
| 51 |
+
## External test
|
| 52 |
+
>I excluded the lung opacity class from external tests because it often co-occurs with other diseases. This makes it challenging to classify accurately.
|
| 53 |
+
However, the reported accuracy for the remaining classes is still quite representative.
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
Results CSV: `test_on_external_dataset_capsnet_lung_disease_classifier_krnl9.csv`
|
| 57 |
+
|
| 58 |
+

|
| 59 |
+
|
| 60 |
+
## Risks, Bias, and Limitations
|
| 61 |
+
- **Domain shift:** performance may degrade on images from other hospitals, scanners, or populations.
|
| 62 |
+
- **Label noise / class imbalance:** training is balanced, but real-world prevalence may differ.
|
| 63 |
+
- **Confounders:** text markers, devices, or preprocessing differences can leak non-pathology signals.
|
| 64 |
+
- **Fairness:** if patient demographics were not controlled, subgroup performance may vary.
|
| 65 |
+
- **Regulatory:** not cleared for clinical use.
|
| 66 |
+
|
| 67 |
+
## Recommendations
|
| 68 |
+
- Always use **human-in-the-loop** review.
|
| 69 |
+
- Report **per-class metrics** and **confidence scores**; calibrate if needed.
|
| 70 |
+
- Perform **external validation** on multiple sites before any operational use.
|
| 71 |
+
- Track experiments with MLflow (`mlruns_capsnet`) and save confusion matrices.
|
| 72 |
+
|
| 73 |
+
## How to Use
|
| 74 |
+
```python
|
| 75 |
+
import tensorflow as tf
|
| 76 |
+
from tensorflow.keras.utils import load_img, img_to_array
|
| 77 |
+
import numpy as np
|
| 78 |
+
|
| 79 |
+
# Load trained Keras model
|
| 80 |
+
model = tf.keras.models.load_model("path/to/model.keras",
|
| 81 |
+
custom_objects={"margin_loss": margin_loss,
|
| 82 |
+
"PrimaryCaps": PrimaryCaps,
|
| 83 |
+
"DigitCaps": DigitCaps,
|
| 84 |
+
"Length": Length})
|
| 85 |
+
|
| 86 |
+
def preprocess(path):
|
| 87 |
+
img = load_img(path, target_size=(256, 256))
|
| 88 |
+
x = img_to_array(img)
|
| 89 |
+
x = np.expand_dims(x, 0)
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
x = preprocess("example_cxr.png")
|
| 93 |
+
pred = model.predict(x)[0] # shape: (4,)
|
| 94 |
+
pred_label = np.argmax(pred)
|
| 95 |
+
print(pred, pred_label)
|
capsnet_training_metrics_all_runs.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:82846b797ad0e81aadb82327792c30b0f09a6ffabb5f8b80b1a9874b1ccdd244
|
| 3 |
+
size 19310
|
image.png
ADDED
|
Git LFS Details
|
model.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eccbc4e258ee7481cb1223db771b39fabdc4ac71a3c03ea8908e910089a727d7
|
| 3 |
+
size 338601839
|
test_on_external_dataset_capsnet_lung_disease_classifier_krnl9.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6134bf0bc50e1b5234cba77c0ebb6f3464ca2c8b4067fcc63cb02f6dc9ef8fc8
|
| 3 |
+
size 212837
|