|
|
import os |
|
|
from PIL import Image |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision import transforms |
|
|
|
|
|
class TinyImageNetDataset(Dataset): |
|
|
def __init__(self, root_dir, transform=None, train=True): |
|
|
self.root_dir = root_dir |
|
|
self.transform = transform |
|
|
self.image_paths = [] |
|
|
|
|
|
if train: |
|
|
|
|
|
train_dir = os.path.join(root_dir, 'train') |
|
|
for cls in os.listdir(train_dir): |
|
|
cls_dir = os.path.join(train_dir, cls, 'images') |
|
|
for img_name in os.listdir(cls_dir): |
|
|
if img_name.endswith('.JPEG'): |
|
|
self.image_paths.append(os.path.join(cls_dir, img_name)) |
|
|
else: |
|
|
|
|
|
val_dir = os.path.join(root_dir, 'val') |
|
|
images_dir = os.path.join(val_dir, 'images') |
|
|
for img_name in os.listdir(images_dir): |
|
|
if img_name.endswith('.JPEG'): |
|
|
self.image_paths.append(os.path.join(images_dir, img_name)) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_paths) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
img = Image.open(self.image_paths[idx]).convert('RGB') |
|
|
if self.transform: |
|
|
img = self.transform(img) |
|
|
return img, 0 |
|
|
|
|
|
def get_dataloaders(config): |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize(config.image_size), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
]) |
|
|
|
|
|
train_dataset = TinyImageNetDataset(config.dataset_path, transform=transform, train=True) |
|
|
val_dataset = TinyImageNetDataset(config.dataset_path, transform=transform, train=False) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=config.num_workers |
|
|
) |
|
|
|
|
|
return train_loader, val_loader |