Spaces:
Sleeping
Sleeping
| """ | |
| Simple PyTorch Dataset for VPR (Visual Place Recognition) | |
| Combines database and query images with ground truth lookup. | |
| """ | |
| import json | |
| from pathlib import Path | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import Dataset | |
| import torchvision.transforms as T | |
| from typing import List, Dict, Tuple | |
| class VPRDataset(Dataset): | |
| """ | |
| Simple VPR Dataset that loads both database and query images. | |
| Usage: | |
| dataset = VPRDataset('data') | |
| # Get an image | |
| img, filename, is_query = dataset[0] | |
| # Get ground truth matches for a query | |
| matches = dataset.gt('place00000123_q0000.jpg') | |
| """ | |
| def __init__( | |
| self, | |
| data_dir='data', | |
| transform=None, | |
| include_queries=True, | |
| include_database=True | |
| ): | |
| """ | |
| Args: | |
| data_dir: Path to data folder containing database/, query/, and ground_truth.json | |
| transform: Optional torchvision transforms to apply to images | |
| include_queries: Whether to include query images in the dataset | |
| include_database: Whether to include database images in the dataset | |
| """ | |
| self.data_dir = Path(data_dir) | |
| self.include_queries = include_queries | |
| self.include_database = include_database | |
| # Default transform if none provided | |
| if transform is None: | |
| self.transform = T.Compose([ | |
| T.Resize((480, 640)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| else: | |
| self.transform = transform | |
| # Load ground truth | |
| gt_path = self.data_dir / 'ground_truth.json' | |
| with open(gt_path, 'r') as f: | |
| self.ground_truth = json.load(f) | |
| # Build the dataset items list | |
| self.items = [] | |
| if include_database: | |
| for item in self.ground_truth['database']: | |
| self.items.append({ | |
| 'filename': item['filename'], | |
| 'path': self.data_dir / 'database' / item['filename'], | |
| 'place_id': item['place_id'], | |
| 'is_query': False, | |
| 'city': item['city'], | |
| 'lat': item['lat'], | |
| 'lon': item['lon'] | |
| }) | |
| if include_queries: | |
| for item in self.ground_truth['query']: | |
| self.items.append({ | |
| 'filename': item['filename'], | |
| 'path': self.data_dir / 'query' / item['filename'], | |
| 'place_id': item['place_id'], | |
| 'is_query': True, | |
| 'city': item['city'], | |
| 'lat': item['lat'], | |
| 'lon': item['lon'] | |
| }) | |
| # Build lookup tables for fast ground truth queries | |
| self._build_lookup_tables() | |
| def _build_lookup_tables(self): | |
| """Build internal lookup tables for efficient ground truth queries.""" | |
| # Map filename -> full item info | |
| self.filename_to_item = {item['filename']: item for item in self.items} | |
| # Map place_id -> list of database filenames | |
| self.place_to_db_files = {} | |
| for item in self.ground_truth['database']: | |
| place_id = item['place_id'] | |
| if place_id not in self.place_to_db_files: | |
| self.place_to_db_files[place_id] = [] | |
| self.place_to_db_files[place_id].append(item['filename']) | |
| # Map query filename -> its place_id for fast lookup | |
| self.query_to_place = { | |
| item['filename']: item['place_id'] | |
| for item in self.ground_truth['query'] | |
| } | |
| def __len__(self): | |
| """Return total number of images in dataset.""" | |
| return len(self.items) | |
| def __getitem__(self, idx) -> Tuple[torch.Tensor, str, bool]: | |
| """ | |
| Get an image from the dataset. | |
| Returns: | |
| tuple: (image_tensor, filename, is_query) | |
| - image_tensor: Transformed image as torch.Tensor | |
| - filename: String filename (e.g., 'place00000123_db0001.jpg') | |
| - is_query: Boolean indicating if this is a query image | |
| """ | |
| item = self.items[idx] | |
| # Load image | |
| img = Image.open(item['path']).convert('RGB') | |
| # Apply transforms | |
| if self.transform: | |
| img = self.transform(img) | |
| return img, item['filename'], item['is_query'] | |
| def gt(self, query_filename: str) -> List[str]: | |
| """ | |
| Get ground truth database matches for a query image. | |
| Args: | |
| query_filename: Filename of the query image (e.g., 'place00000123_q0000.jpg') | |
| Returns: | |
| List of database image filenames that match this query (same place_id) | |
| Example: | |
| >>> dataset = VPRDataset('data') | |
| >>> matches = dataset.gt('place00000123_q0000.jpg') | |
| >>> print(matches) | |
| ['place00000123_db0000.jpg', 'place00000123_db0001.jpg', 'place00000123_db0002.jpg'] | |
| """ | |
| if query_filename not in self.query_to_place: | |
| raise ValueError(f"Query filename '{query_filename}' not found in dataset") | |
| place_id = self.query_to_place[query_filename] | |
| return self.place_to_db_files.get(place_id, []) | |
| def get_query_filenames(self) -> List[str]: | |
| """Get list of all query image filenames.""" | |
| return list(self.query_to_place.keys()) | |
| def get_database_filenames(self) -> List[str]: | |
| """Get list of all database image filenames.""" | |
| all_db_files = [] | |
| for files in self.place_to_db_files.values(): | |
| all_db_files.extend(files) | |
| return all_db_files | |
| def get_item_by_filename(self, filename: str) -> Dict: | |
| """ | |
| Get full item information by filename. | |
| Args: | |
| filename: Image filename | |
| Returns: | |
| Dictionary with keys: filename, path, place_id, is_query, city, lat, lon | |
| """ | |
| if filename not in self.filename_to_item: | |
| raise ValueError(f"Filename '{filename}' not found in dataset") | |
| return self.filename_to_item[filename] | |
| def get_place_id_from_filename(filename: str) -> int: | |
| """ | |
| Extract place_id from filename. | |
| Args: | |
| filename: Image filename (e.g., 'place00000123_db0001.jpg') | |
| Returns: | |
| Integer place_id (e.g., 123) | |
| """ | |
| return int(filename.split('_')[0].replace('place', '')) | |
| # ============================================================================ | |
| # EXAMPLE USAGE | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| from torch.utils.data import DataLoader | |
| print("=" * 60) | |
| print("EXAMPLE 1: Basic Dataset Usage") | |
| print("=" * 60) | |
| # Create dataset with both queries and database | |
| dataset = VPRDataset('data') | |
| print(f"Total images in dataset: {len(dataset)}") | |
| print(f"Query images: {len(dataset.get_query_filenames())}") | |
| print(f"Database images: {len(dataset.get_database_filenames())}") | |
| print() | |
| # Get a single image | |
| img, filename, is_query = dataset[0] | |
| print(f"First image:") | |
| print(f" Filename: {filename}") | |
| print(f" Is query: {is_query}") | |
| print(f" Image shape: {img.shape}") | |
| print() | |
| print("=" * 60) | |
| print("EXAMPLE 2: Ground Truth Lookup") | |
| print("=" * 60) | |
| # Get a query filename | |
| query_files = dataset.get_query_filenames() | |
| query_file = query_files[0] | |
| print(f"Query: {query_file}") | |
| # Get ground truth matches | |
| matches = dataset.gt(query_file) | |
| print(f"Ground truth matches ({len(matches)} images):") | |
| for match in matches: | |
| print(f" - {match}") | |
| print() | |
| print("=" * 60) | |
| print("EXAMPLE 3: Create Separate Query and Database Datasets") | |
| print("=" * 60) | |
| # Create database-only dataset | |
| db_dataset = VPRDataset('data', include_queries=False, include_database=True) | |
| print(f"Database-only dataset size: {len(db_dataset)}") | |
| # Create query-only dataset | |
| query_dataset = VPRDataset('data', include_queries=True, include_database=False) | |
| print(f"Query-only dataset size: {len(query_dataset)}") | |
| print() | |
| print("=" * 60) | |
| print("EXAMPLE 4: Using with DataLoader") | |
| print("=" * 60) | |
| # Create dataloader | |
| dataloader = DataLoader(dataset, batch_size=4, shuffle=False) | |
| # Get a batch | |
| batch_imgs, batch_filenames, batch_is_query = next(iter(dataloader)) | |
| print(f"Batch shape: {batch_imgs.shape}") | |
| print(f"Batch filenames: {batch_filenames}") | |
| print(f"Batch is_query flags: {batch_is_query}") | |
| print() | |
| print("=" * 60) | |
| print("EXAMPLE 5: Get Item Info by Filename") | |
| print("=" * 60) | |
| item_info = dataset.get_item_by_filename(query_file) | |
| print(f"Full info for {query_file}:") | |
| for key, value in item_info.items(): | |
| if key != 'path': # Skip path for cleaner output | |
| print(f" {key}: {value}") | |
| print() | |
| print("=" * 60) | |
| print("EXAMPLE 6: Typical VPR Workflow") | |
| print("=" * 60) | |
| print("Typical usage pattern:") | |
| print(""" | |
| # 1. Create separate datasets | |
| db_dataset = VPRDataset('data', include_queries=False) | |
| query_dataset = VPRDataset('data', include_database=False) | |
| # 2. Extract features for all database images | |
| db_features = [] | |
| db_filenames = [] | |
| for img, filename, _ in db_dataset: | |
| feat = model(img.unsqueeze(0)) # Your VPR model | |
| db_features.append(feat) | |
| db_filenames.append(filename) | |
| # 3. For each query, find matches | |
| for img, query_filename, _ in query_dataset: | |
| # Extract query features | |
| query_feat = model(img.unsqueeze(0)) | |
| # Compute similarities with database | |
| similarities = compute_similarity(query_feat, db_features) | |
| # Get top-K predictions | |
| top_k_indices = similarities.argsort()[::-1][:10] | |
| predicted_files = [db_filenames[i] for i in top_k_indices] | |
| # Get ground truth | |
| gt_files = query_dataset.gt(query_filename) | |
| # Evaluate: check if any gt_files are in predicted_files | |
| recall_at_10 = any(gt in predicted_files for gt in gt_files) | |
| """) |