aartstudio commited on
Commit
2c32f6c
·
verified ·
1 Parent(s): 994a388

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -29
app.py CHANGED
@@ -1,29 +1,49 @@
1
- import gradio as gr
2
- from transformers import ViTFeatureExtractor, ViTForImageClassification
3
- from PIL import Image
4
- import torch
5
-
6
- # Load pre-trained model and feature extractor
7
- model_name = "google/vit-base-patch16-224"
8
- feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
9
- model = ViTForImageClassification.from_pretrained(model_name)
10
-
11
- # Define the prediction function
12
- def classify_image(img):
13
- inputs = feature_extractor(images=img, return_tensors="pt")
14
- with torch.no_grad():
15
- outputs = model(**inputs)
16
- logits = outputs.logits
17
- predicted_class_idx = logits.argmax(-1).item()
18
- predicted_label = model.config.id2label[predicted_class_idx]
19
- return predicted_label
20
-
21
- # Build the Gradio interface
22
- interface = gr.Interface(fn=classify_image,
23
- inputs=gr.Image(type="pil"),
24
- outputs="text",
25
- title="Image Classification with ViT",
26
- description="Upload an image and classify it using Vision Transformer (ViT)")
27
-
28
- # Launch the app
29
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoImageProcessor, ConvNextForImageClassification
4
+
5
+ # Choose a stronger, free image model from Hugging Face
6
+ # You can swap this string for any other image-classification model on HF
7
+ model_name = "facebook/convnext-base-224-22k-1k"
8
+
9
+ # Load pre-trained image processor and model
10
+ image_processor = AutoImageProcessor.from_pretrained(model_name)
11
+ model = ConvNextForImageClassification.from_pretrained(model_name)
12
+
13
+ # Define the prediction function (top 5 classes)
14
+ def classify_image(img):
15
+ # Preprocess image
16
+ inputs = image_processor(images=img, return_tensors="pt")
17
+
18
+ with torch.no_grad():
19
+ outputs = model(**inputs)
20
+
21
+ logits = outputs.logits
22
+ probs = torch.softmax(logits, dim=-1)[0] # shape: [num_classes]
23
+
24
+ # Get top 5 predictions
25
+ topk = torch.topk(probs, k=5)
26
+ top_probs = topk.values
27
+ top_indices = topk.indices
28
+
29
+ # Map indices to labels and convert to a dict that Gradio's Label understands
30
+ results = {}
31
+ for score, idx in zip(top_probs, top_indices):
32
+ label = model.config.id2label[idx.item()]
33
+ results[label] = float(score.item())
34
+
35
+ return results # Gradio Label will show top-k nicely
36
+
37
+
38
+ # Build the Gradio interface
39
+ interface = gr.Interface(
40
+ fn=classify_image,
41
+ inputs=gr.Image(type="pil"),
42
+ outputs=gr.Label(num_top_classes=5),
43
+ title="Image Classification with ConvNeXt (Top-5)",
44
+ description="Upload an image to see the top 5 predicted classes using a ConvNeXt image model from Hugging Face."
45
+ )
46
+
47
+ # Launch the app
48
+ if __name__ == "__main__":
49
+ interface.launch()