| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| import torch | |
| class GetDataset(Dataset): | |
| def __init__(self, data, labels, real_labels): | |
| self.data = data | |
| self.labels = labels | |
| self.real_labels = real_labels | |
| def __getitem__(self, index): | |
| fea0, fea1 = torch.from_numpy(self.data[0][:, index]).float(), torch.from_numpy(self.data[1][:, index]).float() | |
| fea0, fea1 = fea0.unsqueeze(0), fea1.unsqueeze(0) | |
| label = np.int64(self.labels[index]) | |
| if len(self.real_labels) == 0: | |
| return fea0, fea1, label | |
| real_label = np.int64(self.real_labels[index]) | |
| return fea0, fea1, label, real_label | |
| def __len__(self): | |
| return len(self.labels) | |
| class GetAllDataset(Dataset): | |
| def __init__(self, data, labels, class_labels0, class_labels1): | |
| self.data = data | |
| self.labels = labels | |
| self.class_labels0 = class_labels0 | |
| self.class_labels1 = class_labels1 | |
| def __getitem__(self, index): | |
| fea0, fea1 = torch.from_numpy(self.data[0][:, index]).float(), torch.from_numpy(self.data[1][:, index]).float() | |
| fea0, fea1 = fea0.unsqueeze(0), fea1.unsqueeze(0) | |
| label = np.int64(self.labels[index]) | |
| class_labels0 = np.int64(self.class_labels0[index]) | |
| class_labels1 = np.int64(self.class_labels1[index]) | |
| return fea0, fea1, label, class_labels0, class_labels1 | |
| def __len__(self): | |
| return len(self.labels) | |