Spaces:
Paused
Paused
| import torch.nn.functional as F | |
| import torch.multiprocessing as mp | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader | |
| from tokenizer import Tokenizer | |
| from config import ModelArgs | |
| tokenizer = Tokenizer().ready_tokenizer() | |
| tinystories = True | |
| fw = False | |
| fw_train = None | |
| fw_test = None | |
| if(tinystories): | |
| fw_train = load_dataset("roneneldan/TinyStories", split="train") | |
| fw_test = load_dataset("roneneldan/TinyStories", split="validation") | |
| print(fw_train) | |
| print(fw_test) | |
| if(fw): | |
| fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False) | |
| fw_train = fw_train.train_test_split(test_size=0.01) | |
| print(fw_train) | |
| print(fw_train) | |
| tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
| def tokenize_function(examples): | |
| return tokenizer( | |
| examples['text'], | |
| max_length=ModelArgs.block_size, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| def prepare_dataset(split, device, batch_size): | |
| print("Device is: ", device) | |
| def collate_fn(batch): | |
| # Extract text data | |
| texts = [item ["text"] for item in batch] | |
| input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") | |
| input_encodings["labels"] = input_encodings["input_ids"].clone() | |
| input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:] | |
| input_encodings["labels"][:, -1] = tokenizer.eos_token_id | |
| return input_encodings | |
| dataloader = None | |
| if(tinystories): | |
| if(split == 'train'): | |
| data_loader = DataLoader( | |
| fw_train, | |
| # generator=generator, | |
| batch_size=batch_size, | |
| sampler=DistributedSampler(fw_train, shuffle=True), | |
| collate_fn=collate_fn, | |
| drop_last=True, | |
| shuffle=False | |
| ) | |
| elif(split == 'val'): | |
| data_loader = DataLoader( | |
| fw_test, | |
| batch_size=batch_size, | |
| sampler=DistributedSampler(fw_test, shuffle=True), | |
| collate_fn=collate_fn, | |
| drop_last=True, | |
| shuffle=False | |
| ) | |
| elif(fw): | |
| if(split == 'train'): | |
| data_loader = DataLoader( | |
| fw_train['train'], | |
| batch_size=batch_size, | |
| sampler=DistributedSampler(fw_train['train'], shuffle=True), | |
| collate_fn=collate_fn, | |
| drop_last=True, | |
| shuffle=False | |
| ) | |
| elif(split == 'val'): | |
| data_loader = DataLoader( | |
| fw_train['test'], | |
| batch_size=batch_size, | |
| # generator=generator, | |
| sampler=DistributedSampler(fw_train["test"]), | |
| collate_fn=collate_fn, | |
| drop_last=True, | |
| shuffle=False | |
| ) | |
| return data_loader | |