Spaces:
Running
Running
Upload 6 files
Browse files
.gitattributes
CHANGED
|
@@ -1,2 +1,6 @@
|
|
| 1 |
# Auto detect text files and perform LF normalization
|
| 2 |
* text=auto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Auto detect text files and perform LF normalization
|
| 2 |
* text=auto
|
| 3 |
+
derm_foundation_embeddings.npz filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
easi_severity_model_derm_foundation_individual_fixed.keras filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
easi_severity_model_derm_foundation_individual.pkl filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
training_history_fixed.png filter=lfs diff=lfs merge=lfs -text
|
dataset_scin_labels.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
derm_foundation_embeddings.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9fe7208829940a3ad6d9d0964948d3004c3d7fcafa17b98db6402a638ce5ac61
|
| 3 |
+
size 17779095
|
easi_severity_model_derm_foundation_individual.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf635500ebe9f4bf8a02b3d9de17fcc75595e845113ae46fd40c710cb2bf71a7
|
| 3 |
+
size 149217
|
easi_severity_model_derm_foundation_individual_fixed.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b66fa5443ac22a54c74c6a154ea5dcacc2c9856839a17044043ea77f41560177
|
| 3 |
+
size 40255675
|
train_easi_model.py
ADDED
|
@@ -0,0 +1,1525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Derm Foundation Neural Network Classifier Training Script - Fixed Version
|
| 5 |
+
|
| 6 |
+
PURPOSE:
|
| 7 |
+
This script trains a multi-output neural network to predict dermatological
|
| 8 |
+
conditions and their associated metadata from pre-computed embeddings. It
|
| 9 |
+
addresses the challenging problem of multi-label medical diagnosis where:
|
| 10 |
+
1. Multiple conditions can co-exist (multi-label classification)
|
| 11 |
+
2. Each diagnosis has an associated confidence level (regression)
|
| 12 |
+
3. Each diagnosis has a weight indicating relative importance (regression)
|
| 13 |
+
|
| 14 |
+
WHY NEURAL NETWORKS FOR THIS TASK:
|
| 15 |
+
Neural networks are the optimal choice for this problem for several reasons:
|
| 16 |
+
|
| 17 |
+
1. **Non-linear Relationship Learning**: The relationship between image
|
| 18 |
+
embeddings and skin conditions is highly non-linear. Neural networks excel
|
| 19 |
+
at learning complex, non-linear mappings that simpler models (like logistic
|
| 20 |
+
regression) cannot capture.
|
| 21 |
+
|
| 22 |
+
2. **Multi-task Learning**: This problem requires predicting three related but
|
| 23 |
+
distinct outputs (conditions, confidence, weights). Neural networks can
|
| 24 |
+
share learned representations across these tasks through shared layers,
|
| 25 |
+
improving generalization and efficiency.
|
| 26 |
+
|
| 27 |
+
3. **High-dimensional Input**: Embeddings are typically 512-1024 dimensional
|
| 28 |
+
vectors. Neural networks are designed to handle high-dimensional inputs
|
| 29 |
+
effectively through dimensionality reduction in hidden layers.
|
| 30 |
+
|
| 31 |
+
4. **Multi-label Classification**: Medical diagnosis often involves multiple
|
| 32 |
+
co-existing conditions. Neural networks with sigmoid activation can model
|
| 33 |
+
the independent probability of each condition, unlike single-label methods.
|
| 34 |
+
|
| 35 |
+
5. **Flexibility**: The architecture can be customized with task-specific
|
| 36 |
+
heads (branches) for different prediction types, allowing specialized
|
| 37 |
+
processing for classification vs regression outputs.
|
| 38 |
+
|
| 39 |
+
WHY HAMMING LOSS IS VALID:
|
| 40 |
+
Hamming loss is an appropriate metric for multi-label classification because:
|
| 41 |
+
|
| 42 |
+
1. **Accounts for Partial Correctness**: Unlike exact match accuracy, hamming
|
| 43 |
+
loss gives credit for partially correct predictions. Predicting 3 out of 4
|
| 44 |
+
conditions correctly is better than 0 out of 4.
|
| 45 |
+
|
| 46 |
+
2. **Label-wise Evaluation**: It measures the fraction of incorrectly predicted
|
| 47 |
+
labels, treating each label independently - appropriate when conditions can
|
| 48 |
+
co-occur independently.
|
| 49 |
+
|
| 50 |
+
3. **Bounded and Interpretable**: Ranges from 0 (perfect) to 1 (completely
|
| 51 |
+
wrong). A hamming loss of 0.1 means 10% of label predictions were incorrect.
|
| 52 |
+
|
| 53 |
+
4. **Balanced for Sparse Labels**: In medical diagnosis, most samples have few
|
| 54 |
+
positive labels (sparse multi-label). Hamming loss naturally handles this
|
| 55 |
+
imbalance by computing the fraction across all labels.
|
| 56 |
+
|
| 57 |
+
5. **Clinically Relevant**: In medical applications, both false positives and
|
| 58 |
+
false negatives matter. Hamming loss penalizes both equally, unlike metrics
|
| 59 |
+
that focus on one type of error.
|
| 60 |
+
|
| 61 |
+
MATHEMATICAL JUSTIFICATION:
|
| 62 |
+
For a sample with true labels y and predicted labels ŷ:
|
| 63 |
+
Hamming Loss = (1/n_labels) × Σ(y_i XOR ŷ_i)
|
| 64 |
+
|
| 65 |
+
This averages the disagreement across all possible labels, making it suitable
|
| 66 |
+
for scenarios where:
|
| 67 |
+
- The label space is large (many possible conditions)
|
| 68 |
+
- Label correlations exist but aren't perfectly predictable
|
| 69 |
+
- Clinical accuracy matters at the individual label level
|
| 70 |
+
|
| 71 |
+
FIXES APPLIED IN THIS VERSION:
|
| 72 |
+
- Changed confidence activation from ReLU to softplus (prevents zero outputs)
|
| 73 |
+
- Improved confidence scaler fitting (uses only non-zero values)
|
| 74 |
+
- Increased confidence loss weight (1.5x for better learning signal)
|
| 75 |
+
- Enhanced data validation and preprocessing
|
| 76 |
+
- Better handling of sparse confidence/weight matrices
|
| 77 |
+
|
| 78 |
+
Requirements:
|
| 79 |
+
- pandas
|
| 80 |
+
- numpy
|
| 81 |
+
- tensorflow>=2.13.0
|
| 82 |
+
- scikit-learn
|
| 83 |
+
- matplotlib
|
| 84 |
+
- pickle (standard library)
|
| 85 |
+
- os (standard library)
|
| 86 |
+
- derm_foundation_embeddings.npz: Pre-computed embeddings from images
|
| 87 |
+
- dataset_scin_labels.csv: Ground truth labels with conditions, confidences, weights
|
| 88 |
+
|
| 89 |
+
OUTPUT:
|
| 90 |
+
- Trained neural network model (.keras file)
|
| 91 |
+
- Preprocessing components (scalers, label encoder) in .pkl file
|
| 92 |
+
- Training history plots showing convergence
|
| 93 |
+
- Evaluation metrics on test set
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
import numpy as np
|
| 97 |
+
import pandas as pd
|
| 98 |
+
import pickle
|
| 99 |
+
import os
|
| 100 |
+
import tensorflow as tf
|
| 101 |
+
from tensorflow import keras
|
| 102 |
+
from tensorflow.keras import layers
|
| 103 |
+
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler
|
| 104 |
+
from sklearn.model_selection import train_test_split
|
| 105 |
+
from sklearn.metrics import hamming_loss, mean_squared_error, mean_absolute_error
|
| 106 |
+
import matplotlib.pyplot as plt
|
| 107 |
+
import warnings
|
| 108 |
+
warnings.filterwarnings('ignore')
|
| 109 |
+
|
| 110 |
+
"""
|
| 111 |
+
Main class implementing the multi-output neural network classifier.
|
| 112 |
+
|
| 113 |
+
ARCHITECTURE OVERVIEW:
|
| 114 |
+
1. **Shared Feature Extraction**: 3 dense layers (512→256→128) with batch
|
| 115 |
+
normalization and dropout. These layers learn a shared representation
|
| 116 |
+
useful for all prediction tasks.
|
| 117 |
+
|
| 118 |
+
2. **Task-Specific Heads**: Three separate output branches:
|
| 119 |
+
- Condition classification: Sigmoid activation for multi-label prediction
|
| 120 |
+
- Confidence regression: Softplus activation for positive continuous values
|
| 121 |
+
- Weight regression: Sigmoid activation for [0,1] bounded values
|
| 122 |
+
|
| 123 |
+
WHY MULTI-TASK LEARNING:
|
| 124 |
+
- Conditions, confidence, and weights are related but distinct
|
| 125 |
+
- Sharing early layers allows the model to learn features useful for all tasks
|
| 126 |
+
- Task-specific heads allow specialized processing for each output type
|
| 127 |
+
- Improves generalization by preventing overfitting to any single task
|
| 128 |
+
|
| 129 |
+
TRAINING STRATEGY:
|
| 130 |
+
- Multi-task loss: Weighted combination of classification and regression losses
|
| 131 |
+
- Early stopping: Prevents overfitting by monitoring validation loss
|
| 132 |
+
- Learning rate reduction: Adapts learning rate when progress plateaus
|
| 133 |
+
- Batch normalization: Stabilizes training and allows higher learning rates
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
class DermFoundationNeuralNetwork:
|
| 137 |
+
"""
|
| 138 |
+
Initialize the classifier with preprocessing components.
|
| 139 |
+
|
| 140 |
+
PREPROCESSING COMPONENTS:
|
| 141 |
+
- mlb (MultiLabelBinarizer): Converts condition names to binary vectors
|
| 142 |
+
Example: ['Eczema', 'Psoriasis'] → [0,1,0,1,0,...,0]
|
| 143 |
+
|
| 144 |
+
- embedding_scaler (StandardScaler): Normalizes embeddings to mean=0, std=1
|
| 145 |
+
Why: Neural networks train faster with normalized inputs
|
| 146 |
+
|
| 147 |
+
- confidence_scaler (StandardScaler): Normalizes confidence values
|
| 148 |
+
Why: Brings continuous values to similar scale as other outputs
|
| 149 |
+
|
| 150 |
+
- weighted_scaler (StandardScaler): Normalizes weight values
|
| 151 |
+
Why: Ensures balanced gradient contributions during training
|
| 152 |
+
|
| 153 |
+
DESIGN DECISION:
|
| 154 |
+
Separate scalers for each output type allow independent normalization,
|
| 155 |
+
which is crucial when outputs have different scales and distributions.
|
| 156 |
+
"""
|
| 157 |
+
def __init__(self):
|
| 158 |
+
self.model = None
|
| 159 |
+
self.mlb = MultiLabelBinarizer()
|
| 160 |
+
self.embedding_scaler = StandardScaler()
|
| 161 |
+
self.confidence_scaler = StandardScaler()
|
| 162 |
+
self.weighted_scaler = StandardScaler()
|
| 163 |
+
self.history = None
|
| 164 |
+
"""
|
| 165 |
+
Load pre-computed Derm Foundation embeddings from NPZ file.
|
| 166 |
+
|
| 167 |
+
WHAT ARE EMBEDDINGS:
|
| 168 |
+
Embeddings are dense vector representations of images extracted from a
|
| 169 |
+
pre-trained vision model (Derm Foundation model). They capture high-level
|
| 170 |
+
visual features learned from large-scale dermatology image datasets.
|
| 171 |
+
|
| 172 |
+
WHY USE PRE-COMPUTED EMBEDDINGS:
|
| 173 |
+
1. **Efficiency**: Computing embeddings is expensive. Pre-computing them
|
| 174 |
+
allows rapid experimentation with different classifier architectures.
|
| 175 |
+
|
| 176 |
+
2. **Transfer Learning**: Derm Foundation was trained on massive dermatology
|
| 177 |
+
datasets. Its embeddings encode domain-specific visual patterns.
|
| 178 |
+
|
| 179 |
+
3. **Separation of Concerns**: Image processing and classification are
|
| 180 |
+
separated, allowing independent optimization of each component.
|
| 181 |
+
|
| 182 |
+
FORMAT:
|
| 183 |
+
NPZ file contains a dictionary where:
|
| 184 |
+
- Keys: case_id (string identifiers)
|
| 185 |
+
- Values: embedding vectors (typically 512 or 768 dimensions)
|
| 186 |
+
"""
|
| 187 |
+
def load_embeddings(self, npz_file_path):
|
| 188 |
+
"""Load embeddings from NPZ file"""
|
| 189 |
+
print(f"Loading embeddings from {npz_file_path}...")
|
| 190 |
+
|
| 191 |
+
if not os.path.exists(npz_file_path):
|
| 192 |
+
print(f"ERROR: Embeddings file not found: {npz_file_path}")
|
| 193 |
+
return None
|
| 194 |
+
|
| 195 |
+
embeddings_data = {}
|
| 196 |
+
with open(npz_file_path, 'rb') as f:
|
| 197 |
+
npz_file = np.load(f, allow_pickle=True)
|
| 198 |
+
for key in npz_file.files:
|
| 199 |
+
embeddings_data[key] = npz_file[key]
|
| 200 |
+
|
| 201 |
+
print(f"Loaded {len(embeddings_data)} embeddings")
|
| 202 |
+
|
| 203 |
+
# Print info about first embedding for debugging
|
| 204 |
+
first_key = list(embeddings_data.keys())[0]
|
| 205 |
+
first_embedding = embeddings_data[first_key]
|
| 206 |
+
print(f"Embedding shape: {first_embedding.shape}")
|
| 207 |
+
|
| 208 |
+
return embeddings_data
|
| 209 |
+
|
| 210 |
+
"""
|
| 211 |
+
Load ground truth labels from CSV file.
|
| 212 |
+
|
| 213 |
+
REQUIRED COLUMNS:
|
| 214 |
+
1. case_id: Unique identifier matching embedding keys
|
| 215 |
+
2. dermatologist_skin_condition_on_label_name: List of condition names
|
| 216 |
+
3. dermatologist_skin_condition_confidence: Confidence scores (typically 1-5)
|
| 217 |
+
4. weighted_skin_condition_label: Importance weights (0-1 range)
|
| 218 |
+
|
| 219 |
+
DATA TYPES:
|
| 220 |
+
- case_id must be string to match embedding keys
|
| 221 |
+
- Lists stored as strings (e.g., "['Eczema', 'Psoriasis']") are evaluated
|
| 222 |
+
- Handles various formats: lists, dicts, single values
|
| 223 |
+
"""
|
| 224 |
+
def load_dataset(self, csv_path):
|
| 225 |
+
"""Load dataset from CSV file"""
|
| 226 |
+
print(f"Loading dataset from {csv_path}...")
|
| 227 |
+
|
| 228 |
+
if not os.path.exists(csv_path):
|
| 229 |
+
print(f"ERROR: Dataset file not found: {csv_path}")
|
| 230 |
+
return None
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
df = pd.read_csv(csv_path, dtype={'case_id': str})
|
| 234 |
+
df['case_id'] = df['case_id'].astype(str)
|
| 235 |
+
|
| 236 |
+
print(f"Loaded dataset: {len(df)} samples")
|
| 237 |
+
|
| 238 |
+
# Verify required columns
|
| 239 |
+
required_columns = [
|
| 240 |
+
'case_id',
|
| 241 |
+
'dermatologist_skin_condition_on_label_name',
|
| 242 |
+
'dermatologist_skin_condition_confidence',
|
| 243 |
+
'weighted_skin_condition_label'
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
missing_columns = [col for col in required_columns if col not in df.columns]
|
| 247 |
+
if missing_columns:
|
| 248 |
+
print(f"ERROR: Missing required columns: {missing_columns}")
|
| 249 |
+
return None
|
| 250 |
+
|
| 251 |
+
return df
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
print(f"Error loading dataset: {e}")
|
| 255 |
+
return None
|
| 256 |
+
"""
|
| 257 |
+
Prepare training data with comprehensive validation and preprocessing.
|
| 258 |
+
|
| 259 |
+
COMPLEXITY HANDLING:
|
| 260 |
+
This method handles several challenging data characteristics:
|
| 261 |
+
|
| 262 |
+
1. **SPARSE MULTI-LABEL MATRICES**: Most samples have few positive labels
|
| 263 |
+
Solution: Track and report sparsity statistics for validation
|
| 264 |
+
|
| 265 |
+
2. **VARIABLE-LENGTH LISTS**: Different samples have different numbers of
|
| 266 |
+
conditions, confidences, and weights
|
| 267 |
+
Solution: Parse and align lists carefully, use mean values for mismatches
|
| 268 |
+
|
| 269 |
+
3. **RARE CONDITIONS**: Some conditions appear in very few samples
|
| 270 |
+
Solution: Filter to top N conditions and minimum sample requirements
|
| 271 |
+
|
| 272 |
+
4. **ZERO VALUES**: Confidence/weight matrices are mostly zeros (sparse)
|
| 273 |
+
Solution: Track zero vs non-zero ratios, fit scalers only on non-zeros
|
| 274 |
+
|
| 275 |
+
FILTERING STRATEGY:
|
| 276 |
+
- min_condition_samples: Removes rare conditions with insufficient data
|
| 277 |
+
- max_conditions: Limits to most frequent conditions to prevent overfitting
|
| 278 |
+
- Both filters ensure model focuses on well-represented, learnable patterns
|
| 279 |
+
|
| 280 |
+
WHY FILTER CONDITIONS:
|
| 281 |
+
1. **Statistical Validity**: Need sufficient examples to learn patterns
|
| 282 |
+
2. **Generalization**: Rare conditions lead to overfitting
|
| 283 |
+
3. **Computational Efficiency**: Fewer output nodes = faster training
|
| 284 |
+
4. **Clinical Relevance**: Common conditions are higher priority
|
| 285 |
+
|
| 286 |
+
MULTI-LABEL MATRIX STRUCTURE:
|
| 287 |
+
Shape: (n_samples, n_conditions)
|
| 288 |
+
- Rows: Individual patient cases
|
| 289 |
+
- Columns: Binary indicators for each condition (1=present, 0=absent)
|
| 290 |
+
- Multiple 1s per row: Multi-label (multiple conditions co-exist)
|
| 291 |
+
|
| 292 |
+
CONFIDENCE/WEIGHT MATRICES:
|
| 293 |
+
Shape: (n_samples, n_conditions)
|
| 294 |
+
- Values at (i,j): Confidence/weight for condition j in sample i
|
| 295 |
+
- Zero when condition j not present in sample i (sparse structure)
|
| 296 |
+
- Non-zero only where corresponding multi-label entry is 1
|
| 297 |
+
|
| 298 |
+
DATA VALIDATION:
|
| 299 |
+
Extensive logging of:
|
| 300 |
+
- Sample counts (processed vs skipped)
|
| 301 |
+
- Value ranges (min/max/mean)
|
| 302 |
+
- Sparsity statistics (% non-zero)
|
| 303 |
+
- Top conditions by frequency
|
| 304 |
+
|
| 305 |
+
This validation is crucial for:
|
| 306 |
+
- Detecting data quality issues early
|
| 307 |
+
- Understanding model input characteristics
|
| 308 |
+
- Debugging training problems
|
| 309 |
+
"""
|
| 310 |
+
def prepare_training_data(self, df, embeddings, min_condition_samples=5, max_conditions=30):
|
| 311 |
+
"""Prepare training data with improved confidence and weight handling"""
|
| 312 |
+
print("Preparing training data with enhanced validation...")
|
| 313 |
+
|
| 314 |
+
X = [] # Embeddings
|
| 315 |
+
condition_labels = [] # For multi-label classification
|
| 316 |
+
individual_confidences = [] # Individual confidence per condition
|
| 317 |
+
individual_weights = [] # Individual weight per condition
|
| 318 |
+
|
| 319 |
+
skipped_count = 0
|
| 320 |
+
processed_count = 0
|
| 321 |
+
confidence_stats = [] # Track confidence values for validation
|
| 322 |
+
weight_stats = [] # Track weight values for validation
|
| 323 |
+
|
| 324 |
+
for idx, row in df.iterrows():
|
| 325 |
+
try:
|
| 326 |
+
case_id = row['case_id']
|
| 327 |
+
|
| 328 |
+
if case_id not in embeddings:
|
| 329 |
+
skipped_count += 1
|
| 330 |
+
continue
|
| 331 |
+
|
| 332 |
+
# Parse the label data
|
| 333 |
+
try:
|
| 334 |
+
# Parse condition names
|
| 335 |
+
if isinstance(row['dermatologist_skin_condition_on_label_name'], str):
|
| 336 |
+
condition_names = eval(row['dermatologist_skin_condition_on_label_name'])
|
| 337 |
+
else:
|
| 338 |
+
condition_names = row['dermatologist_skin_condition_on_label_name']
|
| 339 |
+
|
| 340 |
+
# Ensure condition_names is a list
|
| 341 |
+
if not isinstance(condition_names, list):
|
| 342 |
+
condition_names = [condition_names] if condition_names else []
|
| 343 |
+
|
| 344 |
+
# Parse confidence scores
|
| 345 |
+
if isinstance(row['dermatologist_skin_condition_confidence'], str):
|
| 346 |
+
confidences = eval(row['dermatologist_skin_condition_confidence'])
|
| 347 |
+
else:
|
| 348 |
+
confidences = row['dermatologist_skin_condition_confidence']
|
| 349 |
+
|
| 350 |
+
# Ensure confidences is a list and matches conditions
|
| 351 |
+
if not isinstance(confidences, list):
|
| 352 |
+
confidences = [confidences] if confidences is not None else []
|
| 353 |
+
|
| 354 |
+
# Match confidence length to conditions
|
| 355 |
+
if len(confidences) != len(condition_names):
|
| 356 |
+
if len(confidences) == 1:
|
| 357 |
+
confidences = confidences * len(condition_names)
|
| 358 |
+
else:
|
| 359 |
+
print(f"Warning: Confidence length mismatch for {case_id}, using mean")
|
| 360 |
+
mean_conf = np.mean(confidences) if confidences else 3.0
|
| 361 |
+
confidences = [mean_conf] * len(condition_names)
|
| 362 |
+
|
| 363 |
+
# Parse weighted labels
|
| 364 |
+
if isinstance(row['weighted_skin_condition_label'], str):
|
| 365 |
+
weighted_labels = eval(row['weighted_skin_condition_label'])
|
| 366 |
+
else:
|
| 367 |
+
weighted_labels = row['weighted_skin_condition_label']
|
| 368 |
+
|
| 369 |
+
# Handle different weight formats
|
| 370 |
+
if isinstance(weighted_labels, dict):
|
| 371 |
+
# Convert dict to list matching condition order
|
| 372 |
+
weights = []
|
| 373 |
+
for condition in condition_names:
|
| 374 |
+
weights.append(weighted_labels.get(condition, 0.0))
|
| 375 |
+
elif isinstance(weighted_labels, list):
|
| 376 |
+
weights = weighted_labels
|
| 377 |
+
if len(weights) != len(condition_names):
|
| 378 |
+
if len(weights) == 1:
|
| 379 |
+
weights = weights * len(condition_names)
|
| 380 |
+
else:
|
| 381 |
+
mean_weight = np.mean(weights) if weights else 0.3
|
| 382 |
+
weights = [mean_weight] * len(condition_names)
|
| 383 |
+
else:
|
| 384 |
+
# Single value
|
| 385 |
+
weights = [weighted_labels] * len(condition_names) if weighted_labels else [0.3] * len(condition_names)
|
| 386 |
+
|
| 387 |
+
except Exception as e:
|
| 388 |
+
print(f"Error parsing data for {case_id}: {e}")
|
| 389 |
+
skipped_count += 1
|
| 390 |
+
continue
|
| 391 |
+
|
| 392 |
+
# Validate data ranges
|
| 393 |
+
try:
|
| 394 |
+
confidences = [max(0.0, float(c)) for c in confidences] # Ensure non-negative
|
| 395 |
+
weights = [max(0.0, min(1.0, float(w))) for w in weights] # Clamp to [0,1]
|
| 396 |
+
except:
|
| 397 |
+
print(f"Error converting values for {case_id}, skipping")
|
| 398 |
+
skipped_count += 1
|
| 399 |
+
continue
|
| 400 |
+
|
| 401 |
+
# Add to training data
|
| 402 |
+
X.append(embeddings[case_id])
|
| 403 |
+
condition_labels.append(condition_names)
|
| 404 |
+
|
| 405 |
+
# Store individual confidences and weights
|
| 406 |
+
individual_confidences.append({
|
| 407 |
+
'conditions': condition_names,
|
| 408 |
+
'confidences': confidences
|
| 409 |
+
})
|
| 410 |
+
|
| 411 |
+
individual_weights.append({
|
| 412 |
+
'conditions': condition_names,
|
| 413 |
+
'weights': weights
|
| 414 |
+
})
|
| 415 |
+
|
| 416 |
+
# Track statistics
|
| 417 |
+
confidence_stats.extend(confidences)
|
| 418 |
+
weight_stats.extend(weights)
|
| 419 |
+
|
| 420 |
+
processed_count += 1
|
| 421 |
+
|
| 422 |
+
except Exception as e:
|
| 423 |
+
print(f"Error processing row {idx}: {e}")
|
| 424 |
+
skipped_count += 1
|
| 425 |
+
continue
|
| 426 |
+
|
| 427 |
+
print(f"Training data prepared: {processed_count} samples, {skipped_count} skipped")
|
| 428 |
+
|
| 429 |
+
if len(X) == 0:
|
| 430 |
+
print("ERROR: No training samples found!")
|
| 431 |
+
return None, None, None, None
|
| 432 |
+
|
| 433 |
+
# Print data statistics
|
| 434 |
+
print(f"\nData validation:")
|
| 435 |
+
print(f" Confidence values - min: {min(confidence_stats):.3f}, max: {max(confidence_stats):.3f}, mean: {np.mean(confidence_stats):.3f}")
|
| 436 |
+
print(f" Weight values - min: {min(weight_stats):.3f}, max: {max(weight_stats):.3f}, mean: {np.mean(weight_stats):.3f}")
|
| 437 |
+
print(f" Non-zero confidences: {sum(1 for c in confidence_stats if c > 0.001)}/{len(confidence_stats)} ({100*sum(1 for c in confidence_stats if c > 0.001)/len(confidence_stats):.1f}%)")
|
| 438 |
+
print(f" Non-zero weights: {sum(1 for w in weight_stats if w > 0.001)}/{len(weight_stats)} ({100*sum(1 for w in weight_stats if w > 0.001)/len(weight_stats):.1f}%)")
|
| 439 |
+
|
| 440 |
+
# Convert to numpy arrays
|
| 441 |
+
X = np.array(X)
|
| 442 |
+
|
| 443 |
+
# Prepare condition labels - focus on top conditions only
|
| 444 |
+
y_conditions_raw = self.mlb.fit_transform(condition_labels)
|
| 445 |
+
condition_counts = y_conditions_raw.sum(axis=0)
|
| 446 |
+
|
| 447 |
+
# Get top conditions by frequency
|
| 448 |
+
sorted_indices = np.argsort(condition_counts)[::-1]
|
| 449 |
+
top_condition_indices = sorted_indices[:max_conditions]
|
| 450 |
+
|
| 451 |
+
# Also ensure minimum samples
|
| 452 |
+
frequent_conditions = condition_counts >= min_condition_samples
|
| 453 |
+
final_indices = np.intersect1d(top_condition_indices, np.where(frequent_conditions)[0])
|
| 454 |
+
|
| 455 |
+
print(f"Total condition classes: {len(self.mlb.classes_)}")
|
| 456 |
+
print(f"Top {max_conditions} most frequent conditions selected")
|
| 457 |
+
print(f"Conditions with at least {min_condition_samples} examples: {frequent_conditions.sum()}")
|
| 458 |
+
|
| 459 |
+
# Keep only selected conditions
|
| 460 |
+
selected_classes = self.mlb.classes_[final_indices]
|
| 461 |
+
y_conditions = y_conditions_raw[:, final_indices]
|
| 462 |
+
|
| 463 |
+
# Update MultiLabelBinarizer
|
| 464 |
+
self.mlb = MultiLabelBinarizer()
|
| 465 |
+
self.mlb.classes_ = selected_classes
|
| 466 |
+
|
| 467 |
+
print(f"Final condition classes: {len(selected_classes)}")
|
| 468 |
+
print(f"Multi-label matrix shape: {y_conditions.shape}")
|
| 469 |
+
|
| 470 |
+
# Create individual confidence and weight matrices
|
| 471 |
+
y_confidences = np.zeros((len(X), len(selected_classes)))
|
| 472 |
+
y_weights = np.zeros((len(X), len(selected_classes)))
|
| 473 |
+
|
| 474 |
+
for i, (conf_data, weight_data) in enumerate(zip(individual_confidences, individual_weights)):
|
| 475 |
+
# Map confidences to selected conditions
|
| 476 |
+
for condition, confidence in zip(conf_data['conditions'], conf_data['confidences']):
|
| 477 |
+
if condition in selected_classes:
|
| 478 |
+
condition_idx = np.where(selected_classes == condition)[0]
|
| 479 |
+
if len(condition_idx) > 0:
|
| 480 |
+
y_confidences[i, condition_idx[0]] = confidence
|
| 481 |
+
|
| 482 |
+
# Map weights to selected conditions
|
| 483 |
+
for condition, weight in zip(weight_data['conditions'], weight_data['weights']):
|
| 484 |
+
if condition in selected_classes:
|
| 485 |
+
condition_idx = np.where(selected_classes == condition)[0]
|
| 486 |
+
if len(condition_idx) > 0:
|
| 487 |
+
y_weights[i, condition_idx[0]] = weight
|
| 488 |
+
|
| 489 |
+
# Print matrix statistics
|
| 490 |
+
nonzero_conf = (y_confidences > 0.001).sum()
|
| 491 |
+
nonzero_weight = (y_weights > 0.001).sum()
|
| 492 |
+
total_elements = y_confidences.size
|
| 493 |
+
|
| 494 |
+
print(f"\nMatrix statistics:")
|
| 495 |
+
print(f" Confidence matrix - non-zero: {nonzero_conf}/{total_elements} ({100*nonzero_conf/total_elements:.1f}%)")
|
| 496 |
+
print(f" Weight matrix - non-zero: {nonzero_weight}/{total_elements} ({100*nonzero_weight/total_elements:.1f}%)")
|
| 497 |
+
print(f" Confidence range: {y_confidences[y_confidences > 0].min():.3f} - {y_confidences[y_confidences > 0].max():.3f}")
|
| 498 |
+
print(f" Weight range: {y_weights[y_weights > 0].min():.3f} - {y_weights[y_weights > 0].max():.3f}")
|
| 499 |
+
|
| 500 |
+
# Print top conditions
|
| 501 |
+
condition_counts_filtered = y_conditions.sum(axis=0)
|
| 502 |
+
print("\nTop conditions selected:")
|
| 503 |
+
for i, (condition, count) in enumerate(zip(selected_classes, condition_counts_filtered)):
|
| 504 |
+
print(f" {i+1:2d}. {condition}: {count} samples")
|
| 505 |
+
|
| 506 |
+
return X, y_conditions, y_confidences, y_weights
|
| 507 |
+
"""
|
| 508 |
+
Build multi-output neural network architecture.
|
| 509 |
+
|
| 510 |
+
ARCHITECTURE RATIONALE:
|
| 511 |
+
|
| 512 |
+
**SHARED LAYERS (512→256→128)**:
|
| 513 |
+
- Purpose: Learn general features useful for all prediction tasks
|
| 514 |
+
- Size progression: Gradual dimensionality reduction (embeddings→features)
|
| 515 |
+
- Batch Normalization: Stabilizes training, allows higher learning rates
|
| 516 |
+
- Dropout (0.3, 0.3, 0.2): Prevents overfitting, forces robust features
|
| 517 |
+
|
| 518 |
+
Why this depth:
|
| 519 |
+
- 3 layers balances capacity (can learn complex patterns) vs simplicity
|
| 520 |
+
- Too shallow: Can't learn complex patterns
|
| 521 |
+
- Too deep: Overfits, slower training, harder to optimize
|
| 522 |
+
|
| 523 |
+
**TASK-SPECIFIC BRANCHES**:
|
| 524 |
+
Each branch has 2 layers (64→output) for specialized processing:
|
| 525 |
+
|
| 526 |
+
1. **CONDITION CLASSIFICATION BRANCH**:
|
| 527 |
+
- Activation: Sigmoid (outputs independent probabilities per condition)
|
| 528 |
+
- Why sigmoid: Allows multiple conditions to be predicted simultaneously
|
| 529 |
+
- Loss: Binary cross-entropy (standard for multi-label classification)
|
| 530 |
+
|
| 531 |
+
2. **CONFIDENCE REGRESSION BRANCH**:
|
| 532 |
+
- Activation: Softplus (ensures positive outputs, smooth gradients)
|
| 533 |
+
- Why softplus not ReLU: ReLU outputs exactly zero for negative inputs,
|
| 534 |
+
causing gradient issues. Softplus outputs small positive values instead.
|
| 535 |
+
- Formula: softplus(x) = log(1 + exp(x))
|
| 536 |
+
- Loss: MSE (Mean Squared Error for continuous values)
|
| 537 |
+
- Loss weight: 1.5x (increased to prioritize confidence learning)
|
| 538 |
+
|
| 539 |
+
3. **WEIGHT REGRESSION BRANCH**:
|
| 540 |
+
- Activation: Sigmoid (ensures [0,1] bounded output)
|
| 541 |
+
- Why sigmoid: Weights represent proportions/probabilities, must be 0-1
|
| 542 |
+
- Loss: MSE (Mean Squared Error for continuous values)
|
| 543 |
+
- Loss weight: 1.2x (slightly increased priority)
|
| 544 |
+
|
| 545 |
+
**LOSS WEIGHTING**:
|
| 546 |
+
Different loss scales require weighting for balanced training:
|
| 547 |
+
- Condition loss: Binary cross-entropy, typically ~0.3-0.7
|
| 548 |
+
- Confidence loss: MSE on scaled values, typically ~0.01-0.1
|
| 549 |
+
- Weight loss: MSE on scaled values, typically ~0.01-0.1
|
| 550 |
+
|
| 551 |
+
Weights (1.0, 1.5, 1.2) ensure:
|
| 552 |
+
- All tasks contribute meaningfully to total loss
|
| 553 |
+
- Confidence gets extra emphasis (was underfitting in previous versions)
|
| 554 |
+
- Gradient magnitudes are balanced across tasks
|
| 555 |
+
|
| 556 |
+
**WHY ADAM OPTIMIZER**:
|
| 557 |
+
- Adaptive learning rates per parameter (handles different loss scales)
|
| 558 |
+
- Momentum for faster convergence
|
| 559 |
+
- Robust to hyperparameter choices
|
| 560 |
+
- Industry standard for multi-task learning
|
| 561 |
+
|
| 562 |
+
**MODEL COMPILATION**:
|
| 563 |
+
The model uses a dictionary output format allowing:
|
| 564 |
+
- Clear separation of different predictions
|
| 565 |
+
- Easy access to specific outputs during inference
|
| 566 |
+
- Flexible loss and metric assignment per output
|
| 567 |
+
"""
|
| 568 |
+
def build_model(self, input_dim, num_conditions, learning_rate=0.001):
|
| 569 |
+
"""Build neural network with improved confidence and weight outputs"""
|
| 570 |
+
print("Building improved neural network model...")
|
| 571 |
+
|
| 572 |
+
# Input layer
|
| 573 |
+
inputs = keras.Input(shape=(input_dim,), name='embeddings')
|
| 574 |
+
|
| 575 |
+
# Shared feature extraction layers
|
| 576 |
+
x = layers.Dense(512, activation='relu', name='dense1')(inputs) # Increased capacity
|
| 577 |
+
x = layers.BatchNormalization(name='bn1')(x)
|
| 578 |
+
x = layers.Dropout(0.3, name='dropout1')(x)
|
| 579 |
+
|
| 580 |
+
x = layers.Dense(256, activation='relu', name='dense2')(x)
|
| 581 |
+
x = layers.BatchNormalization(name='bn2')(x)
|
| 582 |
+
x = layers.Dropout(0.3, name='dropout2')(x)
|
| 583 |
+
|
| 584 |
+
x = layers.Dense(128, activation='relu', name='dense3')(x)
|
| 585 |
+
x = layers.BatchNormalization(name='bn3')(x)
|
| 586 |
+
x = layers.Dropout(0.2, name='dropout3')(x)
|
| 587 |
+
|
| 588 |
+
# Multi-label condition classification head
|
| 589 |
+
condition_branch = layers.Dense(64, activation='relu', name='condition_dense')(x)
|
| 590 |
+
condition_branch = layers.Dropout(0.2, name='condition_dropout')(condition_branch)
|
| 591 |
+
condition_output = layers.Dense(num_conditions, activation='sigmoid',
|
| 592 |
+
name='conditions')(condition_branch)
|
| 593 |
+
|
| 594 |
+
# Individual confidence regression head - FIXED ACTIVATION
|
| 595 |
+
confidence_branch = layers.Dense(64, activation='relu', name='confidence_dense1')(x)
|
| 596 |
+
confidence_branch = layers.Dropout(0.2, name='confidence_dropout1')(confidence_branch)
|
| 597 |
+
confidence_branch = layers.Dense(32, activation='relu', name='confidence_dense2')(confidence_branch)
|
| 598 |
+
confidence_branch = layers.Dropout(0.1, name='confidence_dropout2')(confidence_branch)
|
| 599 |
+
# Changed from ReLU to softplus - ensures positive, non-zero outputs
|
| 600 |
+
confidence_output = layers.Dense(num_conditions, activation='softplus',
|
| 601 |
+
name='individual_confidences')(confidence_branch)
|
| 602 |
+
|
| 603 |
+
# Individual weight regression head
|
| 604 |
+
weighted_branch = layers.Dense(64, activation='relu', name='weighted_dense1')(x)
|
| 605 |
+
weighted_branch = layers.Dropout(0.2, name='weighted_dropout1')(weighted_branch)
|
| 606 |
+
weighted_branch = layers.Dense(32, activation='relu', name='weighted_dense2')(weighted_branch)
|
| 607 |
+
weighted_branch = layers.Dropout(0.1, name='weighted_dropout2')(weighted_branch)
|
| 608 |
+
# Use sigmoid to ensure 0-1 range
|
| 609 |
+
weighted_output = layers.Dense(num_conditions, activation='sigmoid',
|
| 610 |
+
name='individual_weights')(weighted_branch)
|
| 611 |
+
|
| 612 |
+
# Create model
|
| 613 |
+
model = keras.Model(
|
| 614 |
+
inputs=inputs,
|
| 615 |
+
outputs={
|
| 616 |
+
'conditions': condition_output,
|
| 617 |
+
'individual_confidences': confidence_output,
|
| 618 |
+
'individual_weights': weighted_output
|
| 619 |
+
}
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# Compile model with improved loss weights
|
| 623 |
+
model.compile(
|
| 624 |
+
optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
|
| 625 |
+
loss={
|
| 626 |
+
'conditions': 'binary_crossentropy',
|
| 627 |
+
'individual_confidences': 'mse',
|
| 628 |
+
'individual_weights': 'mse'
|
| 629 |
+
},
|
| 630 |
+
loss_weights={
|
| 631 |
+
'conditions': 1.0,
|
| 632 |
+
'individual_confidences': 1.5, # Increased weight for confidence
|
| 633 |
+
'individual_weights': 1.2 # Increased weight for weights
|
| 634 |
+
},
|
| 635 |
+
metrics={
|
| 636 |
+
'conditions': ['accuracy'],
|
| 637 |
+
'individual_confidences': ['mae'],
|
| 638 |
+
'individual_weights': ['mae']
|
| 639 |
+
}
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
return model
|
| 643 |
+
|
| 644 |
+
"""
|
| 645 |
+
Main training orchestration method with improved confidence handling.
|
| 646 |
+
|
| 647 |
+
TRAINING PIPELINE:
|
| 648 |
+
1. Load data (embeddings + labels)
|
| 649 |
+
2. Prepare training matrices (parse, filter, validate)
|
| 650 |
+
3. Scale features and outputs
|
| 651 |
+
4. Split train/validation sets
|
| 652 |
+
5. Build neural network architecture
|
| 653 |
+
6. Train with callbacks (early stopping, LR reduction, checkpointing)
|
| 654 |
+
7. Evaluate performance
|
| 655 |
+
8. Save trained model
|
| 656 |
+
|
| 657 |
+
IMPROVED SCALING STRATEGY (KEY FIX):
|
| 658 |
+
Problem: Previous version scaled all values including zeros
|
| 659 |
+
Solution: Fit scalers only on non-zero values
|
| 660 |
+
|
| 661 |
+
Why this matters:
|
| 662 |
+
- Sparse matrices have many structural zeros (condition not present)
|
| 663 |
+
- Including zeros in scaler fitting shifts mean artificially low
|
| 664 |
+
- Model learns to predict near-zero for everything
|
| 665 |
+
- Confidence predictions collapsed to ~0 (major bug)
|
| 666 |
+
|
| 667 |
+
New approach:
|
| 668 |
+
```python
|
| 669 |
+
conf_nonzero = y_confidences[y_confidences > 0.001]
|
| 670 |
+
self.confidence_scaler.fit(conf_nonzero)
|
| 671 |
+
|
| 672 |
+
Only non-zero values determine scale
|
| 673 |
+
Model learns actual confidence distribution (1-5 range)
|
| 674 |
+
Predictions are meaningful positive values
|
| 675 |
+
|
| 676 |
+
FALLBACK HANDLING:
|
| 677 |
+
If too few non-zero values exist:
|
| 678 |
+
|
| 679 |
+
Use sensible dummy values (1-5 for confidence, 0-1 for weights)
|
| 680 |
+
Prevents scaler failure on edge cases
|
| 681 |
+
Ensures training can proceed
|
| 682 |
+
|
| 683 |
+
TRAIN/TEST SPLIT:
|
| 684 |
+
|
| 685 |
+
80/20 split is standard for medical ML
|
| 686 |
+
Stratification not used (multi-label makes it complex)
|
| 687 |
+
Random state fixed for reproducibility
|
| 688 |
+
|
| 689 |
+
CALLBACKS:
|
| 690 |
+
|
| 691 |
+
Early Stopping (patience=12):
|
| 692 |
+
|
| 693 |
+
Monitors validation loss
|
| 694 |
+
Stops if no improvement for 12 epochs
|
| 695 |
+
Restores best weights (not final weights)
|
| 696 |
+
Why: Prevents overfitting to training set
|
| 697 |
+
|
| 698 |
+
ReduceLROnPlateau (factor=0.5, patience=5):
|
| 699 |
+
|
| 700 |
+
Monitors confidence loss specifically (was problematic)
|
| 701 |
+
Reduces LR by 50% if loss plateaus
|
| 702 |
+
Allows fine-tuning in late training
|
| 703 |
+
Min LR: 1e-7 prevents excessive reduction
|
| 704 |
+
|
| 705 |
+
ModelCheckpoint:
|
| 706 |
+
|
| 707 |
+
Saves best model weights during training
|
| 708 |
+
Insurance against training divergence
|
| 709 |
+
Cleaned up after successful training
|
| 710 |
+
|
| 711 |
+
TRAINING DURATION:
|
| 712 |
+
|
| 713 |
+
60 epochs maximum (increased from 50)
|
| 714 |
+
Early stopping typically triggers around epoch 30-40
|
| 715 |
+
Batch size 32 balances memory vs convergence speed
|
| 716 |
+
|
| 717 |
+
HYPERPARAMETERS:
|
| 718 |
+
|
| 719 |
+
Learning rate: 0.001 (standard for Adam)
|
| 720 |
+
Batch size: 32 (good for datasets of this size)
|
| 721 |
+
Test split: 0.2 (20% validation, standard practice)
|
| 722 |
+
|
| 723 |
+
POST-TRAINING:
|
| 724 |
+
|
| 725 |
+
Comprehensive evaluation on test set
|
| 726 |
+
Detailed metrics for all three outputs
|
| 727 |
+
Analysis of confidence prediction quality
|
| 728 |
+
"""
|
| 729 |
+
def train(self, npz_file_path="derm_foundation_embeddings.npz",
|
| 730 |
+
csv_file_path="dataset_scin_labels.csv",
|
| 731 |
+
test_size=0.2, random_state=42, epochs=50, batch_size=32,
|
| 732 |
+
learning_rate=0.001):
|
| 733 |
+
"""Train the neural network with improved confidence handling"""
|
| 734 |
+
|
| 735 |
+
# Load data
|
| 736 |
+
embeddings = self.load_embeddings(npz_file_path)
|
| 737 |
+
if embeddings is None:
|
| 738 |
+
return False
|
| 739 |
+
|
| 740 |
+
df = self.load_dataset(csv_file_path)
|
| 741 |
+
if df is None:
|
| 742 |
+
return False
|
| 743 |
+
|
| 744 |
+
# Prepare training data
|
| 745 |
+
X, y_conditions, y_confidences, y_weights = self.prepare_training_data(df, embeddings)
|
| 746 |
+
if X is None:
|
| 747 |
+
return False
|
| 748 |
+
|
| 749 |
+
# IMPROVED SCALING - fit only on non-zero values
|
| 750 |
+
print("\nFitting scalers...")
|
| 751 |
+
X_scaled = self.embedding_scaler.fit_transform(X)
|
| 752 |
+
|
| 753 |
+
# Fit confidence scaler on non-zero values only
|
| 754 |
+
conf_nonzero = y_confidences[y_confidences > 0.001]
|
| 755 |
+
if len(conf_nonzero) > 50: # Ensure we have enough data
|
| 756 |
+
print(f"Fitting confidence scaler on {len(conf_nonzero)} non-zero values")
|
| 757 |
+
self.confidence_scaler.fit(conf_nonzero.reshape(-1, 1))
|
| 758 |
+
else:
|
| 759 |
+
print("WARNING: Too few non-zero confidence values, using default scaling")
|
| 760 |
+
# Use a reasonable range for confidence values (e.g., 1-5 scale)
|
| 761 |
+
dummy_conf = np.array([1.0, 2.0, 3.0, 4.0, 5.0]).reshape(-1, 1)
|
| 762 |
+
self.confidence_scaler.fit(dummy_conf)
|
| 763 |
+
|
| 764 |
+
# Fit weight scaler on non-zero values only
|
| 765 |
+
weight_nonzero = y_weights[y_weights > 0.001]
|
| 766 |
+
if len(weight_nonzero) > 50:
|
| 767 |
+
print(f"Fitting weight scaler on {len(weight_nonzero)} non-zero values")
|
| 768 |
+
self.weighted_scaler.fit(weight_nonzero.reshape(-1, 1))
|
| 769 |
+
else:
|
| 770 |
+
print("WARNING: Too few non-zero weight values, using default scaling")
|
| 771 |
+
# Use a reasonable range for weight values (0-1 scale)
|
| 772 |
+
dummy_weight = np.array([0.1, 0.3, 0.5, 0.7, 0.9]).reshape(-1, 1)
|
| 773 |
+
self.weighted_scaler.fit(dummy_weight)
|
| 774 |
+
|
| 775 |
+
# Apply scaling to the matrices (preserve structure)
|
| 776 |
+
y_confidences_scaled = np.zeros_like(y_confidences)
|
| 777 |
+
y_weights_scaled = np.zeros_like(y_weights)
|
| 778 |
+
|
| 779 |
+
# Scale only non-zero values
|
| 780 |
+
for i in range(y_confidences.shape[0]):
|
| 781 |
+
for j in range(y_confidences.shape[1]):
|
| 782 |
+
if y_confidences[i, j] > 0.001:
|
| 783 |
+
y_confidences_scaled[i, j] = self.confidence_scaler.transform([[y_confidences[i, j]]])[0, 0]
|
| 784 |
+
if y_weights[i, j] > 0.001:
|
| 785 |
+
y_weights_scaled[i, j] = self.weighted_scaler.transform([[y_weights[i, j]]])[0, 0]
|
| 786 |
+
|
| 787 |
+
print(f"Scaled confidence range: {y_confidences_scaled[y_confidences_scaled != 0].min():.3f} - {y_confidences_scaled[y_confidences_scaled != 0].max():.3f}")
|
| 788 |
+
print(f"Scaled weight range: {y_weights_scaled[y_weights_scaled != 0].min():.3f} - {y_weights_scaled[y_weights_scaled != 0].max():.3f}")
|
| 789 |
+
|
| 790 |
+
# Split data
|
| 791 |
+
X_train, X_test, y_cond_train, y_cond_test, y_conf_train, y_conf_test, y_weight_train, y_weight_test = train_test_split(
|
| 792 |
+
X_scaled, y_conditions, y_confidences_scaled, y_weights_scaled,
|
| 793 |
+
test_size=test_size, random_state=random_state
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
print(f"\nTraining/test split:")
|
| 797 |
+
print(f" Training samples: {X_train.shape[0]}")
|
| 798 |
+
print(f" Test samples: {X_test.shape[0]}")
|
| 799 |
+
|
| 800 |
+
# Build model
|
| 801 |
+
self.model = self.build_model(
|
| 802 |
+
input_dim=X_scaled.shape[1],
|
| 803 |
+
num_conditions=y_conditions.shape[1],
|
| 804 |
+
learning_rate=learning_rate
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
print(f"\nModel architecture:")
|
| 808 |
+
self.model.summary()
|
| 809 |
+
|
| 810 |
+
# Prepare training data
|
| 811 |
+
train_data = {
|
| 812 |
+
'conditions': y_cond_train,
|
| 813 |
+
'individual_confidences': y_conf_train,
|
| 814 |
+
'individual_weights': y_weight_train
|
| 815 |
+
}
|
| 816 |
+
|
| 817 |
+
val_data = {
|
| 818 |
+
'conditions': y_cond_test,
|
| 819 |
+
'individual_confidences': y_conf_test,
|
| 820 |
+
'individual_weights': y_weight_test
|
| 821 |
+
}
|
| 822 |
+
|
| 823 |
+
# Enhanced callbacks
|
| 824 |
+
early_stopping = keras.callbacks.EarlyStopping(
|
| 825 |
+
monitor='val_loss',
|
| 826 |
+
patience=12, # Increased patience
|
| 827 |
+
restore_best_weights=True,
|
| 828 |
+
verbose=1
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
reduce_lr = keras.callbacks.ReduceLROnPlateau(
|
| 832 |
+
monitor='val_individual_confidences_loss', # Monitor confidence loss specifically
|
| 833 |
+
factor=0.5,
|
| 834 |
+
patience=5,
|
| 835 |
+
min_lr=1e-7,
|
| 836 |
+
mode='min', # We want to minimize the loss
|
| 837 |
+
verbose=1
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
model_checkpoint = keras.callbacks.ModelCheckpoint(
|
| 841 |
+
filepath='best_model_fixed.weights.h5',
|
| 842 |
+
monitor='val_loss',
|
| 843 |
+
save_best_only=True,
|
| 844 |
+
save_weights_only=True,
|
| 845 |
+
verbose=1
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
print(f"\nStarting training for {epochs} epochs...")
|
| 849 |
+
|
| 850 |
+
# Train model
|
| 851 |
+
self.history = self.model.fit(
|
| 852 |
+
X_train, train_data,
|
| 853 |
+
validation_data=(X_test, val_data),
|
| 854 |
+
epochs=epochs,
|
| 855 |
+
batch_size=batch_size,
|
| 856 |
+
callbacks=[early_stopping, reduce_lr, model_checkpoint],
|
| 857 |
+
verbose=1
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
# Evaluate model
|
| 861 |
+
self.evaluate_model(X_test, y_cond_test, y_conf_test, y_weight_test)
|
| 862 |
+
|
| 863 |
+
return True
|
| 864 |
+
"""
|
| 865 |
+
Comprehensive model evaluation with enhanced confidence analysis.
|
| 866 |
+
EVALUATION METRICS:
|
| 867 |
+
1. MULTI-LABEL CLASSIFICATION (Conditions):
|
| 868 |
+
Hamming Loss:
|
| 869 |
+
|
| 870 |
+
Definition: Fraction of incorrectly predicted labels
|
| 871 |
+
Range: [0, 1] where 0 is perfect
|
| 872 |
+
Formula: (1/n_labels) × Σ|y_true ⊕ y_pred|
|
| 873 |
+
Example: If 2 out of 30 labels are wrong, hamming loss = 0.067
|
| 874 |
+
Clinical interpretation: Lower is better, <0.1 is excellent
|
| 875 |
+
|
| 876 |
+
Exact Match Accuracy:
|
| 877 |
+
|
| 878 |
+
Strictest metric: Requires ALL labels perfectly correct
|
| 879 |
+
Range: [0, 1] where 1 is perfect
|
| 880 |
+
Why include: Shows complete prediction correctness
|
| 881 |
+
Medical context: Exact match is ideal but rarely achievable
|
| 882 |
+
(even expert dermatologists disagree on some cases)
|
| 883 |
+
|
| 884 |
+
Average Conditions per Sample:
|
| 885 |
+
|
| 886 |
+
Describes label distribution complexity
|
| 887 |
+
Higher values → harder multi-label problem
|
| 888 |
+
Typical range: 1-3 conditions per sample
|
| 889 |
+
|
| 890 |
+
2. CONFIDENCE REGRESSION:
|
| 891 |
+
Why evaluate only non-zero targets:
|
| 892 |
+
|
| 893 |
+
Zeros are structural (condition not present)
|
| 894 |
+
Including zeros conflates two problems:
|
| 895 |
+
a) Predicting which conditions exist (classification task)
|
| 896 |
+
b) Predicting confidence for existing conditions (regression task)
|
| 897 |
+
We want to evaluate (b) separately
|
| 898 |
+
|
| 899 |
+
Inverse Transform:
|
| 900 |
+
|
| 901 |
+
Converts scaled predictions back to original scale
|
| 902 |
+
Necessary for interpretable metrics
|
| 903 |
+
Example: Scaled 0.3 → Original 3.2 (on 1-5 scale)
|
| 904 |
+
|
| 905 |
+
MSE (Mean Squared Error):
|
| 906 |
+
|
| 907 |
+
Sensitive to large errors (squared penalty)
|
| 908 |
+
Unit: (confidence units)²
|
| 909 |
+
Lower is better
|
| 910 |
+
|
| 911 |
+
MAE (Mean Absolute Error):
|
| 912 |
+
|
| 913 |
+
Average absolute difference from ground truth
|
| 914 |
+
Same units as original values
|
| 915 |
+
More robust to outliers than MSE
|
| 916 |
+
Clinical interpretation: If MAE=0.5, average error is ±0.5 points
|
| 917 |
+
|
| 918 |
+
RMSE (Root Mean Squared Error):
|
| 919 |
+
|
| 920 |
+
Square root of MSE
|
| 921 |
+
Same units as original values (easier to interpret than MSE)
|
| 922 |
+
Emphasizes larger errors more than MAE
|
| 923 |
+
|
| 924 |
+
Prediction Range Analysis:
|
| 925 |
+
|
| 926 |
+
Verifies predictions are in sensible range
|
| 927 |
+
Example: If ground truth is 1-5, predictions should be similar
|
| 928 |
+
Out-of-range predictions indicate scaling or activation issues
|
| 929 |
+
|
| 930 |
+
3. WEIGHT REGRESSION:
|
| 931 |
+
Same metrics as confidence but for weight values (0-1 range)
|
| 932 |
+
DIAGNOSTIC CHECKS:
|
| 933 |
+
|
| 934 |
+
"Predictions > 0.1" percentage: Ensures model isn't predicting near-zero
|
| 935 |
+
Range comparison: Truth vs prediction ranges should align
|
| 936 |
+
Non-zero count: Verifies sparse structure is respected
|
| 937 |
+
|
| 938 |
+
WHY THIS EVALUATION IS COMPREHENSIVE:
|
| 939 |
+
|
| 940 |
+
Multiple metrics cover different aspects (classification + regression)
|
| 941 |
+
Separate evaluation of sparse vs dense regions
|
| 942 |
+
Original scale metrics (clinically interpretable)
|
| 943 |
+
Diagnostic checks for common failure modes
|
| 944 |
+
Both aggregate (MSE) and per-sample (MAE) metrics
|
| 945 |
+
"""
|
| 946 |
+
def evaluate_model(self, X_test, y_cond_test, y_conf_test, y_weight_test):
|
| 947 |
+
"""Evaluate the trained model with enhanced confidence analysis"""
|
| 948 |
+
print("\n" + "="*70)
|
| 949 |
+
print("MODEL EVALUATION - ENHANCED CONFIDENCE ANALYSIS")
|
| 950 |
+
print("="*70)
|
| 951 |
+
|
| 952 |
+
# Make predictions
|
| 953 |
+
predictions = self.model.predict(X_test)
|
| 954 |
+
y_cond_pred = predictions['conditions']
|
| 955 |
+
y_conf_pred = predictions['individual_confidences']
|
| 956 |
+
y_weight_pred = predictions['individual_weights']
|
| 957 |
+
|
| 958 |
+
# Condition classification evaluation
|
| 959 |
+
y_cond_pred_binary = (y_cond_pred > 0.5).astype(int)
|
| 960 |
+
hamming = hamming_loss(y_cond_test, y_cond_pred_binary)
|
| 961 |
+
exact_match = (y_cond_pred_binary == y_cond_test).all(axis=1).mean()
|
| 962 |
+
|
| 963 |
+
print(f"Multi-label Condition Classification:")
|
| 964 |
+
print(f" Hamming Loss: {hamming:.4f}")
|
| 965 |
+
print(f" Exact Match Accuracy: {exact_match:.4f}")
|
| 966 |
+
print(f" Average conditions per sample: {y_cond_test.sum(axis=1).mean():.2f}")
|
| 967 |
+
|
| 968 |
+
# ENHANCED confidence evaluation
|
| 969 |
+
print(f"\nConfidence Prediction Analysis:")
|
| 970 |
+
print(f" Raw prediction range: {y_conf_pred.min():.6f} - {y_conf_pred.max():.6f}")
|
| 971 |
+
print(f" Non-zero predictions: {(y_conf_pred > 0.001).sum()}/{y_conf_pred.size}")
|
| 972 |
+
|
| 973 |
+
# Inverse transform and evaluate confidence
|
| 974 |
+
conf_mask = y_conf_test > 0.001
|
| 975 |
+
if conf_mask.sum() > 0:
|
| 976 |
+
y_conf_test_orig = np.zeros_like(y_conf_test)
|
| 977 |
+
y_conf_pred_orig = np.zeros_like(y_conf_pred)
|
| 978 |
+
|
| 979 |
+
# Inverse transform
|
| 980 |
+
for i in range(y_conf_test.shape[0]):
|
| 981 |
+
for j in range(y_conf_test.shape[1]):
|
| 982 |
+
if y_conf_test[i, j] > 0.001:
|
| 983 |
+
y_conf_test_orig[i, j] = self.confidence_scaler.inverse_transform([[y_conf_test[i, j]]])[0, 0]
|
| 984 |
+
if y_conf_pred[i, j] > 0.001:
|
| 985 |
+
y_conf_pred_orig[i, j] = self.confidence_scaler.inverse_transform([[y_conf_pred[i, j]]])[0, 0]
|
| 986 |
+
|
| 987 |
+
# Calculate metrics only on positions where ground truth is non-zero
|
| 988 |
+
conf_test_nonzero = y_conf_test_orig[conf_mask]
|
| 989 |
+
conf_pred_nonzero = y_conf_pred_orig[conf_mask]
|
| 990 |
+
|
| 991 |
+
conf_mse = mean_squared_error(conf_test_nonzero, conf_pred_nonzero)
|
| 992 |
+
conf_mae = mean_absolute_error(conf_test_nonzero, conf_pred_nonzero)
|
| 993 |
+
|
| 994 |
+
print(f" Individual Confidence Regression (on {conf_mask.sum()} non-zero targets):")
|
| 995 |
+
print(f" MSE: {conf_mse:.4f}")
|
| 996 |
+
print(f" MAE: {conf_mae:.4f}")
|
| 997 |
+
print(f" RMSE: {np.sqrt(conf_mse):.4f}")
|
| 998 |
+
print(f" Prediction range (orig scale): {conf_pred_nonzero.min():.3f} - {conf_pred_nonzero.max():.3f}")
|
| 999 |
+
print(f" Ground truth range (orig scale): {conf_test_nonzero.min():.3f} - {conf_test_nonzero.max():.3f}")
|
| 1000 |
+
|
| 1001 |
+
# Check if predictions are reasonable
|
| 1002 |
+
reasonable_predictions = (conf_pred_nonzero > 0.1).sum()
|
| 1003 |
+
print(f" Predictions > 0.1: {reasonable_predictions}/{len(conf_pred_nonzero)} ({100*reasonable_predictions/len(conf_pred_nonzero):.1f}%)")
|
| 1004 |
+
|
| 1005 |
+
# Individual weight evaluation
|
| 1006 |
+
weight_mask = y_weight_test > 0.001
|
| 1007 |
+
if weight_mask.sum() > 0:
|
| 1008 |
+
y_weight_test_orig = np.zeros_like(y_weight_test)
|
| 1009 |
+
y_weight_pred_orig = np.zeros_like(y_weight_pred)
|
| 1010 |
+
|
| 1011 |
+
for i in range(y_weight_test.shape[0]):
|
| 1012 |
+
for j in range(y_weight_test.shape[1]):
|
| 1013 |
+
if y_weight_test[i, j] > 0.001:
|
| 1014 |
+
y_weight_test_orig[i, j] = self.weighted_scaler.inverse_transform([[y_weight_test[i, j]]])[0, 0]
|
| 1015 |
+
if y_weight_pred[i, j] > 0.001:
|
| 1016 |
+
y_weight_pred_orig[i, j] = self.weighted_scaler.inverse_transform([[y_weight_pred[i, j]]])[0, 0]
|
| 1017 |
+
|
| 1018 |
+
weight_test_nonzero = y_weight_test_orig[weight_mask]
|
| 1019 |
+
weight_pred_nonzero = y_weight_pred_orig[weight_mask]
|
| 1020 |
+
|
| 1021 |
+
weight_mse = mean_squared_error(weight_test_nonzero, weight_pred_nonzero)
|
| 1022 |
+
weight_mae = mean_absolute_error(weight_test_nonzero, weight_pred_nonzero)
|
| 1023 |
+
|
| 1024 |
+
print(f"\nIndividual Weight Regression (on {weight_mask.sum()} non-zero targets):")
|
| 1025 |
+
print(f" MSE: {weight_mse:.4f}")
|
| 1026 |
+
print(f" MAE: {weight_mae:.4f}")
|
| 1027 |
+
print(f" RMSE: {np.sqrt(weight_mse):.4f}")
|
| 1028 |
+
print(f" Prediction range (orig scale): {weight_pred_nonzero.min():.3f} - {weight_pred_nonzero.max():.3f}")
|
| 1029 |
+
print(f" Ground truth range (orig scale): {weight_test_nonzero.min():.3f} - {weight_test_nonzero.max():.3f}")
|
| 1030 |
+
"""
|
| 1031 |
+
Make predictions on new embeddings with comprehensive output formatting.
|
| 1032 |
+
PREDICTION PIPELINE:
|
| 1033 |
+
|
| 1034 |
+
Scale input embedding (using training-fitted scaler)
|
| 1035 |
+
Forward pass through neural network
|
| 1036 |
+
Process raw outputs:
|
| 1037 |
+
|
| 1038 |
+
Condition probabilities: Sigmoid outputs [0,1]
|
| 1039 |
+
Confidence values: Softplus outputs [0,∞)
|
| 1040 |
+
Weight values: Sigmoid outputs [0,1]
|
| 1041 |
+
|
| 1042 |
+
Inverse transform regression outputs to original scale
|
| 1043 |
+
Apply threshold to select predicted conditions
|
| 1044 |
+
Return structured dictionary with multiple views of predictions
|
| 1045 |
+
|
| 1046 |
+
THRESHOLD STRATEGY:
|
| 1047 |
+
|
| 1048 |
+
Condition threshold: 0.3 (lower than typical 0.5)
|
| 1049 |
+
Why lower: Medical diagnosis prefers sensitivity (catch more conditions)
|
| 1050 |
+
False positives less harmful than false negatives in screening
|
| 1051 |
+
Can be adjusted based on clinical requirements
|
| 1052 |
+
|
| 1053 |
+
OUTPUT STRUCTURE:
|
| 1054 |
+
Primary Predictions (conditions above threshold):
|
| 1055 |
+
|
| 1056 |
+
dermatologist_skin_condition_on_label_name: List of predicted conditions
|
| 1057 |
+
dermatologist_skin_condition_confidence: Confidence per predicted condition
|
| 1058 |
+
weighted_skin_condition_label: Weight dict for predicted conditions
|
| 1059 |
+
|
| 1060 |
+
Comprehensive View (all conditions):
|
| 1061 |
+
|
| 1062 |
+
all_condition_probabilities: Probability for every possible condition
|
| 1063 |
+
all_individual_confidences: Confidence for every possible condition
|
| 1064 |
+
all_individual_weights: Weight for every possible condition
|
| 1065 |
+
|
| 1066 |
+
Debugging Information:
|
| 1067 |
+
|
| 1068 |
+
raw_confidence_outputs: Pre-transform neural network outputs
|
| 1069 |
+
raw_weight_outputs: Pre-transform neural network outputs
|
| 1070 |
+
condition_threshold: Threshold used for filtering
|
| 1071 |
+
|
| 1072 |
+
Why provide multiple views:
|
| 1073 |
+
|
| 1074 |
+
Primary predictions: For direct clinical use
|
| 1075 |
+
Comprehensive view: For ranking, uncertainty quantification
|
| 1076 |
+
Debug info: For model validation and troubleshooting
|
| 1077 |
+
|
| 1078 |
+
MINIMUM VALUE CLAMPING:
|
| 1079 |
+
pythonconfidence_orig = max(0.1, confidence_orig)
|
| 1080 |
+
weight_orig = max(0.01, weight_orig)
|
| 1081 |
+
|
| 1082 |
+
Ensures predictions are never exactly zero
|
| 1083 |
+
Confidence ≥0.1: Even lowest predictions are meaningful
|
| 1084 |
+
Weight ≥0.01: Prevents division-by-zero in downstream processing
|
| 1085 |
+
|
| 1086 |
+
SOFTPLUS ADVANTAGE:
|
| 1087 |
+
With softplus activation, even very negative inputs produce small positive
|
| 1088 |
+
outputs, so confidence predictions naturally avoid zero. The max(0.1, x)
|
| 1089 |
+
provides additional safety margin.
|
| 1090 |
+
RETURN FORMAT:
|
| 1091 |
+
Dictionary structure allows:
|
| 1092 |
+
|
| 1093 |
+
Easy access to specific prediction types
|
| 1094 |
+
Clear semantic meaning (key names describe contents)
|
| 1095 |
+
Extensible (can add new keys without breaking existing code)
|
| 1096 |
+
JSON-serializable for API deployment
|
| 1097 |
+
"""
|
| 1098 |
+
def predict(self, embedding):
|
| 1099 |
+
"""Make predictions on a single embedding with individual outputs"""
|
| 1100 |
+
if self.model is None:
|
| 1101 |
+
print("ERROR: Model not trained. Call train() first.")
|
| 1102 |
+
return None
|
| 1103 |
+
|
| 1104 |
+
if len(embedding.shape) == 1:
|
| 1105 |
+
embedding = embedding.reshape(1, -1)
|
| 1106 |
+
|
| 1107 |
+
# Scale embedding
|
| 1108 |
+
embedding_scaled = self.embedding_scaler.transform(embedding)
|
| 1109 |
+
|
| 1110 |
+
# Make predictions
|
| 1111 |
+
predictions = self.model.predict(embedding_scaled, verbose=0)
|
| 1112 |
+
|
| 1113 |
+
# Process condition predictions
|
| 1114 |
+
condition_probs = predictions['conditions'][0]
|
| 1115 |
+
individual_confidences = predictions['individual_confidences'][0]
|
| 1116 |
+
individual_weights = predictions['individual_weights'][0]
|
| 1117 |
+
|
| 1118 |
+
# Get predicted conditions (above threshold)
|
| 1119 |
+
condition_threshold = 0.3 # Lower threshold
|
| 1120 |
+
predicted_condition_indices = np.where(condition_probs > condition_threshold)[0]
|
| 1121 |
+
|
| 1122 |
+
# Build results
|
| 1123 |
+
predicted_conditions = []
|
| 1124 |
+
predicted_confidences = []
|
| 1125 |
+
predicted_weights_dict = {}
|
| 1126 |
+
|
| 1127 |
+
for idx in predicted_condition_indices:
|
| 1128 |
+
condition_name = self.mlb.classes_[idx]
|
| 1129 |
+
condition_prob = float(condition_probs[idx])
|
| 1130 |
+
|
| 1131 |
+
# Inverse transform individual outputs with better handling
|
| 1132 |
+
confidence_raw = individual_confidences[idx]
|
| 1133 |
+
weight_raw = individual_weights[idx]
|
| 1134 |
+
|
| 1135 |
+
# Always inverse transform, even small values (softplus ensures non-zero)
|
| 1136 |
+
confidence_orig = self.confidence_scaler.inverse_transform([[confidence_raw]])[0, 0]
|
| 1137 |
+
weight_orig = self.weighted_scaler.inverse_transform([[weight_raw]])[0, 0]
|
| 1138 |
+
|
| 1139 |
+
predicted_conditions.append(condition_name)
|
| 1140 |
+
predicted_confidences.append(max(0.1, confidence_orig)) # Minimum confidence of 0.1
|
| 1141 |
+
predicted_weights_dict[condition_name] = max(0.01, weight_orig) # Minimum weight of 0.01
|
| 1142 |
+
|
| 1143 |
+
# Also provide all condition probabilities for reference
|
| 1144 |
+
all_condition_probs = {}
|
| 1145 |
+
all_confidences = {}
|
| 1146 |
+
all_weights = {}
|
| 1147 |
+
|
| 1148 |
+
for i, class_name in enumerate(self.mlb.classes_):
|
| 1149 |
+
all_condition_probs[class_name] = float(condition_probs[i])
|
| 1150 |
+
|
| 1151 |
+
# Always inverse transform all outputs
|
| 1152 |
+
conf_raw = individual_confidences[i]
|
| 1153 |
+
weight_raw = individual_weights[i]
|
| 1154 |
+
|
| 1155 |
+
conf_orig = self.confidence_scaler.inverse_transform([[conf_raw]])[0, 0]
|
| 1156 |
+
weight_orig = self.weighted_scaler.inverse_transform([[weight_raw]])[0, 0]
|
| 1157 |
+
|
| 1158 |
+
all_confidences[class_name] = max(0.0, conf_orig)
|
| 1159 |
+
all_weights[class_name] = max(0.0, weight_orig)
|
| 1160 |
+
|
| 1161 |
+
return {
|
| 1162 |
+
# Main predicted results (above threshold)
|
| 1163 |
+
'dermatologist_skin_condition_on_label_name': predicted_conditions,
|
| 1164 |
+
'dermatologist_skin_condition_confidence': predicted_confidences,
|
| 1165 |
+
'weighted_skin_condition_label': predicted_weights_dict,
|
| 1166 |
+
|
| 1167 |
+
# Additional information for analysis
|
| 1168 |
+
'all_condition_probabilities': all_condition_probs,
|
| 1169 |
+
'all_individual_confidences': all_confidences,
|
| 1170 |
+
'all_individual_weights': all_weights,
|
| 1171 |
+
'condition_threshold': condition_threshold,
|
| 1172 |
+
|
| 1173 |
+
# Debug information
|
| 1174 |
+
'raw_confidence_outputs': individual_confidences.tolist(),
|
| 1175 |
+
'raw_weight_outputs': individual_weights.tolist()
|
| 1176 |
+
}
|
| 1177 |
+
|
| 1178 |
+
def plot_training_history(self):
|
| 1179 |
+
if self.history is None:
|
| 1180 |
+
print("No training history available")
|
| 1181 |
+
return
|
| 1182 |
+
|
| 1183 |
+
# Set matplotlib to use non-interactive backend
|
| 1184 |
+
import matplotlib
|
| 1185 |
+
matplotlib.use('Agg')
|
| 1186 |
+
import matplotlib.pyplot as plt
|
| 1187 |
+
|
| 1188 |
+
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
| 1189 |
+
|
| 1190 |
+
# Loss
|
| 1191 |
+
axes[0, 0].plot(self.history.history['loss'], label='Training Loss')
|
| 1192 |
+
axes[0, 0].plot(self.history.history['val_loss'], label='Validation Loss')
|
| 1193 |
+
axes[0, 0].set_title('Model Loss')
|
| 1194 |
+
axes[0, 0].set_xlabel('Epoch')
|
| 1195 |
+
axes[0, 0].set_ylabel('Loss')
|
| 1196 |
+
axes[0, 0].legend()
|
| 1197 |
+
|
| 1198 |
+
# Condition accuracy
|
| 1199 |
+
axes[0, 1].plot(self.history.history['conditions_accuracy'], label='Training Accuracy')
|
| 1200 |
+
axes[0, 1].plot(self.history.history['val_conditions_accuracy'], label='Validation Accuracy')
|
| 1201 |
+
axes[0, 1].set_title('Condition Classification Accuracy')
|
| 1202 |
+
axes[0, 1].set_xlabel('Epoch')
|
| 1203 |
+
axes[0, 1].set_ylabel('Accuracy')
|
| 1204 |
+
axes[0, 1].legend()
|
| 1205 |
+
|
| 1206 |
+
# Individual Confidence MAE
|
| 1207 |
+
axes[0, 2].plot(self.history.history['individual_confidences_mae'], label='Training MAE')
|
| 1208 |
+
axes[0, 2].plot(self.history.history['val_individual_confidences_mae'], label='Validation MAE')
|
| 1209 |
+
axes[0, 2].set_title('Individual Confidence MAE')
|
| 1210 |
+
axes[0, 2].set_xlabel('Epoch')
|
| 1211 |
+
axes[0, 2].set_ylabel('MAE')
|
| 1212 |
+
axes[0, 2].legend()
|
| 1213 |
+
|
| 1214 |
+
# Individual Weight MAE
|
| 1215 |
+
axes[1, 0].plot(self.history.history['individual_weights_mae'], label='Training MAE')
|
| 1216 |
+
axes[1, 0].plot(self.history.history['val_individual_weights_mae'], label='Validation MAE')
|
| 1217 |
+
axes[1, 0].set_title('Individual Weight MAE')
|
| 1218 |
+
axes[1, 0].set_xlabel('Epoch')
|
| 1219 |
+
axes[1, 0].set_ylabel('MAE')
|
| 1220 |
+
axes[1, 0].legend()
|
| 1221 |
+
|
| 1222 |
+
# Individual confidence loss
|
| 1223 |
+
axes[1, 1].plot(self.history.history['individual_confidences_loss'], label='Training Loss')
|
| 1224 |
+
axes[1, 1].plot(self.history.history['val_individual_confidences_loss'], label='Validation Loss')
|
| 1225 |
+
axes[1, 1].set_title('Individual Confidence Loss')
|
| 1226 |
+
axes[1, 1].set_xlabel('Epoch')
|
| 1227 |
+
axes[1, 1].set_ylabel('Loss')
|
| 1228 |
+
axes[1, 1].legend()
|
| 1229 |
+
|
| 1230 |
+
# Individual weight loss
|
| 1231 |
+
axes[1, 2].plot(self.history.history['individual_weights_loss'], label='Training Loss')
|
| 1232 |
+
axes[1, 2].plot(self.history.history['val_individual_weights_loss'], label='Validation Loss')
|
| 1233 |
+
axes[1, 2].set_title('Individual Weight Loss')
|
| 1234 |
+
axes[1, 2].set_xlabel('Epoch')
|
| 1235 |
+
axes[1, 2].set_ylabel('Loss')
|
| 1236 |
+
axes[1, 2].legend()
|
| 1237 |
+
|
| 1238 |
+
plt.tight_layout()
|
| 1239 |
+
plt.savefig('training_history_fixed.png', dpi=300, bbox_inches='tight')
|
| 1240 |
+
print("Training history plot saved as: training_history_fixed.png")
|
| 1241 |
+
plt.close()
|
| 1242 |
+
"""
|
| 1243 |
+
Persist trained model and preprocessing components to disk.
|
| 1244 |
+
|
| 1245 |
+
SAVED COMPONENTS:
|
| 1246 |
+
|
| 1247 |
+
1. **Keras Model (.keras file)**:
|
| 1248 |
+
- Neural network architecture
|
| 1249 |
+
- Trained weights for all layers
|
| 1250 |
+
- Optimizer state (for resuming training)
|
| 1251 |
+
- Compilation settings (loss functions, metrics)
|
| 1252 |
+
|
| 1253 |
+
2. **Preprocessing Data (.pkl file)**:
|
| 1254 |
+
- MultiLabelBinarizer: Maps condition names ↔ indices
|
| 1255 |
+
- embedding_scaler: Normalizes input embeddings
|
| 1256 |
+
- confidence_scaler: Normalizes confidence values
|
| 1257 |
+
- weighted_scaler: Normalizes weight values
|
| 1258 |
+
- Path to .keras file (for loading)
|
| 1259 |
+
|
| 1260 |
+
WHY SEPARATE FILES:
|
| 1261 |
+
- Keras models save to modern .keras format
|
| 1262 |
+
- Scikit-learn components need pickle serialization
|
| 1263 |
+
- Separation allows independent updates of each component
|
| 1264 |
+
|
| 1265 |
+
LOADING REQUIREMENT:
|
| 1266 |
+
Both files are needed for inference:
|
| 1267 |
+
- .keras: Neural network for making predictions
|
| 1268 |
+
- .pkl: Preprocessors for transforming inputs/outputs
|
| 1269 |
+
|
| 1270 |
+
FILE ORGANIZATION:
|
| 1271 |
+
easi_severity_model_derm_foundation_individual_fixed.pkl (main file)
|
| 1272 |
+
easi_severity_model_derm_foundation_individual_fixed.keras (model)
|
| 1273 |
+
User loads .pkl file, which contains path to .keras file
|
| 1274 |
+
|
| 1275 |
+
CLEANUP:
|
| 1276 |
+
Removes temporary checkpoint file (best_model_fixed.weights.h5)
|
| 1277 |
+
created during training to avoid confusion with final model.
|
| 1278 |
+
|
| 1279 |
+
ERROR HANDLING:
|
| 1280 |
+
Checks if model exists before saving, provides clear error messages
|
| 1281 |
+
and file paths for debugging.
|
| 1282 |
+
"""
|
| 1283 |
+
|
| 1284 |
+
def save_model(self, filepath="easi_severity_model_derm_foundation_individual_fixed.pkl"):
|
| 1285 |
+
"""Save the trained model"""
|
| 1286 |
+
if self.model is None:
|
| 1287 |
+
print("ERROR: No trained model to save.")
|
| 1288 |
+
return False
|
| 1289 |
+
|
| 1290 |
+
# Get current directory
|
| 1291 |
+
current_dir = os.getcwd()
|
| 1292 |
+
|
| 1293 |
+
# Save Keras model with proper extension
|
| 1294 |
+
model_filename = os.path.splitext(filepath)[0]
|
| 1295 |
+
keras_model_path = os.path.join(current_dir, f"{model_filename}.keras")
|
| 1296 |
+
|
| 1297 |
+
print(f"Saving Keras model to: {keras_model_path}")
|
| 1298 |
+
self.model.save(keras_model_path)
|
| 1299 |
+
|
| 1300 |
+
# Save preprocessing components
|
| 1301 |
+
pkl_filepath = os.path.join(current_dir, filepath)
|
| 1302 |
+
model_data = {
|
| 1303 |
+
'mlb': self.mlb,
|
| 1304 |
+
'embedding_scaler': self.embedding_scaler,
|
| 1305 |
+
'confidence_scaler': self.confidence_scaler,
|
| 1306 |
+
'weighted_scaler': self.weighted_scaler,
|
| 1307 |
+
'keras_model_path': keras_model_path
|
| 1308 |
+
}
|
| 1309 |
+
|
| 1310 |
+
print(f"Saving preprocessing data to: {pkl_filepath}")
|
| 1311 |
+
with open(pkl_filepath, 'wb') as f:
|
| 1312 |
+
pickle.dump(model_data, f)
|
| 1313 |
+
|
| 1314 |
+
print(f"Model saved successfully!")
|
| 1315 |
+
print(f" - Main file: {pkl_filepath}")
|
| 1316 |
+
print(f" - Keras model: {keras_model_path}")
|
| 1317 |
+
|
| 1318 |
+
# Clean up temporary checkpoint file
|
| 1319 |
+
checkpoint_file = os.path.join(current_dir, 'best_model_fixed.weights.h5')
|
| 1320 |
+
if os.path.exists(checkpoint_file):
|
| 1321 |
+
os.remove(checkpoint_file)
|
| 1322 |
+
print(f" - Cleaned up temporary checkpoint file")
|
| 1323 |
+
|
| 1324 |
+
return True
|
| 1325 |
+
|
| 1326 |
+
def load_model(self, filepath="easi_severity_model_derm_foundation_individual_fixed.pkl"):
|
| 1327 |
+
"""Load trained model"""
|
| 1328 |
+
if not os.path.exists(filepath):
|
| 1329 |
+
print(f"ERROR: Model file not found: {filepath}")
|
| 1330 |
+
return False
|
| 1331 |
+
|
| 1332 |
+
try:
|
| 1333 |
+
with open(filepath, 'rb') as f:
|
| 1334 |
+
model_data = pickle.load(f)
|
| 1335 |
+
|
| 1336 |
+
# Load preprocessing components
|
| 1337 |
+
self.mlb = model_data['mlb']
|
| 1338 |
+
self.embedding_scaler = model_data['embedding_scaler']
|
| 1339 |
+
self.confidence_scaler = model_data['confidence_scaler']
|
| 1340 |
+
self.weighted_scaler = model_data['weighted_scaler']
|
| 1341 |
+
|
| 1342 |
+
# Load Keras model
|
| 1343 |
+
keras_model_path = model_data['keras_model_path']
|
| 1344 |
+
if os.path.exists(keras_model_path):
|
| 1345 |
+
self.model = keras.models.load_model(keras_model_path)
|
| 1346 |
+
print(f"Model loaded from {filepath}")
|
| 1347 |
+
print(f"Available condition classes: {len(self.mlb.classes_)}")
|
| 1348 |
+
return True
|
| 1349 |
+
else:
|
| 1350 |
+
print(f"ERROR: Keras model not found at {keras_model_path}")
|
| 1351 |
+
return False
|
| 1352 |
+
|
| 1353 |
+
except Exception as e:
|
| 1354 |
+
print(f"Error loading model: {e}")
|
| 1355 |
+
return False
|
| 1356 |
+
|
| 1357 |
+
"""
|
| 1358 |
+
WORKFLOW:
|
| 1359 |
+
|
| 1360 |
+
1. Print configuration and fixes applied (user visibility)
|
| 1361 |
+
2. Initialize classifier
|
| 1362 |
+
3. Validate input files exist
|
| 1363 |
+
4. Train model with improved confidence handling
|
| 1364 |
+
5. Plot training history
|
| 1365 |
+
6. Test model predictions (validate fix effectiveness)
|
| 1366 |
+
7. Save trained model
|
| 1367 |
+
|
| 1368 |
+
MODEL TESTING (NEW):
|
| 1369 |
+
After training completes, runs a sample prediction to verify:
|
| 1370 |
+
|
| 1371 |
+
Model produces non-zero confidence values (fix validation)
|
| 1372 |
+
Predictions are in expected ranges
|
| 1373 |
+
Output structure is correct
|
| 1374 |
+
|
| 1375 |
+
This immediate validation catches issues before deployment.
|
| 1376 |
+
WHY TEST WITH SAMPLE:
|
| 1377 |
+
|
| 1378 |
+
Confirms confidence scaling fix worked
|
| 1379 |
+
Provides immediate feedback on model quality
|
| 1380 |
+
Demonstrates expected output format
|
| 1381 |
+
Catches activation function issues (like ReLU→0 bug)
|
| 1382 |
+
|
| 1383 |
+
SUCCESS CRITERIA:
|
| 1384 |
+
✅ Non-zero confidences in reasonable range (e.g., 1-5)
|
| 1385 |
+
✅ Multiple conditions predicted with varying probabilities
|
| 1386 |
+
✅ Weights sum to reasonable values
|
| 1387 |
+
⚠️ Warning if confidence outputs still mostly zero
|
| 1388 |
+
"""
|
| 1389 |
+
|
| 1390 |
+
def main():
|
| 1391 |
+
"""Main training function with enhanced confidence handling"""
|
| 1392 |
+
print("Derm Foundation Neural Network Classifier Training - FIXED VERSION")
|
| 1393 |
+
print("="*70)
|
| 1394 |
+
print("FIXES APPLIED:")
|
| 1395 |
+
print("- Changed confidence activation from ReLU to softplus")
|
| 1396 |
+
print("- Improved confidence scaler fitting (non-zero values only)")
|
| 1397 |
+
print("- Increased confidence loss weight (1.5x)")
|
| 1398 |
+
print("- Enhanced data validation and preprocessing")
|
| 1399 |
+
print("- Better handling of sparse confidence/weight matrices")
|
| 1400 |
+
print("="*70)
|
| 1401 |
+
print("Training neural network to predict:")
|
| 1402 |
+
print("1. Skin conditions (multi-label classification)")
|
| 1403 |
+
print("2. Individual confidence scores per condition (regression)")
|
| 1404 |
+
print("3. Individual weight scores per condition (regression)")
|
| 1405 |
+
print("="*70)
|
| 1406 |
+
|
| 1407 |
+
# Initialize classifier
|
| 1408 |
+
classifier = DermFoundationNeuralNetwork()
|
| 1409 |
+
|
| 1410 |
+
# File paths
|
| 1411 |
+
npz_file = "derm_foundation_embeddings.npz"
|
| 1412 |
+
csv_file = "dataset_scin_labels.csv"
|
| 1413 |
+
model_output = "easi_severity_model_derm_foundation_individual_fixed.pkl"
|
| 1414 |
+
|
| 1415 |
+
# Check if files exist
|
| 1416 |
+
missing_files = []
|
| 1417 |
+
if not os.path.exists(npz_file):
|
| 1418 |
+
missing_files.append(npz_file)
|
| 1419 |
+
if not os.path.exists(csv_file):
|
| 1420 |
+
missing_files.append(csv_file)
|
| 1421 |
+
|
| 1422 |
+
if missing_files:
|
| 1423 |
+
print(f"ERROR: Missing required files:")
|
| 1424 |
+
for file in missing_files:
|
| 1425 |
+
print(f" - {file}")
|
| 1426 |
+
return
|
| 1427 |
+
|
| 1428 |
+
try:
|
| 1429 |
+
# Train the model
|
| 1430 |
+
success = classifier.train(
|
| 1431 |
+
npz_file_path=npz_file,
|
| 1432 |
+
csv_file_path=csv_file,
|
| 1433 |
+
epochs=60, # Increased epochs
|
| 1434 |
+
batch_size=32,
|
| 1435 |
+
learning_rate=0.001
|
| 1436 |
+
)
|
| 1437 |
+
|
| 1438 |
+
if not success:
|
| 1439 |
+
print("Training failed!")
|
| 1440 |
+
return
|
| 1441 |
+
|
| 1442 |
+
# Plot training history
|
| 1443 |
+
try:
|
| 1444 |
+
classifier.plot_training_history()
|
| 1445 |
+
except Exception as e:
|
| 1446 |
+
print(f"Could not plot training history: {e}")
|
| 1447 |
+
|
| 1448 |
+
# Test the model with a sample prediction to verify confidence outputs
|
| 1449 |
+
print("\n" + "="*70)
|
| 1450 |
+
print("TESTING MODEL OUTPUTS")
|
| 1451 |
+
print("="*70)
|
| 1452 |
+
|
| 1453 |
+
# Get a sample embedding for testing
|
| 1454 |
+
try:
|
| 1455 |
+
embeddings = classifier.load_embeddings(npz_file)
|
| 1456 |
+
if embeddings:
|
| 1457 |
+
sample_key = list(embeddings.keys())[0]
|
| 1458 |
+
sample_embedding = embeddings[sample_key]
|
| 1459 |
+
|
| 1460 |
+
print(f"Testing with sample embedding: {sample_key}")
|
| 1461 |
+
test_result = classifier.predict(sample_embedding)
|
| 1462 |
+
|
| 1463 |
+
if test_result:
|
| 1464 |
+
print("✅ Model prediction successful!")
|
| 1465 |
+
print(f"Predicted conditions: {len(test_result['dermatologist_skin_condition_on_label_name'])}")
|
| 1466 |
+
|
| 1467 |
+
# Check confidence outputs
|
| 1468 |
+
all_confidences = list(test_result['all_individual_confidences'].values())
|
| 1469 |
+
nonzero_conf = sum(1 for c in all_confidences if c > 0.01)
|
| 1470 |
+
|
| 1471 |
+
print(f"Confidence range: {min(all_confidences):.4f} - {max(all_confidences):.4f}")
|
| 1472 |
+
print(f"Non-zero confidences: {nonzero_conf}/{len(all_confidences)}")
|
| 1473 |
+
|
| 1474 |
+
if nonzero_conf > 0:
|
| 1475 |
+
print("✅ CONFIDENCE ISSUE APPEARS TO BE FIXED!")
|
| 1476 |
+
else:
|
| 1477 |
+
print("⚠️ Confidence outputs still mostly zero - may need further investigation")
|
| 1478 |
+
|
| 1479 |
+
# Show top predictions
|
| 1480 |
+
if test_result['dermatologist_skin_condition_on_label_name']:
|
| 1481 |
+
print(f"\nSample predictions:")
|
| 1482 |
+
for i, condition in enumerate(test_result['dermatologist_skin_condition_on_label_name'][:3]):
|
| 1483 |
+
prob = test_result['all_condition_probabilities'][condition]
|
| 1484 |
+
conf = test_result['dermatologist_skin_condition_confidence'][i]
|
| 1485 |
+
weight = test_result['weighted_skin_condition_label'][condition]
|
| 1486 |
+
print(f" {condition}: prob={prob:.3f}, conf={conf:.3f}, weight={weight:.3f}")
|
| 1487 |
+
else:
|
| 1488 |
+
print("❌ Model prediction failed")
|
| 1489 |
+
except Exception as e:
|
| 1490 |
+
print(f"Could not test model: {e}")
|
| 1491 |
+
|
| 1492 |
+
# Save the model
|
| 1493 |
+
classifier.save_model(model_output)
|
| 1494 |
+
|
| 1495 |
+
print(f"\n{'='*70}")
|
| 1496 |
+
print("TRAINING COMPLETE!")
|
| 1497 |
+
print(f"{'='*70}")
|
| 1498 |
+
print(f"Model saved as: {model_output}")
|
| 1499 |
+
print(f"Training history plot saved as: training_history_fixed.png")
|
| 1500 |
+
print(f"\nTo use the trained model:")
|
| 1501 |
+
print(f"```python")
|
| 1502 |
+
print(f"classifier = DermFoundationNeuralNetwork()")
|
| 1503 |
+
print(f"classifier.load_model('{model_output}')")
|
| 1504 |
+
print(f"result = classifier.predict(embedding)")
|
| 1505 |
+
print(f"print(result['dermatologist_skin_condition_on_label_name'])")
|
| 1506 |
+
print(f"print(result['dermatologist_skin_condition_confidence'])")
|
| 1507 |
+
print(f"print(result['weighted_skin_condition_label'])")
|
| 1508 |
+
print(f"```")
|
| 1509 |
+
|
| 1510 |
+
# Example prediction output format
|
| 1511 |
+
print(f"\nExpected prediction output format:")
|
| 1512 |
+
print(f"{{")
|
| 1513 |
+
print(f" 'dermatologist_skin_condition_on_label_name': ['Eczema', 'Irritant Contact Dermatitis'],")
|
| 1514 |
+
print(f" 'dermatologist_skin_condition_confidence': [4.2, 3.1],")
|
| 1515 |
+
print(f" 'weighted_skin_condition_label': {{'Eczema': 0.65, 'Irritant Contact Dermatitis': 0.35}}")
|
| 1516 |
+
print(f"}}")
|
| 1517 |
+
|
| 1518 |
+
except Exception as e:
|
| 1519 |
+
print(f"Error during training: {e}")
|
| 1520 |
+
import traceback
|
| 1521 |
+
traceback.print_exc()
|
| 1522 |
+
|
| 1523 |
+
|
| 1524 |
+
if __name__ == "__main__":
|
| 1525 |
+
main()
|
training_history_fixed.png
ADDED
|
Git LFS Details
|