import gradio as gr import torch from transformers import AutoImageProcessor, ConvNextForImageClassification # Choose a stronger, free image model from Hugging Face # You can swap this string for any other image-classification model on HF model_name = "facebook/convnext-base-224-22k-1k" # Load pre-trained image processor and model image_processor = AutoImageProcessor.from_pretrained(model_name) model = ConvNextForImageClassification.from_pretrained(model_name) # Define the prediction function (top 5 classes) def classify_image(img): # Preprocess image inputs = image_processor(images=img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.softmax(logits, dim=-1)[0] # shape: [num_classes] # Get top 5 predictions topk = torch.topk(probs, k=5) top_probs = topk.values top_indices = topk.indices # Map indices to labels and convert to a dict that Gradio's Label understands results = {} for score, idx in zip(top_probs, top_indices): label = model.config.id2label[idx.item()] results[label] = float(score.item()) return results # Gradio Label will show top-k nicely # Build the Gradio interface interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), title="Image Classification with ConvNeXt (Top-5)", description="Upload an image to see the top 5 predicted classes using a ConvNeXt image model from Hugging Face." ) # Launch the app if __name__ == "__main__": interface.launch()