""" Simple PyTorch Dataset for VPR (Visual Place Recognition) Combines database and query images with ground truth lookup. """ import json from pathlib import Path from PIL import Image import torch from torch.utils.data import Dataset import torchvision.transforms as T from typing import List, Dict, Tuple class VPRDataset(Dataset): """ Simple VPR Dataset that loads both database and query images. Usage: dataset = VPRDataset('data') # Get an image img, filename, is_query = dataset[0] # Get ground truth matches for a query matches = dataset.gt('place00000123_q0000.jpg') """ def __init__( self, data_dir='data', transform=None, include_queries=True, include_database=True ): """ Args: data_dir: Path to data folder containing database/, query/, and ground_truth.json transform: Optional torchvision transforms to apply to images include_queries: Whether to include query images in the dataset include_database: Whether to include database images in the dataset """ self.data_dir = Path(data_dir) self.include_queries = include_queries self.include_database = include_database # Default transform if none provided if transform is None: self.transform = T.Compose([ T.Resize((480, 640)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: self.transform = transform # Load ground truth gt_path = self.data_dir / 'ground_truth.json' with open(gt_path, 'r') as f: self.ground_truth = json.load(f) # Build the dataset items list self.items = [] if include_database: for item in self.ground_truth['database']: self.items.append({ 'filename': item['filename'], 'path': self.data_dir / 'database' / item['filename'], 'place_id': item['place_id'], 'is_query': False, 'city': item['city'], 'lat': item['lat'], 'lon': item['lon'] }) if include_queries: for item in self.ground_truth['query']: self.items.append({ 'filename': item['filename'], 'path': self.data_dir / 'query' / item['filename'], 'place_id': item['place_id'], 'is_query': True, 'city': item['city'], 'lat': item['lat'], 'lon': item['lon'] }) # Build lookup tables for fast ground truth queries self._build_lookup_tables() def _build_lookup_tables(self): """Build internal lookup tables for efficient ground truth queries.""" # Map filename -> full item info self.filename_to_item = {item['filename']: item for item in self.items} # Map place_id -> list of database filenames self.place_to_db_files = {} for item in self.ground_truth['database']: place_id = item['place_id'] if place_id not in self.place_to_db_files: self.place_to_db_files[place_id] = [] self.place_to_db_files[place_id].append(item['filename']) # Map query filename -> its place_id for fast lookup self.query_to_place = { item['filename']: item['place_id'] for item in self.ground_truth['query'] } def __len__(self): """Return total number of images in dataset.""" return len(self.items) def __getitem__(self, idx) -> Tuple[torch.Tensor, str, bool]: """ Get an image from the dataset. Returns: tuple: (image_tensor, filename, is_query) - image_tensor: Transformed image as torch.Tensor - filename: String filename (e.g., 'place00000123_db0001.jpg') - is_query: Boolean indicating if this is a query image """ item = self.items[idx] # Load image img = Image.open(item['path']).convert('RGB') # Apply transforms if self.transform: img = self.transform(img) return img, item['filename'], item['is_query'] def gt(self, query_filename: str) -> List[str]: """ Get ground truth database matches for a query image. Args: query_filename: Filename of the query image (e.g., 'place00000123_q0000.jpg') Returns: List of database image filenames that match this query (same place_id) Example: >>> dataset = VPRDataset('data') >>> matches = dataset.gt('place00000123_q0000.jpg') >>> print(matches) ['place00000123_db0000.jpg', 'place00000123_db0001.jpg', 'place00000123_db0002.jpg'] """ if query_filename not in self.query_to_place: raise ValueError(f"Query filename '{query_filename}' not found in dataset") place_id = self.query_to_place[query_filename] return self.place_to_db_files.get(place_id, []) def get_query_filenames(self) -> List[str]: """Get list of all query image filenames.""" return list(self.query_to_place.keys()) def get_database_filenames(self) -> List[str]: """Get list of all database image filenames.""" all_db_files = [] for files in self.place_to_db_files.values(): all_db_files.extend(files) return all_db_files def get_item_by_filename(self, filename: str) -> Dict: """ Get full item information by filename. Args: filename: Image filename Returns: Dictionary with keys: filename, path, place_id, is_query, city, lat, lon """ if filename not in self.filename_to_item: raise ValueError(f"Filename '{filename}' not found in dataset") return self.filename_to_item[filename] @staticmethod def get_place_id_from_filename(filename: str) -> int: """ Extract place_id from filename. Args: filename: Image filename (e.g., 'place00000123_db0001.jpg') Returns: Integer place_id (e.g., 123) """ return int(filename.split('_')[0].replace('place', '')) # ============================================================================ # EXAMPLE USAGE # ============================================================================ if __name__ == "__main__": from torch.utils.data import DataLoader print("=" * 60) print("EXAMPLE 1: Basic Dataset Usage") print("=" * 60) # Create dataset with both queries and database dataset = VPRDataset('data') print(f"Total images in dataset: {len(dataset)}") print(f"Query images: {len(dataset.get_query_filenames())}") print(f"Database images: {len(dataset.get_database_filenames())}") print() # Get a single image img, filename, is_query = dataset[0] print(f"First image:") print(f" Filename: {filename}") print(f" Is query: {is_query}") print(f" Image shape: {img.shape}") print() print("=" * 60) print("EXAMPLE 2: Ground Truth Lookup") print("=" * 60) # Get a query filename query_files = dataset.get_query_filenames() query_file = query_files[0] print(f"Query: {query_file}") # Get ground truth matches matches = dataset.gt(query_file) print(f"Ground truth matches ({len(matches)} images):") for match in matches: print(f" - {match}") print() print("=" * 60) print("EXAMPLE 3: Create Separate Query and Database Datasets") print("=" * 60) # Create database-only dataset db_dataset = VPRDataset('data', include_queries=False, include_database=True) print(f"Database-only dataset size: {len(db_dataset)}") # Create query-only dataset query_dataset = VPRDataset('data', include_queries=True, include_database=False) print(f"Query-only dataset size: {len(query_dataset)}") print() print("=" * 60) print("EXAMPLE 4: Using with DataLoader") print("=" * 60) # Create dataloader dataloader = DataLoader(dataset, batch_size=4, shuffle=False) # Get a batch batch_imgs, batch_filenames, batch_is_query = next(iter(dataloader)) print(f"Batch shape: {batch_imgs.shape}") print(f"Batch filenames: {batch_filenames}") print(f"Batch is_query flags: {batch_is_query}") print() print("=" * 60) print("EXAMPLE 5: Get Item Info by Filename") print("=" * 60) item_info = dataset.get_item_by_filename(query_file) print(f"Full info for {query_file}:") for key, value in item_info.items(): if key != 'path': # Skip path for cleaner output print(f" {key}: {value}") print() print("=" * 60) print("EXAMPLE 6: Typical VPR Workflow") print("=" * 60) print("Typical usage pattern:") print(""" # 1. Create separate datasets db_dataset = VPRDataset('data', include_queries=False) query_dataset = VPRDataset('data', include_database=False) # 2. Extract features for all database images db_features = [] db_filenames = [] for img, filename, _ in db_dataset: feat = model(img.unsqueeze(0)) # Your VPR model db_features.append(feat) db_filenames.append(filename) # 3. For each query, find matches for img, query_filename, _ in query_dataset: # Extract query features query_feat = model(img.unsqueeze(0)) # Compute similarities with database similarities = compute_similarity(query_feat, db_features) # Get top-K predictions top_k_indices = similarities.argsort()[::-1][:10] predicted_files = [db_filenames[i] for i in top_k_indices] # Get ground truth gt_files = query_dataset.gt(query_filename) # Evaluate: check if any gt_files are in predicted_files recall_at_10 = any(gt in predicted_files for gt in gt_files) """)