Bowerbird viewpoint classifier (ResNet18)

  • Task: classify each frame into one of four viewpoints:
    ["back", "front", "left_side", "right_side"]
  • Base model: torchvision.models.resnet18 with weights="IMAGENET1K_V1"
  • Input size: 224 × 224 (after cropping)
  • Preprocessing (training/eval):
    • Resize to 256 px on the shorter side
    • Train: RandomResizedCrop(224), RandomRotation(7°), ColorJitter
    • Eval: CenterCrop(224)
  • Normalization:
    • mean = [0.485, 0.456, 0.406]
    • std = [0.229, 0.224, 0.225]
  • Checkpoint file: Bbird_viewpoint_classifier.pth
  • The checkpoint stores a PyTorch state_dict for ResNet18 with a final linear layer of 4 outputs (one per viewpoint class).

This model is not generic. It is specific to the four viewpoint classes listed above. The classification head must have 4 outputs, in the same class order: back, front, left_side, right_side.

Usage

import torch
from torch import nn
from torchvision.models import resnet18
from huggingface_hub import hf_hub_download

# Replace this with the actual repo id on the Hub if different
repo_id = "sarequi/bowerbird-viewpoint-classifier"

# Download checkpoint
ckpt_path = hf_hub_download(
    repo_id=repo_id,
    filename="Bbird_viewpoint_classifier.pth",
)

# Rebuild the model architecture exactly as in training
NUM_CLASSES = 4
model = resnet18(weights="IMAGENET1K_V1")
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)

# Load weights
state_dict = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

VIEWPOINT_CLASSES = ["back", "front", "left_side", "right_side"]
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for sarequi/Bowerbird_viewpoint_classifier

Finetuned
(629)
this model