valste commited on
Commit
e5cd79c
·
0 Parent(s):

Update model: capsnet-4class-lung-disease-classifier

Browse files
.gitattributes ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.keras filter=lfs diff=lfs merge=lfs -text
2
+ *.h5 filter=lfs diff=lfs merge=lfs -text
3
+ *.png filter=lfs diff=lfs merge=lfs -text
4
+ *.jpg filter=lfs diff=lfs merge=lfs -text
5
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
6
+ *.csv filter=lfs diff=lfs merge=lfs -text
7
+ *.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ metrics:
4
+ - accuracy
5
+ ---
6
+ # Model Card — CapsNet 4-Class Lung-Disease Classifier
7
+
8
+ **Model name:** `capsnet_4class_lung_disease_classifier`
9
+ **Version:** 1.0
10
+ **Date:** 2025-09-17
11
+
12
+ ## Overview
13
+ A Capsule Network (CapsNet) implemented in TensorFlow/Keras to classify **four lung-disease categories** from masked chest X-ray images. The model uses routing-by-agreement and **margin loss**, and was trained with MLflow tracking. Input images are resized to **256×256×3**.
14
+
15
+ > ⚠️ **Not a medical device.** Outputs are for research/education. Clinician review is required before any clinical use.
16
+
17
+ ## Intended Use
18
+ - **Primary use:** Educational/research experiments on lung-disease image classification.
19
+ - **Users:** ML practitioners and students familiar with Python/TensorFlow.
20
+ - **Out-of-scope:** Direct clinical decision-making; deployment on patient data without formal validation and regulatory clearance.
21
+
22
+ ## Model Details
23
+ - **Architecture:** CapsNet with `PrimaryCaps` and `DigitCaps`; **routing iterations:** 3; **first Conv2D kernel size:** tuned over [5, 7, 9, 10, 11] after an exploratory sweep (3-epoch runs over [3,5,7,9,10,11]).
24
+ - **Loss:** `margin_loss` (capsule margin).
25
+ - **Optimizer:** `Adam` with a learning-rate scheduler (`lr_scheduler`).
26
+ - **Metrics:** accuracy.
27
+ - **Input shape:** `(256, 256, 3)`; **#classes:** 4; **batch size:** 32.
28
+ - **Training schedule:** up to **50 epochs** with callbacks: `EarlyStopping` (mode='min'), `ReduceLROnPlateau`, custom `StopAtValAccuracy(target=0.95)`, and `ModelCheckpoint(save_best_only=True)`.
29
+ - **Framework:** TensorFlow/Keras.
30
+
31
+
32
+
33
+ ## Data
34
+ - **Training/validation:** Balanced masked CXR (loaded as `train_dataset` and `val_dataset`). Long-run training used `validation_split = 0.2`; kernel exploration used `validation_split = 0.5`.
35
+ - **Preprocessing:** resize to 256×256; masking performed upstream. Any additional normalization should match the notebook pipeline.
36
+ - **Label schema:** 4 disease classes (variable `disease_labels = ['COVID', 'Lung_Opacity', 'Normal', 'Viral Pneumonia']`).
37
+
38
+ > **Licenses & provenance:**
39
+
40
+ **Training/Validation dataset:** https://www.kaggle.com/tawsifurrahman/covid19-radiography-database (License: Data files © Original Authors)
41
+ (Data size: 1.15 Gb)
42
+
43
+ **Dataset for external tests:** https://www.kaggle.com/datasets/pranavraikokte/covid19-image-dataset (License: https://creativecommons.org/licenses/by-sa/4.0/)
44
+ and
45
+ https://www.kaggle.com/datasets/omkarmanohardalvi/lungs-disease-dataset-4-types (License: Unknown)
46
+
47
+ ## Evaluation
48
+ - **Protocol:** train on the balanced training set, validate with a held-out split; select first-layer kernel size from [5,7,9,10,11] based on validation performance; evaluate best checkpoint on the unseen test set.
49
+ - **Reported metrics:** Metrics for each model being trained are in `capsnet_training_metrics_all_runs.csv`
50
+
51
+ ## External test
52
+ >I excluded the lung opacity class from external tests because it often co-occurs with other diseases. This makes it challenging to classify accurately.
53
+ However, the reported accuracy for the remaining classes is still quite representative.
54
+
55
+
56
+ Results CSV: `test_on_external_dataset_capsnet_lung_disease_classifier_krnl9.csv`
57
+
58
+ ![alt text](image.png)
59
+
60
+ ## Risks, Bias, and Limitations
61
+ - **Domain shift:** performance may degrade on images from other hospitals, scanners, or populations.
62
+ - **Label noise / class imbalance:** training is balanced, but real-world prevalence may differ.
63
+ - **Confounders:** text markers, devices, or preprocessing differences can leak non-pathology signals.
64
+ - **Fairness:** if patient demographics were not controlled, subgroup performance may vary.
65
+ - **Regulatory:** not cleared for clinical use.
66
+
67
+ ## Recommendations
68
+ - Always use **human-in-the-loop** review.
69
+ - Report **per-class metrics** and **confidence scores**; calibrate if needed.
70
+ - Perform **external validation** on multiple sites before any operational use.
71
+ - Track experiments with MLflow (`mlruns_capsnet`) and save confusion matrices.
72
+
73
+ ## How to Use
74
+ ```python
75
+ import tensorflow as tf
76
+ from tensorflow.keras.utils import load_img, img_to_array
77
+ import numpy as np
78
+
79
+ # Load trained Keras model
80
+ model = tf.keras.models.load_model("path/to/model.keras",
81
+ custom_objects={"margin_loss": margin_loss,
82
+ "PrimaryCaps": PrimaryCaps,
83
+ "DigitCaps": DigitCaps,
84
+ "Length": Length})
85
+
86
+ def preprocess(path):
87
+ img = load_img(path, target_size=(256, 256))
88
+ x = img_to_array(img)
89
+ x = np.expand_dims(x, 0)
90
+ return x
91
+
92
+ x = preprocess("example_cxr.png")
93
+ pred = model.predict(x)[0] # shape: (4,)
94
+ pred_label = np.argmax(pred)
95
+ print(pred, pred_label)
capsnet_training_metrics_all_runs.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82846b797ad0e81aadb82327792c30b0f09a6ffabb5f8b80b1a9874b1ccdd244
3
+ size 19310
image.png ADDED

Git LFS Details

  • SHA256: 979243bf3dc233b8717b3647b8ad9b2891448a897bf1c546a87152596f51f123
  • Pointer size: 131 Bytes
  • Size of remote file: 121 kB
model.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eccbc4e258ee7481cb1223db771b39fabdc4ac71a3c03ea8908e910089a727d7
3
+ size 338601839
test_on_external_dataset_capsnet_lung_disease_classifier_krnl9.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6134bf0bc50e1b5234cba77c0ebb6f3464ca2c8b4067fcc63cb02f6dc9ef8fc8
3
+ size 212837