resnet50-flowers102-classifier
This model is a fine-tuned version of microsoft/resnet-50 on flowers102 dataset.
It achieves the following results on the evaluation set:
- Loss: 0.2261
- Accuracy: 0.9304
Model description
This model uses a ResNet50 backbone, initialized with ImageNet weights, and was fine-tuned for the 102-class flower classification task.
The custom model structure (ResNetClassifier class) is required to load the weights correctly.
Intended uses & limitations
This model is intended for research use in fine-grained image classification tasks related to flowers.
How to use
Since this model uses a custom class (ResNetClassifier) inheriting from PyTorchModelHubMixin, you must define the class locally and load the configuration separately to get the labels.
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import PyTorchModelHubMixin
from transformers import AutoConfig
import os
# 1. Define the custom model class
class ResNetClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self, num_classes=102, model_name='resnet50', freeze_backbone=True):
super().__init__()
self.num_classes = num_classes
self.model_name = model_name
self.freeze_backbone = freeze_backbone
self.backbone = models.resnet50(weights=None)
num_ftrs = self.backbone.fc.in_features
self.backbone.fc = nn.Linear(num_ftrs, num_classes)
def forward(self, x):
return self.backbone(x)
MODEL_ID = "sukinggg/resnet-50-flowers102"
IMAGE_PATH = "path/to/your/image.jpg" # Update this path
# 2. Load the model weights
model = ResNetClassifier.from_pretrained(MODEL_ID)
model.eval()
# 3. Load the configuration to get the labels
config = AutoConfig.from_pretrained(MODEL_ID)
labels = list(config.id2label.values())
# 4. Define the preprocessing transformation
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 5. Run inference
pil_image = Image.open(IMAGE_PATH).convert('RGB')
input_tensor = preprocess(pil_image).unsqueeze(0)
with torch.no_grad():
outputs = model(input_tensor)
predicted_idx = torch.argmax(outputs).item()
predicted_label = labels[predicted_idx]
print("Detected label:", predicted_label)
- Downloads last month
- 12
Model tree for sukinggg/resnet50-flowers102-classifier
Base model
microsoft/resnet-50