import torch import pandas as pd import typing import math import utils import numpy as np import os base_path = "" # Fill in directory of the pretrained checkpoints, e.g., "...../data_and_model/" LOGGER = utils.get_logger(__name__) DNA_ALPHABET = {'A': 0, 'C': 1, 'G': 2, 'T': 3} #, 'M': 4} INDEX_TO_DNA = {v: k for k, v in DNA_ALPHABET.items()} lookup_array = np.array([INDEX_TO_DNA[i] for i in range(len(INDEX_TO_DNA))]) def dna_detokenize(seq): return ''.join([list(DNA_ALPHABET.keys())[int(i)] for i in seq]) def batch_dna_detokenize(batch_seq): """ batch_seq: numpy array of shape [batch_size, seq_len] return: list of strings """ detokenized_batch = lookup_array[batch_seq] detokenized_batch = [''.join(seq) for seq in detokenized_batch] return detokenized_batch def dna_tokenize(seq): return [DNA_ALPHABET[c] for c in seq] def batch_dna_tokenize(batch_seq): """ batch_seq: list of strings return: numpy array of shape [batch_size, seq_len] """ tokenized_batch = np.array([[DNA_ALPHABET[c] for c in seq] for seq in batch_seq]) return tokenized_batch class GosaiDataset(torch.utils.data.Dataset): def __init__(self): data_df = pd.read_csv(os.path.join(base_path, f'mdlm/gosai_data/processed_data/gosai_all.csv')) self.seqs = torch.tensor(data_df['seq'].apply(lambda x: [DNA_ALPHABET[c] for c in x]).tolist()) self.clss = torch.tensor(data_df[['hepg2', 'k562', 'sknsh']].to_numpy()) LOGGER.info(f'Loaded data: seqs shape: {self.seqs.shape}, clss shape: {self.clss.shape}') def __len__(self): return len(self.seqs) def __getitem__(self, idx): return {'seqs': self.seqs[idx], 'clss': self.clss[idx], 'attention_mask': torch.ones(len(self.seqs[idx]))} def get_datasets_gosai(): return GosaiDataset() def get_dataloaders_gosai(config, skip_valid=False, valid_seed=None): num_gpus = torch.cuda.device_count() if config.loader.global_batch_size % ( num_gpus * config.trainer.accumulate_grad_batches) != 0: raise ValueError( f'Train Batch Size {config.training.batch_size}' f'not divisible by {num_gpus} gpus with accumulation ' f'{config.trainer.accumulate_grad_batches}.') if config.loader.eval_global_batch_size % num_gpus != 0: raise ValueError( f'Eval Batch Size for {config.eval.batch_size} ' f'not divisible by {num_gpus}.') train_set = GosaiDataset() # randomly sample a subset of the train_set as valid_set valid_set = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), 40000, replace=False)) test_set = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), 40000, replace=False)) train_loader = torch.utils.data.DataLoader( train_set, batch_size=config.loader.batch_size, num_workers=config.loader.num_workers, pin_memory=config.loader.pin_memory, shuffle=not config.data.streaming, persistent_workers=True) if skip_valid: valid_loader = None test_loader = None else: if valid_seed is None: shuffle_valid = False generator = None else: shuffle_valid = True generator = torch.Generator().manual_seed(valid_seed) valid_loader = torch.utils.data.DataLoader( valid_set, batch_size=config.loader.eval_batch_size, num_workers=config.loader.num_workers, pin_memory=config.loader.pin_memory, shuffle=shuffle_valid, generator=generator) test_loader = torch.utils.data.DataLoader( test_set, batch_size=config.loader.eval_batch_size, num_workers=config.loader.num_workers, pin_memory=config.loader.pin_memory, shuffle=shuffle_valid, generator=generator) return train_loader, valid_loader, test_loader # Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py class RandomFaultTolerantSampler(torch.utils.data.RandomSampler): def __init__(self, *args, generator=None, **kwargs): # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed, # which should be reproducible if pl.seed_everything was called beforehand. # This means that changing the seed of the experiment will also change the # sampling order. if generator is None: seed = int(torch.empty((), dtype=torch.int64).random_().item()) generator = torch.Generator().manual_seed(seed) kwargs.pop('shuffle', None) super().__init__(*args, generator=generator, **kwargs) self.counter = 0 self.restarting = False def state_dict(self): return {'random_state': self.generator.get_state(), 'counter': self.counter} def load_state_dict(self, state_dict): self.generator.set_state(state_dict.get('random_state')) self.counter = state_dict['counter'] # self.start_counter = self.counter self.restarting = True # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per # epoch, and subsequent epoch will have very few batches. def __iter__(self) -> typing.Iterator[int]: n = len(self.data_source) self.state = self.generator.get_state() indices = torch.randperm(n, generator=self.generator).tolist() if not self.restarting: self.counter = 0 else: indices = indices[self.counter:] self.restarting = False for index in indices: self.counter += 1 yield index self.counter = 0 class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.counter = 0 self.restarting = False def state_dict(self): return {'epoch': self.epoch, 'counter': self.counter} def load_state_dict(self, state_dict): self.epoch = state_dict['epoch'] self.counter = state_dict['counter'] self.restarting = True # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per # epoch, and subsequent epoch will have very few batches. def __iter__(self): if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] else: indices = list(range(len(self.dataset))) # type: ignore[arg-type] if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil( padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[:self.total_size] assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] assert len(indices) == self.num_samples if not self.restarting: self.counter = 0 else: indices = indices[self.counter:] self.restarting = False for index in indices: self.counter += 1 yield index self.counter = 0