resnet50-flowers102-classifier

This model is a fine-tuned version of microsoft/resnet-50 on flowers102 dataset.

It achieves the following results on the evaluation set:

  • Loss: 0.2261
  • Accuracy: 0.9304

Model description

This model uses a ResNet50 backbone, initialized with ImageNet weights, and was fine-tuned for the 102-class flower classification task.

The custom model structure (ResNetClassifier class) is required to load the weights correctly.

Intended uses & limitations

This model is intended for research use in fine-grained image classification tasks related to flowers.

How to use

Since this model uses a custom class (ResNetClassifier) inheriting from PyTorchModelHubMixin, you must define the class locally and load the configuration separately to get the labels.

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import PyTorchModelHubMixin 
from transformers import AutoConfig
import os

# 1. Define the custom model class
class ResNetClassifier(nn.Module, PyTorchModelHubMixin):
    def __init__(self, num_classes=102, model_name='resnet50', freeze_backbone=True):
        super().__init__()
        self.num_classes = num_classes
        self.model_name = model_name
        self.freeze_backbone = freeze_backbone
        self.backbone = models.resnet50(weights=None) 
        num_ftrs = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.backbone(x)

MODEL_ID = "sukinggg/resnet-50-flowers102"
IMAGE_PATH = "path/to/your/image.jpg" # Update this path

# 2. Load the model weights
model = ResNetClassifier.from_pretrained(MODEL_ID)
model.eval()

# 3. Load the configuration to get the labels
config = AutoConfig.from_pretrained(MODEL_ID)
labels = list(config.id2label.values())

# 4. Define the preprocessing transformation
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 5. Run inference
pil_image = Image.open(IMAGE_PATH).convert('RGB')
input_tensor = preprocess(pil_image).unsqueeze(0)

with torch.no_grad():
    outputs = model(input_tensor)
    predicted_idx = torch.argmax(outputs).item()

predicted_label = labels[predicted_idx]

print("Detected label:", predicted_label)
Downloads last month
12
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for sukinggg/resnet50-flowers102-classifier

Finetuned
(447)
this model

Evaluation results