Sammy Harris commited on
Commit
95d3566
Β·
1 Parent(s): 5b696e3

Add Gradio demo app and requirements

Browse files
Files changed (2) hide show
  1. app.py +149 -0
  2. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import cv2
6
+
7
+ # Load the model
8
+ @gr.cache
9
+ def load_model():
10
+ """Load the age group classification model"""
11
+ try:
12
+ model = tf.keras.models.load_model('resnet50v2_age_classifier_best.h5')
13
+ return model
14
+ except:
15
+ # Fallback if model file not found
16
+ return None
17
+
18
+ # Age group labels
19
+ AGE_GROUPS = {
20
+ 0: "Youth (0-20)",
21
+ 1: "Young Adult (21-40)",
22
+ 2: "Middle Age (41-60)",
23
+ 3: "Senior (61-80)",
24
+ 4: "Elderly (81-100)"
25
+ }
26
+
27
+ def preprocess_image(image):
28
+ """Preprocess image for model input"""
29
+ if image is None:
30
+ return None
31
+
32
+ # Convert to RGB if needed
33
+ if len(image.shape) == 3 and image.shape[2] == 4:
34
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
35
+ elif len(image.shape) == 3 and image.shape[2] == 1:
36
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
37
+
38
+ # Resize to model input size
39
+ image = cv2.resize(image, (224, 224))
40
+
41
+ # Normalize to [0, 1]
42
+ image = image.astype(np.float32) / 255.0
43
+
44
+ # Add batch dimension
45
+ image = np.expand_dims(image, axis=0)
46
+
47
+ return image
48
+
49
+ def predict_age_group(image):
50
+ """Predict age group from facial image"""
51
+ if image is None:
52
+ return "Please upload an image."
53
+
54
+ model = load_model()
55
+ if model is None:
56
+ return "Model not available. Please check the model file."
57
+
58
+ # Preprocess the image
59
+ processed_image = preprocess_image(image)
60
+ if processed_image is None:
61
+ return "Error processing image."
62
+
63
+ try:
64
+ # Get predictions
65
+ predictions = model.predict(processed_image)[0]
66
+
67
+ # Get top prediction
68
+ predicted_class = np.argmax(predictions)
69
+ confidence = predictions[predicted_class]
70
+
71
+ # Format results
72
+ result = f"**Predicted Age Group:** {AGE_GROUPS[predicted_class]}\n"
73
+ result += f"**Confidence:** {confidence:.1%}\n\n"
74
+ result += "**All Predictions:**\n"
75
+
76
+ # Sort by confidence
77
+ sorted_indices = np.argsort(predictions)[::-1]
78
+ for i, idx in enumerate(sorted_indices):
79
+ emoji = "🎯" if i == 0 else "πŸ“Š"
80
+ result += f"{emoji} {AGE_GROUPS[idx]}: {predictions[idx]:.1%}\n"
81
+
82
+ return result
83
+
84
+ except Exception as e:
85
+ return f"Error during prediction: {str(e)}"
86
+
87
+ # Create Gradio interface
88
+ def create_demo():
89
+ """Create the Gradio demo interface"""
90
+
91
+ title = "🎯 Age Group Classification"
92
+ description = """
93
+ This model classifies facial images into 5 age groups instead of predicting exact ages.
94
+
95
+ **Why Age Groups?**
96
+ - More reliable than exact age prediction
97
+ - Solves common bias where 70-year-olds are predicted as 30-year-olds
98
+ - More practical for most applications
99
+
100
+ **Age Groups:**
101
+ - πŸ‘Ά Youth (0-20)
102
+ - πŸ§‘ Young Adult (21-40)
103
+ - πŸ‘¨ Middle Age (41-60)
104
+ - πŸ‘΄ Senior (61-80)
105
+ - πŸ‘΅ Elderly (81-100)
106
+
107
+ Upload a clear frontal face image for best results!
108
+ """
109
+
110
+ article = """
111
+ ### Model Details
112
+ - **Architecture:** ResNet50V2 with transfer learning
113
+ - **Performance:** 75.5% validation accuracy
114
+ - **Training:** 13 epochs with early stopping
115
+ - **Dataset:** UTKFace (23,687 images)
116
+
117
+ ### Limitations
118
+ - Works best with frontal face images
119
+ - Performance may vary with extreme lighting
120
+ - Border cases between age groups can be challenging
121
+
122
+ ### Bias Correction
123
+ This model was specifically designed to solve age prediction bias, particularly the common issue where seniors are incorrectly classified as young adults.
124
+ """
125
+
126
+ # Create interface
127
+ iface = gr.Interface(
128
+ fn=predict_age_group,
129
+ inputs=gr.Image(type="numpy", label="Upload Face Image"),
130
+ outputs=gr.Textbox(label="Age Group Prediction", lines=10),
131
+ title=title,
132
+ description=description,
133
+ article=article,
134
+ examples=[
135
+ # Add example images if available
136
+ ],
137
+ theme="default",
138
+ allow_flagging="never"
139
+ )
140
+
141
+ return iface
142
+
143
+ if __name__ == "__main__":
144
+ demo = create_demo()
145
+ demo.launch(
146
+ share=True,
147
+ server_name="0.0.0.0",
148
+ server_port=7860
149
+ )
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0gradio
2
+
3
+ tensorflow>=2.10.0huggingface_hub
4
+
5
+ numpy>=1.21.0tensorflow>=2.10.0
6
+
7
+ Pillow>=8.3.0numpy
8
+
9
+ opencv-python>=4.5.0pillow
10
+ matplotlib
11
+ requests
12
+ tqdm