#!/usr/bin/env python3 import torch from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image import gradio as gr import numpy as np from typing import Dict, Tuple import logging from pathlib import Path # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class SkincareInference: """Lightweight skincare classifier for HuggingFace Spaces.""" def __init__(self, model_name: str = '0xnu/skincare-detection'): """Initialize the model with error handling for cloud deployment.""" self.model_name = model_name self.processor = None self.model = None self.id2label = None self.load_model() def load_model(self) -> None: """Load the model with comprehensive error handling.""" try: logger.info(f"Loading model: {self.model_name}") # Load processor and model self.processor = ViTImageProcessor.from_pretrained(self.model_name) self.model = ViTForImageClassification.from_pretrained(self.model_name) self.model.eval() # Extract class labels self.id2label = self.model.config.id2label logger.info(f"Model loaded successfully with {len(self.id2label)} classes") except Exception as e: logger.error(f"Failed to load model: {e}") raise RuntimeError(f"Model loading failed: {str(e)}") def predict(self, image: Image.Image) -> Tuple[Dict[str, float], str]: """ Classify image and return predictions with confidence scores. Args: image: PIL Image object Returns: Tuple of (class_probabilities_dict, formatted_output_string) """ if self.model is None or self.processor is None: return {}, "โŒ Model not loaded properly" try: # Ensure image is RGB if image.mode != 'RGB': image = image.convert('RGB') # Process image inputs = self.processor(images=image, return_tensors="pt") # Get predictions with torch.no_grad(): outputs = self.model(**inputs) probabilities = torch.softmax(outputs.logits, dim=-1)[0] # Convert to dictionary with class names class_probs = {} for class_id, prob in enumerate(probabilities): class_name = self.id2label[class_id] class_probs[class_name] = float(prob) # Sort by confidence sorted_probs = dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True)) # Format output text output_text = self._format_results(sorted_probs, image.size) return sorted_probs, output_text except Exception as e: error_msg = f"โŒ Classification failed: {str(e)}" logger.error(error_msg) return {}, error_msg def _format_results(self, class_probs: Dict[str, float], image_size: Tuple[int, int]) -> str: """Format classification results as readable text.""" # Get top prediction top_class = next(iter(class_probs)) top_confidence = class_probs[top_class] # Build formatted output output_lines = [ f"๐ŸŽฏ **Top Prediction:** {top_class.title()}", f"๐Ÿ“Š **Confidence:** {top_confidence:.1%}", f"๐Ÿ“ **Image Size:** {image_size[0]} ร— {image_size[1]} pixels", "", "๐Ÿ“ˆ **Top 10 Predictions:**" ] # Add only top 10 predictions with visual bars for idx, (class_name, confidence) in enumerate(class_probs.items()): if idx >= 10: # Only show top 10 break if confidence >= 0.001: # Only show > 0.1% # Create visual confidence bar bar_length = int(confidence * 20) # Scale to 20 chars bar = "โ–ˆ" * bar_length + "โ–‘" * (20 - bar_length) output_lines.append( f"โ€ข **{class_name.title()}:** {confidence:.1%} `{bar}`" ) return "\n".join(output_lines) # Initialize the classifier globally try: classifier = SkincareInference() model_loaded = True except Exception as e: logger.error(f"Failed to initialize classifier: {e}") classifier = None model_loaded = False def classify_image(image: Image.Image) -> str: """ Main prediction function for Gradio interface. Args: image: PIL Image from Gradio Returns: Formatted text output with results """ if not model_loaded or classifier is None: return "โŒ Model not available. Please try again later." if image is None: return "โŒ Please upload an image first." try: # Get predictions class_probs, formatted_output = classifier.predict(image) if not class_probs: # Error case return formatted_output return formatted_output except Exception as e: logger.error(f"Classification error: {e}") return f"โŒ Processing failed: {str(e)}" # Create Gradio interface def create_interface(): """Create and configure the Gradio interface.""" with gr.Blocks( title="Skincare Product Classifier", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 900px !important; margin: auto !important; } """ ) as interface: gr.Markdown(""" # ๐Ÿงด Skincare Disease Identification Upload an image and identify skin diseases easily. """) with gr.Row(): with gr.Column(scale=1): # Input components image_input = gr.Image( type="pil", label="Upload Skincare Product Image", height=400 ) classify_btn = gr.Button( "๐Ÿ” Identify Disease", variant="primary", size="lg" ) # Model info if model_loaded: gr.Markdown(f""" **Model:** `{classifier.model_name}` **Classes:** {len(classifier.id2label)} **Status:** โœ… Ready """) else: gr.Markdown("**Status:** โŒ Model loading failed") with gr.Column(scale=1): # Output components result_text = gr.Markdown( label="Classification Results", value="Upload an image and click classify to see results." ) # Example images section gr.Markdown("### ๐Ÿ“‹ Try These Examples") example_images = [ ["examples/andrea.jpeg"], ["examples/disorder.jpeg"], ["examples/joe.jpeg"], ["examples/woman.jpeg"] ] if Path("examples").exists() else None if example_images: gr.Examples( examples=example_images, inputs=[image_input], label="Sample Images" ) # Event handlers classify_btn.click( fn=classify_image, inputs=[image_input], outputs=[result_text] ) image_input.change( fn=classify_image, inputs=[image_input], outputs=[result_text] ) # Footer gr.Markdown(""" --- **Note:** This classifier works best with clear, well-lit images of skin conditions. Results are AI-generated predictions and should be verified by medical professionals. """) return interface # Additional utility functions for deployment def health_check() -> Dict[str, str]: """Health check function for deployment monitoring.""" return { "status": "healthy" if model_loaded else "unhealthy", "model": classifier.model_name if classifier else "not_loaded", "classes": len(classifier.id2label) if classifier else 0 } # Main execution if __name__ == "__main__": # Create and launch interface try: app = create_interface() # Launch configuration for HuggingFace Spaces app.launch( server_name="0.0.0.0", # Listen on all interfaces server_port=7860, # Standard HF Spaces port share=False # Disable sharing for security ) except Exception as e: logger.error(f"Failed to launch interface: {e}") print(f"โŒ Application failed to start: {e}") # For programmatic usage (optional) def classify_from_path(image_path: str) -> Dict: """ Classify image from file path (useful for testing). Args: image_path: Path to image file Returns: Classification results dictionary """ if not model_loaded: return {"error": "Model not loaded"} try: image = Image.open(image_path) class_probs, formatted_output = classifier.predict(image) return { "success": True, "predictions": class_probs, "formatted_output": formatted_output, "image_path": image_path } except Exception as e: return { "success": False, "error": str(e), "image_path": image_path } # Export key functions for external usage __all__ = ['SkincareInference', 'classify_image', 'classify_from_path', 'health_check']