import gradio as gr import tensorflow as tf import numpy as np from PIL import Image import cv2 from huggingface_hub import hf_hub_download # Load the model def load_model(): """Load the age group classification model""" try: # Download model from Hugging Face Hub model_path = hf_hub_download( repo_id="Sharris/age-group-classifier", filename="resnet50v2_age_classifier_best.h5" ) model = tf.keras.models.load_model(model_path) return model except Exception as e: print(f"Error loading model: {e}") # Fallback if model file not found return None # Age group labels AGE_GROUPS = { 0: "Youth (0-20)", 1: "Young Adult (21-40)", 2: "Middle Age (41-60)", 3: "Senior (61-80)", 4: "Elderly (81-100)" } def preprocess_image(image): """Preprocess image for model input""" if image is None: return None # Convert to RGB if needed if len(image.shape) == 3 and image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) elif len(image.shape) == 3 and image.shape[2] == 1: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # Resize to model input size image = cv2.resize(image, (224, 224)) # Normalize to [0, 1] image = image.astype(np.float32) / 255.0 # Add batch dimension image = np.expand_dims(image, axis=0) return image def predict_age_group(image): """Predict age group from facial image""" if image is None: return "Please upload an image." model = load_model() if model is None: return "Model not available. Please check the model file." # Preprocess the image processed_image = preprocess_image(image) if processed_image is None: return "Error processing image." try: # Get predictions predictions = model.predict(processed_image)[0] # Get top prediction predicted_class = np.argmax(predictions) confidence = predictions[predicted_class] # Format results result = f"**Predicted Age Group:** {AGE_GROUPS[predicted_class]}\n" result += f"**Confidence:** {confidence:.1%}\n\n" result += "**All Predictions:**\n" # Sort by confidence sorted_indices = np.argsort(predictions)[::-1] for i, idx in enumerate(sorted_indices): emoji = "🎯" if i == 0 else "📊" result += f"{emoji} {AGE_GROUPS[idx]}: {predictions[idx]:.1%}\n" return result except Exception as e: return f"Error during prediction: {str(e)}" # Create Gradio interface def create_demo(): """Create the Gradio demo interface""" title = "🎯 Age Group Classification" description = """ This model classifies facial images into 5 age groups instead of predicting exact ages. **Why Age Groups?** - More reliable than exact age prediction - Solves common bias where 70-year-olds are predicted as 30-year-olds - More practical for most applications **Age Groups:** - 👶 Youth (0-20) - 🧑 Young Adult (21-40) - 👨 Middle Age (41-60) - 👴 Senior (61-80) - 👵 Elderly (81-100) Upload a clear frontal face image for best results! """ article = """ ### Model Details - **Architecture:** ResNet50V2 with transfer learning - **Performance:** 75.5% validation accuracy - **Training:** 13 epochs with early stopping - **Dataset:** UTKFace (23,687 images) ### Limitations - Works best with frontal face images - Performance may vary with extreme lighting - Border cases between age groups can be challenging ### Bias Correction This model was specifically designed to solve age prediction bias, particularly the common issue where seniors are incorrectly classified as young adults. """ # Create interface iface = gr.Interface( fn=predict_age_group, inputs=gr.Image(type="numpy", label="Upload Face Image"), outputs=gr.Textbox(label="Age Group Prediction", lines=10), title=title, description=description, article=article, examples=[ # Add example images if available ], theme="default", allow_flagging="never" ) return iface if __name__ == "__main__": demo = create_demo() demo.launch( share=True, server_name="0.0.0.0", server_port=7860 )