Sammy Harris
Fix Gradio compatibility: remove deprecated @gr .cache and add model download from HF Hub
2cf61e8
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import cv2
from huggingface_hub import hf_hub_download
# Load the model
def load_model():
"""Load the age group classification model"""
try:
# Download model from Hugging Face Hub
model_path = hf_hub_download(
repo_id="Sharris/age-group-classifier",
filename="resnet50v2_age_classifier_best.h5"
)
model = tf.keras.models.load_model(model_path)
return model
except Exception as e:
print(f"Error loading model: {e}")
# Fallback if model file not found
return None
# Age group labels
AGE_GROUPS = {
0: "Youth (0-20)",
1: "Young Adult (21-40)",
2: "Middle Age (41-60)",
3: "Senior (61-80)",
4: "Elderly (81-100)"
}
def preprocess_image(image):
"""Preprocess image for model input"""
if image is None:
return None
# Convert to RGB if needed
if len(image.shape) == 3 and image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
elif len(image.shape) == 3 and image.shape[2] == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# Resize to model input size
image = cv2.resize(image, (224, 224))
# Normalize to [0, 1]
image = image.astype(np.float32) / 255.0
# Add batch dimension
image = np.expand_dims(image, axis=0)
return image
def predict_age_group(image):
"""Predict age group from facial image"""
if image is None:
return "Please upload an image."
model = load_model()
if model is None:
return "Model not available. Please check the model file."
# Preprocess the image
processed_image = preprocess_image(image)
if processed_image is None:
return "Error processing image."
try:
# Get predictions
predictions = model.predict(processed_image)[0]
# Get top prediction
predicted_class = np.argmax(predictions)
confidence = predictions[predicted_class]
# Format results
result = f"**Predicted Age Group:** {AGE_GROUPS[predicted_class]}\n"
result += f"**Confidence:** {confidence:.1%}\n\n"
result += "**All Predictions:**\n"
# Sort by confidence
sorted_indices = np.argsort(predictions)[::-1]
for i, idx in enumerate(sorted_indices):
emoji = "🎯" if i == 0 else "πŸ“Š"
result += f"{emoji} {AGE_GROUPS[idx]}: {predictions[idx]:.1%}\n"
return result
except Exception as e:
return f"Error during prediction: {str(e)}"
# Create Gradio interface
def create_demo():
"""Create the Gradio demo interface"""
title = "🎯 Age Group Classification"
description = """
This model classifies facial images into 5 age groups instead of predicting exact ages.
**Why Age Groups?**
- More reliable than exact age prediction
- Solves common bias where 70-year-olds are predicted as 30-year-olds
- More practical for most applications
**Age Groups:**
- πŸ‘Ά Youth (0-20)
- πŸ§‘ Young Adult (21-40)
- πŸ‘¨ Middle Age (41-60)
- πŸ‘΄ Senior (61-80)
- πŸ‘΅ Elderly (81-100)
Upload a clear frontal face image for best results!
"""
article = """
### Model Details
- **Architecture:** ResNet50V2 with transfer learning
- **Performance:** 75.5% validation accuracy
- **Training:** 13 epochs with early stopping
- **Dataset:** UTKFace (23,687 images)
### Limitations
- Works best with frontal face images
- Performance may vary with extreme lighting
- Border cases between age groups can be challenging
### Bias Correction
This model was specifically designed to solve age prediction bias, particularly the common issue where seniors are incorrectly classified as young adults.
"""
# Create interface
iface = gr.Interface(
fn=predict_age_group,
inputs=gr.Image(type="numpy", label="Upload Face Image"),
outputs=gr.Textbox(label="Age Group Prediction", lines=10),
title=title,
description=description,
article=article,
examples=[
# Add example images if available
],
theme="default",
allow_flagging="never"
)
return iface
if __name__ == "__main__":
demo = create_demo()
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
)