MiniTransformer / minitransformer.py
pierjoe's picture
Upload minitransformer.py with huggingface_hub
e2eaa11 verified
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import urllib.request
import os
from transformers import AutoTokenizer, logging
import pandas as pd
from tqdm import tqdm
from safetensors.torch import save_file
logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# ----------------- CONFIG -----------------
SAVE_EVERY = 5
MODEL_NAME = "mini_transformer_v3"
N_DATA_WORKERS = 8
PIN_MEMORY = True if N_DATA_WORKERS > 0 and torch.cuda.is_available() else False
BATCH_SIZE = 512
EVAL_EVERY = 5
LEARNING_RATE = 3e-4
NUM_EPOCHS = 50
USE_AMP = True
STRIDE = 64
CHECKPOINT_DIR = f"MODELS/checkpoints/{MODEL_NAME}"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
DATASET = "DATA/generated_dataset_very_big.csv"
CONTEXT_LENGTH = 128
EMBEDDING_DIMENSION = 512
HEAD_NUMBER = 4
N_LAYER = 4
# ----------------- MODEL -----------------
# TransformerBlock (from your previous code)
class TransformerBlock(nn.Module):
def __init__(self, emb_dim, num_heads, context_length, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(emb_dim)
self.ln2 = nn.LayerNorm(emb_dim)
self.attn = nn.MultiheadAttention(
emb_dim, num_heads, dropout=dropout, batch_first=True
)
self.mlp = nn.Sequential(
nn.Linear(emb_dim, 4 * emb_dim),
nn.GELU(),
nn.Linear(4 * emb_dim, emb_dim),
nn.Dropout(dropout),
)
def forward(self, x):
attn_out, _ = self.attn(
self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False
)
x = x + attn_out
x = x + self.mlp(self.ln2(x))
return x
class MiniTransformer(nn.Module):
def __init__(
self,
vocab_size,
emb_dim,
context_length,
num_heads,
num_layers,
dropout=0.1,
):
super().__init__()
self.emb = nn.Embedding(vocab_size, emb_dim)
self.pos_emb = nn.Embedding(context_length, emb_dim)
self.blocks = nn.Sequential(
*[
TransformerBlock(emb_dim, num_heads, context_length, dropout)
for _ in range(num_layers)
]
)
self.ln_f = nn.LayerNorm(emb_dim)
self.head = nn.Linear(emb_dim, vocab_size, bias=False)
self.context_length = context_length
def forward(self, x):
B, T = x.shape
pos = torch.arange(T, device=x.device)
x = self.emb(x) + self.pos_emb(pos)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)
return logits
# ----------------- DATASET -----------------
class SlidingWindowDataset(Dataset):
def __init__(self, texts, tokenizer, context_length=128, stride=64):
self.tokenizer = tokenizer
self.context_length = context_length
self.stride = stride
# Flatten all text into a single long stream of token IDs
self.tokens = []
for text in texts:
ids = tokenizer.encode(text, add_special_tokens=False)
self.tokens.extend(ids)
self.tokens = torch.tensor(self.tokens, dtype=torch.long)
self.n_samples = (len(self.tokens) - context_length) // stride
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
start = idx * self.stride
end = start + self.context_length + 1
chunk = self.tokens[start:end]
x = chunk[:-1]
y = chunk[1:]
return x, y
# as long as we flatten the list of strings into one single piece of text
# and then we divide it into pieces of the same length, by definition we don't need padding.
# we need padding in the case when we have multiple separated sentences in a list,
# and we want to create a batch with them --> than we surely need to padd all the sequences
# to the same length --> max length or context length (with duely truncation if needed)
# example
# we have a batch like this:
# ["ciao", "ciao io sono", "ciao io sono pippo"]
# becomes:
# [101, 2003, 102]
# [101, 2003, 2026, 2070, 102]
# [101, 2003, 2026, 2070, 5274, 102]
# we have to pad to max length
# [101, 2003, 102, 0, 0, 0]
# [101, 2003, 2026, 2070, 102, 0]
# [101, 2003, 2026, 2070, 5274, 102]
# ----------------- DEVICE -----------------
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
print(f"Using device: {device}")
if device.type == "cuda":
print(torch.cuda.get_device_name(0))
print(torch.cuda.memory_allocated() / 1024**2, "MB allocated")
print(torch.cuda.memory_reserved() / 1024**2, "MB reserved")
# ----------------- LOAD DATA -----------------
df = pd.read_csv(DATASET)
texts = [
f"{row['system_prompt']} {row['question']} {row['answer']}"
for _, row in df.iterrows()
]
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
vocab_size = tokenizer.vocab_size
dataset = SlidingWindowDataset(texts, tokenizer, CONTEXT_LENGTH, STRIDE)
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
print(f"dataset train lenght: {len(train_dataset)}")
loader_train = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=N_DATA_WORKERS,
pin_memory=PIN_MEMORY,
)
loader_test = DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=N_DATA_WORKERS,
pin_memory=PIN_MEMORY,
)
# ----------------- TRAINING SETUP -----------------
model = MiniTransformer(
vocab_size=vocab_size,
emb_dim=EMBEDDING_DIMENSION,
context_length=CONTEXT_LENGTH,
num_heads=HEAD_NUMBER,
num_layers=N_LAYER,
).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"number of parameters: {n_params}")
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scaler = torch.amp.GradScaler(enabled=USE_AMP and device.type == "cuda")
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
# ----------------- CHECKPOINT RESUME -----------------
checkpoint_files = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.endswith(".pt")])
if checkpoint_files:
latest_ckpt = os.path.join(CHECKPOINT_DIR, checkpoint_files[-1])
ckpt = torch.load(latest_ckpt, map_location=device)
model.load_state_dict(ckpt["model_state"])
optimizer.load_state_dict(ckpt["optimizer_state"])
start_epoch = ckpt["epoch"] + 1
print(f"Resumed from {latest_ckpt}")
else:
start_epoch = 0
model = torch.compile(model)
# ----------------- TRAINING LOOP -----------------
for epoch in range(start_epoch, NUM_EPOCHS):
model.train()
total_loss = 0
for x, y in tqdm(loader_train, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
optimizer.zero_grad()
with torch.amp.autocast(
"cuda", dtype=torch.float16, enabled=USE_AMP and device.type == "cuda"
):
logits = model(x)
loss = criterion(logits.view(-1, vocab_size), y.view(-1))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss.item() * x.size(0)
avg_train_loss = total_loss / len(train_dataset)
print(f"Train Loss: {avg_train_loss:.4f}")
# --- Evaluation ---
if (epoch + 1) % EVAL_EVERY == 0:
model.eval()
total_loss = 0
with torch.no_grad():
for x, y in loader_test:
x, y = x.to(device), y.to(device)
with torch.amp.autocast(
"cuda",
dtype=torch.float16,
enabled=USE_AMP and device.type == "cuda",
):
logits = model(x)
loss = criterion(logits.view(-1, vocab_size), y.view(-1))
total_loss += loss.item() * x.size(0)
avg_test_loss = total_loss / len(test_dataset)
print(f"Test Loss: {avg_test_loss:.4f}")
# --- Save checkpoint ---
if SAVE_EVERY > 0 and (epoch + 1) % SAVE_EVERY == 0:
torch.save(
{
"epoch": epoch,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"scaler_state": scaler.state_dict(),
},
os.path.join(CHECKPOINT_DIR, f"checkpoint_{MODEL_NAME}_epoch_{epoch+1}.pt"),
)
save_file(
model.state_dict(),
os.path.join(CHECKPOINT_DIR, f"model_{epoch+1}.safetensors"),
)
# check GPU utilization metrics here:
# nvidia-smi dmon -s u