import torch from TransformerBlock import TransformerBlock from torch import nn from tools import compute_rope_params class Llama3Model(nn.Module): def __init__(self, cfg): super().__init__() # Main model parameters self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"]) self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin` [TransformerBlock(cfg) for _ in range(cfg["n_layers"])] ) self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) # Reusuable utilities cos, sin = compute_rope_params( head_dim=cfg["emb_dim"] // cfg["n_heads"], theta_base=cfg["rope_base"], context_length=cfg["context_length"], freq_config=cfg["rope_freq"] ) self.register_buffer("cos", cos, persistent=False) self.register_buffer("sin", sin, persistent=False) self.cfg = cfg def forward(self, in_idx): # Forward pass tok_embeds = self.tok_emb(in_idx) x = tok_embeds num_tokens = x.shape[1] mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1) for block in self.trf_blocks: x = block(x, mask, self.cos, self.sin) x = self.final_norm(x) logits = self.out_head(x.to(self.cfg["dtype"])) return logits