--- license: apache-2.0 metrics: - accuracy --- # Model Card — CapsNet 4-Class Lung-Disease Classifier **Model name:** `capsnet_4class_lung_disease_classifier` **Version:** 1.0 **Date:** 2025-09-17 ## Overview A Capsule Network (CapsNet) implemented in TensorFlow/Keras to classify **four lung-disease categories** from masked chest X-ray images. The model uses routing-by-agreement and **margin loss**, and was trained with MLflow tracking. Input images are resized to **256×256×3**. > ⚠️ **Not a medical device.** Outputs are for research/education. Clinician review is required before any clinical use. ## Intended Use - **Primary use:** Educational/research experiments on lung-disease image classification. - **Users:** ML practitioners and students familiar with Python/TensorFlow. - **Out-of-scope:** Direct clinical decision-making; deployment on patient data without formal validation and regulatory clearance. ## Model Details - **Architecture:** CapsNet with `PrimaryCaps` and `DigitCaps` - built-in augmentation(disabled by default during inference) and rescaling layers: - ´layers.RandomRotation(0.1)`, 'layers.RandomTranslation(height_factor=0.1, width_factor=0.1)', 'layers.RandomZoom(0.1)' - 'Rescaling(1./255)' - routing iterations: 3; - first Conv2D kernel size: tuned over [5, 7, 9, 10, 11] after an exploratory sweep (3-epoch runs over [3,5,7,9,10,11]). - **Loss:** `margin_loss` (capsule margin). - **Optimizer:** `Adam` with a learning-rate scheduler (`lr_scheduler`). - **Metrics:** accuracy. - **Input shape:** `(256, 256, 3)`; **#classes:** 4; **batch size:** 32. - **Training schedule:** up to **50 epochs** with callbacks: `EarlyStopping` (mode='min'), `ReduceLROnPlateau`, custom `StopAtValAccuracy(target=0.95)`, and `ModelCheckpoint(save_best_only=True)`. - **Framework:** TensorFlow/Keras. ## Data - **Training/validation:** Balanced masked CXR (loaded as `train_dataset` and `val_dataset`). Long-run training used `validation_split = 0.2`; kernel exploration used `validation_split = 0.5`. - **Preprocessing:** resize to 256×256; masking performed upstream. Any additional normalization should match the notebook pipeline. - **Label schema:** 4 disease classes (variable `disease_labels = ['COVID', 'Lung_Opacity', 'Normal', 'Viral Pneumonia']`). > **Licenses & provenance:** **Training/Validation dataset:** https://www.kaggle.com/tawsifurrahman/covid19-radiography-database (License: Data files © Original Authors) (Data size: 1.15 Gb) **Dataset for external tests:** https://www.kaggle.com/datasets/pranavraikokte/covid19-image-dataset (License: https://creativecommons.org/licenses/by-sa/4.0/) and https://www.kaggle.com/datasets/omkarmanohardalvi/lungs-disease-dataset-4-types (License: Unknown) ## Evaluation - **Protocol:** train on the balanced training set, validate with a held-out split; select first-layer kernel size from [5,7,9,10,11] based on validation performance; evaluate best checkpoint on the external test set. - **Reported metrics:** Metrics for each model being trained are in `capsnet_training_metrics_all_runs.csv` ## External test and the winner model >I excluded the lung opacity class from external tests because it often co-occurs with other diseases. This class makes it challenging to classify accurately. However, the reported accuracy for the remaining classes is still quite representative. - krnl9 (best COVID F1, balanced performance) (the one here presented)-> Results CSV: `test_on_external_dataset_capsnet_lung_disease_classifier_krnl9.csv` - krnl11 (highest accuracy, strongest Viral Pneumonia detection) - optionally krnl10 (solid performance, close to the leaders) ![alt text](image.png) ## Risks, Bias, and Limitations - **Domain shift:** performance may degrade on images from other hospitals, scanners, or populations. - **Label noise / class imbalance:** training is balanced, but real-world prevalence may differ. - **Confounders:** text markers, devices, or preprocessing differences can leak non-pathology signals. - **Fairness:** if patient demographics were not controlled, subgroup performance may vary. - **Regulatory:** not cleared for clinical use. ## Recommendations - Always use **human-in-the-loop** review. - Report **per-class metrics** and **confidence scores**; calibrate if needed. - Perform **external validation** on multiple sites before any operational use. - Track experiments with MLflow (`mlruns_capsnet`) and save confusion matrices. ## How to Use ```python import tensorflow as tf from tensorflow.keras.utils import load_img, img_to_array import numpy as np # Load trained Keras model model = tf.keras.models.load_model("your/path/to/model.keras", custom_objects={"margin_loss": margin_loss, "PrimaryCaps": PrimaryCaps, "DigitCaps": DigitCaps, "Length": Length}) def preprocess(path): img = load_img(path, target_size=(256, 256)) x = img_to_array(img) x = np.expand_dims(x, 0) return x x = preprocess("example_cxr.png") pred = model.predict(x)[0] # shape: (4,) pred_label = np.argmax(pred) print(pred, pred_label)