import torch from torch import nn from torch.nn import functional as F from torch.utils.data import Dataset, DataLoader, random_split import urllib.request import os from transformers import AutoTokenizer, logging import pandas as pd from tqdm import tqdm from safetensors.torch import save_file logging.set_verbosity_error() os.environ["TOKENIZERS_PARALLELISM"] = "false" # ----------------- CONFIG ----------------- SAVE_EVERY = 5 MODEL_NAME = "mini_transformer_v3" N_DATA_WORKERS = 8 PIN_MEMORY = True if N_DATA_WORKERS > 0 and torch.cuda.is_available() else False BATCH_SIZE = 512 EVAL_EVERY = 5 LEARNING_RATE = 3e-4 NUM_EPOCHS = 50 USE_AMP = True STRIDE = 64 CHECKPOINT_DIR = f"MODELS/checkpoints/{MODEL_NAME}" os.makedirs(CHECKPOINT_DIR, exist_ok=True) DATASET = "DATA/generated_dataset_very_big.csv" CONTEXT_LENGTH = 128 EMBEDDING_DIMENSION = 512 HEAD_NUMBER = 4 N_LAYER = 4 # ----------------- MODEL ----------------- # TransformerBlock (from your previous code) class TransformerBlock(nn.Module): def __init__(self, emb_dim, num_heads, context_length, dropout=0.1): super().__init__() self.ln1 = nn.LayerNorm(emb_dim) self.ln2 = nn.LayerNorm(emb_dim) self.attn = nn.MultiheadAttention( emb_dim, num_heads, dropout=dropout, batch_first=True ) self.mlp = nn.Sequential( nn.Linear(emb_dim, 4 * emb_dim), nn.GELU(), nn.Linear(4 * emb_dim, emb_dim), nn.Dropout(dropout), ) def forward(self, x): attn_out, _ = self.attn( self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False ) x = x + attn_out x = x + self.mlp(self.ln2(x)) return x class MiniTransformer(nn.Module): def __init__( self, vocab_size, emb_dim, context_length, num_heads, num_layers, dropout=0.1, ): super().__init__() self.emb = nn.Embedding(vocab_size, emb_dim) self.pos_emb = nn.Embedding(context_length, emb_dim) self.blocks = nn.Sequential( *[ TransformerBlock(emb_dim, num_heads, context_length, dropout) for _ in range(num_layers) ] ) self.ln_f = nn.LayerNorm(emb_dim) self.head = nn.Linear(emb_dim, vocab_size, bias=False) self.context_length = context_length def forward(self, x): B, T = x.shape pos = torch.arange(T, device=x.device) x = self.emb(x) + self.pos_emb(pos) x = self.blocks(x) x = self.ln_f(x) logits = self.head(x) return logits # ----------------- DATASET ----------------- class SlidingWindowDataset(Dataset): def __init__(self, texts, tokenizer, context_length=128, stride=64): self.tokenizer = tokenizer self.context_length = context_length self.stride = stride # Flatten all text into a single long stream of token IDs self.tokens = [] for text in texts: ids = tokenizer.encode(text, add_special_tokens=False) self.tokens.extend(ids) self.tokens = torch.tensor(self.tokens, dtype=torch.long) self.n_samples = (len(self.tokens) - context_length) // stride def __len__(self): return self.n_samples def __getitem__(self, idx): start = idx * self.stride end = start + self.context_length + 1 chunk = self.tokens[start:end] x = chunk[:-1] y = chunk[1:] return x, y # as long as we flatten the list of strings into one single piece of text # and then we divide it into pieces of the same length, by definition we don't need padding. # we need padding in the case when we have multiple separated sentences in a list, # and we want to create a batch with them --> than we surely need to padd all the sequences # to the same length --> max length or context length (with duely truncation if needed) # example # we have a batch like this: # ["ciao", "ciao io sono", "ciao io sono pippo"] # becomes: # [101, 2003, 102] # [101, 2003, 2026, 2070, 102] # [101, 2003, 2026, 2070, 5274, 102] # we have to pad to max length # [101, 2003, 102, 0, 0, 0] # [101, 2003, 2026, 2070, 102, 0] # [101, 2003, 2026, 2070, 5274, 102] # ----------------- DEVICE ----------------- device = torch.device("cuda" if torch.cuda.is_available() else "mps") print(f"Using device: {device}") if device.type == "cuda": print(torch.cuda.get_device_name(0)) print(torch.cuda.memory_allocated() / 1024**2, "MB allocated") print(torch.cuda.memory_reserved() / 1024**2, "MB reserved") # ----------------- LOAD DATA ----------------- df = pd.read_csv(DATASET) texts = [ f"{row['system_prompt']} {row['question']} {row['answer']}" for _, row in df.iterrows() ] tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") vocab_size = tokenizer.vocab_size dataset = SlidingWindowDataset(texts, tokenizer, CONTEXT_LENGTH, STRIDE) train_size = int(0.9 * len(dataset)) test_size = len(dataset) - train_size train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) print(f"dataset train lenght: {len(train_dataset)}") loader_train = DataLoader( train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=N_DATA_WORKERS, pin_memory=PIN_MEMORY, ) loader_test = DataLoader( test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=N_DATA_WORKERS, pin_memory=PIN_MEMORY, ) # ----------------- TRAINING SETUP ----------------- model = MiniTransformer( vocab_size=vocab_size, emb_dim=EMBEDDING_DIMENSION, context_length=CONTEXT_LENGTH, num_heads=HEAD_NUMBER, num_layers=N_LAYER, ).to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"number of parameters: {n_params}") optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) scaler = torch.amp.GradScaler(enabled=USE_AMP and device.type == "cuda") criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) # ----------------- CHECKPOINT RESUME ----------------- checkpoint_files = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.endswith(".pt")]) if checkpoint_files: latest_ckpt = os.path.join(CHECKPOINT_DIR, checkpoint_files[-1]) ckpt = torch.load(latest_ckpt, map_location=device) model.load_state_dict(ckpt["model_state"]) optimizer.load_state_dict(ckpt["optimizer_state"]) start_epoch = ckpt["epoch"] + 1 print(f"Resumed from {latest_ckpt}") else: start_epoch = 0 model = torch.compile(model) # ----------------- TRAINING LOOP ----------------- for epoch in range(start_epoch, NUM_EPOCHS): model.train() total_loss = 0 for x, y in tqdm(loader_train, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"): x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) optimizer.zero_grad() with torch.amp.autocast( "cuda", dtype=torch.float16, enabled=USE_AMP and device.type == "cuda" ): logits = model(x) loss = criterion(logits.view(-1, vocab_size), y.view(-1)) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() total_loss += loss.item() * x.size(0) avg_train_loss = total_loss / len(train_dataset) print(f"Train Loss: {avg_train_loss:.4f}") # --- Evaluation --- if (epoch + 1) % EVAL_EVERY == 0: model.eval() total_loss = 0 with torch.no_grad(): for x, y in loader_test: x, y = x.to(device), y.to(device) with torch.amp.autocast( "cuda", dtype=torch.float16, enabled=USE_AMP and device.type == "cuda", ): logits = model(x) loss = criterion(logits.view(-1, vocab_size), y.view(-1)) total_loss += loss.item() * x.size(0) avg_test_loss = total_loss / len(test_dataset) print(f"Test Loss: {avg_test_loss:.4f}") # --- Save checkpoint --- if SAVE_EVERY > 0 and (epoch + 1) % SAVE_EVERY == 0: torch.save( { "epoch": epoch, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scaler_state": scaler.state_dict(), }, os.path.join(CHECKPOINT_DIR, f"checkpoint_{MODEL_NAME}_epoch_{epoch+1}.pt"), ) save_file( model.state_dict(), os.path.join(CHECKPOINT_DIR, f"model_{epoch+1}.safetensors"), ) # check GPU utilization metrics here: # nvidia-smi dmon -s u