File size: 2,214 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class TinyImageNetDataset(Dataset):
    def __init__(self, root_dir, transform=None, train=True):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        
        if train:
            # Train set structure: root/train/class/images/*.JPEG
            train_dir = os.path.join(root_dir, 'train')
            for cls in os.listdir(train_dir):
                cls_dir = os.path.join(train_dir, cls, 'images')
                for img_name in os.listdir(cls_dir):
                    if img_name.endswith('.JPEG'):
                        self.image_paths.append(os.path.join(cls_dir, img_name))
        else:
            # Val set structure: root/val/images/*.JPEG
            val_dir = os.path.join(root_dir, 'val')
            images_dir = os.path.join(val_dir, 'images')
            for img_name in os.listdir(images_dir):
                if img_name.endswith('.JPEG'):
                    self.image_paths.append(os.path.join(images_dir, img_name))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, 0  # Dummy label

def get_dataloaders(config):
    transform = transforms.Compose([
        transforms.Resize(config.image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    train_dataset = TinyImageNetDataset(config.dataset_path, transform=transform, train=True)
    val_dataset = TinyImageNetDataset(config.dataset_path, transform=transform, train=False)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers
    )
    
    return train_loader, val_loader