Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from huggingface_hub import hf_hub_download | |
| # Load the model | |
| def load_model(): | |
| """Load the age group classification model""" | |
| try: | |
| # Download model from Hugging Face Hub | |
| model_path = hf_hub_download( | |
| repo_id="Sharris/age-group-classifier", | |
| filename="resnet50v2_age_classifier_best.h5" | |
| ) | |
| model = tf.keras.models.load_model(model_path) | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Fallback if model file not found | |
| return None | |
| # Age group labels | |
| AGE_GROUPS = { | |
| 0: "Youth (0-20)", | |
| 1: "Young Adult (21-40)", | |
| 2: "Middle Age (41-60)", | |
| 3: "Senior (61-80)", | |
| 4: "Elderly (81-100)" | |
| } | |
| def preprocess_image(image): | |
| """Preprocess image for model input""" | |
| if image is None: | |
| return None | |
| # Convert to RGB if needed | |
| if len(image.shape) == 3 and image.shape[2] == 4: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
| elif len(image.shape) == 3 and image.shape[2] == 1: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
| # Resize to model input size | |
| image = cv2.resize(image, (224, 224)) | |
| # Normalize to [0, 1] | |
| image = image.astype(np.float32) / 255.0 | |
| # Add batch dimension | |
| image = np.expand_dims(image, axis=0) | |
| return image | |
| def predict_age_group(image): | |
| """Predict age group from facial image""" | |
| if image is None: | |
| return "Please upload an image." | |
| model = load_model() | |
| if model is None: | |
| return "Model not available. Please check the model file." | |
| # Preprocess the image | |
| processed_image = preprocess_image(image) | |
| if processed_image is None: | |
| return "Error processing image." | |
| try: | |
| # Get predictions | |
| predictions = model.predict(processed_image)[0] | |
| # Get top prediction | |
| predicted_class = np.argmax(predictions) | |
| confidence = predictions[predicted_class] | |
| # Format results | |
| result = f"**Predicted Age Group:** {AGE_GROUPS[predicted_class]}\n" | |
| result += f"**Confidence:** {confidence:.1%}\n\n" | |
| result += "**All Predictions:**\n" | |
| # Sort by confidence | |
| sorted_indices = np.argsort(predictions)[::-1] | |
| for i, idx in enumerate(sorted_indices): | |
| emoji = "π―" if i == 0 else "π" | |
| result += f"{emoji} {AGE_GROUPS[idx]}: {predictions[idx]:.1%}\n" | |
| return result | |
| except Exception as e: | |
| return f"Error during prediction: {str(e)}" | |
| # Create Gradio interface | |
| def create_demo(): | |
| """Create the Gradio demo interface""" | |
| title = "π― Age Group Classification" | |
| description = """ | |
| This model classifies facial images into 5 age groups instead of predicting exact ages. | |
| **Why Age Groups?** | |
| - More reliable than exact age prediction | |
| - Solves common bias where 70-year-olds are predicted as 30-year-olds | |
| - More practical for most applications | |
| **Age Groups:** | |
| - πΆ Youth (0-20) | |
| - π§ Young Adult (21-40) | |
| - π¨ Middle Age (41-60) | |
| - π΄ Senior (61-80) | |
| - π΅ Elderly (81-100) | |
| Upload a clear frontal face image for best results! | |
| """ | |
| article = """ | |
| ### Model Details | |
| - **Architecture:** ResNet50V2 with transfer learning | |
| - **Performance:** 75.5% validation accuracy | |
| - **Training:** 13 epochs with early stopping | |
| - **Dataset:** UTKFace (23,687 images) | |
| ### Limitations | |
| - Works best with frontal face images | |
| - Performance may vary with extreme lighting | |
| - Border cases between age groups can be challenging | |
| ### Bias Correction | |
| This model was specifically designed to solve age prediction bias, particularly the common issue where seniors are incorrectly classified as young adults. | |
| """ | |
| # Create interface | |
| iface = gr.Interface( | |
| fn=predict_age_group, | |
| inputs=gr.Image(type="numpy", label="Upload Face Image"), | |
| outputs=gr.Textbox(label="Age Group Prediction", lines=10), | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=[ | |
| # Add example images if available | |
| ], | |
| theme="default", | |
| allow_flagging="never" | |
| ) | |
| return iface | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch( | |
| share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) |