|
|
import torch |
|
|
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32): |
|
|
assert head_dim % 2 == 0, "Embedding dimension must be even" |
|
|
|
|
|
|
|
|
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim)) |
|
|
|
|
|
|
|
|
if freq_config is not None: |
|
|
low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"] |
|
|
high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"] |
|
|
|
|
|
wavelen = 2 * torch.pi / inv_freq |
|
|
|
|
|
inv_freq_llama = torch.where( |
|
|
wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq |
|
|
) |
|
|
|
|
|
smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / ( |
|
|
freq_config["high_freq_factor"] - freq_config["low_freq_factor"] |
|
|
) |
|
|
|
|
|
smoothed_inv_freq = ( |
|
|
(1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq |
|
|
) |
|
|
|
|
|
is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen) |
|
|
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) |
|
|
inv_freq = inv_freq_llama |
|
|
|
|
|
|
|
|
positions = torch.arange(context_length, dtype=dtype) |
|
|
|
|
|
|
|
|
angles = positions[:, None] * inv_freq[None, :] |
|
|
|
|
|
|
|
|
angles = torch.cat([angles, angles], dim=1) |
|
|
|
|
|
|
|
|
cos = torch.cos(angles) |
|
|
sin = torch.sin(angles) |
|
|
|
|
|
return cos, sin |
|
|
|
|
|
|
|
|
def apply_rope(x, cos, sin): |
|
|
|
|
|
batch_size, num_heads, seq_len, head_dim = x.shape |
|
|
assert head_dim % 2 == 0, "Head dimension must be even" |
|
|
|
|
|
|
|
|
x1 = x[..., : head_dim // 2] |
|
|
x2 = x[..., head_dim // 2 :] |
|
|
|
|
|
|
|
|
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) |
|
|
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
rotated = torch.cat((-x2, x1), dim=-1) |
|
|
x_rotated = (x * cos) + (rotated * sin) |
|
|
|
|
|
|
|
|
return x_rotated.to(dtype=x.dtype) |
|
|
|
|
|
def model_memory_size(model, input_dtype=torch.float32): |
|
|
total_params = 0 |
|
|
total_grads = 0 |
|
|
for param in model.parameters(): |
|
|
|
|
|
param_size = param.numel() |
|
|
total_params += param_size |
|
|
|
|
|
if param.requires_grad: |
|
|
total_grads += param_size |
|
|
|
|
|
|
|
|
total_buffers = sum(buf.numel() for buf in model.buffers()) |
|
|
|
|
|
|
|
|
|
|
|
element_size = torch.tensor(0, dtype=input_dtype).element_size() |
|
|
total_memory_bytes = (total_params + total_grads + total_buffers) * element_size |
|
|
|
|
|
|
|
|
total_memory_gb = total_memory_bytes / (1024**3) |
|
|
|
|
|
return total_memory_gb |
|
|
|
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import tiktoken |
|
|
from tiktoken.load import load_tiktoken_bpe |
|
|
|
|
|
|
|
|
|
|
|
class Tokenizer: |
|
|
"""Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.""" |
|
|
def __init__(self, model_path): |
|
|
if not os.path.isfile(model_path): |
|
|
raise FileNotFoundError(model_path) |
|
|
|
|
|
mergeable = load_tiktoken_bpe(model_path) |
|
|
|
|
|
|
|
|
self.special = { |
|
|
"<|begin_of_text|>": 128000, |
|
|
"<|end_of_text|>": 128001, |
|
|
"<|start_header_id|>": 128006, |
|
|
"<|end_header_id|>": 128007, |
|
|
"<|eot_id|>": 128009, |
|
|
} |
|
|
self.special.update({f"<|reserved_{i}|>": 128002 + i |
|
|
for i in range(256) |
|
|
if 128002 + i not in self.special.values()}) |
|
|
|
|
|
self.model = tiktoken.Encoding( |
|
|
name=Path(model_path).name, |
|
|
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)" |
|
|
r"|[^\r\n\p{L}\p{N}]?\p{L}+" |
|
|
r"|\p{N}{1,3}" |
|
|
r"| ?[^\s\p{L}\p{N}]+[\r\n]*" |
|
|
r"|\s*[\r\n]+" |
|
|
r"|\s+(?!\S)" |
|
|
r"|\s+", |
|
|
mergeable_ranks=mergeable, |
|
|
special_tokens=self.special, |
|
|
) |
|
|
|
|
|
def encode(self, text, bos=False, eos=False): |
|
|
ids = ([self.special["<|begin_of_text|>"]] if bos else []) \ |
|
|
+ self.model.encode(text) |
|
|
if eos: |
|
|
ids.append(self.special["<|end_of_text|>"]) |
|
|
return ids |
|
|
|
|
|
def decode(self, ids): |
|
|
return self.model.decode(ids) |
|
|
|
|
|
|
|
|
class ChatFormat: |
|
|
|
|
|
def __init__(self, tokenizer: Tokenizer, *, |
|
|
default_system="You are a helpful assistant."): |
|
|
self.tok = tokenizer |
|
|
self.default_system = default_system |
|
|
|
|
|
def _header(self, role): |
|
|
"""Encode <|start_header_id|>role<|end_header_id|>\n\n""" |
|
|
return ( |
|
|
[self.tok.special["<|start_header_id|>"]] |
|
|
+ self.tok.encode(role) |
|
|
+ [self.tok.special["<|end_header_id|>"]] |
|
|
+ self.tok.encode("\n\n") |
|
|
) |
|
|
|
|
|
def encode(self, user_message, system_message=None): |
|
|
sys_msg = system_message if system_message is not None else self.default_system |
|
|
|
|
|
ids = [self.tok.special["<|begin_of_text|>"]] |
|
|
|
|
|
|
|
|
ids += self._header("system") |
|
|
ids += self.tok.encode(sys_msg) |
|
|
ids += [self.tok.special["<|eot_id|>"]] |
|
|
|
|
|
|
|
|
ids += self._header("user") |
|
|
ids += self.tok.encode(user_message) |
|
|
ids += [self.tok.special["<|eot_id|>"]] |
|
|
|
|
|
|
|
|
ids += self._header("assistant") |
|
|
|
|
|
return ids |
|
|
|
|
|
def assign(left, right, tensor_name="unknown"): |
|
|
if left.shape != right.shape: |
|
|
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}") |
|
|
|
|
|
if isinstance(right, torch.Tensor): |
|
|
return torch.nn.Parameter(right.clone().detach()) |
|
|
else: |
|
|
return torch.nn.Parameter(torch.tensor(right)) |
|
|
|
|
|
|
|
|
def load_weights_into_llama(model, param_config, params): |
|
|
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") |
|
|
|
|
|
for l in range(param_config["n_layers"]): |
|
|
|
|
|
|
|
|
model.trf_blocks[l].att.W_query.weight = assign( |
|
|
model.trf_blocks[l].att.W_query.weight, |
|
|
params[f"model.layers.{l}.self_attn.q_proj.weight"], |
|
|
f"model.layers.{l}.self_attn.q_proj.weight" |
|
|
) |
|
|
model.trf_blocks[l].att.W_key.weight = assign( |
|
|
model.trf_blocks[l].att.W_key.weight, |
|
|
params[f"model.layers.{l}.self_attn.k_proj.weight"], |
|
|
f"model.layers.{l}.self_attn.k_proj.weight" |
|
|
) |
|
|
model.trf_blocks[l].att.W_value.weight = assign( |
|
|
model.trf_blocks[l].att.W_value.weight, |
|
|
params[f"model.layers.{l}.self_attn.v_proj.weight"], |
|
|
f"model.layers.{l}.self_attn.v_proj.weight" |
|
|
) |
|
|
model.trf_blocks[l].att.out_proj.weight = assign( |
|
|
model.trf_blocks[l].att.out_proj.weight, |
|
|
params[f"model.layers.{l}.self_attn.o_proj.weight"], |
|
|
f"model.layers.{l}.self_attn.o_proj.weight" |
|
|
) |
|
|
model.trf_blocks[l].norm1.weight = assign( |
|
|
model.trf_blocks[l].norm1.weight, |
|
|
params[f"model.layers.{l}.input_layernorm.weight"], |
|
|
f"model.layers.{l}.input_layernorm.weight" |
|
|
) |
|
|
|
|
|
|
|
|
model.trf_blocks[l].ff.fc1.weight = assign( |
|
|
model.trf_blocks[l].ff.fc1.weight, |
|
|
params[f"model.layers.{l}.mlp.gate_proj.weight"], |
|
|
f"model.layers.{l}.mlp.gate_proj.weight" |
|
|
) |
|
|
model.trf_blocks[l].ff.fc2.weight = assign( |
|
|
model.trf_blocks[l].ff.fc2.weight, |
|
|
params[f"model.layers.{l}.mlp.up_proj.weight"], |
|
|
f"model.layers.{l}.mlp.up_proj.weight" |
|
|
) |
|
|
model.trf_blocks[l].ff.fc3.weight = assign( |
|
|
model.trf_blocks[l].ff.fc3.weight, |
|
|
params[f"model.layers.{l}.mlp.down_proj.weight"], |
|
|
f"model.layers.{l}.mlp.down_proj.weight" |
|
|
) |
|
|
model.trf_blocks[l].norm2.weight = assign( |
|
|
model.trf_blocks[l].norm2.weight, |
|
|
params[f"model.layers.{l}.post_attention_layernorm.weight"], |
|
|
f"model.layers.{l}.post_attention_layernorm.weight" |
|
|
) |
|
|
|
|
|
|
|
|
model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight") |
|
|
|
|
|
if "lm_head.weight" in params.keys(): |
|
|
model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight") |
|
|
else: |
|
|
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") |
|
|
print("Model uses weight tying.") |
|
|
|
|
|
def text_to_token_ids(text, tokenizer): |
|
|
encoded = tokenizer.encode(text) |
|
|
encoded_tensor = torch.tensor(encoded).unsqueeze(0) |
|
|
return encoded_tensor |
|
|
|
|
|
|
|
|
def token_ids_to_text(token_ids, tokenizer): |
|
|
flat = token_ids.squeeze(0) |
|
|
return tokenizer.decode(flat.tolist()) |
|
|
|
|
|
|
|
|
def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None): |
|
|
|
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
idx_cond = idx[:, -context_size:] |
|
|
with torch.no_grad(): |
|
|
logits = model(idx_cond) |
|
|
logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
if top_k is not None: |
|
|
|
|
|
top_logits, _ = torch.topk(logits, top_k) |
|
|
min_val = top_logits[:, -1] |
|
|
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits) |
|
|
|
|
|
|
|
|
if temperature > 0.0: |
|
|
logits = logits / temperature |
|
|
|
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
else: |
|
|
idx_next = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
|
|
if idx_next == eos_id: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
|
|
return idx |