Upload 7 files
Browse files- inference.py +244 -0
- loss_plot_step_0_to_120.png +0 -0
- model_architecture.py +519 -0
- step_600.pt +3 -0
- step_800.pt +3 -0
- wikitext2_tokens_128k.pt +3 -0
- wikitext2_val_tokens_128k.pt +3 -0
inference.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ==============================================================================
|
| 2 |
+
# Inference Script
|
| 3 |
+
# ==============================================================================
|
| 4 |
+
# --- Necessary Imports ---
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
import math
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from transformers import AutoTokenizer
|
| 11 |
+
import os
|
| 12 |
+
import glob
|
| 13 |
+
import time
|
| 14 |
+
import datetime
|
| 15 |
+
import traceback
|
| 16 |
+
import dataclasses # Make sure this is imported
|
| 17 |
+
|
| 18 |
+
# --- Model Configuration ---
|
| 19 |
+
# IMPORTANT: This definition MUST exactly match the one used during training
|
| 20 |
+
# when the checkpoint was saved.
|
| 21 |
+
@dataclass
|
| 22 |
+
class ModelArgs:
|
| 23 |
+
# --- ~221M Config used for training step_1200 ---
|
| 24 |
+
hidden_size: int = 768; num_hidden_layers: int = 12; num_attention_heads: int = 12
|
| 25 |
+
num_key_value_heads: int = 12; intermediate_size: int = 2048; vocab_size: int = 128000
|
| 26 |
+
rms_norm_eps: float = 1e-5; rope_theta: float = 500000.0; max_position_embeddings: int = 4096
|
| 27 |
+
head_dim: int = field(init=False)
|
| 28 |
+
add_recency_bias: bool = False # Ensure this matches the value used when saving the checkpoint
|
| 29 |
+
|
| 30 |
+
def __post_init__(self):
|
| 31 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
| 32 |
+
if self.hidden_size % self.num_attention_heads != 0: raise ValueError("hidden_size % num_attention_heads != 0")
|
| 33 |
+
if self.num_attention_heads % self.num_key_value_heads != 0: raise ValueError("num_attention_heads % num_key_value_heads != 0")
|
| 34 |
+
|
| 35 |
+
# --- Model Components (RMSNorm, RoPE funcs, Attention, FeedForward, TransformerBlock, Llama) ---
|
| 36 |
+
# V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V
|
| 37 |
+
# --- PASTE THE FULL DEFINITIONS OF THE FOLLOWING CLASSES/FUNCTIONS HERE ---
|
| 38 |
+
# --- from your model_architecture.py script: ---
|
| 39 |
+
#
|
| 40 |
+
class RMSNorm(nn.Module):
|
| 41 |
+
def __init__(self, dim: int, eps: float = 1e-6): super().__init__(); self.eps = eps; self.weight = nn.Parameter(torch.ones(dim))
|
| 42 |
+
def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 43 |
+
def forward(self, x): original_dtype = x.dtype; output = self._norm(x.float()).to(original_dtype); return output * self.weight
|
| 44 |
+
|
| 45 |
+
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str | torch.device, theta: float = 10000.0):
|
| 46 |
+
if head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE")
|
| 47 |
+
theta_indices = torch.arange(0, head_dim, 2).float(); theta_freqs = 1.0 / (theta**(theta_indices / head_dim))
|
| 48 |
+
target_device = torch.device(device) if isinstance(device, str) else device; theta_freqs = theta_freqs.to(target_device)
|
| 49 |
+
positions = torch.arange(seq_len, device=target_device).float(); freqs = torch.outer(positions, theta_freqs).float(); return freqs, positions
|
| 50 |
+
|
| 51 |
+
def apply_rotary_embeddings(x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor):
|
| 52 |
+
positions = positions.long(); max_pos = freqs_cis_full.shape[0]
|
| 53 |
+
if torch.max(positions) >= max_pos: positions = torch.clamp(positions, max=max_pos - 1)
|
| 54 |
+
freqs = freqs_cis_full[positions]; freqs = freqs.unsqueeze(0).unsqueeze(2)
|
| 55 |
+
bsz, seq_len, n_part_heads, head_dim = x.shape; x1 = x[..., : head_dim // 2]; x2 = x[..., head_dim // 2 :]
|
| 56 |
+
cos_freqs = torch.cos(freqs).type_as(x); sin_freqs = torch.sin(freqs).type_as(x)
|
| 57 |
+
rotated_x1 = x1 * cos_freqs - x2 * sin_freqs; rotated_x2 = x1 * sin_freqs + x2 * cos_freqs
|
| 58 |
+
rotated_x = torch.cat([rotated_x1, rotated_x2], dim=-1); return rotated_x.type_as(x)
|
| 59 |
+
|
| 60 |
+
class Attention(nn.Module):
|
| 61 |
+
def __init__(self, args: ModelArgs):
|
| 62 |
+
super().__init__(); self.args = args; self.num_heads = args.num_attention_heads; self.num_kv_heads = args.num_key_value_heads
|
| 63 |
+
self.head_dim = args.head_dim; self.repeats = self.num_heads // self.num_kv_heads
|
| 64 |
+
self.wq = nn.Linear(args.hidden_size, args.num_attention_heads * args.head_dim, bias=False)
|
| 65 |
+
self.wk = nn.Linear(args.hidden_size, args.num_key_value_heads * args.head_dim, bias=False)
|
| 66 |
+
self.wv = nn.Linear(args.hidden_size, args.num_key_value_heads * args.head_dim, bias=False)
|
| 67 |
+
self.wo = nn.Linear(args.num_attention_heads * args.head_dim, args.hidden_size, bias=False)
|
| 68 |
+
def _repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 69 |
+
bsz, n_kv_heads, seqlen, head_dim = x.shape;
|
| 70 |
+
if n_rep == 1: return x
|
| 71 |
+
return (x[:, :, None, :, :].expand(bsz, n_kv_heads, n_rep, seqlen, head_dim).reshape(bsz, n_kv_heads * n_rep, seqlen, head_dim))
|
| 72 |
+
def _create_recency_bias(self, seqlen, full_seqlen, device, dtype, bias_strength=0.1, decay_rate=0.9):
|
| 73 |
+
bias = torch.zeros((1, 1, seqlen, full_seqlen), device=device, dtype=dtype); indices = torch.arange(full_seqlen, device=device)
|
| 74 |
+
rel_pos = torch.arange(seqlen, device=device).unsqueeze(1) - indices.unsqueeze(0); mask = rel_pos >= 0
|
| 75 |
+
decaying_bias = bias_strength * (decay_rate ** (-rel_pos[mask])); bias[:, :, mask] = decaying_bias.type_as(bias); return bias
|
| 76 |
+
def forward(self, x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor, mask: torch.Tensor | None = None, cache: tuple[torch.Tensor, torch.Tensor] | None = None) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
| 77 |
+
bsz, seqlen, _ = x.shape; xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
| 78 |
+
xq = xq.view(bsz, seqlen, self.num_heads, self.head_dim); xk = xk.view(bsz, seqlen, self.num_kv_heads, self.head_dim); xv = xv.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
|
| 79 |
+
xq = apply_rotary_embeddings(xq, freqs_cis_full, positions); xk = apply_rotary_embeddings(xk, freqs_cis_full, positions)
|
| 80 |
+
xk = xk.transpose(1, 2); xv = xv.transpose(1, 2)
|
| 81 |
+
if cache is not None: cache_k, cache_v = cache; keys = torch.cat((cache_k.to(xk.device), xk), dim=2); values = torch.cat((cache_v.to(xv.device), xv), dim=2)
|
| 82 |
+
else: keys = xk; values = xv
|
| 83 |
+
updated_cache = (keys.detach(), values.detach()); keys_repeated = self._repeat_kv(keys, self.repeats); values_repeated = self._repeat_kv(values, self.repeats)
|
| 84 |
+
xq = xq.transpose(1, 2); scores = torch.matmul(xq.float(), keys_repeated.transpose(-2, -1).float()) / math.sqrt(self.head_dim)
|
| 85 |
+
if self.args.add_recency_bias:
|
| 86 |
+
full_seqlen = keys_repeated.shape[-2]; recency_bias = self._create_recency_bias(seqlen, full_seqlen, device=scores.device, dtype=scores.dtype); scores = scores + recency_bias
|
| 87 |
+
if mask is not None:
|
| 88 |
+
full_seqlen = keys_repeated.shape[-2]; expected_mask_shape_end = (seqlen, full_seqlen)
|
| 89 |
+
if mask.shape[-2:] != expected_mask_shape_end:
|
| 90 |
+
try: mask_slice = mask[:, :, -seqlen:, :full_seqlen]; scores = scores + mask_slice.float()
|
| 91 |
+
except Exception: pass
|
| 92 |
+
else: scores = scores + mask.float()
|
| 93 |
+
scores = nn.functional.softmax(scores, dim=-1).type_as(xq); output = torch.matmul(scores, values_repeated)
|
| 94 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1); output = self.wo(output); return output, updated_cache
|
| 95 |
+
|
| 96 |
+
class FeedForward(nn.Module):
|
| 97 |
+
def __init__(self, args: ModelArgs): super().__init__(); self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False); self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False); self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
|
| 98 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
| 99 |
+
|
| 100 |
+
class TransformerBlock(nn.Module):
|
| 101 |
+
def __init__(self, args: ModelArgs): super().__init__(); self.args = args; self.attention_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.attention = Attention(args); self.ffn_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.feed_forward = FeedForward(args)
|
| 102 |
+
def forward(self, x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor, mask: torch.Tensor | None = None, cache: tuple[torch.Tensor, torch.Tensor] | None = None) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
| 103 |
+
r, cache = self.attention(self.attention_norm(x), freqs_cis_full, positions, mask, cache); h = x + r; r = self.feed_forward(self.ffn_norm(h)); out = h + r; return out, cache
|
| 104 |
+
|
| 105 |
+
class Llama(nn.Module):
|
| 106 |
+
def __init__(self, args: ModelArgs):
|
| 107 |
+
super().__init__(); self.args = args; self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size); self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.num_hidden_layers)])
|
| 108 |
+
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.tok_embeddings.weight.requires_grad = True
|
| 109 |
+
freqs_cis, _ = precompute_theta_pos_frequencies(args.head_dim, args.max_position_embeddings, device='cpu', theta=args.rope_theta)
|
| 110 |
+
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
| 111 |
+
def forward(self, tokens: torch.Tensor, positions: torch.Tensor):
|
| 112 |
+
bsz, seqlen = tokens.shape; h = self.tok_embeddings(tokens); freqs_cis_full = self.freqs_cis.to(h.device); mask = None
|
| 113 |
+
if seqlen > 1: mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device); mask = torch.triu(mask, diagonal=1).type_as(h)
|
| 114 |
+
positions = positions.to(h.device)
|
| 115 |
+
for layer in self.layers: h, _ = layer(h, freqs_cis_full, positions, mask, cache=None) # Pass cache=None for non-cached forward
|
| 116 |
+
h = self.norm(h); output = F.linear(h, self.tok_embeddings.weight); return output # Use tied weights
|
| 117 |
+
|
| 118 |
+
@torch.no_grad()
|
| 119 |
+
def generate(model: Llama, tokenizer: AutoTokenizer, prompt: str, max_new_tokens: int, temperature: float = 1.0, top_k: int | None = None, top_p: float | None = None):
|
| 120 |
+
model.eval() # CORRECTED: Separate line
|
| 121 |
+
try:
|
| 122 |
+
model_device = next(model.parameters()).device
|
| 123 |
+
model_dtype = next(model.parameters()).dtype
|
| 124 |
+
except StopIteration:
|
| 125 |
+
print("Warning: Model has no parameters. Assuming CPU and float32.")
|
| 126 |
+
model_device = torch.device("cpu")
|
| 127 |
+
model_dtype = torch.float32
|
| 128 |
+
|
| 129 |
+
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True); tokens = torch.tensor([prompt_ids], dtype=torch.long, device=model_device)
|
| 130 |
+
cache = [(torch.zeros((1, model.args.num_key_value_heads, 0, model.args.head_dim), device=model_device, dtype=model_dtype),
|
| 131 |
+
torch.zeros((1, model.args.num_key_value_heads, 0, model.args.head_dim), device=model_device, dtype=model_dtype))
|
| 132 |
+
for _ in range(model.args.num_hidden_layers)]
|
| 133 |
+
generated_token_ids = []; current_tokens = tokens; print(f"Generating {max_new_tokens} tokens from prompt: '{prompt}'"); print("Output: ", end='')
|
| 134 |
+
full_freqs_cis = model.freqs_cis.to(model_device)
|
| 135 |
+
for i in range(max_new_tokens):
|
| 136 |
+
current_seq_len = current_tokens.shape[1]; start_pos = cache[0][0].shape[2]; positions = torch.arange(start_pos, start_pos + current_seq_len, device=model_device)
|
| 137 |
+
current_mask = None;
|
| 138 |
+
if i == 0 and current_seq_len > 1: current_mask = torch.full((1, 1, current_seq_len, current_seq_len), float("-inf"), device=model_device); current_mask = torch.triu(current_mask, diagonal=1).type(model_dtype)
|
| 139 |
+
h = model.tok_embeddings(current_tokens); updated_cache_list = []
|
| 140 |
+
for layer_idx, layer in enumerate(model.layers): h, updated_layer_cache = layer(h, full_freqs_cis, positions, current_mask, cache[layer_idx]); updated_cache_list.append(updated_layer_cache)
|
| 141 |
+
cache = updated_cache_list; h = model.norm(h); logits = F.linear(h, model.tok_embeddings.weight)
|
| 142 |
+
next_token_logits = logits[:, -1, :]
|
| 143 |
+
if temperature == 0: next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
| 144 |
+
else:
|
| 145 |
+
next_token_logits = next_token_logits / temperature
|
| 146 |
+
if top_k is not None and top_k > 0: v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))); next_token_logits[next_token_logits < v[:, [-1]]] = float('-inf')
|
| 147 |
+
if top_p is not None and 0.0 < top_p < 1.0:
|
| 148 |
+
probs_for_filter = F.softmax(next_token_logits, dim=-1); probs_sort, probs_idx = torch.sort(probs_for_filter, descending=True); probs_sum = torch.cumsum(probs_sort, dim=-1)
|
| 149 |
+
mask_top_p = probs_sum > top_p; mask_top_p[..., 0] = False; mask_top_p[..., 1:] = mask_top_p[..., :-1].clone(); indices_to_remove = mask_top_p.scatter(1, probs_idx, mask_top_p); next_token_logits[indices_to_remove] = float('-inf')
|
| 150 |
+
probs = F.softmax(next_token_logits, dim=-1); next_token_id = torch.multinomial(probs, num_samples=1)
|
| 151 |
+
if tokenizer.eos_token_id is not None and next_token_id.item() == tokenizer.eos_token_id: print("\n[EOS token reached]"); break
|
| 152 |
+
next_token_id_item = next_token_id.item(); generated_token_ids.append(next_token_id_item); current_tokens = next_token_id.clone()
|
| 153 |
+
print(tokenizer.decode([next_token_id_item]), end='', flush=True)
|
| 154 |
+
if len(generated_token_ids) >= max_new_tokens: break
|
| 155 |
+
print("\n--- Generation Complete ---"); final_token_ids = prompt_ids + generated_token_ids; full_generated_text = tokenizer.decode(final_token_ids, skip_special_tokens=False)
|
| 156 |
+
print(f"\nFull generated text:\n{full_generated_text}"); return full_generated_text
|
| 157 |
+
# --- End Placeholders ---
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# --- Main Inference Execution ---
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
|
| 163 |
+
# --- Configuration for Inference ---
|
| 164 |
+
# --- !! USE SPECIFIC WINDOWS PATH !! ---
|
| 165 |
+
raw_checkpoint_path = r".\step_800.pt" # <<< CHANGED to step 1200
|
| 166 |
+
# --- Normalize the path ---
|
| 167 |
+
checkpoint_path = os.path.normpath(raw_checkpoint_path)
|
| 168 |
+
# --- End Adjust ---
|
| 169 |
+
|
| 170 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 171 |
+
print(f"\n--- Inference Setup ---")
|
| 172 |
+
print(f"Using device: {device}")
|
| 173 |
+
print(f"Attempting to load checkpoint: {checkpoint_path}")
|
| 174 |
+
|
| 175 |
+
# --- Load Checkpoint and Model Args ---
|
| 176 |
+
if not os.path.exists(checkpoint_path):
|
| 177 |
+
# Removed the fallback logic as we are specifying an exact path
|
| 178 |
+
exit(f"Error: Checkpoint file not found at the specified path: {checkpoint_path}")
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
# Load checkpoint to CPU first
|
| 182 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) # weights_only=False needed for ModelArgs
|
| 183 |
+
|
| 184 |
+
# Load args dict and instantiate ModelArgs
|
| 185 |
+
saved_args_data = checkpoint.get('model_args', checkpoint.get('model_args_dict')) # Check both keys
|
| 186 |
+
if not saved_args_data: exit("Error: model_args not found in checkpoint.")
|
| 187 |
+
if not isinstance(saved_args_data, dict): saved_args_dict = dataclasses.asdict(saved_args_data)
|
| 188 |
+
else: saved_args_dict = saved_args_data
|
| 189 |
+
init_field_names = {f.name for f in dataclasses.fields(ModelArgs) if f.init}
|
| 190 |
+
filtered_args_dict = {k: v for k, v in saved_args_dict.items() if k in init_field_names}
|
| 191 |
+
config_inf = ModelArgs(**filtered_args_dict)
|
| 192 |
+
|
| 193 |
+
print(f"Loaded model config from checkpoint: {config_inf}")
|
| 194 |
+
|
| 195 |
+
# --- Instantiate Model ---
|
| 196 |
+
model_inf = Llama(config_inf) # Instantiate on CPU
|
| 197 |
+
print("Model instantiated on CPU.")
|
| 198 |
+
|
| 199 |
+
# --- Load Weights ---
|
| 200 |
+
model_inf.load_state_dict(checkpoint['model_state_dict'])
|
| 201 |
+
print("Model weights loaded.")
|
| 202 |
+
model_inf.to(device) # Move model to target device
|
| 203 |
+
print(f"Model moved to {device}.")
|
| 204 |
+
|
| 205 |
+
# --- Prepare for Inference ---
|
| 206 |
+
model_inf.eval()
|
| 207 |
+
if device.type == 'cuda':
|
| 208 |
+
try: model_inf = model_inf.half(); print("Converted loaded model to float16 for inference.")
|
| 209 |
+
except Exception as e: print(f"Could not convert model to float16: {e}")
|
| 210 |
+
|
| 211 |
+
except Exception as e: exit(f"Error loading checkpoint or instantiating model: {e}")
|
| 212 |
+
|
| 213 |
+
# --- Load Tokenizer ---
|
| 214 |
+
tokenizer_name_inf = "deepseek-ai/DeepSeek-R1"
|
| 215 |
+
print(f"Loading tokenizer: {tokenizer_name_inf}")
|
| 216 |
+
try:
|
| 217 |
+
tokenizer_inf = AutoTokenizer.from_pretrained(tokenizer_name_inf, trust_remote_code=True)
|
| 218 |
+
if tokenizer_inf.pad_token is None:
|
| 219 |
+
if tokenizer_inf.eos_token: tokenizer_inf.pad_token = tokenizer_inf.eos_token
|
| 220 |
+
else: tokenizer_inf.add_special_tokens({'pad_token': '[PAD]'})
|
| 221 |
+
print("Tokenizer loaded.")
|
| 222 |
+
except Exception as e: exit(f"Error loading tokenizer: {e}")
|
| 223 |
+
|
| 224 |
+
# --- Run Generation ---
|
| 225 |
+
print(f"\n--- Running Generation with Loaded Checkpoint ({os.path.basename(checkpoint_path)}) ---") # Updated print
|
| 226 |
+
prompt_inf = "Valkyria Chronicles is a tactical role-playing game developed and published by"
|
| 227 |
+
max_gen_len = 100
|
| 228 |
+
gen_temperature = 0.7
|
| 229 |
+
gen_top_k = 50
|
| 230 |
+
gen_top_p = 0.9
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
start_time_inf = time.time()
|
| 234 |
+
_ = generate(
|
| 235 |
+
model=model_inf, tokenizer=tokenizer_inf, prompt=prompt_inf,
|
| 236 |
+
max_new_tokens=max_gen_len, temperature=gen_temperature,
|
| 237 |
+
top_k=gen_top_k, top_p=gen_top_p
|
| 238 |
+
)
|
| 239 |
+
end_time_inf = time.time()
|
| 240 |
+
print(f"\nInference duration: {datetime.timedelta(seconds=int(end_time_inf - start_time_inf))}")
|
| 241 |
+
print("\n(Output quality depends heavily on limited training. Expect limited coherence.)")
|
| 242 |
+
except Exception as e: print(f"\nAn error occurred during generation: {e}"); traceback.print_exc()
|
| 243 |
+
|
| 244 |
+
print("\n--- Inference Script Section Finished ---")
|
loss_plot_step_0_to_120.png
ADDED
|
model_architecture.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ==============================================================================
|
| 2 |
+
# Single-File Script ~221M Model - Resume Training for ~4 Hours
|
| 3 |
+
# ==============================================================================
|
| 4 |
+
# --- Necessary Imports ---
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
import math
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
| 11 |
+
import os
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import traceback
|
| 14 |
+
# Corrected import: Added IterableDataset AND Dataset
|
| 15 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| 16 |
+
import torch.optim as optim
|
| 17 |
+
# Use torch.amp imports (recommended over torch.cuda.amp)
|
| 18 |
+
from torch.amp import GradScaler, autocast
|
| 19 |
+
from datasets import load_dataset, IterableDataset as HFIterableDataset
|
| 20 |
+
import datetime
|
| 21 |
+
import random
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
import glob
|
| 24 |
+
import time
|
| 25 |
+
import dataclasses # Make sure this is imported
|
| 26 |
+
|
| 27 |
+
# --- Model Configuration ---
|
| 28 |
+
@dataclass
|
| 29 |
+
class ModelArgs:
|
| 30 |
+
# --- ~221M Config for 4GB VRAM ---
|
| 31 |
+
hidden_size: int = 768; num_hidden_layers: int = 12; num_attention_heads: int = 12
|
| 32 |
+
num_key_value_heads: int = 12; intermediate_size: int = 2048; vocab_size: int = 128000
|
| 33 |
+
rms_norm_eps: float = 1e-5; rope_theta: float = 500000.0; max_position_embeddings: int = 4096
|
| 34 |
+
head_dim: int = field(init=False)
|
| 35 |
+
add_recency_bias: bool = False # Keep this option if desired
|
| 36 |
+
|
| 37 |
+
def __post_init__(self):
|
| 38 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
| 39 |
+
if self.hidden_size % self.num_attention_heads != 0: raise ValueError("hidden_size % num_attention_heads != 0")
|
| 40 |
+
if self.num_attention_heads % self.num_key_value_heads != 0: raise ValueError("num_attention_heads % num_key_value_heads != 0")
|
| 41 |
+
|
| 42 |
+
# --- Model Components (RMSNorm, RoPE funcs, Attention, FeedForward, TransformerBlock, Llama) ---
|
| 43 |
+
class RMSNorm(nn.Module):
|
| 44 |
+
def __init__(self, dim: int, eps: float = 1e-6): super().__init__(); self.eps = eps; self.weight = nn.Parameter(torch.ones(dim))
|
| 45 |
+
def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 46 |
+
def forward(self, x): original_dtype = x.dtype; output = self._norm(x.float()).to(original_dtype); return output * self.weight
|
| 47 |
+
|
| 48 |
+
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str | torch.device, theta: float = 10000.0):
|
| 49 |
+
if head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE")
|
| 50 |
+
theta_indices = torch.arange(0, head_dim, 2).float(); theta_freqs = 1.0 / (theta**(theta_indices / head_dim))
|
| 51 |
+
target_device = torch.device(device) if isinstance(device, str) else device; theta_freqs = theta_freqs.to(target_device)
|
| 52 |
+
positions = torch.arange(seq_len, device=target_device).float(); freqs = torch.outer(positions, theta_freqs).float(); return freqs, positions
|
| 53 |
+
|
| 54 |
+
def apply_rotary_embeddings(x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor):
|
| 55 |
+
positions = positions.long(); max_pos = freqs_cis_full.shape[0]
|
| 56 |
+
if torch.max(positions) >= max_pos: positions = torch.clamp(positions, max=max_pos - 1)
|
| 57 |
+
freqs = freqs_cis_full[positions]; freqs = freqs.unsqueeze(0).unsqueeze(2)
|
| 58 |
+
bsz, seq_len, n_part_heads, head_dim = x.shape; x1 = x[..., : head_dim // 2]; x2 = x[..., head_dim // 2 :]
|
| 59 |
+
cos_freqs = torch.cos(freqs).type_as(x); sin_freqs = torch.sin(freqs).type_as(x)
|
| 60 |
+
rotated_x1 = x1 * cos_freqs - x2 * sin_freqs; rotated_x2 = x1 * sin_freqs + x2 * cos_freqs
|
| 61 |
+
rotated_x = torch.cat([rotated_x1, rotated_x2], dim=-1); return rotated_x.type_as(x)
|
| 62 |
+
|
| 63 |
+
class Attention(nn.Module):
|
| 64 |
+
def __init__(self, args: ModelArgs):
|
| 65 |
+
super().__init__(); self.args = args; self.num_heads = args.num_attention_heads; self.num_kv_heads = args.num_key_value_heads
|
| 66 |
+
self.head_dim = args.head_dim; self.repeats = self.num_heads // self.num_kv_heads
|
| 67 |
+
self.wq = nn.Linear(args.hidden_size, args.num_attention_heads * args.head_dim, bias=False)
|
| 68 |
+
self.wk = nn.Linear(args.hidden_size, args.num_key_value_heads * args.head_dim, bias=False)
|
| 69 |
+
self.wv = nn.Linear(args.hidden_size, args.num_key_value_heads * args.head_dim, bias=False)
|
| 70 |
+
self.wo = nn.Linear(args.num_attention_heads * args.head_dim, args.hidden_size, bias=False)
|
| 71 |
+
def _repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 72 |
+
bsz, n_kv_heads, seqlen, head_dim = x.shape;
|
| 73 |
+
if n_rep == 1: return x
|
| 74 |
+
return (x[:, :, None, :, :].expand(bsz, n_kv_heads, n_rep, seqlen, head_dim).reshape(bsz, n_kv_heads * n_rep, seqlen, head_dim))
|
| 75 |
+
def _create_recency_bias(self, seqlen, full_seqlen, device, dtype, bias_strength=0.1, decay_rate=0.9):
|
| 76 |
+
bias = torch.zeros((1, 1, seqlen, full_seqlen), device=device, dtype=dtype); indices = torch.arange(full_seqlen, device=device)
|
| 77 |
+
rel_pos = torch.arange(seqlen, device=device).unsqueeze(1) - indices.unsqueeze(0); mask = rel_pos >= 0
|
| 78 |
+
decaying_bias = bias_strength * (decay_rate ** (-rel_pos[mask])); bias[:, :, mask] = decaying_bias.type_as(bias); return bias
|
| 79 |
+
def forward(self, x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor, mask: torch.Tensor | None = None, cache: tuple[torch.Tensor, torch.Tensor] | None = None) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
| 80 |
+
bsz, seqlen, _ = x.shape; xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
| 81 |
+
xq = xq.view(bsz, seqlen, self.num_heads, self.head_dim); xk = xk.view(bsz, seqlen, self.num_kv_heads, self.head_dim); xv = xv.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
|
| 82 |
+
xq = apply_rotary_embeddings(xq, freqs_cis_full, positions); xk = apply_rotary_embeddings(xk, freqs_cis_full, positions)
|
| 83 |
+
xk = xk.transpose(1, 2); xv = xv.transpose(1, 2)
|
| 84 |
+
if cache is not None: cache_k, cache_v = cache; keys = torch.cat((cache_k.to(xk.device), xk), dim=2); values = torch.cat((cache_v.to(xv.device), xv), dim=2)
|
| 85 |
+
else: keys = xk; values = xv
|
| 86 |
+
updated_cache = (keys.detach(), values.detach()); keys_repeated = self._repeat_kv(keys, self.repeats); values_repeated = self._repeat_kv(values, self.repeats)
|
| 87 |
+
xq = xq.transpose(1, 2); scores = torch.matmul(xq.float(), keys_repeated.transpose(-2, -1).float()) / math.sqrt(self.head_dim)
|
| 88 |
+
if self.args.add_recency_bias:
|
| 89 |
+
full_seqlen = keys_repeated.shape[-2]; recency_bias = self._create_recency_bias(seqlen, full_seqlen, device=scores.device, dtype=scores.dtype); scores = scores + recency_bias
|
| 90 |
+
if mask is not None:
|
| 91 |
+
full_seqlen = keys_repeated.shape[-2]; expected_mask_shape_end = (seqlen, full_seqlen)
|
| 92 |
+
if mask.shape[-2:] != expected_mask_shape_end:
|
| 93 |
+
try: mask_slice = mask[:, :, -seqlen:, :full_seqlen]; scores = scores + mask_slice.float()
|
| 94 |
+
except Exception: pass
|
| 95 |
+
else: scores = scores + mask.float()
|
| 96 |
+
scores = nn.functional.softmax(scores, dim=-1).type_as(xq); output = torch.matmul(scores, values_repeated)
|
| 97 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1); output = self.wo(output); return output, updated_cache
|
| 98 |
+
|
| 99 |
+
class FeedForward(nn.Module):
|
| 100 |
+
def __init__(self, args: ModelArgs): super().__init__(); self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False); self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False); self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
|
| 101 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
| 102 |
+
|
| 103 |
+
class TransformerBlock(nn.Module):
|
| 104 |
+
def __init__(self, args: ModelArgs): super().__init__(); self.args = args; self.attention_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.attention = Attention(args); self.ffn_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.feed_forward = FeedForward(args)
|
| 105 |
+
def forward(self, x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor, mask: torch.Tensor | None = None, cache: tuple[torch.Tensor, torch.Tensor] | None = None) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
| 106 |
+
r, cache = self.attention(self.attention_norm(x), freqs_cis_full, positions, mask, cache); h = x + r; r = self.feed_forward(self.ffn_norm(h)); out = h + r; return out, cache
|
| 107 |
+
|
| 108 |
+
class Llama(nn.Module):
|
| 109 |
+
def __init__(self, args: ModelArgs):
|
| 110 |
+
super().__init__(); self.args = args; self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size); self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.num_hidden_layers)])
|
| 111 |
+
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.tok_embeddings.weight.requires_grad = True
|
| 112 |
+
freqs_cis, _ = precompute_theta_pos_frequencies(args.head_dim, args.max_position_embeddings, device='cpu', theta=args.rope_theta)
|
| 113 |
+
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
| 114 |
+
def forward(self, tokens: torch.Tensor, positions: torch.Tensor):
|
| 115 |
+
bsz, seqlen = tokens.shape; h = self.tok_embeddings(tokens); freqs_cis_full = self.freqs_cis.to(h.device); mask = None
|
| 116 |
+
if seqlen > 1: mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device); mask = torch.triu(mask, diagonal=1).type_as(h)
|
| 117 |
+
positions = positions.to(h.device)
|
| 118 |
+
for layer in self.layers: h, _ = layer(h, freqs_cis_full, positions, mask, cache=None)
|
| 119 |
+
h = self.norm(h); output = F.linear(h, self.tok_embeddings.weight); return output
|
| 120 |
+
|
| 121 |
+
# --- Generate function (Added Top-P Sampling) ---
|
| 122 |
+
@torch.no_grad()
|
| 123 |
+
def generate(model: Llama, tokenizer: AutoTokenizer, prompt: str, max_new_tokens: int, temperature: float = 1.0, top_k: int | None = None, top_p: float | None = None):
|
| 124 |
+
model.eval()
|
| 125 |
+
try: model_device = next(model.parameters()).device; model_dtype = next(model.parameters()).dtype
|
| 126 |
+
except StopIteration: model_device = torch.device("cpu"); model_dtype = torch.float32; print("Warning: Model has no parameters.")
|
| 127 |
+
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True); tokens = torch.tensor([prompt_ids], dtype=torch.long, device=model_device)
|
| 128 |
+
cache = [(torch.zeros((1, model.args.num_key_value_heads, 0, model.args.head_dim), device=model_device, dtype=model_dtype),
|
| 129 |
+
torch.zeros((1, model.args.num_key_value_heads, 0, model.args.head_dim), device=model_device, dtype=model_dtype))
|
| 130 |
+
for _ in range(model.args.num_hidden_layers)]
|
| 131 |
+
generated_token_ids = []; current_tokens = tokens; print(f"Generating {max_new_tokens} tokens from prompt: '{prompt}'"); print("Output: ", end='')
|
| 132 |
+
full_freqs_cis = model.freqs_cis.to(model_device)
|
| 133 |
+
for i in range(max_new_tokens):
|
| 134 |
+
current_seq_len = current_tokens.shape[1]; start_pos = cache[0][0].shape[2]; positions = torch.arange(start_pos, start_pos + current_seq_len, device=model_device)
|
| 135 |
+
current_mask = None;
|
| 136 |
+
if i == 0 and current_seq_len > 1: current_mask = torch.full((1, 1, current_seq_len, current_seq_len), float("-inf"), device=model_device); current_mask = torch.triu(current_mask, diagonal=1).type(model_dtype)
|
| 137 |
+
h = model.tok_embeddings(current_tokens); updated_cache_list = []
|
| 138 |
+
for layer_idx, layer in enumerate(model.layers): h, updated_layer_cache = layer(h, full_freqs_cis, positions, current_mask, cache[layer_idx]); updated_cache_list.append(updated_layer_cache)
|
| 139 |
+
cache = updated_cache_list; h = model.norm(h); logits = F.linear(h, model.tok_embeddings.weight)
|
| 140 |
+
next_token_logits = logits[:, -1, :]
|
| 141 |
+
if temperature == 0: next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
| 142 |
+
else:
|
| 143 |
+
next_token_logits = next_token_logits / temperature
|
| 144 |
+
if top_k is not None and top_k > 0: v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))); next_token_logits[next_token_logits < v[:, [-1]]] = float('-inf')
|
| 145 |
+
if top_p is not None and 0.0 < top_p < 1.0:
|
| 146 |
+
probs_for_filter = F.softmax(next_token_logits, dim=-1); probs_sort, probs_idx = torch.sort(probs_for_filter, descending=True); probs_sum = torch.cumsum(probs_sort, dim=-1)
|
| 147 |
+
mask_top_p = probs_sum > top_p; mask_top_p[..., 0] = False; mask_top_p[..., 1:] = mask_top_p[..., :-1].clone(); indices_to_remove = mask_top_p.scatter(1, probs_idx, mask_top_p); next_token_logits[indices_to_remove] = float('-inf')
|
| 148 |
+
probs = F.softmax(next_token_logits, dim=-1); next_token_id = torch.multinomial(probs, num_samples=1)
|
| 149 |
+
if tokenizer.eos_token_id is not None and next_token_id.item() == tokenizer.eos_token_id: print("\n[EOS token reached]"); break
|
| 150 |
+
next_token_id_item = next_token_id.item(); generated_token_ids.append(next_token_id_item); current_tokens = next_token_id.clone()
|
| 151 |
+
print(tokenizer.decode([next_token_id_item]), end='', flush=True)
|
| 152 |
+
if len(generated_token_ids) >= max_new_tokens: break
|
| 153 |
+
print("\n--- Generation Complete ---"); final_token_ids = prompt_ids + generated_token_ids; full_generated_text = tokenizer.decode(final_token_ids, skip_special_tokens=False)
|
| 154 |
+
print(f"\nFull generated text:\n{full_generated_text}"); return full_generated_text
|
| 155 |
+
|
| 156 |
+
# --- Dataset Class (Map Style for WikiText) ---
|
| 157 |
+
class SimpleLMDataset(Dataset):
|
| 158 |
+
def __init__(self, token_ids: list[int], sequence_length: int):
|
| 159 |
+
self.token_ids = token_ids; self.sequence_length = sequence_length
|
| 160 |
+
self.num_sequences = max(0, len(token_ids) - sequence_length)
|
| 161 |
+
if self.num_sequences == 0: raise ValueError(f"Dataset token count ({len(token_ids)}) not > sequence length ({sequence_length}).")
|
| 162 |
+
def __len__(self): return self.num_sequences
|
| 163 |
+
def __getitem__(self, idx):
|
| 164 |
+
chunk = self.token_ids[idx : idx + self.sequence_length + 1]
|
| 165 |
+
if len(chunk) < self.sequence_length + 1:
|
| 166 |
+
last_valid_idx = len(self.token_ids) - self.sequence_length - 1
|
| 167 |
+
chunk = self.token_ids[last_valid_idx : last_valid_idx + self.sequence_length + 1]
|
| 168 |
+
input_ids = torch.tensor(chunk[:-1], dtype=torch.long); target_ids = torch.tensor(chunk[1:], dtype=torch.long)
|
| 169 |
+
return input_ids, target_ids
|
| 170 |
+
|
| 171 |
+
# --- Dataset Class (Iterable for SlimPajama - Kept for reference/fallback) ---
|
| 172 |
+
class TokenizedSequenceDataset(IterableDataset):
|
| 173 |
+
def __init__(self, dataset_name, dataset_config, split, tokenizer, sequence_length, buffer_size=10000):
|
| 174 |
+
try: self.dataset = load_dataset(dataset_name, dataset_config, split=split, streaming=True); print(f"Successfully loaded streaming dataset: {dataset_name} ({split})")
|
| 175 |
+
except Exception as e: raise RuntimeError(f"Failed to load streaming dataset {dataset_name} ({split}): {e}") from e
|
| 176 |
+
self.tokenizer = tokenizer; self.sequence_length = sequence_length; self.buffer_size = buffer_size; self.buffer = []
|
| 177 |
+
try: self.iter_dataset = iter(self.dataset)
|
| 178 |
+
except Exception as e: raise RuntimeError(f"Failed to create iterator for dataset {dataset_name} ({split}): {e}") from e
|
| 179 |
+
def __iter__(self):
|
| 180 |
+
while True:
|
| 181 |
+
while len(self.buffer) < self.sequence_length + 1:
|
| 182 |
+
try:
|
| 183 |
+
item = next(self.iter_dataset); text = item.get('text', '')
|
| 184 |
+
if text and text.strip(): token_ids = self.tokenizer.encode(text, add_special_tokens=False); self.buffer.extend(token_ids)
|
| 185 |
+
except StopIteration:
|
| 186 |
+
if len(self.buffer) < self.sequence_length + 1: return
|
| 187 |
+
else: break
|
| 188 |
+
if len(self.buffer) < self.sequence_length + 1: return
|
| 189 |
+
chunk = self.buffer[:self.sequence_length + 1]; input_ids = torch.tensor(chunk[:-1], dtype=torch.long); target_ids = torch.tensor(chunk[1:], dtype=torch.long)
|
| 190 |
+
yield input_ids, target_ids; self.buffer = self.buffer[1:]
|
| 191 |
+
|
| 192 |
+
# --- Checkpoint Loading Function ---
|
| 193 |
+
def load_checkpoint(checkpoint_dir: str, model: Llama, optimizer, scaler, scheduler, device):
|
| 194 |
+
latest_checkpoint_path = None; highest_step = -1
|
| 195 |
+
if os.path.isdir(checkpoint_dir):
|
| 196 |
+
checkpoints = glob.glob(os.path.join(checkpoint_dir, "step_*.pt"))
|
| 197 |
+
for ckpt_path in checkpoints:
|
| 198 |
+
try: step = int(os.path.basename(ckpt_path).split('_')[-1].split('.')[0]);
|
| 199 |
+
except ValueError: continue
|
| 200 |
+
if step > highest_step: highest_step = step; latest_checkpoint_path = ckpt_path
|
| 201 |
+
if latest_checkpoint_path:
|
| 202 |
+
print(f"Loading checkpoint from: {latest_checkpoint_path}")
|
| 203 |
+
try:
|
| 204 |
+
checkpoint = torch.load(latest_checkpoint_path, map_location='cpu', weights_only=False) # Use False for safety
|
| 205 |
+
current_args_dict = model.args.__dict__
|
| 206 |
+
saved_args_data = checkpoint.get('model_args', checkpoint.get('model_args_dict'))
|
| 207 |
+
if not saved_args_data: print("Warning: Checkpoint missing model_args."); saved_args_dict=None; args_match=False
|
| 208 |
+
elif not isinstance(saved_args_data, dict): saved_args_dict = dataclasses.asdict(saved_args_data) # Use imported module
|
| 209 |
+
else: saved_args_dict = saved_args_data
|
| 210 |
+
args_match = True
|
| 211 |
+
if saved_args_dict:
|
| 212 |
+
for f in dataclasses.fields(ModelArgs): # Use dataclasses.fields
|
| 213 |
+
if f.init and f.name != 'head_dim':
|
| 214 |
+
current_val = current_args_dict.get(f.name); saved_val = saved_args_dict.get(f.name)
|
| 215 |
+
if current_val != saved_val: print(f"Mismatch in arg '{f.name}': Current={current_val}, Saved={saved_val}"); args_match = False; break
|
| 216 |
+
else: args_match = False
|
| 217 |
+
if not args_match: print("ERROR: Model args mismatch. Cannot load checkpoint."); return 0
|
| 218 |
+
model.load_state_dict(checkpoint['model_state_dict']); model.to(device)
|
| 219 |
+
if optimizer is not None:
|
| 220 |
+
try: optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 221 |
+
except Exception as e: print(f"Warning: Could not load optimizer state dict: {e}")
|
| 222 |
+
for state in optimizer.state.values():
|
| 223 |
+
for k, v in state.items():
|
| 224 |
+
if isinstance(v, torch.Tensor): state[k] = v.to(device)
|
| 225 |
+
if scaler is not None:
|
| 226 |
+
try: scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
| 227 |
+
except Exception as e: print(f"Warning: Could not load scaler state dict: {e}")
|
| 228 |
+
if scheduler is not None:
|
| 229 |
+
try: scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 230 |
+
except Exception as e: print(f"Warning: Could not load scheduler state dict: {e}")
|
| 231 |
+
start_step = checkpoint['step']; print(f"Resuming training from step {start_step + 1}"); return start_step
|
| 232 |
+
except Exception as e: print(f"Error loading checkpoint {latest_checkpoint_path}: {e}"); traceback.print_exc(); return 0
|
| 233 |
+
else: print("No checkpoint found. Starting training from scratch."); return 0
|
| 234 |
+
|
| 235 |
+
# --- Plotting Function ---
|
| 236 |
+
def plot_loss(train_losses, val_losses, val_steps_list, checkpoint_dir, start_step=0):
|
| 237 |
+
plt.figure(figsize=(12, 6)); smoothing_window = 10
|
| 238 |
+
train_steps = list(range(start_step + 1, start_step + len(train_losses) + 1))
|
| 239 |
+
plt.plot(train_steps, train_losses, label='Training Loss (Raw)', alpha=0.3)
|
| 240 |
+
if len(train_losses) > smoothing_window:
|
| 241 |
+
train_losses_smoothed = [sum(train_losses[max(0, i-smoothing_window):i+1])/min(i+1, smoothing_window) for i in range(len(train_losses))]
|
| 242 |
+
plt.plot(train_steps, train_losses_smoothed, label=f'Training Loss (Smoothed {smoothing_window} steps)', alpha=0.9)
|
| 243 |
+
if val_losses and val_steps_list: plt.plot(val_steps_list, val_losses, label='Validation Loss', marker='o', linestyle='--')
|
| 244 |
+
plt.xlabel("Optimizer Steps"); plt.ylabel("Loss"); plt.yscale('log'); plt.title("Training and Validation Loss Over Steps")
|
| 245 |
+
plt.legend(); plt.grid(True); plot_filename = f"loss_plot_step_{start_step}_to_{start_step+len(train_losses)}.png"
|
| 246 |
+
plot_path = os.path.join(checkpoint_dir, plot_filename); plt.savefig(plot_path)
|
| 247 |
+
print(f"Loss plot saved to {plot_path}")
|
| 248 |
+
|
| 249 |
+
# --- Basic Training Function (Single GPU, AMP, LR Schedule, Validation, Checkpointing, Plotting) ---
|
| 250 |
+
def simple_train(
|
| 251 |
+
model: Llama, tokenizer: AutoTokenizer, train_dataset: IterableDataset | Dataset, val_dataset: IterableDataset | Dataset | None,
|
| 252 |
+
optimizer: torch.optim.Optimizer, criterion: nn.Module, scheduler,
|
| 253 |
+
num_epochs: int, device: torch.device, gradient_accumulation_steps: int = 1,
|
| 254 |
+
use_amp: bool = False, max_train_steps: int | None = None, start_step: int = 0,
|
| 255 |
+
save_interval: int = 1000, checkpoint_dir: str = ".",
|
| 256 |
+
validation_interval: int = 500, val_steps: int = 50
|
| 257 |
+
):
|
| 258 |
+
model.train(); total_steps = start_step; global_step_this_run = 0
|
| 259 |
+
scaler = GradScaler(enabled=use_amp and device.type == 'cuda')
|
| 260 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 261 |
+
train_loss_history = []; val_loss_history = []; val_steps_history = []
|
| 262 |
+
print(f"\n--- Starting Training (Resuming from step {start_step}, Target Steps: {max_train_steps if max_train_steps else 'N/A'}) ---")
|
| 263 |
+
print(f"--- (AMP: {use_amp and device.type == 'cuda'}) ---")
|
| 264 |
+
is_iterable = isinstance(train_dataset, IterableDataset)
|
| 265 |
+
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=0, shuffle=(not is_iterable))
|
| 266 |
+
if val_dataset: val_loader = DataLoader(val_dataset, batch_size=1, num_workers=0)
|
| 267 |
+
training_complete = False
|
| 268 |
+
# Adjust tqdm total based on remaining steps
|
| 269 |
+
tqdm_total = (max_train_steps - start_step) if max_train_steps is not None else None
|
| 270 |
+
print(f"Starting loop, aiming for {max_train_steps} total steps...")
|
| 271 |
+
# Use total=None for iterable datasets if max_steps not set, as length is unknown
|
| 272 |
+
pbar = tqdm(total=tqdm_total, desc=f"Optim Steps ({start_step}...)")
|
| 273 |
+
|
| 274 |
+
# Need to manually track iterations vs optimizer steps
|
| 275 |
+
data_iterator = iter(train_loader)
|
| 276 |
+
accum_count = 0 # Counter for gradient accumulation steps
|
| 277 |
+
|
| 278 |
+
while not training_complete:
|
| 279 |
+
# Check if we need to stop before starting the next optimizer step
|
| 280 |
+
if max_train_steps is not None and total_steps >= max_train_steps:
|
| 281 |
+
training_complete = True; break
|
| 282 |
+
|
| 283 |
+
# --- Accumulation Loop ---
|
| 284 |
+
accum_loss = 0.0
|
| 285 |
+
optimizer.zero_grad() # Zero gradients at start of accumulation cycle
|
| 286 |
+
|
| 287 |
+
for _ in range(gradient_accumulation_steps):
|
| 288 |
+
try:
|
| 289 |
+
input_ids, target_ids = next(data_iterator)
|
| 290 |
+
except StopIteration:
|
| 291 |
+
print("\nDataLoader exhausted within accumulation cycle or epoch.")
|
| 292 |
+
# If loader exhausted before completing max_steps, stop training
|
| 293 |
+
training_complete = True; break # Break inner accum loop
|
| 294 |
+
|
| 295 |
+
input_ids = input_ids.to(device); target_ids = target_ids.to(device)
|
| 296 |
+
seqlen = input_ids.shape[1]; positions = torch.arange(seqlen, device=device)
|
| 297 |
+
|
| 298 |
+
# Use torch.amp.autocast
|
| 299 |
+
with autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp and device.type == 'cuda'):
|
| 300 |
+
logits = model(input_ids, positions)
|
| 301 |
+
loss = criterion(logits.view(-1, logits.size(-1)).float(), target_ids.view(-1))
|
| 302 |
+
loss = loss / gradient_accumulation_steps # Normalize loss for accumulation
|
| 303 |
+
|
| 304 |
+
scaler.scale(loss).backward()
|
| 305 |
+
accum_loss += loss.item() # Accumulate *normalized* loss item
|
| 306 |
+
|
| 307 |
+
if training_complete: break # Exit outer loop if data exhausted
|
| 308 |
+
|
| 309 |
+
# --- Optimizer Step ---
|
| 310 |
+
scaler.unscale_(optimizer)
|
| 311 |
+
# Optional: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 312 |
+
scaler.step(optimizer); scaler.update(); scheduler.step(); optimizer.zero_grad(set_to_none=True)
|
| 313 |
+
total_steps += 1; global_step_this_run += 1
|
| 314 |
+
pbar.update(1) # Update progress bar by one optimizer step
|
| 315 |
+
|
| 316 |
+
# --- Logging ---
|
| 317 |
+
current_loss = accum_loss * gradient_accumulation_steps # Log un-normalized loss for this step
|
| 318 |
+
train_loss_history.append(current_loss)
|
| 319 |
+
# Note: epoch_loss calculation might be less meaningful with iterable dataset and max_steps
|
| 320 |
+
# avg_loss_so_far = sum(train_loss_history[-50:]) / min(len(train_loss_history), 50) # Example: rolling average
|
| 321 |
+
pbar.set_postfix({"Loss": f"{current_loss:.4f}", "LR": f"{scheduler.get_last_lr()[0]:.6f}", "Steps": total_steps})
|
| 322 |
+
|
| 323 |
+
# --- Validation ---
|
| 324 |
+
if val_dataset and total_steps % validation_interval == 0 and total_steps > 0:
|
| 325 |
+
model.eval(); val_loss = 0.0; val_batches = 0; print(f"\nRunning validation at step {total_steps}...")
|
| 326 |
+
val_pbar = tqdm(enumerate(val_loader), total=val_steps, desc="Validation")
|
| 327 |
+
with torch.no_grad():
|
| 328 |
+
val_iter = iter(val_loader)
|
| 329 |
+
for val_step in range(val_steps):
|
| 330 |
+
try:
|
| 331 |
+
val_input_ids, val_target_ids = next(val_iter)
|
| 332 |
+
val_input_ids = val_input_ids.to(device); val_target_ids = val_target_ids.to(device)
|
| 333 |
+
val_seqlen = val_input_ids.shape[1]; val_positions = torch.arange(val_seqlen, device=device)
|
| 334 |
+
with autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp and device.type == 'cuda'):
|
| 335 |
+
val_logits = model(val_input_ids, val_positions)
|
| 336 |
+
v_loss = criterion(val_logits.view(-1, val_logits.size(-1)).float(), val_target_ids.view(-1))
|
| 337 |
+
val_loss += v_loss.item(); val_batches += 1; val_pbar.update(1); val_pbar.set_postfix({"Val Loss": f"{val_loss/val_batches:.4f}"})
|
| 338 |
+
except StopIteration: print("Validation loader exhausted early."); break
|
| 339 |
+
val_pbar.close()
|
| 340 |
+
avg_val_loss = val_loss / val_batches if val_batches > 0 else float('inf')
|
| 341 |
+
val_loss_history.append(avg_val_loss); val_steps_history.append(total_steps)
|
| 342 |
+
print(f"Validation finished. Average Val Loss: {avg_val_loss:.4f}"); model.train()
|
| 343 |
+
|
| 344 |
+
# --- Checkpointing ---
|
| 345 |
+
if total_steps % save_interval == 0 and total_steps > 0:
|
| 346 |
+
save_path = os.path.join(checkpoint_dir, f"step_{total_steps}.pt")
|
| 347 |
+
try:
|
| 348 |
+
model_args_dict = dataclasses.asdict(model.args)
|
| 349 |
+
save_content = { 'step': total_steps, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
|
| 350 |
+
'scaler_state_dict': scaler.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'model_args_dict': model_args_dict }
|
| 351 |
+
torch.save(save_content, save_path); print(f"\nCheckpoint saved to {save_path}")
|
| 352 |
+
except Exception as e: print(f"\nError saving checkpoint: {e}")
|
| 353 |
+
|
| 354 |
+
# --- Check Max Steps ---
|
| 355 |
+
if max_train_steps is not None and total_steps >= max_train_steps:
|
| 356 |
+
print(f"\nReached max_train_steps ({max_train_steps}). Stopping training."); training_complete = True; break
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
pbar.close() # Close pbar if loop finishes naturally
|
| 360 |
+
print("--- Training Finished ---")
|
| 361 |
+
return train_loss_history, val_loss_history, val_steps_history
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
# --- Main Execution Block ---
|
| 365 |
+
if __name__ == "__main__":
|
| 366 |
+
# --- Configuration ---
|
| 367 |
+
config = ModelArgs(add_recency_bias=False) # Use ~221M config
|
| 368 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 369 |
+
print(f"Model Configuration:\n{config}")
|
| 370 |
+
print(f"Calculated Head Dimension: {config.head_dim}")
|
| 371 |
+
print(f"\nUsing device: {device}")
|
| 372 |
+
|
| 373 |
+
# --- Component Tests (Commented out) ---
|
| 374 |
+
""" """
|
| 375 |
+
|
| 376 |
+
# --- Tokenizer ---
|
| 377 |
+
print("\n--- Tokenizer Loading ---")
|
| 378 |
+
tokenizer_name = "deepseek-ai/DeepSeek-R1"
|
| 379 |
+
print(f"Loading tokenizer: {tokenizer_name}")
|
| 380 |
+
try:
|
| 381 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
|
| 382 |
+
print("Tokenizer loaded successfully.")
|
| 383 |
+
if tokenizer.vocab_size != config.vocab_size: exit(f"FATAL: Tokenizer vocab size mismatch!")
|
| 384 |
+
else: print(f"Tokenizer vocab size ({tokenizer.vocab_size}) matches model config.")
|
| 385 |
+
if tokenizer.pad_token is None:
|
| 386 |
+
if tokenizer.eos_token: tokenizer.pad_token = tokenizer.eos_token; print(f"Set PAD token to EOS token: {tokenizer.pad_token}")
|
| 387 |
+
else: tokenizer.add_special_tokens({'pad_token': '[PAD]'}); print("Added a generic [PAD] token.")
|
| 388 |
+
except Exception as e: exit(f"Error loading tokenizer '{tokenizer_name}': {e}")
|
| 389 |
+
|
| 390 |
+
# --- Training Setup ---
|
| 391 |
+
print("\n--- Training Setup ---")
|
| 392 |
+
train_batch_size = 1
|
| 393 |
+
train_seq_len = 256
|
| 394 |
+
grad_accum_steps = 16
|
| 395 |
+
use_amp_training = True if device.type == 'cuda' else False
|
| 396 |
+
learning_rate = 1e-4 # Lower LR
|
| 397 |
+
num_epochs = 1
|
| 398 |
+
# --- ADJUSTED MAX STEPS for ~4 hour run ---
|
| 399 |
+
max_steps_for_run = 1200 # Absolute target step for this run (start_step + new_steps)
|
| 400 |
+
# --- ADJUSTED Total Scheduler Steps (longer term goal) ---
|
| 401 |
+
total_scheduler_steps = 10000 # Example longer goal
|
| 402 |
+
warmup_steps = 100
|
| 403 |
+
# --- Save to current directory ---
|
| 404 |
+
checkpoint_dir = "."
|
| 405 |
+
save_interval = 200 # Save less frequently
|
| 406 |
+
validation_interval = 100 # Validate less frequently
|
| 407 |
+
val_steps = 20
|
| 408 |
+
|
| 409 |
+
# --- Dataset ---
|
| 410 |
+
print("\nLoading and preparing WikiText-2 dataset...")
|
| 411 |
+
train_dataset, val_dataset = None, None
|
| 412 |
+
try:
|
| 413 |
+
# Using WikiText-2 directly
|
| 414 |
+
token_file = "./wikitext2_tokens_128k.pt"
|
| 415 |
+
val_token_file = "./wikitext2_val_tokens_128k.pt"
|
| 416 |
+
force_remake_dataset = False
|
| 417 |
+
if os.path.exists(token_file) and os.path.exists(val_token_file) and not force_remake_dataset:
|
| 418 |
+
print(f"Loading tokenized data from {token_file} and {val_token_file}...")
|
| 419 |
+
all_token_ids = torch.load(token_file)
|
| 420 |
+
all_val_token_ids = torch.load(val_token_file)
|
| 421 |
+
print("Tokenized data loaded.")
|
| 422 |
+
else:
|
| 423 |
+
print("Token files not found or remake forced, processing WikiText-2...")
|
| 424 |
+
print("Processing train split...")
|
| 425 |
+
train_raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
| 426 |
+
train_full_text = "\n".join([item['text'] for item in train_raw_dataset if item['text'].strip()])
|
| 427 |
+
all_token_ids = tokenizer.encode(train_full_text)
|
| 428 |
+
torch.save(all_token_ids, token_file)
|
| 429 |
+
print(f"Saved tokenized train data ({len(all_token_ids)} tokens) to {token_file}")
|
| 430 |
+
print("Processing validation split...")
|
| 431 |
+
val_raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
|
| 432 |
+
val_full_text = "\n".join([item['text'] for item in val_raw_dataset if item['text'].strip()])
|
| 433 |
+
all_val_token_ids = tokenizer.encode(val_full_text)
|
| 434 |
+
torch.save(all_val_token_ids, val_token_file)
|
| 435 |
+
print(f"Saved tokenized validation data ({len(all_val_token_ids)} tokens) to {val_token_file}")
|
| 436 |
+
|
| 437 |
+
if len(all_token_ids) <= train_seq_len: exit("Train dataset too short.")
|
| 438 |
+
if len(all_val_token_ids) <= train_seq_len: exit("Validation dataset too short.")
|
| 439 |
+
train_dataset = SimpleLMDataset(all_token_ids, sequence_length=train_seq_len)
|
| 440 |
+
val_dataset = SimpleLMDataset(all_val_token_ids, sequence_length=train_seq_len)
|
| 441 |
+
print("Using WikiText-2 dataset.")
|
| 442 |
+
except Exception as e: exit(f"Dataset error: {e}")
|
| 443 |
+
|
| 444 |
+
# DataLoaders
|
| 445 |
+
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=2, pin_memory=True)
|
| 446 |
+
val_loader = DataLoader(val_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True) if val_dataset else None
|
| 447 |
+
print(f"DataLoaders created. Training Seq Len: {train_seq_len}")
|
| 448 |
+
print(f"Train sequences: {len(train_dataset)}, Val sequences: {len(val_dataset) if val_dataset else 0}")
|
| 449 |
+
|
| 450 |
+
# --- Model, Optimizer, Scheduler, Loss ---
|
| 451 |
+
train_model = Llama(config).to(device)
|
| 452 |
+
print(f"Training model instantiated ({'float32' if not use_amp_training else 'mixed precision'}). Recency Bias: {config.add_recency_bias}")
|
| 453 |
+
total_params_train = sum(p.numel() for p in train_model.parameters() if p.requires_grad)
|
| 454 |
+
print(f"Total Trainable Parameters: {total_params_train / 1e6:.2f} Million")
|
| 455 |
+
|
| 456 |
+
optimizer = optim.AdamW(train_model.parameters(), lr=learning_rate, weight_decay=0.1)
|
| 457 |
+
criterion = nn.CrossEntropyLoss()
|
| 458 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 459 |
+
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_scheduler_steps
|
| 460 |
+
)
|
| 461 |
+
scaler = GradScaler(enabled=use_amp_training and device.type == 'cuda')
|
| 462 |
+
print(f"Optimizer: AdamW, Base LR: {learning_rate}")
|
| 463 |
+
print(f"Scheduler: Cosine with {warmup_steps} warmup steps up to {total_scheduler_steps} steps")
|
| 464 |
+
print(f"Loss Function: CrossEntropyLoss")
|
| 465 |
+
|
| 466 |
+
# --- Load Checkpoint ---
|
| 467 |
+
# Pass optimizer, scaler, scheduler to be loaded
|
| 468 |
+
start_step = load_checkpoint(checkpoint_dir, train_model, optimizer, scaler, scheduler, device)
|
| 469 |
+
|
| 470 |
+
# Calculate steps to run in this session
|
| 471 |
+
steps_to_run_this_session = max(0, max_steps_for_run - start_step)
|
| 472 |
+
# The absolute step number to stop at in this run
|
| 473 |
+
current_run_target_step = start_step + steps_to_run_this_session
|
| 474 |
+
|
| 475 |
+
if steps_to_run_this_session <= 0:
|
| 476 |
+
print(f"Already completed or exceeded target steps ({max_steps_for_run}). Exiting.")
|
| 477 |
+
exit()
|
| 478 |
+
|
| 479 |
+
# --- Run Training ---
|
| 480 |
+
print(f"\n--- Running Training (Will run for {steps_to_run_this_session} steps in this session, target total: {max_steps_for_run}) ---")
|
| 481 |
+
start_time = time.time()
|
| 482 |
+
train_loss_hist, val_loss_hist, val_steps_hist = [], [], []
|
| 483 |
+
try:
|
| 484 |
+
# Pass the absolute target step for this run
|
| 485 |
+
train_loss_hist, val_loss_hist, val_steps_hist = simple_train(
|
| 486 |
+
model=train_model, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset,
|
| 487 |
+
optimizer=optimizer, criterion=criterion, scheduler=scheduler,
|
| 488 |
+
num_epochs=num_epochs, device=device, gradient_accumulation_steps=grad_accum_steps,
|
| 489 |
+
use_amp=use_amp_training, max_train_steps=current_run_target_step, start_step=start_step,
|
| 490 |
+
save_interval=save_interval, checkpoint_dir=checkpoint_dir, # Pass "."
|
| 491 |
+
validation_interval=validation_interval, val_steps=val_steps
|
| 492 |
+
)
|
| 493 |
+
print("\nTraining loop finished.")
|
| 494 |
+
end_time = time.time(); print(f"Training duration for this session: {datetime.timedelta(seconds=int(end_time - start_time))}")
|
| 495 |
+
|
| 496 |
+
# --- Plotting ---
|
| 497 |
+
if train_loss_hist:
|
| 498 |
+
# Adjust steps for plotting if resuming
|
| 499 |
+
plot_train_steps = list(range(start_step + 1, start_step + len(train_loss_hist) + 1))
|
| 500 |
+
# Filter validation steps/losses that occurred *during this run*
|
| 501 |
+
plot_val_steps = [s for s in val_steps_history if s >= start_step]
|
| 502 |
+
plot_val_loss = [val_loss_history[i] for i, s in enumerate(val_steps_history) if s >= start_step]
|
| 503 |
+
plot_loss(train_loss_hist, plot_val_loss, plot_val_steps, checkpoint_dir, start_step=start_step) # Pass "."
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
# --- Generation After Training ---
|
| 507 |
+
print("\n--- Generation After Training ---")
|
| 508 |
+
train_model.eval()
|
| 509 |
+
if device.type == 'cuda':
|
| 510 |
+
try: train_model = train_model.half(); print("Trained model converted to float16 for generation.")
|
| 511 |
+
except Exception as e: print(f"Could not convert trained model to float16: {e}.")
|
| 512 |
+
test_prompt_after = "The meaning of life is"
|
| 513 |
+
_ = generate(model=train_model, tokenizer=tokenizer, prompt=test_prompt_after, max_new_tokens=60, temperature=0.7, top_k=50, top_p=0.9)
|
| 514 |
+
print("\n(Check if output shows more structure than random)")
|
| 515 |
+
|
| 516 |
+
except torch.cuda.OutOfMemoryError: print("\n--- CUDA Out of Memory during Training ---"); print("Try reducing train_seq_len or gradient_accumulation_steps further.")
|
| 517 |
+
except Exception as e: print(f"\nAn error occurred during training: {e}"); traceback.print_exc()
|
| 518 |
+
|
| 519 |
+
print("\n--- Script Finished ---")
|
step_600.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:530888d8725a62ab23389bea6a1b7d2a116c3f3c1be2b594f5480ab481dbff04
|
| 3 |
+
size 2199232318
|
step_800.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2c4b611bfda30d3e87bcc48c1676d4b959da17dd08d00d55f511f1a2b7dd498
|
| 3 |
+
size 2199232318
|
wikitext2_tokens_128k.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cecec1d0a73cb0418069e202c7890691c82fd8eef1bc7fd8165345cfc8be3e1b
|
| 3 |
+
size 7303192
|
wikitext2_val_tokens_128k.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:015fabc451eda93a3b3c644a6df177d29234478916d030bba0b86c802e3f0efa
|
| 3 |
+
size 761320
|