CAPIMAC / Datasets.py
bestow136's picture
Upload 13 files
8ffcfd0 verified
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
class GetDataset(Dataset):
def __init__(self, data, labels, real_labels):
self.data = data
self.labels = labels
self.real_labels = real_labels
def __getitem__(self, index):
fea0, fea1 = torch.from_numpy(self.data[0][:, index]).float(), torch.from_numpy(self.data[1][:, index]).float()
fea0, fea1 = fea0.unsqueeze(0), fea1.unsqueeze(0)
label = np.int64(self.labels[index])
if len(self.real_labels) == 0:
return fea0, fea1, label
real_label = np.int64(self.real_labels[index])
return fea0, fea1, label, real_label
def __len__(self):
return len(self.labels)
class GetAllDataset(Dataset):
def __init__(self, data, labels, class_labels0, class_labels1):
self.data = data
self.labels = labels
self.class_labels0 = class_labels0
self.class_labels1 = class_labels1
def __getitem__(self, index):
fea0, fea1 = torch.from_numpy(self.data[0][:, index]).float(), torch.from_numpy(self.data[1][:, index]).float()
fea0, fea1 = fea0.unsqueeze(0), fea1.unsqueeze(0)
label = np.int64(self.labels[index])
class_labels0 = np.int64(self.class_labels0[index])
class_labels1 = np.int64(self.class_labels1[index])
return fea0, fea1, label, class_labels0, class_labels1
def __len__(self):
return len(self.labels)