import torch def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32): assert head_dim % 2 == 0, "Embedding dimension must be even" # Compute the inverse frequencies inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim)) # Frequency adjustments if freq_config is not None: low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"] high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"] wavelen = 2 * torch.pi / inv_freq inv_freq_llama = torch.where( wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq ) smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / ( freq_config["high_freq_factor"] - freq_config["low_freq_factor"] ) smoothed_inv_freq = ( (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq ) is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen) inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) inv_freq = inv_freq_llama # Generate position indices positions = torch.arange(context_length, dtype=dtype) # Compute the angles angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2) # Expand angles to match the head_dim angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim) # Precompute sine and cosine cos = torch.cos(angles) sin = torch.sin(angles) return cos, sin def apply_rope(x, cos, sin): # x: (batch_size, num_heads, seq_len, head_dim) batch_size, num_heads, seq_len, head_dim = x.shape assert head_dim % 2 == 0, "Head dimension must be even" # Split x into first half and second half x1 = x[..., : head_dim // 2] # First half x2 = x[..., head_dim // 2 :] # Second half # Adjust sin and cos shapes cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim) sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) # Apply the rotary transformation rotated = torch.cat((-x2, x1), dim=-1) x_rotated = (x * cos) + (rotated * sin) # It's ok to use lower-precision after applying cos and sin rotation return x_rotated.to(dtype=x.dtype) def model_memory_size(model, input_dtype=torch.float32): total_params = 0 total_grads = 0 for param in model.parameters(): # Calculate total number of elements per parameter param_size = param.numel() total_params += param_size # Check if gradients are stored for this parameter if param.requires_grad: total_grads += param_size # Calculate buffer size (non-parameters that require memory) total_buffers = sum(buf.numel() for buf in model.buffers()) # Size in bytes = (Number of elements) * (Size of each element in bytes) # We assume parameters and gradients are stored in the same type as input dtype element_size = torch.tensor(0, dtype=input_dtype).element_size() total_memory_bytes = (total_params + total_grads + total_buffers) * element_size # Convert bytes to gigabytes total_memory_gb = total_memory_bytes / (1024**3) return total_memory_gb import os from pathlib import Path import tiktoken from tiktoken.load import load_tiktoken_bpe class Tokenizer: """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.""" def __init__(self, model_path): if not os.path.isfile(model_path): raise FileNotFoundError(model_path) mergeable = load_tiktoken_bpe(model_path) # hard-coded from Meta's tokenizer.json self.special = { "<|begin_of_text|>": 128000, "<|end_of_text|>": 128001, "<|start_header_id|>": 128006, "<|end_header_id|>": 128007, "<|eot_id|>": 128009, } self.special.update({f"<|reserved_{i}|>": 128002 + i for i in range(256) if 128002 + i not in self.special.values()}) self.model = tiktoken.Encoding( name=Path(model_path).name, pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)" r"|[^\r\n\p{L}\p{N}]?\p{L}+" r"|\p{N}{1,3}" r"| ?[^\s\p{L}\p{N}]+[\r\n]*" r"|\s*[\r\n]+" r"|\s+(?!\S)" r"|\s+", mergeable_ranks=mergeable, special_tokens=self.special, ) def encode(self, text, bos=False, eos=False): ids = ([self.special["<|begin_of_text|>"]] if bos else []) \ + self.model.encode(text) if eos: ids.append(self.special["<|end_of_text|>"]) return ids def decode(self, ids): return self.model.decode(ids) class ChatFormat: def __init__(self, tokenizer: Tokenizer, *, default_system="You are a helpful assistant."): self.tok = tokenizer self.default_system = default_system def _header(self, role): """Encode <|start_header_id|>role<|end_header_id|>\n\n""" return ( [self.tok.special["<|start_header_id|>"]] + self.tok.encode(role) + [self.tok.special["<|end_header_id|>"]] + self.tok.encode("\n\n") ) def encode(self, user_message, system_message=None): sys_msg = system_message if system_message is not None else self.default_system ids = [self.tok.special["<|begin_of_text|>"]] # system ids += self._header("system") ids += self.tok.encode(sys_msg) ids += [self.tok.special["<|eot_id|>"]] # user ids += self._header("user") ids += self.tok.encode(user_message) ids += [self.tok.special["<|eot_id|>"]] # assistant header (no content yet) ids += self._header("assistant") return ids def assign(left, right, tensor_name="unknown"): if left.shape != right.shape: raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}") if isinstance(right, torch.Tensor): return torch.nn.Parameter(right.clone().detach()) else: return torch.nn.Parameter(torch.tensor(right)) def load_weights_into_llama(model, param_config, params): model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") for l in range(param_config["n_layers"]): # Load attention weights model.trf_blocks[l].att.W_query.weight = assign( model.trf_blocks[l].att.W_query.weight, params[f"model.layers.{l}.self_attn.q_proj.weight"], f"model.layers.{l}.self_attn.q_proj.weight" ) model.trf_blocks[l].att.W_key.weight = assign( model.trf_blocks[l].att.W_key.weight, params[f"model.layers.{l}.self_attn.k_proj.weight"], f"model.layers.{l}.self_attn.k_proj.weight" ) model.trf_blocks[l].att.W_value.weight = assign( model.trf_blocks[l].att.W_value.weight, params[f"model.layers.{l}.self_attn.v_proj.weight"], f"model.layers.{l}.self_attn.v_proj.weight" ) model.trf_blocks[l].att.out_proj.weight = assign( model.trf_blocks[l].att.out_proj.weight, params[f"model.layers.{l}.self_attn.o_proj.weight"], f"model.layers.{l}.self_attn.o_proj.weight" ) model.trf_blocks[l].norm1.weight = assign( model.trf_blocks[l].norm1.weight, params[f"model.layers.{l}.input_layernorm.weight"], f"model.layers.{l}.input_layernorm.weight" ) # Load FeedForward weights model.trf_blocks[l].ff.fc1.weight = assign( model.trf_blocks[l].ff.fc1.weight, params[f"model.layers.{l}.mlp.gate_proj.weight"], f"model.layers.{l}.mlp.gate_proj.weight" ) model.trf_blocks[l].ff.fc2.weight = assign( model.trf_blocks[l].ff.fc2.weight, params[f"model.layers.{l}.mlp.up_proj.weight"], f"model.layers.{l}.mlp.up_proj.weight" ) model.trf_blocks[l].ff.fc3.weight = assign( model.trf_blocks[l].ff.fc3.weight, params[f"model.layers.{l}.mlp.down_proj.weight"], f"model.layers.{l}.mlp.down_proj.weight" ) model.trf_blocks[l].norm2.weight = assign( model.trf_blocks[l].norm2.weight, params[f"model.layers.{l}.post_attention_layernorm.weight"], f"model.layers.{l}.post_attention_layernorm.weight" ) # Load output layer weights model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight") if "lm_head.weight" in params.keys(): model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight") else: model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") print("Model uses weight tying.") def text_to_token_ids(text, tokenizer): encoded = tokenizer.encode(text) encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension return encoded_tensor def token_ids_to_text(token_ids, tokenizer): flat = token_ids.squeeze(0) # remove batch dimension return tokenizer.decode(flat.tolist()) def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None): # For-loop is the same as before: Get logits, and only focus on last time step for _ in range(max_new_tokens): idx_cond = idx[:, -context_size:] with torch.no_grad(): logits = model(idx_cond) logits = logits[:, -1, :] # New: Filter logits with top_k sampling if top_k is not None: # Keep only top_k values top_logits, _ = torch.topk(logits, top_k) min_val = top_logits[:, -1] logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits) # New: Apply temperature scaling if temperature > 0.0: logits = logits / temperature # Apply softmax to get probabilities probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) # Sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) # Otherwise same as before: get idx of the vocab entry with the highest logits value else: idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified break #print(f"{idx_next} ") # Same as before: append sampled index to the running sequence idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) return idx