RezinWiz commited on
Commit
dc04619
·
verified ·
1 Parent(s): 310a49f

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -1,2 +1,6 @@
1
  # Auto detect text files and perform LF normalization
2
  * text=auto
 
 
 
 
 
1
  # Auto detect text files and perform LF normalization
2
  * text=auto
3
+ derm_foundation_embeddings.npz filter=lfs diff=lfs merge=lfs -text
4
+ easi_severity_model_derm_foundation_individual_fixed.keras filter=lfs diff=lfs merge=lfs -text
5
+ easi_severity_model_derm_foundation_individual.pkl filter=lfs diff=lfs merge=lfs -text
6
+ training_history_fixed.png filter=lfs diff=lfs merge=lfs -text
dataset_scin_labels.csv ADDED
The diff for this file is too large to render. See raw diff
 
derm_foundation_embeddings.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fe7208829940a3ad6d9d0964948d3004c3d7fcafa17b98db6402a638ce5ac61
3
+ size 17779095
easi_severity_model_derm_foundation_individual.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf635500ebe9f4bf8a02b3d9de17fcc75595e845113ae46fd40c710cb2bf71a7
3
+ size 149217
easi_severity_model_derm_foundation_individual_fixed.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b66fa5443ac22a54c74c6a154ea5dcacc2c9856839a17044043ea77f41560177
3
+ size 40255675
train_easi_model.py ADDED
@@ -0,0 +1,1525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Derm Foundation Neural Network Classifier Training Script - Fixed Version
5
+
6
+ PURPOSE:
7
+ This script trains a multi-output neural network to predict dermatological
8
+ conditions and their associated metadata from pre-computed embeddings. It
9
+ addresses the challenging problem of multi-label medical diagnosis where:
10
+ 1. Multiple conditions can co-exist (multi-label classification)
11
+ 2. Each diagnosis has an associated confidence level (regression)
12
+ 3. Each diagnosis has a weight indicating relative importance (regression)
13
+
14
+ WHY NEURAL NETWORKS FOR THIS TASK:
15
+ Neural networks are the optimal choice for this problem for several reasons:
16
+
17
+ 1. **Non-linear Relationship Learning**: The relationship between image
18
+ embeddings and skin conditions is highly non-linear. Neural networks excel
19
+ at learning complex, non-linear mappings that simpler models (like logistic
20
+ regression) cannot capture.
21
+
22
+ 2. **Multi-task Learning**: This problem requires predicting three related but
23
+ distinct outputs (conditions, confidence, weights). Neural networks can
24
+ share learned representations across these tasks through shared layers,
25
+ improving generalization and efficiency.
26
+
27
+ 3. **High-dimensional Input**: Embeddings are typically 512-1024 dimensional
28
+ vectors. Neural networks are designed to handle high-dimensional inputs
29
+ effectively through dimensionality reduction in hidden layers.
30
+
31
+ 4. **Multi-label Classification**: Medical diagnosis often involves multiple
32
+ co-existing conditions. Neural networks with sigmoid activation can model
33
+ the independent probability of each condition, unlike single-label methods.
34
+
35
+ 5. **Flexibility**: The architecture can be customized with task-specific
36
+ heads (branches) for different prediction types, allowing specialized
37
+ processing for classification vs regression outputs.
38
+
39
+ WHY HAMMING LOSS IS VALID:
40
+ Hamming loss is an appropriate metric for multi-label classification because:
41
+
42
+ 1. **Accounts for Partial Correctness**: Unlike exact match accuracy, hamming
43
+ loss gives credit for partially correct predictions. Predicting 3 out of 4
44
+ conditions correctly is better than 0 out of 4.
45
+
46
+ 2. **Label-wise Evaluation**: It measures the fraction of incorrectly predicted
47
+ labels, treating each label independently - appropriate when conditions can
48
+ co-occur independently.
49
+
50
+ 3. **Bounded and Interpretable**: Ranges from 0 (perfect) to 1 (completely
51
+ wrong). A hamming loss of 0.1 means 10% of label predictions were incorrect.
52
+
53
+ 4. **Balanced for Sparse Labels**: In medical diagnosis, most samples have few
54
+ positive labels (sparse multi-label). Hamming loss naturally handles this
55
+ imbalance by computing the fraction across all labels.
56
+
57
+ 5. **Clinically Relevant**: In medical applications, both false positives and
58
+ false negatives matter. Hamming loss penalizes both equally, unlike metrics
59
+ that focus on one type of error.
60
+
61
+ MATHEMATICAL JUSTIFICATION:
62
+ For a sample with true labels y and predicted labels ŷ:
63
+ Hamming Loss = (1/n_labels) × Σ(y_i XOR ŷ_i)
64
+
65
+ This averages the disagreement across all possible labels, making it suitable
66
+ for scenarios where:
67
+ - The label space is large (many possible conditions)
68
+ - Label correlations exist but aren't perfectly predictable
69
+ - Clinical accuracy matters at the individual label level
70
+
71
+ FIXES APPLIED IN THIS VERSION:
72
+ - Changed confidence activation from ReLU to softplus (prevents zero outputs)
73
+ - Improved confidence scaler fitting (uses only non-zero values)
74
+ - Increased confidence loss weight (1.5x for better learning signal)
75
+ - Enhanced data validation and preprocessing
76
+ - Better handling of sparse confidence/weight matrices
77
+
78
+ Requirements:
79
+ - pandas
80
+ - numpy
81
+ - tensorflow>=2.13.0
82
+ - scikit-learn
83
+ - matplotlib
84
+ - pickle (standard library)
85
+ - os (standard library)
86
+ - derm_foundation_embeddings.npz: Pre-computed embeddings from images
87
+ - dataset_scin_labels.csv: Ground truth labels with conditions, confidences, weights
88
+
89
+ OUTPUT:
90
+ - Trained neural network model (.keras file)
91
+ - Preprocessing components (scalers, label encoder) in .pkl file
92
+ - Training history plots showing convergence
93
+ - Evaluation metrics on test set
94
+ """
95
+
96
+ import numpy as np
97
+ import pandas as pd
98
+ import pickle
99
+ import os
100
+ import tensorflow as tf
101
+ from tensorflow import keras
102
+ from tensorflow.keras import layers
103
+ from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler
104
+ from sklearn.model_selection import train_test_split
105
+ from sklearn.metrics import hamming_loss, mean_squared_error, mean_absolute_error
106
+ import matplotlib.pyplot as plt
107
+ import warnings
108
+ warnings.filterwarnings('ignore')
109
+
110
+ """
111
+ Main class implementing the multi-output neural network classifier.
112
+
113
+ ARCHITECTURE OVERVIEW:
114
+ 1. **Shared Feature Extraction**: 3 dense layers (512→256→128) with batch
115
+ normalization and dropout. These layers learn a shared representation
116
+ useful for all prediction tasks.
117
+
118
+ 2. **Task-Specific Heads**: Three separate output branches:
119
+ - Condition classification: Sigmoid activation for multi-label prediction
120
+ - Confidence regression: Softplus activation for positive continuous values
121
+ - Weight regression: Sigmoid activation for [0,1] bounded values
122
+
123
+ WHY MULTI-TASK LEARNING:
124
+ - Conditions, confidence, and weights are related but distinct
125
+ - Sharing early layers allows the model to learn features useful for all tasks
126
+ - Task-specific heads allow specialized processing for each output type
127
+ - Improves generalization by preventing overfitting to any single task
128
+
129
+ TRAINING STRATEGY:
130
+ - Multi-task loss: Weighted combination of classification and regression losses
131
+ - Early stopping: Prevents overfitting by monitoring validation loss
132
+ - Learning rate reduction: Adapts learning rate when progress plateaus
133
+ - Batch normalization: Stabilizes training and allows higher learning rates
134
+ """
135
+
136
+ class DermFoundationNeuralNetwork:
137
+ """
138
+ Initialize the classifier with preprocessing components.
139
+
140
+ PREPROCESSING COMPONENTS:
141
+ - mlb (MultiLabelBinarizer): Converts condition names to binary vectors
142
+ Example: ['Eczema', 'Psoriasis'] → [0,1,0,1,0,...,0]
143
+
144
+ - embedding_scaler (StandardScaler): Normalizes embeddings to mean=0, std=1
145
+ Why: Neural networks train faster with normalized inputs
146
+
147
+ - confidence_scaler (StandardScaler): Normalizes confidence values
148
+ Why: Brings continuous values to similar scale as other outputs
149
+
150
+ - weighted_scaler (StandardScaler): Normalizes weight values
151
+ Why: Ensures balanced gradient contributions during training
152
+
153
+ DESIGN DECISION:
154
+ Separate scalers for each output type allow independent normalization,
155
+ which is crucial when outputs have different scales and distributions.
156
+ """
157
+ def __init__(self):
158
+ self.model = None
159
+ self.mlb = MultiLabelBinarizer()
160
+ self.embedding_scaler = StandardScaler()
161
+ self.confidence_scaler = StandardScaler()
162
+ self.weighted_scaler = StandardScaler()
163
+ self.history = None
164
+ """
165
+ Load pre-computed Derm Foundation embeddings from NPZ file.
166
+
167
+ WHAT ARE EMBEDDINGS:
168
+ Embeddings are dense vector representations of images extracted from a
169
+ pre-trained vision model (Derm Foundation model). They capture high-level
170
+ visual features learned from large-scale dermatology image datasets.
171
+
172
+ WHY USE PRE-COMPUTED EMBEDDINGS:
173
+ 1. **Efficiency**: Computing embeddings is expensive. Pre-computing them
174
+ allows rapid experimentation with different classifier architectures.
175
+
176
+ 2. **Transfer Learning**: Derm Foundation was trained on massive dermatology
177
+ datasets. Its embeddings encode domain-specific visual patterns.
178
+
179
+ 3. **Separation of Concerns**: Image processing and classification are
180
+ separated, allowing independent optimization of each component.
181
+
182
+ FORMAT:
183
+ NPZ file contains a dictionary where:
184
+ - Keys: case_id (string identifiers)
185
+ - Values: embedding vectors (typically 512 or 768 dimensions)
186
+ """
187
+ def load_embeddings(self, npz_file_path):
188
+ """Load embeddings from NPZ file"""
189
+ print(f"Loading embeddings from {npz_file_path}...")
190
+
191
+ if not os.path.exists(npz_file_path):
192
+ print(f"ERROR: Embeddings file not found: {npz_file_path}")
193
+ return None
194
+
195
+ embeddings_data = {}
196
+ with open(npz_file_path, 'rb') as f:
197
+ npz_file = np.load(f, allow_pickle=True)
198
+ for key in npz_file.files:
199
+ embeddings_data[key] = npz_file[key]
200
+
201
+ print(f"Loaded {len(embeddings_data)} embeddings")
202
+
203
+ # Print info about first embedding for debugging
204
+ first_key = list(embeddings_data.keys())[0]
205
+ first_embedding = embeddings_data[first_key]
206
+ print(f"Embedding shape: {first_embedding.shape}")
207
+
208
+ return embeddings_data
209
+
210
+ """
211
+ Load ground truth labels from CSV file.
212
+
213
+ REQUIRED COLUMNS:
214
+ 1. case_id: Unique identifier matching embedding keys
215
+ 2. dermatologist_skin_condition_on_label_name: List of condition names
216
+ 3. dermatologist_skin_condition_confidence: Confidence scores (typically 1-5)
217
+ 4. weighted_skin_condition_label: Importance weights (0-1 range)
218
+
219
+ DATA TYPES:
220
+ - case_id must be string to match embedding keys
221
+ - Lists stored as strings (e.g., "['Eczema', 'Psoriasis']") are evaluated
222
+ - Handles various formats: lists, dicts, single values
223
+ """
224
+ def load_dataset(self, csv_path):
225
+ """Load dataset from CSV file"""
226
+ print(f"Loading dataset from {csv_path}...")
227
+
228
+ if not os.path.exists(csv_path):
229
+ print(f"ERROR: Dataset file not found: {csv_path}")
230
+ return None
231
+
232
+ try:
233
+ df = pd.read_csv(csv_path, dtype={'case_id': str})
234
+ df['case_id'] = df['case_id'].astype(str)
235
+
236
+ print(f"Loaded dataset: {len(df)} samples")
237
+
238
+ # Verify required columns
239
+ required_columns = [
240
+ 'case_id',
241
+ 'dermatologist_skin_condition_on_label_name',
242
+ 'dermatologist_skin_condition_confidence',
243
+ 'weighted_skin_condition_label'
244
+ ]
245
+
246
+ missing_columns = [col for col in required_columns if col not in df.columns]
247
+ if missing_columns:
248
+ print(f"ERROR: Missing required columns: {missing_columns}")
249
+ return None
250
+
251
+ return df
252
+
253
+ except Exception as e:
254
+ print(f"Error loading dataset: {e}")
255
+ return None
256
+ """
257
+ Prepare training data with comprehensive validation and preprocessing.
258
+
259
+ COMPLEXITY HANDLING:
260
+ This method handles several challenging data characteristics:
261
+
262
+ 1. **SPARSE MULTI-LABEL MATRICES**: Most samples have few positive labels
263
+ Solution: Track and report sparsity statistics for validation
264
+
265
+ 2. **VARIABLE-LENGTH LISTS**: Different samples have different numbers of
266
+ conditions, confidences, and weights
267
+ Solution: Parse and align lists carefully, use mean values for mismatches
268
+
269
+ 3. **RARE CONDITIONS**: Some conditions appear in very few samples
270
+ Solution: Filter to top N conditions and minimum sample requirements
271
+
272
+ 4. **ZERO VALUES**: Confidence/weight matrices are mostly zeros (sparse)
273
+ Solution: Track zero vs non-zero ratios, fit scalers only on non-zeros
274
+
275
+ FILTERING STRATEGY:
276
+ - min_condition_samples: Removes rare conditions with insufficient data
277
+ - max_conditions: Limits to most frequent conditions to prevent overfitting
278
+ - Both filters ensure model focuses on well-represented, learnable patterns
279
+
280
+ WHY FILTER CONDITIONS:
281
+ 1. **Statistical Validity**: Need sufficient examples to learn patterns
282
+ 2. **Generalization**: Rare conditions lead to overfitting
283
+ 3. **Computational Efficiency**: Fewer output nodes = faster training
284
+ 4. **Clinical Relevance**: Common conditions are higher priority
285
+
286
+ MULTI-LABEL MATRIX STRUCTURE:
287
+ Shape: (n_samples, n_conditions)
288
+ - Rows: Individual patient cases
289
+ - Columns: Binary indicators for each condition (1=present, 0=absent)
290
+ - Multiple 1s per row: Multi-label (multiple conditions co-exist)
291
+
292
+ CONFIDENCE/WEIGHT MATRICES:
293
+ Shape: (n_samples, n_conditions)
294
+ - Values at (i,j): Confidence/weight for condition j in sample i
295
+ - Zero when condition j not present in sample i (sparse structure)
296
+ - Non-zero only where corresponding multi-label entry is 1
297
+
298
+ DATA VALIDATION:
299
+ Extensive logging of:
300
+ - Sample counts (processed vs skipped)
301
+ - Value ranges (min/max/mean)
302
+ - Sparsity statistics (% non-zero)
303
+ - Top conditions by frequency
304
+
305
+ This validation is crucial for:
306
+ - Detecting data quality issues early
307
+ - Understanding model input characteristics
308
+ - Debugging training problems
309
+ """
310
+ def prepare_training_data(self, df, embeddings, min_condition_samples=5, max_conditions=30):
311
+ """Prepare training data with improved confidence and weight handling"""
312
+ print("Preparing training data with enhanced validation...")
313
+
314
+ X = [] # Embeddings
315
+ condition_labels = [] # For multi-label classification
316
+ individual_confidences = [] # Individual confidence per condition
317
+ individual_weights = [] # Individual weight per condition
318
+
319
+ skipped_count = 0
320
+ processed_count = 0
321
+ confidence_stats = [] # Track confidence values for validation
322
+ weight_stats = [] # Track weight values for validation
323
+
324
+ for idx, row in df.iterrows():
325
+ try:
326
+ case_id = row['case_id']
327
+
328
+ if case_id not in embeddings:
329
+ skipped_count += 1
330
+ continue
331
+
332
+ # Parse the label data
333
+ try:
334
+ # Parse condition names
335
+ if isinstance(row['dermatologist_skin_condition_on_label_name'], str):
336
+ condition_names = eval(row['dermatologist_skin_condition_on_label_name'])
337
+ else:
338
+ condition_names = row['dermatologist_skin_condition_on_label_name']
339
+
340
+ # Ensure condition_names is a list
341
+ if not isinstance(condition_names, list):
342
+ condition_names = [condition_names] if condition_names else []
343
+
344
+ # Parse confidence scores
345
+ if isinstance(row['dermatologist_skin_condition_confidence'], str):
346
+ confidences = eval(row['dermatologist_skin_condition_confidence'])
347
+ else:
348
+ confidences = row['dermatologist_skin_condition_confidence']
349
+
350
+ # Ensure confidences is a list and matches conditions
351
+ if not isinstance(confidences, list):
352
+ confidences = [confidences] if confidences is not None else []
353
+
354
+ # Match confidence length to conditions
355
+ if len(confidences) != len(condition_names):
356
+ if len(confidences) == 1:
357
+ confidences = confidences * len(condition_names)
358
+ else:
359
+ print(f"Warning: Confidence length mismatch for {case_id}, using mean")
360
+ mean_conf = np.mean(confidences) if confidences else 3.0
361
+ confidences = [mean_conf] * len(condition_names)
362
+
363
+ # Parse weighted labels
364
+ if isinstance(row['weighted_skin_condition_label'], str):
365
+ weighted_labels = eval(row['weighted_skin_condition_label'])
366
+ else:
367
+ weighted_labels = row['weighted_skin_condition_label']
368
+
369
+ # Handle different weight formats
370
+ if isinstance(weighted_labels, dict):
371
+ # Convert dict to list matching condition order
372
+ weights = []
373
+ for condition in condition_names:
374
+ weights.append(weighted_labels.get(condition, 0.0))
375
+ elif isinstance(weighted_labels, list):
376
+ weights = weighted_labels
377
+ if len(weights) != len(condition_names):
378
+ if len(weights) == 1:
379
+ weights = weights * len(condition_names)
380
+ else:
381
+ mean_weight = np.mean(weights) if weights else 0.3
382
+ weights = [mean_weight] * len(condition_names)
383
+ else:
384
+ # Single value
385
+ weights = [weighted_labels] * len(condition_names) if weighted_labels else [0.3] * len(condition_names)
386
+
387
+ except Exception as e:
388
+ print(f"Error parsing data for {case_id}: {e}")
389
+ skipped_count += 1
390
+ continue
391
+
392
+ # Validate data ranges
393
+ try:
394
+ confidences = [max(0.0, float(c)) for c in confidences] # Ensure non-negative
395
+ weights = [max(0.0, min(1.0, float(w))) for w in weights] # Clamp to [0,1]
396
+ except:
397
+ print(f"Error converting values for {case_id}, skipping")
398
+ skipped_count += 1
399
+ continue
400
+
401
+ # Add to training data
402
+ X.append(embeddings[case_id])
403
+ condition_labels.append(condition_names)
404
+
405
+ # Store individual confidences and weights
406
+ individual_confidences.append({
407
+ 'conditions': condition_names,
408
+ 'confidences': confidences
409
+ })
410
+
411
+ individual_weights.append({
412
+ 'conditions': condition_names,
413
+ 'weights': weights
414
+ })
415
+
416
+ # Track statistics
417
+ confidence_stats.extend(confidences)
418
+ weight_stats.extend(weights)
419
+
420
+ processed_count += 1
421
+
422
+ except Exception as e:
423
+ print(f"Error processing row {idx}: {e}")
424
+ skipped_count += 1
425
+ continue
426
+
427
+ print(f"Training data prepared: {processed_count} samples, {skipped_count} skipped")
428
+
429
+ if len(X) == 0:
430
+ print("ERROR: No training samples found!")
431
+ return None, None, None, None
432
+
433
+ # Print data statistics
434
+ print(f"\nData validation:")
435
+ print(f" Confidence values - min: {min(confidence_stats):.3f}, max: {max(confidence_stats):.3f}, mean: {np.mean(confidence_stats):.3f}")
436
+ print(f" Weight values - min: {min(weight_stats):.3f}, max: {max(weight_stats):.3f}, mean: {np.mean(weight_stats):.3f}")
437
+ print(f" Non-zero confidences: {sum(1 for c in confidence_stats if c > 0.001)}/{len(confidence_stats)} ({100*sum(1 for c in confidence_stats if c > 0.001)/len(confidence_stats):.1f}%)")
438
+ print(f" Non-zero weights: {sum(1 for w in weight_stats if w > 0.001)}/{len(weight_stats)} ({100*sum(1 for w in weight_stats if w > 0.001)/len(weight_stats):.1f}%)")
439
+
440
+ # Convert to numpy arrays
441
+ X = np.array(X)
442
+
443
+ # Prepare condition labels - focus on top conditions only
444
+ y_conditions_raw = self.mlb.fit_transform(condition_labels)
445
+ condition_counts = y_conditions_raw.sum(axis=0)
446
+
447
+ # Get top conditions by frequency
448
+ sorted_indices = np.argsort(condition_counts)[::-1]
449
+ top_condition_indices = sorted_indices[:max_conditions]
450
+
451
+ # Also ensure minimum samples
452
+ frequent_conditions = condition_counts >= min_condition_samples
453
+ final_indices = np.intersect1d(top_condition_indices, np.where(frequent_conditions)[0])
454
+
455
+ print(f"Total condition classes: {len(self.mlb.classes_)}")
456
+ print(f"Top {max_conditions} most frequent conditions selected")
457
+ print(f"Conditions with at least {min_condition_samples} examples: {frequent_conditions.sum()}")
458
+
459
+ # Keep only selected conditions
460
+ selected_classes = self.mlb.classes_[final_indices]
461
+ y_conditions = y_conditions_raw[:, final_indices]
462
+
463
+ # Update MultiLabelBinarizer
464
+ self.mlb = MultiLabelBinarizer()
465
+ self.mlb.classes_ = selected_classes
466
+
467
+ print(f"Final condition classes: {len(selected_classes)}")
468
+ print(f"Multi-label matrix shape: {y_conditions.shape}")
469
+
470
+ # Create individual confidence and weight matrices
471
+ y_confidences = np.zeros((len(X), len(selected_classes)))
472
+ y_weights = np.zeros((len(X), len(selected_classes)))
473
+
474
+ for i, (conf_data, weight_data) in enumerate(zip(individual_confidences, individual_weights)):
475
+ # Map confidences to selected conditions
476
+ for condition, confidence in zip(conf_data['conditions'], conf_data['confidences']):
477
+ if condition in selected_classes:
478
+ condition_idx = np.where(selected_classes == condition)[0]
479
+ if len(condition_idx) > 0:
480
+ y_confidences[i, condition_idx[0]] = confidence
481
+
482
+ # Map weights to selected conditions
483
+ for condition, weight in zip(weight_data['conditions'], weight_data['weights']):
484
+ if condition in selected_classes:
485
+ condition_idx = np.where(selected_classes == condition)[0]
486
+ if len(condition_idx) > 0:
487
+ y_weights[i, condition_idx[0]] = weight
488
+
489
+ # Print matrix statistics
490
+ nonzero_conf = (y_confidences > 0.001).sum()
491
+ nonzero_weight = (y_weights > 0.001).sum()
492
+ total_elements = y_confidences.size
493
+
494
+ print(f"\nMatrix statistics:")
495
+ print(f" Confidence matrix - non-zero: {nonzero_conf}/{total_elements} ({100*nonzero_conf/total_elements:.1f}%)")
496
+ print(f" Weight matrix - non-zero: {nonzero_weight}/{total_elements} ({100*nonzero_weight/total_elements:.1f}%)")
497
+ print(f" Confidence range: {y_confidences[y_confidences > 0].min():.3f} - {y_confidences[y_confidences > 0].max():.3f}")
498
+ print(f" Weight range: {y_weights[y_weights > 0].min():.3f} - {y_weights[y_weights > 0].max():.3f}")
499
+
500
+ # Print top conditions
501
+ condition_counts_filtered = y_conditions.sum(axis=0)
502
+ print("\nTop conditions selected:")
503
+ for i, (condition, count) in enumerate(zip(selected_classes, condition_counts_filtered)):
504
+ print(f" {i+1:2d}. {condition}: {count} samples")
505
+
506
+ return X, y_conditions, y_confidences, y_weights
507
+ """
508
+ Build multi-output neural network architecture.
509
+
510
+ ARCHITECTURE RATIONALE:
511
+
512
+ **SHARED LAYERS (512→256→128)**:
513
+ - Purpose: Learn general features useful for all prediction tasks
514
+ - Size progression: Gradual dimensionality reduction (embeddings→features)
515
+ - Batch Normalization: Stabilizes training, allows higher learning rates
516
+ - Dropout (0.3, 0.3, 0.2): Prevents overfitting, forces robust features
517
+
518
+ Why this depth:
519
+ - 3 layers balances capacity (can learn complex patterns) vs simplicity
520
+ - Too shallow: Can't learn complex patterns
521
+ - Too deep: Overfits, slower training, harder to optimize
522
+
523
+ **TASK-SPECIFIC BRANCHES**:
524
+ Each branch has 2 layers (64→output) for specialized processing:
525
+
526
+ 1. **CONDITION CLASSIFICATION BRANCH**:
527
+ - Activation: Sigmoid (outputs independent probabilities per condition)
528
+ - Why sigmoid: Allows multiple conditions to be predicted simultaneously
529
+ - Loss: Binary cross-entropy (standard for multi-label classification)
530
+
531
+ 2. **CONFIDENCE REGRESSION BRANCH**:
532
+ - Activation: Softplus (ensures positive outputs, smooth gradients)
533
+ - Why softplus not ReLU: ReLU outputs exactly zero for negative inputs,
534
+ causing gradient issues. Softplus outputs small positive values instead.
535
+ - Formula: softplus(x) = log(1 + exp(x))
536
+ - Loss: MSE (Mean Squared Error for continuous values)
537
+ - Loss weight: 1.5x (increased to prioritize confidence learning)
538
+
539
+ 3. **WEIGHT REGRESSION BRANCH**:
540
+ - Activation: Sigmoid (ensures [0,1] bounded output)
541
+ - Why sigmoid: Weights represent proportions/probabilities, must be 0-1
542
+ - Loss: MSE (Mean Squared Error for continuous values)
543
+ - Loss weight: 1.2x (slightly increased priority)
544
+
545
+ **LOSS WEIGHTING**:
546
+ Different loss scales require weighting for balanced training:
547
+ - Condition loss: Binary cross-entropy, typically ~0.3-0.7
548
+ - Confidence loss: MSE on scaled values, typically ~0.01-0.1
549
+ - Weight loss: MSE on scaled values, typically ~0.01-0.1
550
+
551
+ Weights (1.0, 1.5, 1.2) ensure:
552
+ - All tasks contribute meaningfully to total loss
553
+ - Confidence gets extra emphasis (was underfitting in previous versions)
554
+ - Gradient magnitudes are balanced across tasks
555
+
556
+ **WHY ADAM OPTIMIZER**:
557
+ - Adaptive learning rates per parameter (handles different loss scales)
558
+ - Momentum for faster convergence
559
+ - Robust to hyperparameter choices
560
+ - Industry standard for multi-task learning
561
+
562
+ **MODEL COMPILATION**:
563
+ The model uses a dictionary output format allowing:
564
+ - Clear separation of different predictions
565
+ - Easy access to specific outputs during inference
566
+ - Flexible loss and metric assignment per output
567
+ """
568
+ def build_model(self, input_dim, num_conditions, learning_rate=0.001):
569
+ """Build neural network with improved confidence and weight outputs"""
570
+ print("Building improved neural network model...")
571
+
572
+ # Input layer
573
+ inputs = keras.Input(shape=(input_dim,), name='embeddings')
574
+
575
+ # Shared feature extraction layers
576
+ x = layers.Dense(512, activation='relu', name='dense1')(inputs) # Increased capacity
577
+ x = layers.BatchNormalization(name='bn1')(x)
578
+ x = layers.Dropout(0.3, name='dropout1')(x)
579
+
580
+ x = layers.Dense(256, activation='relu', name='dense2')(x)
581
+ x = layers.BatchNormalization(name='bn2')(x)
582
+ x = layers.Dropout(0.3, name='dropout2')(x)
583
+
584
+ x = layers.Dense(128, activation='relu', name='dense3')(x)
585
+ x = layers.BatchNormalization(name='bn3')(x)
586
+ x = layers.Dropout(0.2, name='dropout3')(x)
587
+
588
+ # Multi-label condition classification head
589
+ condition_branch = layers.Dense(64, activation='relu', name='condition_dense')(x)
590
+ condition_branch = layers.Dropout(0.2, name='condition_dropout')(condition_branch)
591
+ condition_output = layers.Dense(num_conditions, activation='sigmoid',
592
+ name='conditions')(condition_branch)
593
+
594
+ # Individual confidence regression head - FIXED ACTIVATION
595
+ confidence_branch = layers.Dense(64, activation='relu', name='confidence_dense1')(x)
596
+ confidence_branch = layers.Dropout(0.2, name='confidence_dropout1')(confidence_branch)
597
+ confidence_branch = layers.Dense(32, activation='relu', name='confidence_dense2')(confidence_branch)
598
+ confidence_branch = layers.Dropout(0.1, name='confidence_dropout2')(confidence_branch)
599
+ # Changed from ReLU to softplus - ensures positive, non-zero outputs
600
+ confidence_output = layers.Dense(num_conditions, activation='softplus',
601
+ name='individual_confidences')(confidence_branch)
602
+
603
+ # Individual weight regression head
604
+ weighted_branch = layers.Dense(64, activation='relu', name='weighted_dense1')(x)
605
+ weighted_branch = layers.Dropout(0.2, name='weighted_dropout1')(weighted_branch)
606
+ weighted_branch = layers.Dense(32, activation='relu', name='weighted_dense2')(weighted_branch)
607
+ weighted_branch = layers.Dropout(0.1, name='weighted_dropout2')(weighted_branch)
608
+ # Use sigmoid to ensure 0-1 range
609
+ weighted_output = layers.Dense(num_conditions, activation='sigmoid',
610
+ name='individual_weights')(weighted_branch)
611
+
612
+ # Create model
613
+ model = keras.Model(
614
+ inputs=inputs,
615
+ outputs={
616
+ 'conditions': condition_output,
617
+ 'individual_confidences': confidence_output,
618
+ 'individual_weights': weighted_output
619
+ }
620
+ )
621
+
622
+ # Compile model with improved loss weights
623
+ model.compile(
624
+ optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
625
+ loss={
626
+ 'conditions': 'binary_crossentropy',
627
+ 'individual_confidences': 'mse',
628
+ 'individual_weights': 'mse'
629
+ },
630
+ loss_weights={
631
+ 'conditions': 1.0,
632
+ 'individual_confidences': 1.5, # Increased weight for confidence
633
+ 'individual_weights': 1.2 # Increased weight for weights
634
+ },
635
+ metrics={
636
+ 'conditions': ['accuracy'],
637
+ 'individual_confidences': ['mae'],
638
+ 'individual_weights': ['mae']
639
+ }
640
+ )
641
+
642
+ return model
643
+
644
+ """
645
+ Main training orchestration method with improved confidence handling.
646
+
647
+ TRAINING PIPELINE:
648
+ 1. Load data (embeddings + labels)
649
+ 2. Prepare training matrices (parse, filter, validate)
650
+ 3. Scale features and outputs
651
+ 4. Split train/validation sets
652
+ 5. Build neural network architecture
653
+ 6. Train with callbacks (early stopping, LR reduction, checkpointing)
654
+ 7. Evaluate performance
655
+ 8. Save trained model
656
+
657
+ IMPROVED SCALING STRATEGY (KEY FIX):
658
+ Problem: Previous version scaled all values including zeros
659
+ Solution: Fit scalers only on non-zero values
660
+
661
+ Why this matters:
662
+ - Sparse matrices have many structural zeros (condition not present)
663
+ - Including zeros in scaler fitting shifts mean artificially low
664
+ - Model learns to predict near-zero for everything
665
+ - Confidence predictions collapsed to ~0 (major bug)
666
+
667
+ New approach:
668
+ ```python
669
+ conf_nonzero = y_confidences[y_confidences > 0.001]
670
+ self.confidence_scaler.fit(conf_nonzero)
671
+
672
+ Only non-zero values determine scale
673
+ Model learns actual confidence distribution (1-5 range)
674
+ Predictions are meaningful positive values
675
+
676
+ FALLBACK HANDLING:
677
+ If too few non-zero values exist:
678
+
679
+ Use sensible dummy values (1-5 for confidence, 0-1 for weights)
680
+ Prevents scaler failure on edge cases
681
+ Ensures training can proceed
682
+
683
+ TRAIN/TEST SPLIT:
684
+
685
+ 80/20 split is standard for medical ML
686
+ Stratification not used (multi-label makes it complex)
687
+ Random state fixed for reproducibility
688
+
689
+ CALLBACKS:
690
+
691
+ Early Stopping (patience=12):
692
+
693
+ Monitors validation loss
694
+ Stops if no improvement for 12 epochs
695
+ Restores best weights (not final weights)
696
+ Why: Prevents overfitting to training set
697
+
698
+ ReduceLROnPlateau (factor=0.5, patience=5):
699
+
700
+ Monitors confidence loss specifically (was problematic)
701
+ Reduces LR by 50% if loss plateaus
702
+ Allows fine-tuning in late training
703
+ Min LR: 1e-7 prevents excessive reduction
704
+
705
+ ModelCheckpoint:
706
+
707
+ Saves best model weights during training
708
+ Insurance against training divergence
709
+ Cleaned up after successful training
710
+
711
+ TRAINING DURATION:
712
+
713
+ 60 epochs maximum (increased from 50)
714
+ Early stopping typically triggers around epoch 30-40
715
+ Batch size 32 balances memory vs convergence speed
716
+
717
+ HYPERPARAMETERS:
718
+
719
+ Learning rate: 0.001 (standard for Adam)
720
+ Batch size: 32 (good for datasets of this size)
721
+ Test split: 0.2 (20% validation, standard practice)
722
+
723
+ POST-TRAINING:
724
+
725
+ Comprehensive evaluation on test set
726
+ Detailed metrics for all three outputs
727
+ Analysis of confidence prediction quality
728
+ """
729
+ def train(self, npz_file_path="derm_foundation_embeddings.npz",
730
+ csv_file_path="dataset_scin_labels.csv",
731
+ test_size=0.2, random_state=42, epochs=50, batch_size=32,
732
+ learning_rate=0.001):
733
+ """Train the neural network with improved confidence handling"""
734
+
735
+ # Load data
736
+ embeddings = self.load_embeddings(npz_file_path)
737
+ if embeddings is None:
738
+ return False
739
+
740
+ df = self.load_dataset(csv_file_path)
741
+ if df is None:
742
+ return False
743
+
744
+ # Prepare training data
745
+ X, y_conditions, y_confidences, y_weights = self.prepare_training_data(df, embeddings)
746
+ if X is None:
747
+ return False
748
+
749
+ # IMPROVED SCALING - fit only on non-zero values
750
+ print("\nFitting scalers...")
751
+ X_scaled = self.embedding_scaler.fit_transform(X)
752
+
753
+ # Fit confidence scaler on non-zero values only
754
+ conf_nonzero = y_confidences[y_confidences > 0.001]
755
+ if len(conf_nonzero) > 50: # Ensure we have enough data
756
+ print(f"Fitting confidence scaler on {len(conf_nonzero)} non-zero values")
757
+ self.confidence_scaler.fit(conf_nonzero.reshape(-1, 1))
758
+ else:
759
+ print("WARNING: Too few non-zero confidence values, using default scaling")
760
+ # Use a reasonable range for confidence values (e.g., 1-5 scale)
761
+ dummy_conf = np.array([1.0, 2.0, 3.0, 4.0, 5.0]).reshape(-1, 1)
762
+ self.confidence_scaler.fit(dummy_conf)
763
+
764
+ # Fit weight scaler on non-zero values only
765
+ weight_nonzero = y_weights[y_weights > 0.001]
766
+ if len(weight_nonzero) > 50:
767
+ print(f"Fitting weight scaler on {len(weight_nonzero)} non-zero values")
768
+ self.weighted_scaler.fit(weight_nonzero.reshape(-1, 1))
769
+ else:
770
+ print("WARNING: Too few non-zero weight values, using default scaling")
771
+ # Use a reasonable range for weight values (0-1 scale)
772
+ dummy_weight = np.array([0.1, 0.3, 0.5, 0.7, 0.9]).reshape(-1, 1)
773
+ self.weighted_scaler.fit(dummy_weight)
774
+
775
+ # Apply scaling to the matrices (preserve structure)
776
+ y_confidences_scaled = np.zeros_like(y_confidences)
777
+ y_weights_scaled = np.zeros_like(y_weights)
778
+
779
+ # Scale only non-zero values
780
+ for i in range(y_confidences.shape[0]):
781
+ for j in range(y_confidences.shape[1]):
782
+ if y_confidences[i, j] > 0.001:
783
+ y_confidences_scaled[i, j] = self.confidence_scaler.transform([[y_confidences[i, j]]])[0, 0]
784
+ if y_weights[i, j] > 0.001:
785
+ y_weights_scaled[i, j] = self.weighted_scaler.transform([[y_weights[i, j]]])[0, 0]
786
+
787
+ print(f"Scaled confidence range: {y_confidences_scaled[y_confidences_scaled != 0].min():.3f} - {y_confidences_scaled[y_confidences_scaled != 0].max():.3f}")
788
+ print(f"Scaled weight range: {y_weights_scaled[y_weights_scaled != 0].min():.3f} - {y_weights_scaled[y_weights_scaled != 0].max():.3f}")
789
+
790
+ # Split data
791
+ X_train, X_test, y_cond_train, y_cond_test, y_conf_train, y_conf_test, y_weight_train, y_weight_test = train_test_split(
792
+ X_scaled, y_conditions, y_confidences_scaled, y_weights_scaled,
793
+ test_size=test_size, random_state=random_state
794
+ )
795
+
796
+ print(f"\nTraining/test split:")
797
+ print(f" Training samples: {X_train.shape[0]}")
798
+ print(f" Test samples: {X_test.shape[0]}")
799
+
800
+ # Build model
801
+ self.model = self.build_model(
802
+ input_dim=X_scaled.shape[1],
803
+ num_conditions=y_conditions.shape[1],
804
+ learning_rate=learning_rate
805
+ )
806
+
807
+ print(f"\nModel architecture:")
808
+ self.model.summary()
809
+
810
+ # Prepare training data
811
+ train_data = {
812
+ 'conditions': y_cond_train,
813
+ 'individual_confidences': y_conf_train,
814
+ 'individual_weights': y_weight_train
815
+ }
816
+
817
+ val_data = {
818
+ 'conditions': y_cond_test,
819
+ 'individual_confidences': y_conf_test,
820
+ 'individual_weights': y_weight_test
821
+ }
822
+
823
+ # Enhanced callbacks
824
+ early_stopping = keras.callbacks.EarlyStopping(
825
+ monitor='val_loss',
826
+ patience=12, # Increased patience
827
+ restore_best_weights=True,
828
+ verbose=1
829
+ )
830
+
831
+ reduce_lr = keras.callbacks.ReduceLROnPlateau(
832
+ monitor='val_individual_confidences_loss', # Monitor confidence loss specifically
833
+ factor=0.5,
834
+ patience=5,
835
+ min_lr=1e-7,
836
+ mode='min', # We want to minimize the loss
837
+ verbose=1
838
+ )
839
+
840
+ model_checkpoint = keras.callbacks.ModelCheckpoint(
841
+ filepath='best_model_fixed.weights.h5',
842
+ monitor='val_loss',
843
+ save_best_only=True,
844
+ save_weights_only=True,
845
+ verbose=1
846
+ )
847
+
848
+ print(f"\nStarting training for {epochs} epochs...")
849
+
850
+ # Train model
851
+ self.history = self.model.fit(
852
+ X_train, train_data,
853
+ validation_data=(X_test, val_data),
854
+ epochs=epochs,
855
+ batch_size=batch_size,
856
+ callbacks=[early_stopping, reduce_lr, model_checkpoint],
857
+ verbose=1
858
+ )
859
+
860
+ # Evaluate model
861
+ self.evaluate_model(X_test, y_cond_test, y_conf_test, y_weight_test)
862
+
863
+ return True
864
+ """
865
+ Comprehensive model evaluation with enhanced confidence analysis.
866
+ EVALUATION METRICS:
867
+ 1. MULTI-LABEL CLASSIFICATION (Conditions):
868
+ Hamming Loss:
869
+
870
+ Definition: Fraction of incorrectly predicted labels
871
+ Range: [0, 1] where 0 is perfect
872
+ Formula: (1/n_labels) × Σ|y_true ⊕ y_pred|
873
+ Example: If 2 out of 30 labels are wrong, hamming loss = 0.067
874
+ Clinical interpretation: Lower is better, <0.1 is excellent
875
+
876
+ Exact Match Accuracy:
877
+
878
+ Strictest metric: Requires ALL labels perfectly correct
879
+ Range: [0, 1] where 1 is perfect
880
+ Why include: Shows complete prediction correctness
881
+ Medical context: Exact match is ideal but rarely achievable
882
+ (even expert dermatologists disagree on some cases)
883
+
884
+ Average Conditions per Sample:
885
+
886
+ Describes label distribution complexity
887
+ Higher values → harder multi-label problem
888
+ Typical range: 1-3 conditions per sample
889
+
890
+ 2. CONFIDENCE REGRESSION:
891
+ Why evaluate only non-zero targets:
892
+
893
+ Zeros are structural (condition not present)
894
+ Including zeros conflates two problems:
895
+ a) Predicting which conditions exist (classification task)
896
+ b) Predicting confidence for existing conditions (regression task)
897
+ We want to evaluate (b) separately
898
+
899
+ Inverse Transform:
900
+
901
+ Converts scaled predictions back to original scale
902
+ Necessary for interpretable metrics
903
+ Example: Scaled 0.3 → Original 3.2 (on 1-5 scale)
904
+
905
+ MSE (Mean Squared Error):
906
+
907
+ Sensitive to large errors (squared penalty)
908
+ Unit: (confidence units)²
909
+ Lower is better
910
+
911
+ MAE (Mean Absolute Error):
912
+
913
+ Average absolute difference from ground truth
914
+ Same units as original values
915
+ More robust to outliers than MSE
916
+ Clinical interpretation: If MAE=0.5, average error is ±0.5 points
917
+
918
+ RMSE (Root Mean Squared Error):
919
+
920
+ Square root of MSE
921
+ Same units as original values (easier to interpret than MSE)
922
+ Emphasizes larger errors more than MAE
923
+
924
+ Prediction Range Analysis:
925
+
926
+ Verifies predictions are in sensible range
927
+ Example: If ground truth is 1-5, predictions should be similar
928
+ Out-of-range predictions indicate scaling or activation issues
929
+
930
+ 3. WEIGHT REGRESSION:
931
+ Same metrics as confidence but for weight values (0-1 range)
932
+ DIAGNOSTIC CHECKS:
933
+
934
+ "Predictions > 0.1" percentage: Ensures model isn't predicting near-zero
935
+ Range comparison: Truth vs prediction ranges should align
936
+ Non-zero count: Verifies sparse structure is respected
937
+
938
+ WHY THIS EVALUATION IS COMPREHENSIVE:
939
+
940
+ Multiple metrics cover different aspects (classification + regression)
941
+ Separate evaluation of sparse vs dense regions
942
+ Original scale metrics (clinically interpretable)
943
+ Diagnostic checks for common failure modes
944
+ Both aggregate (MSE) and per-sample (MAE) metrics
945
+ """
946
+ def evaluate_model(self, X_test, y_cond_test, y_conf_test, y_weight_test):
947
+ """Evaluate the trained model with enhanced confidence analysis"""
948
+ print("\n" + "="*70)
949
+ print("MODEL EVALUATION - ENHANCED CONFIDENCE ANALYSIS")
950
+ print("="*70)
951
+
952
+ # Make predictions
953
+ predictions = self.model.predict(X_test)
954
+ y_cond_pred = predictions['conditions']
955
+ y_conf_pred = predictions['individual_confidences']
956
+ y_weight_pred = predictions['individual_weights']
957
+
958
+ # Condition classification evaluation
959
+ y_cond_pred_binary = (y_cond_pred > 0.5).astype(int)
960
+ hamming = hamming_loss(y_cond_test, y_cond_pred_binary)
961
+ exact_match = (y_cond_pred_binary == y_cond_test).all(axis=1).mean()
962
+
963
+ print(f"Multi-label Condition Classification:")
964
+ print(f" Hamming Loss: {hamming:.4f}")
965
+ print(f" Exact Match Accuracy: {exact_match:.4f}")
966
+ print(f" Average conditions per sample: {y_cond_test.sum(axis=1).mean():.2f}")
967
+
968
+ # ENHANCED confidence evaluation
969
+ print(f"\nConfidence Prediction Analysis:")
970
+ print(f" Raw prediction range: {y_conf_pred.min():.6f} - {y_conf_pred.max():.6f}")
971
+ print(f" Non-zero predictions: {(y_conf_pred > 0.001).sum()}/{y_conf_pred.size}")
972
+
973
+ # Inverse transform and evaluate confidence
974
+ conf_mask = y_conf_test > 0.001
975
+ if conf_mask.sum() > 0:
976
+ y_conf_test_orig = np.zeros_like(y_conf_test)
977
+ y_conf_pred_orig = np.zeros_like(y_conf_pred)
978
+
979
+ # Inverse transform
980
+ for i in range(y_conf_test.shape[0]):
981
+ for j in range(y_conf_test.shape[1]):
982
+ if y_conf_test[i, j] > 0.001:
983
+ y_conf_test_orig[i, j] = self.confidence_scaler.inverse_transform([[y_conf_test[i, j]]])[0, 0]
984
+ if y_conf_pred[i, j] > 0.001:
985
+ y_conf_pred_orig[i, j] = self.confidence_scaler.inverse_transform([[y_conf_pred[i, j]]])[0, 0]
986
+
987
+ # Calculate metrics only on positions where ground truth is non-zero
988
+ conf_test_nonzero = y_conf_test_orig[conf_mask]
989
+ conf_pred_nonzero = y_conf_pred_orig[conf_mask]
990
+
991
+ conf_mse = mean_squared_error(conf_test_nonzero, conf_pred_nonzero)
992
+ conf_mae = mean_absolute_error(conf_test_nonzero, conf_pred_nonzero)
993
+
994
+ print(f" Individual Confidence Regression (on {conf_mask.sum()} non-zero targets):")
995
+ print(f" MSE: {conf_mse:.4f}")
996
+ print(f" MAE: {conf_mae:.4f}")
997
+ print(f" RMSE: {np.sqrt(conf_mse):.4f}")
998
+ print(f" Prediction range (orig scale): {conf_pred_nonzero.min():.3f} - {conf_pred_nonzero.max():.3f}")
999
+ print(f" Ground truth range (orig scale): {conf_test_nonzero.min():.3f} - {conf_test_nonzero.max():.3f}")
1000
+
1001
+ # Check if predictions are reasonable
1002
+ reasonable_predictions = (conf_pred_nonzero > 0.1).sum()
1003
+ print(f" Predictions > 0.1: {reasonable_predictions}/{len(conf_pred_nonzero)} ({100*reasonable_predictions/len(conf_pred_nonzero):.1f}%)")
1004
+
1005
+ # Individual weight evaluation
1006
+ weight_mask = y_weight_test > 0.001
1007
+ if weight_mask.sum() > 0:
1008
+ y_weight_test_orig = np.zeros_like(y_weight_test)
1009
+ y_weight_pred_orig = np.zeros_like(y_weight_pred)
1010
+
1011
+ for i in range(y_weight_test.shape[0]):
1012
+ for j in range(y_weight_test.shape[1]):
1013
+ if y_weight_test[i, j] > 0.001:
1014
+ y_weight_test_orig[i, j] = self.weighted_scaler.inverse_transform([[y_weight_test[i, j]]])[0, 0]
1015
+ if y_weight_pred[i, j] > 0.001:
1016
+ y_weight_pred_orig[i, j] = self.weighted_scaler.inverse_transform([[y_weight_pred[i, j]]])[0, 0]
1017
+
1018
+ weight_test_nonzero = y_weight_test_orig[weight_mask]
1019
+ weight_pred_nonzero = y_weight_pred_orig[weight_mask]
1020
+
1021
+ weight_mse = mean_squared_error(weight_test_nonzero, weight_pred_nonzero)
1022
+ weight_mae = mean_absolute_error(weight_test_nonzero, weight_pred_nonzero)
1023
+
1024
+ print(f"\nIndividual Weight Regression (on {weight_mask.sum()} non-zero targets):")
1025
+ print(f" MSE: {weight_mse:.4f}")
1026
+ print(f" MAE: {weight_mae:.4f}")
1027
+ print(f" RMSE: {np.sqrt(weight_mse):.4f}")
1028
+ print(f" Prediction range (orig scale): {weight_pred_nonzero.min():.3f} - {weight_pred_nonzero.max():.3f}")
1029
+ print(f" Ground truth range (orig scale): {weight_test_nonzero.min():.3f} - {weight_test_nonzero.max():.3f}")
1030
+ """
1031
+ Make predictions on new embeddings with comprehensive output formatting.
1032
+ PREDICTION PIPELINE:
1033
+
1034
+ Scale input embedding (using training-fitted scaler)
1035
+ Forward pass through neural network
1036
+ Process raw outputs:
1037
+
1038
+ Condition probabilities: Sigmoid outputs [0,1]
1039
+ Confidence values: Softplus outputs [0,∞)
1040
+ Weight values: Sigmoid outputs [0,1]
1041
+
1042
+ Inverse transform regression outputs to original scale
1043
+ Apply threshold to select predicted conditions
1044
+ Return structured dictionary with multiple views of predictions
1045
+
1046
+ THRESHOLD STRATEGY:
1047
+
1048
+ Condition threshold: 0.3 (lower than typical 0.5)
1049
+ Why lower: Medical diagnosis prefers sensitivity (catch more conditions)
1050
+ False positives less harmful than false negatives in screening
1051
+ Can be adjusted based on clinical requirements
1052
+
1053
+ OUTPUT STRUCTURE:
1054
+ Primary Predictions (conditions above threshold):
1055
+
1056
+ dermatologist_skin_condition_on_label_name: List of predicted conditions
1057
+ dermatologist_skin_condition_confidence: Confidence per predicted condition
1058
+ weighted_skin_condition_label: Weight dict for predicted conditions
1059
+
1060
+ Comprehensive View (all conditions):
1061
+
1062
+ all_condition_probabilities: Probability for every possible condition
1063
+ all_individual_confidences: Confidence for every possible condition
1064
+ all_individual_weights: Weight for every possible condition
1065
+
1066
+ Debugging Information:
1067
+
1068
+ raw_confidence_outputs: Pre-transform neural network outputs
1069
+ raw_weight_outputs: Pre-transform neural network outputs
1070
+ condition_threshold: Threshold used for filtering
1071
+
1072
+ Why provide multiple views:
1073
+
1074
+ Primary predictions: For direct clinical use
1075
+ Comprehensive view: For ranking, uncertainty quantification
1076
+ Debug info: For model validation and troubleshooting
1077
+
1078
+ MINIMUM VALUE CLAMPING:
1079
+ pythonconfidence_orig = max(0.1, confidence_orig)
1080
+ weight_orig = max(0.01, weight_orig)
1081
+
1082
+ Ensures predictions are never exactly zero
1083
+ Confidence ≥0.1: Even lowest predictions are meaningful
1084
+ Weight ≥0.01: Prevents division-by-zero in downstream processing
1085
+
1086
+ SOFTPLUS ADVANTAGE:
1087
+ With softplus activation, even very negative inputs produce small positive
1088
+ outputs, so confidence predictions naturally avoid zero. The max(0.1, x)
1089
+ provides additional safety margin.
1090
+ RETURN FORMAT:
1091
+ Dictionary structure allows:
1092
+
1093
+ Easy access to specific prediction types
1094
+ Clear semantic meaning (key names describe contents)
1095
+ Extensible (can add new keys without breaking existing code)
1096
+ JSON-serializable for API deployment
1097
+ """
1098
+ def predict(self, embedding):
1099
+ """Make predictions on a single embedding with individual outputs"""
1100
+ if self.model is None:
1101
+ print("ERROR: Model not trained. Call train() first.")
1102
+ return None
1103
+
1104
+ if len(embedding.shape) == 1:
1105
+ embedding = embedding.reshape(1, -1)
1106
+
1107
+ # Scale embedding
1108
+ embedding_scaled = self.embedding_scaler.transform(embedding)
1109
+
1110
+ # Make predictions
1111
+ predictions = self.model.predict(embedding_scaled, verbose=0)
1112
+
1113
+ # Process condition predictions
1114
+ condition_probs = predictions['conditions'][0]
1115
+ individual_confidences = predictions['individual_confidences'][0]
1116
+ individual_weights = predictions['individual_weights'][0]
1117
+
1118
+ # Get predicted conditions (above threshold)
1119
+ condition_threshold = 0.3 # Lower threshold
1120
+ predicted_condition_indices = np.where(condition_probs > condition_threshold)[0]
1121
+
1122
+ # Build results
1123
+ predicted_conditions = []
1124
+ predicted_confidences = []
1125
+ predicted_weights_dict = {}
1126
+
1127
+ for idx in predicted_condition_indices:
1128
+ condition_name = self.mlb.classes_[idx]
1129
+ condition_prob = float(condition_probs[idx])
1130
+
1131
+ # Inverse transform individual outputs with better handling
1132
+ confidence_raw = individual_confidences[idx]
1133
+ weight_raw = individual_weights[idx]
1134
+
1135
+ # Always inverse transform, even small values (softplus ensures non-zero)
1136
+ confidence_orig = self.confidence_scaler.inverse_transform([[confidence_raw]])[0, 0]
1137
+ weight_orig = self.weighted_scaler.inverse_transform([[weight_raw]])[0, 0]
1138
+
1139
+ predicted_conditions.append(condition_name)
1140
+ predicted_confidences.append(max(0.1, confidence_orig)) # Minimum confidence of 0.1
1141
+ predicted_weights_dict[condition_name] = max(0.01, weight_orig) # Minimum weight of 0.01
1142
+
1143
+ # Also provide all condition probabilities for reference
1144
+ all_condition_probs = {}
1145
+ all_confidences = {}
1146
+ all_weights = {}
1147
+
1148
+ for i, class_name in enumerate(self.mlb.classes_):
1149
+ all_condition_probs[class_name] = float(condition_probs[i])
1150
+
1151
+ # Always inverse transform all outputs
1152
+ conf_raw = individual_confidences[i]
1153
+ weight_raw = individual_weights[i]
1154
+
1155
+ conf_orig = self.confidence_scaler.inverse_transform([[conf_raw]])[0, 0]
1156
+ weight_orig = self.weighted_scaler.inverse_transform([[weight_raw]])[0, 0]
1157
+
1158
+ all_confidences[class_name] = max(0.0, conf_orig)
1159
+ all_weights[class_name] = max(0.0, weight_orig)
1160
+
1161
+ return {
1162
+ # Main predicted results (above threshold)
1163
+ 'dermatologist_skin_condition_on_label_name': predicted_conditions,
1164
+ 'dermatologist_skin_condition_confidence': predicted_confidences,
1165
+ 'weighted_skin_condition_label': predicted_weights_dict,
1166
+
1167
+ # Additional information for analysis
1168
+ 'all_condition_probabilities': all_condition_probs,
1169
+ 'all_individual_confidences': all_confidences,
1170
+ 'all_individual_weights': all_weights,
1171
+ 'condition_threshold': condition_threshold,
1172
+
1173
+ # Debug information
1174
+ 'raw_confidence_outputs': individual_confidences.tolist(),
1175
+ 'raw_weight_outputs': individual_weights.tolist()
1176
+ }
1177
+
1178
+ def plot_training_history(self):
1179
+ if self.history is None:
1180
+ print("No training history available")
1181
+ return
1182
+
1183
+ # Set matplotlib to use non-interactive backend
1184
+ import matplotlib
1185
+ matplotlib.use('Agg')
1186
+ import matplotlib.pyplot as plt
1187
+
1188
+ fig, axes = plt.subplots(2, 3, figsize=(18, 10))
1189
+
1190
+ # Loss
1191
+ axes[0, 0].plot(self.history.history['loss'], label='Training Loss')
1192
+ axes[0, 0].plot(self.history.history['val_loss'], label='Validation Loss')
1193
+ axes[0, 0].set_title('Model Loss')
1194
+ axes[0, 0].set_xlabel('Epoch')
1195
+ axes[0, 0].set_ylabel('Loss')
1196
+ axes[0, 0].legend()
1197
+
1198
+ # Condition accuracy
1199
+ axes[0, 1].plot(self.history.history['conditions_accuracy'], label='Training Accuracy')
1200
+ axes[0, 1].plot(self.history.history['val_conditions_accuracy'], label='Validation Accuracy')
1201
+ axes[0, 1].set_title('Condition Classification Accuracy')
1202
+ axes[0, 1].set_xlabel('Epoch')
1203
+ axes[0, 1].set_ylabel('Accuracy')
1204
+ axes[0, 1].legend()
1205
+
1206
+ # Individual Confidence MAE
1207
+ axes[0, 2].plot(self.history.history['individual_confidences_mae'], label='Training MAE')
1208
+ axes[0, 2].plot(self.history.history['val_individual_confidences_mae'], label='Validation MAE')
1209
+ axes[0, 2].set_title('Individual Confidence MAE')
1210
+ axes[0, 2].set_xlabel('Epoch')
1211
+ axes[0, 2].set_ylabel('MAE')
1212
+ axes[0, 2].legend()
1213
+
1214
+ # Individual Weight MAE
1215
+ axes[1, 0].plot(self.history.history['individual_weights_mae'], label='Training MAE')
1216
+ axes[1, 0].plot(self.history.history['val_individual_weights_mae'], label='Validation MAE')
1217
+ axes[1, 0].set_title('Individual Weight MAE')
1218
+ axes[1, 0].set_xlabel('Epoch')
1219
+ axes[1, 0].set_ylabel('MAE')
1220
+ axes[1, 0].legend()
1221
+
1222
+ # Individual confidence loss
1223
+ axes[1, 1].plot(self.history.history['individual_confidences_loss'], label='Training Loss')
1224
+ axes[1, 1].plot(self.history.history['val_individual_confidences_loss'], label='Validation Loss')
1225
+ axes[1, 1].set_title('Individual Confidence Loss')
1226
+ axes[1, 1].set_xlabel('Epoch')
1227
+ axes[1, 1].set_ylabel('Loss')
1228
+ axes[1, 1].legend()
1229
+
1230
+ # Individual weight loss
1231
+ axes[1, 2].plot(self.history.history['individual_weights_loss'], label='Training Loss')
1232
+ axes[1, 2].plot(self.history.history['val_individual_weights_loss'], label='Validation Loss')
1233
+ axes[1, 2].set_title('Individual Weight Loss')
1234
+ axes[1, 2].set_xlabel('Epoch')
1235
+ axes[1, 2].set_ylabel('Loss')
1236
+ axes[1, 2].legend()
1237
+
1238
+ plt.tight_layout()
1239
+ plt.savefig('training_history_fixed.png', dpi=300, bbox_inches='tight')
1240
+ print("Training history plot saved as: training_history_fixed.png")
1241
+ plt.close()
1242
+ """
1243
+ Persist trained model and preprocessing components to disk.
1244
+
1245
+ SAVED COMPONENTS:
1246
+
1247
+ 1. **Keras Model (.keras file)**:
1248
+ - Neural network architecture
1249
+ - Trained weights for all layers
1250
+ - Optimizer state (for resuming training)
1251
+ - Compilation settings (loss functions, metrics)
1252
+
1253
+ 2. **Preprocessing Data (.pkl file)**:
1254
+ - MultiLabelBinarizer: Maps condition names ↔ indices
1255
+ - embedding_scaler: Normalizes input embeddings
1256
+ - confidence_scaler: Normalizes confidence values
1257
+ - weighted_scaler: Normalizes weight values
1258
+ - Path to .keras file (for loading)
1259
+
1260
+ WHY SEPARATE FILES:
1261
+ - Keras models save to modern .keras format
1262
+ - Scikit-learn components need pickle serialization
1263
+ - Separation allows independent updates of each component
1264
+
1265
+ LOADING REQUIREMENT:
1266
+ Both files are needed for inference:
1267
+ - .keras: Neural network for making predictions
1268
+ - .pkl: Preprocessors for transforming inputs/outputs
1269
+
1270
+ FILE ORGANIZATION:
1271
+ easi_severity_model_derm_foundation_individual_fixed.pkl (main file)
1272
+ easi_severity_model_derm_foundation_individual_fixed.keras (model)
1273
+ User loads .pkl file, which contains path to .keras file
1274
+
1275
+ CLEANUP:
1276
+ Removes temporary checkpoint file (best_model_fixed.weights.h5)
1277
+ created during training to avoid confusion with final model.
1278
+
1279
+ ERROR HANDLING:
1280
+ Checks if model exists before saving, provides clear error messages
1281
+ and file paths for debugging.
1282
+ """
1283
+
1284
+ def save_model(self, filepath="easi_severity_model_derm_foundation_individual_fixed.pkl"):
1285
+ """Save the trained model"""
1286
+ if self.model is None:
1287
+ print("ERROR: No trained model to save.")
1288
+ return False
1289
+
1290
+ # Get current directory
1291
+ current_dir = os.getcwd()
1292
+
1293
+ # Save Keras model with proper extension
1294
+ model_filename = os.path.splitext(filepath)[0]
1295
+ keras_model_path = os.path.join(current_dir, f"{model_filename}.keras")
1296
+
1297
+ print(f"Saving Keras model to: {keras_model_path}")
1298
+ self.model.save(keras_model_path)
1299
+
1300
+ # Save preprocessing components
1301
+ pkl_filepath = os.path.join(current_dir, filepath)
1302
+ model_data = {
1303
+ 'mlb': self.mlb,
1304
+ 'embedding_scaler': self.embedding_scaler,
1305
+ 'confidence_scaler': self.confidence_scaler,
1306
+ 'weighted_scaler': self.weighted_scaler,
1307
+ 'keras_model_path': keras_model_path
1308
+ }
1309
+
1310
+ print(f"Saving preprocessing data to: {pkl_filepath}")
1311
+ with open(pkl_filepath, 'wb') as f:
1312
+ pickle.dump(model_data, f)
1313
+
1314
+ print(f"Model saved successfully!")
1315
+ print(f" - Main file: {pkl_filepath}")
1316
+ print(f" - Keras model: {keras_model_path}")
1317
+
1318
+ # Clean up temporary checkpoint file
1319
+ checkpoint_file = os.path.join(current_dir, 'best_model_fixed.weights.h5')
1320
+ if os.path.exists(checkpoint_file):
1321
+ os.remove(checkpoint_file)
1322
+ print(f" - Cleaned up temporary checkpoint file")
1323
+
1324
+ return True
1325
+
1326
+ def load_model(self, filepath="easi_severity_model_derm_foundation_individual_fixed.pkl"):
1327
+ """Load trained model"""
1328
+ if not os.path.exists(filepath):
1329
+ print(f"ERROR: Model file not found: {filepath}")
1330
+ return False
1331
+
1332
+ try:
1333
+ with open(filepath, 'rb') as f:
1334
+ model_data = pickle.load(f)
1335
+
1336
+ # Load preprocessing components
1337
+ self.mlb = model_data['mlb']
1338
+ self.embedding_scaler = model_data['embedding_scaler']
1339
+ self.confidence_scaler = model_data['confidence_scaler']
1340
+ self.weighted_scaler = model_data['weighted_scaler']
1341
+
1342
+ # Load Keras model
1343
+ keras_model_path = model_data['keras_model_path']
1344
+ if os.path.exists(keras_model_path):
1345
+ self.model = keras.models.load_model(keras_model_path)
1346
+ print(f"Model loaded from {filepath}")
1347
+ print(f"Available condition classes: {len(self.mlb.classes_)}")
1348
+ return True
1349
+ else:
1350
+ print(f"ERROR: Keras model not found at {keras_model_path}")
1351
+ return False
1352
+
1353
+ except Exception as e:
1354
+ print(f"Error loading model: {e}")
1355
+ return False
1356
+
1357
+ """
1358
+ WORKFLOW:
1359
+
1360
+ 1. Print configuration and fixes applied (user visibility)
1361
+ 2. Initialize classifier
1362
+ 3. Validate input files exist
1363
+ 4. Train model with improved confidence handling
1364
+ 5. Plot training history
1365
+ 6. Test model predictions (validate fix effectiveness)
1366
+ 7. Save trained model
1367
+
1368
+ MODEL TESTING (NEW):
1369
+ After training completes, runs a sample prediction to verify:
1370
+
1371
+ Model produces non-zero confidence values (fix validation)
1372
+ Predictions are in expected ranges
1373
+ Output structure is correct
1374
+
1375
+ This immediate validation catches issues before deployment.
1376
+ WHY TEST WITH SAMPLE:
1377
+
1378
+ Confirms confidence scaling fix worked
1379
+ Provides immediate feedback on model quality
1380
+ Demonstrates expected output format
1381
+ Catches activation function issues (like ReLU→0 bug)
1382
+
1383
+ SUCCESS CRITERIA:
1384
+ ✅ Non-zero confidences in reasonable range (e.g., 1-5)
1385
+ ✅ Multiple conditions predicted with varying probabilities
1386
+ ✅ Weights sum to reasonable values
1387
+ ⚠️ Warning if confidence outputs still mostly zero
1388
+ """
1389
+
1390
+ def main():
1391
+ """Main training function with enhanced confidence handling"""
1392
+ print("Derm Foundation Neural Network Classifier Training - FIXED VERSION")
1393
+ print("="*70)
1394
+ print("FIXES APPLIED:")
1395
+ print("- Changed confidence activation from ReLU to softplus")
1396
+ print("- Improved confidence scaler fitting (non-zero values only)")
1397
+ print("- Increased confidence loss weight (1.5x)")
1398
+ print("- Enhanced data validation and preprocessing")
1399
+ print("- Better handling of sparse confidence/weight matrices")
1400
+ print("="*70)
1401
+ print("Training neural network to predict:")
1402
+ print("1. Skin conditions (multi-label classification)")
1403
+ print("2. Individual confidence scores per condition (regression)")
1404
+ print("3. Individual weight scores per condition (regression)")
1405
+ print("="*70)
1406
+
1407
+ # Initialize classifier
1408
+ classifier = DermFoundationNeuralNetwork()
1409
+
1410
+ # File paths
1411
+ npz_file = "derm_foundation_embeddings.npz"
1412
+ csv_file = "dataset_scin_labels.csv"
1413
+ model_output = "easi_severity_model_derm_foundation_individual_fixed.pkl"
1414
+
1415
+ # Check if files exist
1416
+ missing_files = []
1417
+ if not os.path.exists(npz_file):
1418
+ missing_files.append(npz_file)
1419
+ if not os.path.exists(csv_file):
1420
+ missing_files.append(csv_file)
1421
+
1422
+ if missing_files:
1423
+ print(f"ERROR: Missing required files:")
1424
+ for file in missing_files:
1425
+ print(f" - {file}")
1426
+ return
1427
+
1428
+ try:
1429
+ # Train the model
1430
+ success = classifier.train(
1431
+ npz_file_path=npz_file,
1432
+ csv_file_path=csv_file,
1433
+ epochs=60, # Increased epochs
1434
+ batch_size=32,
1435
+ learning_rate=0.001
1436
+ )
1437
+
1438
+ if not success:
1439
+ print("Training failed!")
1440
+ return
1441
+
1442
+ # Plot training history
1443
+ try:
1444
+ classifier.plot_training_history()
1445
+ except Exception as e:
1446
+ print(f"Could not plot training history: {e}")
1447
+
1448
+ # Test the model with a sample prediction to verify confidence outputs
1449
+ print("\n" + "="*70)
1450
+ print("TESTING MODEL OUTPUTS")
1451
+ print("="*70)
1452
+
1453
+ # Get a sample embedding for testing
1454
+ try:
1455
+ embeddings = classifier.load_embeddings(npz_file)
1456
+ if embeddings:
1457
+ sample_key = list(embeddings.keys())[0]
1458
+ sample_embedding = embeddings[sample_key]
1459
+
1460
+ print(f"Testing with sample embedding: {sample_key}")
1461
+ test_result = classifier.predict(sample_embedding)
1462
+
1463
+ if test_result:
1464
+ print("✅ Model prediction successful!")
1465
+ print(f"Predicted conditions: {len(test_result['dermatologist_skin_condition_on_label_name'])}")
1466
+
1467
+ # Check confidence outputs
1468
+ all_confidences = list(test_result['all_individual_confidences'].values())
1469
+ nonzero_conf = sum(1 for c in all_confidences if c > 0.01)
1470
+
1471
+ print(f"Confidence range: {min(all_confidences):.4f} - {max(all_confidences):.4f}")
1472
+ print(f"Non-zero confidences: {nonzero_conf}/{len(all_confidences)}")
1473
+
1474
+ if nonzero_conf > 0:
1475
+ print("✅ CONFIDENCE ISSUE APPEARS TO BE FIXED!")
1476
+ else:
1477
+ print("⚠️ Confidence outputs still mostly zero - may need further investigation")
1478
+
1479
+ # Show top predictions
1480
+ if test_result['dermatologist_skin_condition_on_label_name']:
1481
+ print(f"\nSample predictions:")
1482
+ for i, condition in enumerate(test_result['dermatologist_skin_condition_on_label_name'][:3]):
1483
+ prob = test_result['all_condition_probabilities'][condition]
1484
+ conf = test_result['dermatologist_skin_condition_confidence'][i]
1485
+ weight = test_result['weighted_skin_condition_label'][condition]
1486
+ print(f" {condition}: prob={prob:.3f}, conf={conf:.3f}, weight={weight:.3f}")
1487
+ else:
1488
+ print("❌ Model prediction failed")
1489
+ except Exception as e:
1490
+ print(f"Could not test model: {e}")
1491
+
1492
+ # Save the model
1493
+ classifier.save_model(model_output)
1494
+
1495
+ print(f"\n{'='*70}")
1496
+ print("TRAINING COMPLETE!")
1497
+ print(f"{'='*70}")
1498
+ print(f"Model saved as: {model_output}")
1499
+ print(f"Training history plot saved as: training_history_fixed.png")
1500
+ print(f"\nTo use the trained model:")
1501
+ print(f"```python")
1502
+ print(f"classifier = DermFoundationNeuralNetwork()")
1503
+ print(f"classifier.load_model('{model_output}')")
1504
+ print(f"result = classifier.predict(embedding)")
1505
+ print(f"print(result['dermatologist_skin_condition_on_label_name'])")
1506
+ print(f"print(result['dermatologist_skin_condition_confidence'])")
1507
+ print(f"print(result['weighted_skin_condition_label'])")
1508
+ print(f"```")
1509
+
1510
+ # Example prediction output format
1511
+ print(f"\nExpected prediction output format:")
1512
+ print(f"{{")
1513
+ print(f" 'dermatologist_skin_condition_on_label_name': ['Eczema', 'Irritant Contact Dermatitis'],")
1514
+ print(f" 'dermatologist_skin_condition_confidence': [4.2, 3.1],")
1515
+ print(f" 'weighted_skin_condition_label': {{'Eczema': 0.65, 'Irritant Contact Dermatitis': 0.35}}")
1516
+ print(f"}}")
1517
+
1518
+ except Exception as e:
1519
+ print(f"Error during training: {e}")
1520
+ import traceback
1521
+ traceback.print_exc()
1522
+
1523
+
1524
+ if __name__ == "__main__":
1525
+ main()
training_history_fixed.png ADDED

Git LFS Details

  • SHA256: 142f90f9a26d0b76b801abacacf9e4aab9827363746909cf8774f6e60628b4e1
  • Pointer size: 131 Bytes
  • Size of remote file: 491 kB