simple-vpr-demo / dataset.py
Oliver Grainge
Initial VPR demo implementation
351130e
"""
Simple PyTorch Dataset for VPR (Visual Place Recognition)
Combines database and query images with ground truth lookup.
"""
import json
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
from typing import List, Dict, Tuple
class VPRDataset(Dataset):
"""
Simple VPR Dataset that loads both database and query images.
Usage:
dataset = VPRDataset('data')
# Get an image
img, filename, is_query = dataset[0]
# Get ground truth matches for a query
matches = dataset.gt('place00000123_q0000.jpg')
"""
def __init__(
self,
data_dir='data',
transform=None,
include_queries=True,
include_database=True
):
"""
Args:
data_dir: Path to data folder containing database/, query/, and ground_truth.json
transform: Optional torchvision transforms to apply to images
include_queries: Whether to include query images in the dataset
include_database: Whether to include database images in the dataset
"""
self.data_dir = Path(data_dir)
self.include_queries = include_queries
self.include_database = include_database
# Default transform if none provided
if transform is None:
self.transform = T.Compose([
T.Resize((480, 640)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
else:
self.transform = transform
# Load ground truth
gt_path = self.data_dir / 'ground_truth.json'
with open(gt_path, 'r') as f:
self.ground_truth = json.load(f)
# Build the dataset items list
self.items = []
if include_database:
for item in self.ground_truth['database']:
self.items.append({
'filename': item['filename'],
'path': self.data_dir / 'database' / item['filename'],
'place_id': item['place_id'],
'is_query': False,
'city': item['city'],
'lat': item['lat'],
'lon': item['lon']
})
if include_queries:
for item in self.ground_truth['query']:
self.items.append({
'filename': item['filename'],
'path': self.data_dir / 'query' / item['filename'],
'place_id': item['place_id'],
'is_query': True,
'city': item['city'],
'lat': item['lat'],
'lon': item['lon']
})
# Build lookup tables for fast ground truth queries
self._build_lookup_tables()
def _build_lookup_tables(self):
"""Build internal lookup tables for efficient ground truth queries."""
# Map filename -> full item info
self.filename_to_item = {item['filename']: item for item in self.items}
# Map place_id -> list of database filenames
self.place_to_db_files = {}
for item in self.ground_truth['database']:
place_id = item['place_id']
if place_id not in self.place_to_db_files:
self.place_to_db_files[place_id] = []
self.place_to_db_files[place_id].append(item['filename'])
# Map query filename -> its place_id for fast lookup
self.query_to_place = {
item['filename']: item['place_id']
for item in self.ground_truth['query']
}
def __len__(self):
"""Return total number of images in dataset."""
return len(self.items)
def __getitem__(self, idx) -> Tuple[torch.Tensor, str, bool]:
"""
Get an image from the dataset.
Returns:
tuple: (image_tensor, filename, is_query)
- image_tensor: Transformed image as torch.Tensor
- filename: String filename (e.g., 'place00000123_db0001.jpg')
- is_query: Boolean indicating if this is a query image
"""
item = self.items[idx]
# Load image
img = Image.open(item['path']).convert('RGB')
# Apply transforms
if self.transform:
img = self.transform(img)
return img, item['filename'], item['is_query']
def gt(self, query_filename: str) -> List[str]:
"""
Get ground truth database matches for a query image.
Args:
query_filename: Filename of the query image (e.g., 'place00000123_q0000.jpg')
Returns:
List of database image filenames that match this query (same place_id)
Example:
>>> dataset = VPRDataset('data')
>>> matches = dataset.gt('place00000123_q0000.jpg')
>>> print(matches)
['place00000123_db0000.jpg', 'place00000123_db0001.jpg', 'place00000123_db0002.jpg']
"""
if query_filename not in self.query_to_place:
raise ValueError(f"Query filename '{query_filename}' not found in dataset")
place_id = self.query_to_place[query_filename]
return self.place_to_db_files.get(place_id, [])
def get_query_filenames(self) -> List[str]:
"""Get list of all query image filenames."""
return list(self.query_to_place.keys())
def get_database_filenames(self) -> List[str]:
"""Get list of all database image filenames."""
all_db_files = []
for files in self.place_to_db_files.values():
all_db_files.extend(files)
return all_db_files
def get_item_by_filename(self, filename: str) -> Dict:
"""
Get full item information by filename.
Args:
filename: Image filename
Returns:
Dictionary with keys: filename, path, place_id, is_query, city, lat, lon
"""
if filename not in self.filename_to_item:
raise ValueError(f"Filename '{filename}' not found in dataset")
return self.filename_to_item[filename]
@staticmethod
def get_place_id_from_filename(filename: str) -> int:
"""
Extract place_id from filename.
Args:
filename: Image filename (e.g., 'place00000123_db0001.jpg')
Returns:
Integer place_id (e.g., 123)
"""
return int(filename.split('_')[0].replace('place', ''))
# ============================================================================
# EXAMPLE USAGE
# ============================================================================
if __name__ == "__main__":
from torch.utils.data import DataLoader
print("=" * 60)
print("EXAMPLE 1: Basic Dataset Usage")
print("=" * 60)
# Create dataset with both queries and database
dataset = VPRDataset('data')
print(f"Total images in dataset: {len(dataset)}")
print(f"Query images: {len(dataset.get_query_filenames())}")
print(f"Database images: {len(dataset.get_database_filenames())}")
print()
# Get a single image
img, filename, is_query = dataset[0]
print(f"First image:")
print(f" Filename: {filename}")
print(f" Is query: {is_query}")
print(f" Image shape: {img.shape}")
print()
print("=" * 60)
print("EXAMPLE 2: Ground Truth Lookup")
print("=" * 60)
# Get a query filename
query_files = dataset.get_query_filenames()
query_file = query_files[0]
print(f"Query: {query_file}")
# Get ground truth matches
matches = dataset.gt(query_file)
print(f"Ground truth matches ({len(matches)} images):")
for match in matches:
print(f" - {match}")
print()
print("=" * 60)
print("EXAMPLE 3: Create Separate Query and Database Datasets")
print("=" * 60)
# Create database-only dataset
db_dataset = VPRDataset('data', include_queries=False, include_database=True)
print(f"Database-only dataset size: {len(db_dataset)}")
# Create query-only dataset
query_dataset = VPRDataset('data', include_queries=True, include_database=False)
print(f"Query-only dataset size: {len(query_dataset)}")
print()
print("=" * 60)
print("EXAMPLE 4: Using with DataLoader")
print("=" * 60)
# Create dataloader
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
# Get a batch
batch_imgs, batch_filenames, batch_is_query = next(iter(dataloader))
print(f"Batch shape: {batch_imgs.shape}")
print(f"Batch filenames: {batch_filenames}")
print(f"Batch is_query flags: {batch_is_query}")
print()
print("=" * 60)
print("EXAMPLE 5: Get Item Info by Filename")
print("=" * 60)
item_info = dataset.get_item_by_filename(query_file)
print(f"Full info for {query_file}:")
for key, value in item_info.items():
if key != 'path': # Skip path for cleaner output
print(f" {key}: {value}")
print()
print("=" * 60)
print("EXAMPLE 6: Typical VPR Workflow")
print("=" * 60)
print("Typical usage pattern:")
print("""
# 1. Create separate datasets
db_dataset = VPRDataset('data', include_queries=False)
query_dataset = VPRDataset('data', include_database=False)
# 2. Extract features for all database images
db_features = []
db_filenames = []
for img, filename, _ in db_dataset:
feat = model(img.unsqueeze(0)) # Your VPR model
db_features.append(feat)
db_filenames.append(filename)
# 3. For each query, find matches
for img, query_filename, _ in query_dataset:
# Extract query features
query_feat = model(img.unsqueeze(0))
# Compute similarities with database
similarities = compute_similarity(query_feat, db_features)
# Get top-K predictions
top_k_indices = similarities.argsort()[::-1][:10]
predicted_files = [db_filenames[i] for i in top_k_indices]
# Get ground truth
gt_files = query_dataset.gt(query_filename)
# Evaluate: check if any gt_files are in predicted_files
recall_at_10 = any(gt in predicted_files for gt in gt_files)
""")