Thewhey-Brian commited on
Commit
bd710e9
·
1 Parent(s): d2d9bb7

Deploy nanochat

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +41 -0
  2. Dockerfile +39 -0
  3. app.py +52 -0
  4. nanochat/__init__.py +0 -0
  5. nanochat/__pycache__/__init__.cpython-310.pyc +0 -0
  6. nanochat/__pycache__/__init__.cpython-312.pyc +0 -0
  7. nanochat/__pycache__/adamw.cpython-310.pyc +0 -0
  8. nanochat/__pycache__/checkpoint_manager.cpython-310.pyc +0 -0
  9. nanochat/__pycache__/common.cpython-310.pyc +0 -0
  10. nanochat/__pycache__/core_eval.cpython-310.pyc +0 -0
  11. nanochat/__pycache__/dataloader.cpython-310.pyc +0 -0
  12. nanochat/__pycache__/dataset.cpython-310.pyc +0 -0
  13. nanochat/__pycache__/dataset.cpython-312.pyc +0 -0
  14. nanochat/__pycache__/engine.cpython-310.pyc +0 -0
  15. nanochat/__pycache__/gpt.cpython-310.pyc +0 -0
  16. nanochat/__pycache__/loss_eval.cpython-310.pyc +0 -0
  17. nanochat/__pycache__/muon.cpython-310.pyc +0 -0
  18. nanochat/__pycache__/report.cpython-310.pyc +0 -0
  19. nanochat/__pycache__/tokenizer.cpython-310.pyc +0 -0
  20. nanochat/adamw.py +77 -0
  21. nanochat/checkpoint_manager.py +146 -0
  22. nanochat/common.py +144 -0
  23. nanochat/configurator.py +56 -0
  24. nanochat/core_eval.py +262 -0
  25. nanochat/dataloader.py +49 -0
  26. nanochat/dataset.py +128 -0
  27. nanochat/engine.py +343 -0
  28. nanochat/execution.py +350 -0
  29. nanochat/gpt.py +322 -0
  30. nanochat/logo.svg +8 -0
  31. nanochat/loss_eval.py +63 -0
  32. nanochat/muon.py +187 -0
  33. nanochat/report.py +404 -0
  34. nanochat/tokenizer.py +395 -0
  35. nanochat/ui.html +561 -0
  36. pyproject.toml +55 -0
  37. requirements.txt +11 -0
  38. rustbpe/Cargo.lock +458 -0
  39. rustbpe/Cargo.toml +15 -0
  40. rustbpe/README.md +5 -0
  41. rustbpe/src/lib.rs +476 -0
  42. scripts/__pycache__/base_eval.cpython-310.pyc +0 -0
  43. scripts/__pycache__/base_train.cpython-310.pyc +0 -0
  44. scripts/__pycache__/chat_web.cpython-310.pyc +0 -0
  45. scripts/__pycache__/tok_eval.cpython-310.pyc +0 -0
  46. scripts/__pycache__/tok_train.cpython-310.pyc +0 -0
  47. scripts/base_eval.py +180 -0
  48. scripts/base_loss.py +78 -0
  49. scripts/base_train.py +339 -0
  50. scripts/chat_cli.py +99 -0
.dockerignore ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ .venv
8
+ venv/
9
+ ENV/
10
+
11
+ # Git
12
+ .git
13
+ .gitignore
14
+
15
+ # IDE
16
+ .vscode/
17
+ .idea/
18
+ *.swp
19
+ *.swo
20
+
21
+ # Data directories (will be mounted or downloaded separately)
22
+ base_checkpoints/
23
+ mid_checkpoints/
24
+ chatsft_checkpoints/
25
+ chatrl_checkpoints/
26
+ base_data/
27
+ base_eval/
28
+ eval_bundle/
29
+ report/
30
+
31
+ # Build artifacts
32
+ *.egg-info/
33
+ dist/
34
+ build/
35
+ target/
36
+
37
+ # Logs
38
+ *.log
39
+
40
+ # UV lock file (not needed in container)
41
+ uv.lock
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for Hugging Face Spaces deployment
2
+ FROM python:3.10-slim
3
+
4
+ # Install system dependencies including Rust (needed for rustbpe)
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ && curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
9
+ && . $HOME/.cargo/env \
10
+ && apt-get clean \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Add Rust to PATH
14
+ ENV PATH="/root/.cargo/bin:${PATH}"
15
+
16
+ # Set working directory
17
+ WORKDIR /app
18
+
19
+ # Copy project files
20
+ COPY . /app
21
+
22
+ # Install Python dependencies
23
+ RUN pip install --no-cache-dir -r requirements.txt
24
+
25
+ # Build the Rust component
26
+ RUN . $HOME/.cargo/env && maturin develop --release
27
+
28
+ # Create data directory for model checkpoints
29
+ RUN mkdir -p /data
30
+
31
+ # Set environment variables
32
+ ENV NANOCHAT_BASE_DIR=/data
33
+ ENV PYTHONUNBUFFERED=1
34
+
35
+ # Expose port 7860 (HF Spaces default)
36
+ EXPOSE 7860
37
+
38
+ # Run the application
39
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Hugging Face Spaces entry point for NanoChat.
4
+ This file is automatically detected and run by HF Spaces.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+
10
+ # Set environment variables for HF Spaces
11
+ os.environ.setdefault("NANOCHAT_BASE_DIR", "/data")
12
+
13
+ # Download model from HF if not present
14
+ def download_model():
15
+ """Download model weights from Hugging Face."""
16
+ checkpoint_dir = "/data/chatsft_checkpoints"
17
+
18
+ if os.path.exists(checkpoint_dir) and os.listdir(checkpoint_dir):
19
+ print(f"Model checkpoints found, skipping download")
20
+ return
21
+
22
+ print("Downloading model from BrianGuo/nanochat-d20-chat...")
23
+ from huggingface_hub import snapshot_download
24
+
25
+ snapshot_download(
26
+ repo_id="BrianGuo/nanochat-d20-chat",
27
+ local_dir="/data/chatsft_checkpoints"
28
+ )
29
+ print("Model downloaded successfully!")
30
+
31
+ if __name__ == "__main__":
32
+ # Download model before starting
33
+ download_model()
34
+
35
+ # Override sys.argv to pass default arguments for HF Spaces
36
+ sys.argv = [
37
+ "app.py",
38
+ "--port", "7860", # HF Spaces default port
39
+ "--host", "0.0.0.0",
40
+ "--source", "sft",
41
+ "--model-tag", os.environ.get("MODEL_TAG", "d20"),
42
+ "--step", os.environ.get("MODEL_STEP", "650"),
43
+ ]
44
+
45
+ # Import and run the web server
46
+ from scripts.chat_web import app
47
+ import uvicorn
48
+
49
+ print("Starting NanoChat on Hugging Face Spaces...")
50
+ print(f"Model: {os.environ.get('MODEL_TAG', 'd20')} - Step: {os.environ.get('MODEL_STEP', '650')}")
51
+
52
+ uvicorn.run(app, host="0.0.0.0", port=7860)
nanochat/__init__.py ADDED
File without changes
nanochat/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
nanochat/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (156 Bytes). View file
 
nanochat/__pycache__/adamw.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
nanochat/__pycache__/checkpoint_manager.cpython-310.pyc ADDED
Binary file (5.19 kB). View file
 
nanochat/__pycache__/common.cpython-310.pyc ADDED
Binary file (5.76 kB). View file
 
nanochat/__pycache__/core_eval.cpython-310.pyc ADDED
Binary file (8.47 kB). View file
 
nanochat/__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (1.78 kB). View file
 
nanochat/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (4.58 kB). View file
 
nanochat/__pycache__/dataset.cpython-312.pyc ADDED
Binary file (6.87 kB). View file
 
nanochat/__pycache__/engine.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
nanochat/__pycache__/gpt.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
nanochat/__pycache__/loss_eval.cpython-310.pyc ADDED
Binary file (2.29 kB). View file
 
nanochat/__pycache__/muon.cpython-310.pyc ADDED
Binary file (8.32 kB). View file
 
nanochat/__pycache__/report.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
nanochat/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
nanochat/adamw.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
3
+ Not a general optimizer! But works for our specific use.
4
+ """
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torch import Tensor
8
+
9
+
10
+ class DistAdamW(torch.optim.Optimizer):
11
+ """
12
+ Distributed AdamW optimizer.
13
+ In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
14
+ """
15
+ def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
16
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
17
+ super().__init__(param_groups, defaults)
18
+
19
+ @torch.compile
20
+ @torch.no_grad()
21
+ def step(self):
22
+ rank = dist.get_rank()
23
+ world_size = dist.get_world_size()
24
+ reduce_scatter_futures: list[torch.Future] = []
25
+ all_reduce_futures: list[torch.Future] = []
26
+ grad_slices = []
27
+ for group in self.param_groups:
28
+ params: list[Tensor] = group["params"]
29
+ grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly
30
+ for base_i in range(len(params)):
31
+ grad = params[base_i].grad
32
+ rank_size = grad.shape[0] // world_size
33
+ grad_slice = torch.empty_like(grad[:rank_size])
34
+ reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
35
+ grad_slices.append(grad_slice)
36
+
37
+ idx = 0
38
+ for group in self.param_groups:
39
+ beta1, beta2 = group['betas']
40
+ eps = group['eps']
41
+ wd = group['weight_decay']
42
+ params = group['params']
43
+ for base in range(len(params)):
44
+ reduce_scatter_futures[idx].wait()
45
+ p = params[base]
46
+ rank_size = p.shape[0] // world_size
47
+ p_slice = p[rank * rank_size:(rank + 1) * rank_size]
48
+ lr = group['lr'] * getattr(p, "lr_mul", 1.0)
49
+ state = self.state[p]
50
+ g_slice = grad_slices[idx]
51
+ # State init
52
+ if not state:
53
+ state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
54
+ state['exp_avg'] = torch.zeros_like(p_slice)
55
+ state['exp_avg_sq'] = torch.zeros_like(p_slice)
56
+ exp_avg = state['exp_avg']
57
+ exp_avg_sq = state['exp_avg_sq']
58
+ state['step'] += 1
59
+ t = state['step']
60
+ # weight decay
61
+ if wd != 0:
62
+ eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
63
+ p_slice.mul_(1 - eff_weight_decay)
64
+ # update running averages
65
+ exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
66
+ exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
67
+ # bias corrections
68
+ bias1 = 1 - beta1 ** t
69
+ bias2 = 1 - beta2 ** t
70
+ # compute step
71
+ denom = exp_avg_sq.sqrt().add_(eps)
72
+ step_size = lr * (torch.sqrt(bias2) / bias1)
73
+ update = exp_avg.div(denom).mul_(step_size)
74
+ p_slice.add_(other=update, alpha=-1.0)
75
+ idx += 1
76
+ all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
77
+ torch.futures.collect_all(all_reduce_futures).wait()
nanochat/checkpoint_manager.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for saving and loading model/optim/state checkpoints.
3
+ """
4
+ import os
5
+ import re
6
+ import glob
7
+ import json
8
+ import logging
9
+ import torch
10
+
11
+ from nanochat.common import get_base_dir
12
+ from nanochat.gpt import GPT, GPTConfig
13
+ from nanochat.tokenizer import get_tokenizer
14
+ from nanochat.common import setup_default_logging
15
+
16
+ # Set up logging
17
+ setup_default_logging()
18
+ logger = logging.getLogger(__name__)
19
+ def log0(message):
20
+ if int(os.environ.get('RANK', 0)) == 0:
21
+ logger.info(message)
22
+
23
+ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
24
+ assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
25
+ os.makedirs(checkpoint_dir, exist_ok=True)
26
+ # Save the model state (parameters)
27
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
28
+ torch.save(model_data, model_path)
29
+ log0(f"Saved model file to: {model_path}")
30
+ # Save the optimizer state (useful for SFT or any other fine-tuning)
31
+ if optimizer_data is not None:
32
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
33
+ torch.save(optimizer_data, optimizer_path)
34
+ log0(f"Saved optimizer file to: {optimizer_path}")
35
+ # Save the metadata dict as json
36
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
37
+ with open(meta_path, "w") as f:
38
+ json.dump(meta_data, f, indent=2)
39
+ log0(f"Saved metadata file to: {meta_path}")
40
+
41
+
42
+ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
43
+ # Load the model state
44
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
45
+ model_data = torch.load(model_path, map_location=device)
46
+ # Load the optimizer state if requested
47
+ optimizer_data = None
48
+ if load_optimizer:
49
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
50
+ optimizer_data = torch.load(optimizer_path, map_location=device)
51
+ # Load the metadata
52
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
53
+ with open(meta_path, "r") as f:
54
+ meta_data = json.load(f)
55
+ return model_data, optimizer_data, meta_data
56
+
57
+
58
+ def build_model(checkpoint_dir, step, device, phase):
59
+ """
60
+ A bunch of repetitive code to build a model from a given checkpoint.
61
+ Returns:
62
+ - base model - uncompiled, not wrapped in DDP
63
+ - tokenizer
64
+ - meta data saved during base model training
65
+ """
66
+ assert phase in ["train", "eval"], f"Invalid phase: {phase}"
67
+ model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
68
+ # Hack: fix torch compile issue, which prepends all keys with _orig_mod.
69
+ model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
70
+ model_config_kwargs = meta_data["model_config"]
71
+ log0(f"Building model with config: {model_config_kwargs}")
72
+ model_config = GPTConfig(**model_config_kwargs)
73
+ with torch.device("meta"):
74
+ model = GPT(model_config)
75
+ # Load the model state
76
+ model.to_empty(device=device)
77
+ model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
78
+ model.load_state_dict(model_data, strict=True, assign=True)
79
+ # Put the model in the right training phase / mode
80
+ if phase == "eval":
81
+ model.eval()
82
+ else:
83
+ model.train()
84
+ # Load the Tokenizer
85
+ tokenizer = get_tokenizer()
86
+ # Sanity check: compatibility between model and tokenizer
87
+ assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
88
+ return model, tokenizer, meta_data
89
+
90
+
91
+ def find_largest_model(checkpoint_dir):
92
+ # attempt to guess the model tag: take the biggest model available
93
+ model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
94
+ if not model_tags:
95
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
96
+ # 1) normally all model tags are of the form d<number>, try that first:
97
+ candidates = []
98
+ for model_tag in model_tags:
99
+ match = re.match(r"d(\d+)", model_tag)
100
+ if match:
101
+ model_depth = int(match.group(1))
102
+ candidates.append((model_depth, model_tag))
103
+ if candidates:
104
+ candidates.sort(key=lambda x: x[0], reverse=True)
105
+ return candidates[0][1]
106
+ # 2) if that failed, take the most recently updated model:
107
+ model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
108
+ return model_tags[0]
109
+
110
+
111
+ def find_last_step(checkpoint_dir):
112
+ # Look into checkpoint_dir and find model_<step>.pt with the highest step
113
+ checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
114
+ if not checkpoint_files:
115
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
116
+ last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
117
+ return last_step
118
+
119
+ # -----------------------------------------------------------------------------
120
+ # convenience functions that take into account nanochat's directory structure
121
+
122
+ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
123
+ if model_tag is None:
124
+ # guess the model tag by defaulting to the largest model
125
+ model_tag = find_largest_model(checkpoints_dir)
126
+ log0(f"No model tag provided, guessing model tag: {model_tag}")
127
+ checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
128
+ if step is None:
129
+ # guess the step by defaulting to the last step
130
+ step = find_last_step(checkpoint_dir)
131
+ assert step is not None, f"No checkpoints found in {checkpoint_dir}"
132
+ # build the model
133
+ log0(f"Loading model from {checkpoint_dir} with step {step}")
134
+ model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
135
+ return model, tokenizer, meta_data
136
+
137
+ def load_model(source, *args, **kwargs):
138
+ model_dir = {
139
+ "base": "base_checkpoints",
140
+ "mid": "mid_checkpoints",
141
+ "sft": "chatsft_checkpoints",
142
+ "rl": "chatrl_checkpoints",
143
+ }[source]
144
+ base_dir = get_base_dir()
145
+ checkpoints_dir = os.path.join(base_dir, model_dir)
146
+ return load_model_from_dir(checkpoints_dir, *args, **kwargs)
nanochat/common.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities for nanochat.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import logging
8
+ import torch
9
+ import torch.distributed as dist
10
+
11
+ class ColoredFormatter(logging.Formatter):
12
+ """Custom formatter that adds colors to log messages."""
13
+ # ANSI color codes
14
+ COLORS = {
15
+ 'DEBUG': '\033[36m', # Cyan
16
+ 'INFO': '\033[32m', # Green
17
+ 'WARNING': '\033[33m', # Yellow
18
+ 'ERROR': '\033[31m', # Red
19
+ 'CRITICAL': '\033[35m', # Magenta
20
+ }
21
+ RESET = '\033[0m'
22
+ BOLD = '\033[1m'
23
+ def format(self, record):
24
+ # Add color to the level name
25
+ levelname = record.levelname
26
+ if levelname in self.COLORS:
27
+ record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
28
+ # Format the message
29
+ message = super().format(record)
30
+ # Add color to specific parts of the message
31
+ if levelname == 'INFO':
32
+ # Highlight numbers and percentages
33
+ message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
34
+ message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
35
+ return message
36
+
37
+ def setup_default_logging():
38
+ handler = logging.StreamHandler()
39
+ handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
40
+ logging.basicConfig(
41
+ level=logging.INFO,
42
+ handlers=[handler]
43
+ )
44
+
45
+ setup_default_logging()
46
+ logger = logging.getLogger(__name__)
47
+
48
+ def get_base_dir():
49
+ # co-locate nanochat intermediates with other cached data in ~/.cache (by default)
50
+ if os.environ.get("NANOCHAT_BASE_DIR"):
51
+ nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
52
+ else:
53
+ home_dir = os.path.expanduser("~")
54
+ cache_dir = os.path.join(home_dir, ".cache")
55
+ nanochat_dir = os.path.join(cache_dir, "nanochat")
56
+ os.makedirs(nanochat_dir, exist_ok=True)
57
+ return nanochat_dir
58
+
59
+ def print0(s="",**kwargs):
60
+ ddp_rank = int(os.environ.get('RANK', 0))
61
+ if ddp_rank == 0:
62
+ print(s, **kwargs)
63
+
64
+ def print_banner():
65
+ # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
66
+ banner = """
67
+ █████ █████
68
+ ░░███ ░░███
69
+ ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
70
+ ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███ ░░░███░
71
+ ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
72
+ ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
73
+ ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░████████ ░░█████
74
+ ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
75
+ """
76
+ print0(banner)
77
+
78
+ def is_ddp():
79
+ # TODO is there a proper way
80
+ return int(os.environ.get('RANK', -1)) != -1
81
+
82
+ def get_dist_info():
83
+ if is_ddp():
84
+ assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
85
+ ddp_rank = int(os.environ['RANK'])
86
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
87
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
88
+ return True, ddp_rank, ddp_local_rank, ddp_world_size
89
+ else:
90
+ return False, 0, 0, 1
91
+
92
+ def compute_init():
93
+ """Basic initialization that we keep doing over and over, so make common."""
94
+
95
+ # Check if CUDA is available, fallback to CPU
96
+ use_cuda = torch.cuda.is_available()
97
+
98
+ if use_cuda:
99
+ logger.info("CUDA available - using GPU")
100
+ else:
101
+ logger.info("CUDA not available - using CPU (inference will be slower)")
102
+
103
+ # Reproducibility
104
+ torch.manual_seed(42)
105
+ if use_cuda:
106
+ torch.cuda.manual_seed(42)
107
+ # skipping full reproducibility for now, possibly investigate slowdown later
108
+ # torch.use_deterministic_algorithms(True)
109
+ # torch.backends.cudnn.deterministic = True
110
+ # torch.backends.cudnn.benchmark = False
111
+
112
+ # Precision
113
+ if use_cuda:
114
+ torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
115
+
116
+ # Distributed setup: Distributed Data Parallel (DDP), optional
117
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
118
+ if ddp:
119
+ assert use_cuda, "Distributed training requires CUDA"
120
+ device = torch.device("cuda", ddp_local_rank)
121
+ torch.cuda.set_device(device) # make "cuda" default to this device
122
+ dist.init_process_group(backend="nccl", device_id=device)
123
+ dist.barrier()
124
+ else:
125
+ device = torch.device("cuda" if use_cuda else "cpu")
126
+
127
+ if ddp_rank == 0:
128
+ logger.info(f"Distributed world size: {ddp_world_size}")
129
+
130
+ return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
131
+
132
+ def compute_cleanup():
133
+ """Companion function to compute_init, to clean things up before script exit"""
134
+ if is_ddp():
135
+ dist.destroy_process_group()
136
+
137
+ class DummyWandb:
138
+ """Useful if we wish to not use wandb but have all the same signatures"""
139
+ def __init__(self):
140
+ pass
141
+ def log(self, *args, **kwargs):
142
+ pass
143
+ def finish(self):
144
+ pass
nanochat/configurator.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Poor Man's Configurator. Probably a terrible idea. Example usage:
3
+ $ python train.py config/override_file.py --batch_size=32
4
+ this will first run config/override_file.py, then override batch_size to 32
5
+
6
+ The code in this file will be run as follows from e.g. train.py:
7
+ >>> exec(open('configurator.py').read())
8
+
9
+ So it's not a Python module, it's just shuttling this code away from train.py
10
+ The code in this script then overrides the globals()
11
+
12
+ I know people are not going to love this, I just really dislike configuration
13
+ complexity and having to prepend config. to every single variable. If someone
14
+ comes up with a better simple Python solution I am all ears.
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ from ast import literal_eval
20
+
21
+ def print0(s="",**kwargs):
22
+ ddp_rank = int(os.environ.get('RANK', 0))
23
+ if ddp_rank == 0:
24
+ print(s, **kwargs)
25
+
26
+ for arg in sys.argv[1:]:
27
+ if '=' not in arg:
28
+ # assume it's the name of a config file
29
+ assert not arg.startswith('--')
30
+ config_file = arg
31
+ print0(f"Overriding config with {config_file}:")
32
+ with open(config_file) as f:
33
+ print0(f.read())
34
+ exec(open(config_file).read())
35
+ else:
36
+ # assume it's a --key=value argument
37
+ assert arg.startswith('--')
38
+ key, val = arg.split('=')
39
+ key = key[2:]
40
+ if key in globals():
41
+ try:
42
+ # attempt to eval it it (e.g. if bool, number, or etc)
43
+ attempt = literal_eval(val)
44
+ except (SyntaxError, ValueError):
45
+ # if that goes wrong, just use the string
46
+ attempt = val
47
+ # ensure the types match ok
48
+ if globals()[key] is not None:
49
+ attempt_type = type(attempt)
50
+ default_type = type(globals()[key])
51
+ assert attempt_type == default_type, f"Type mismatch: {attempt_type} != {default_type}"
52
+ # cross fingers
53
+ print0(f"Overriding: {key} = {attempt}")
54
+ globals()[key] = attempt
55
+ else:
56
+ raise ValueError(f"Unknown config key: {key}")
nanochat/core_eval.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions for evaluating the CORE metric, as described in the DCLM paper.
3
+ https://arxiv.org/abs/2406.11794
4
+
5
+ TODOs:
6
+ - All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
7
+ """
8
+ import random
9
+
10
+ from jinja2 import Template
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ # -----------------------------------------------------------------------------
15
+ # Prompt rendering utilities
16
+
17
+ def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
18
+ """Render complete prompts for a multiple choice question"""
19
+ template_str = """
20
+ {%- for example in fewshot_examples -%}
21
+ {{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
22
+
23
+ {% endfor -%}
24
+ {{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
25
+ template = Template(template_str)
26
+ fewshot_examples = fewshot_examples or []
27
+ context = {
28
+ 'fewshot_examples': fewshot_examples,
29
+ 'continuation_delimiter': continuation_delimiter,
30
+ 'item': item
31
+ }
32
+ prompts = [template.render(choice=choice, **context) for choice in item['choices']]
33
+ return prompts
34
+
35
+
36
+ def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
37
+ """Render complete prompts for a schema question"""
38
+ template_str = """
39
+ {%- for example in fewshot_examples -%}
40
+ {{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
41
+
42
+ {% endfor -%}
43
+ {{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
44
+ template = Template(template_str)
45
+ fewshot_examples = fewshot_examples or []
46
+ context = {
47
+ 'fewshot_examples': fewshot_examples,
48
+ 'continuation_delimiter': continuation_delimiter,
49
+ 'item': item
50
+ }
51
+ prompts = [template.render(context=context_option, **context)
52
+ for context_option in item['context_options']]
53
+ return prompts
54
+
55
+
56
+ def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
57
+ """
58
+ Render complete prompt for a language modeling task.
59
+ Notice that we manually trim the context in the template,
60
+ which in some datasets seems to have trailing whitespace (which we don't want).
61
+ """
62
+ template_str = """
63
+ {%- for example in fewshot_examples -%}
64
+ {{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
65
+
66
+ {% endfor -%}
67
+ {{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
68
+ template = Template(template_str)
69
+ fewshot_examples = fewshot_examples or []
70
+ context = {
71
+ 'fewshot_examples': fewshot_examples,
72
+ 'continuation_delimiter': continuation_delimiter,
73
+ 'item': item
74
+ }
75
+ # Return two prompts: without and with the continuation
76
+ prompt_without = template.render(include_continuation=False, **context)
77
+ prompt_with = template.render(include_continuation=True, **context)
78
+ # Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
79
+ # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
80
+ # token in prompt_with), meaning we don't get a nice and clean prefix in the token space
81
+ # to detect the final continuation. Tokenizers...
82
+ prompt_without = prompt_without.strip()
83
+ return [prompt_without, prompt_with]
84
+
85
+
86
+ def find_common_length(token_sequences, direction='left'):
87
+ """
88
+ Find the length of the common prefix or suffix across token sequences
89
+ - direction: 'left' for prefix, 'right' for suffix
90
+ """
91
+ min_len = min(len(seq) for seq in token_sequences)
92
+ indices = {
93
+ 'left': range(min_len),
94
+ 'right': range(-1, -min_len-1, -1)
95
+ }[direction]
96
+ # Find the first position where the token sequences differ
97
+ for i, idx in enumerate(indices):
98
+ token = token_sequences[0][idx]
99
+ if not all(seq[idx] == token for seq in token_sequences):
100
+ return i
101
+ return min_len
102
+
103
+
104
+ def stack_sequences(tokens, pad_token_id):
105
+ """Stack up a list of token sequences, pad to longest on the right"""
106
+ bsz, seq_len = len(tokens), max(len(x) for x in tokens)
107
+ input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
108
+ for i, x in enumerate(tokens):
109
+ input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
110
+ return input_ids
111
+
112
+
113
+ def batch_sequences_mc(tokenizer, prompts):
114
+ # In multiple choice, contexts are the same but the continuation is different (common prefix)
115
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
116
+ # figure out the start and end of each continuation
117
+ answer_start_idx = find_common_length(tokens, direction='left')
118
+ start_indices = [answer_start_idx] * len(prompts)
119
+ end_indices = [len(x) for x in tokens]
120
+ return tokens, start_indices, end_indices
121
+
122
+
123
+ def batch_sequences_schema(tokenizer, prompts):
124
+ # In schema tasks, contexts vary but continuation is the same (common suffix)
125
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
126
+ # figure out the start and end of each context
127
+ suffix_length = find_common_length(tokens, direction='right')
128
+ end_indices = [len(x) for x in tokens]
129
+ start_indices = [ei - suffix_length for ei in end_indices]
130
+ return tokens, start_indices, end_indices
131
+
132
+
133
+ def batch_sequences_lm(tokenizer, prompts):
134
+ # In LM tasks, we have two prompts: without and with continuation
135
+ tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
136
+ tokens_without, tokens_with = tokens
137
+ start_idx, end_idx = len(tokens_without), len(tokens_with)
138
+ assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
139
+ assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
140
+ # we only need the with continuation prompt in the LM task, i.e. batch size of 1
141
+ return [tokens_with], [start_idx], [end_idx]
142
+
143
+
144
+ @torch.no_grad()
145
+ def forward_model(model, input_ids):
146
+ """
147
+ Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
148
+ The last column of losses is set to nan because we don't have autoregressive targets there.
149
+ """
150
+ batch_size, seq_len = input_ids.size()
151
+ outputs = model(input_ids)
152
+ # Roll the tensor to the left by one position to get the (autoregressive) target ids
153
+ target_ids = torch.roll(input_ids, shifts=-1, dims=1)
154
+ # Calculate cross entropy at all positions
155
+ losses = torch.nn.functional.cross_entropy(
156
+ outputs.view(batch_size * seq_len, -1),
157
+ target_ids.view(batch_size * seq_len),
158
+ reduction='none'
159
+ ).view(batch_size, seq_len)
160
+ # Set the last column to be nan because there is no autoregressive loss there
161
+ losses[:, -1] = float('nan')
162
+ # Get the argmax predictions at each position
163
+ predictions = outputs.argmax(dim=-1)
164
+ return losses, predictions
165
+
166
+
167
+ @torch.no_grad()
168
+ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
169
+ """Evaluate a single example, return True if correct, False otherwise"""
170
+ item = data[idx]
171
+ task_type = task_meta['task_type']
172
+ num_fewshot = task_meta['num_fewshot']
173
+ continuation_delimiter = task_meta['continuation_delimiter']
174
+
175
+ # Sample few-shot examples (excluding current item)
176
+ fewshot_examples = []
177
+ if num_fewshot > 0:
178
+ rng = random.Random(1234 + idx)
179
+ available_indices = [i for i in range(len(data)) if i != idx]
180
+ fewshot_indices = rng.sample(available_indices, num_fewshot)
181
+ fewshot_examples = [data[i] for i in fewshot_indices]
182
+
183
+ # Render prompts and batch sequences based on task type
184
+ if task_type == 'multiple_choice':
185
+ prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
186
+ tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
187
+ elif task_type == 'schema':
188
+ prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
189
+ tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
190
+ elif task_type == 'language_modeling':
191
+ prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
192
+ tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
193
+ else:
194
+ raise ValueError(f"Unsupported task type: {task_type}")
195
+
196
+ # Some models can't forward sequences beyond a certain length (e.g. GPT-2)
197
+ # In these cases, we have to truncate sequences to max length and adjust the indices
198
+ if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
199
+ max_tokens = model.max_seq_len
200
+ new_tokens, new_start_idxs, new_end_idxs = [], [], []
201
+ for t, s, e in zip(tokens, start_idxs, end_idxs):
202
+ if len(t) > max_tokens:
203
+ num_to_crop = len(t) - max_tokens
204
+ new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
205
+ new_start_idxs.append(s - num_to_crop) # shift the indices down
206
+ new_end_idxs.append(e - num_to_crop)
207
+ assert s - num_to_crop >= 0, "this should never happen right?"
208
+ assert e - num_to_crop >= 0, "this should never happen right?"
209
+ else:
210
+ new_tokens.append(t) # keep unchanged
211
+ new_start_idxs.append(s)
212
+ new_end_idxs.append(e)
213
+ tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
214
+
215
+ # Stack up all the sequences into a batch
216
+ pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
217
+ input_ids = stack_sequences(tokens, pad_token_id)
218
+ input_ids = input_ids.to(device)
219
+
220
+ # Forward the model, get the autoregressive loss and argmax prediction at each token
221
+ losses, predictions = forward_model(model, input_ids)
222
+
223
+ # See if the losses/predictions come out correctly
224
+ if task_type == 'language_modeling':
225
+ # language modeling task is currently always batch size 1
226
+ si = start_idxs[0]
227
+ ei = end_idxs[0]
228
+ # predictions[i] predict input_ids[i+1] autoregressively
229
+ predicted_tokens = predictions[0, si-1:ei-1]
230
+ actual_tokens = input_ids[0, si:ei]
231
+ is_correct = torch.all(predicted_tokens == actual_tokens).item()
232
+ elif task_type in ['multiple_choice', 'schema']:
233
+ # For MC/schema: find the option with lowest average loss
234
+ mean_losses = [losses[i, si-1:ei-1].mean().item()
235
+ for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
236
+ pred_idx = mean_losses.index(min(mean_losses))
237
+ is_correct = pred_idx == item['gold']
238
+ else:
239
+ raise ValueError(f"Unsupported task type: {task_type}")
240
+
241
+ return is_correct
242
+
243
+
244
+ def evaluate_task(model, tokenizer, data, device, task_meta):
245
+ """
246
+ This function is responsible for evaluating one task across many examples.
247
+ It also handles dispatch to all processes if the script is run with torchrun.
248
+ """
249
+ rank = dist.get_rank() if dist.is_initialized() else 0
250
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
251
+ correct = torch.zeros(len(data), dtype=torch.float32, device=device)
252
+ # stride the examples to each rank
253
+ for idx in range(rank, len(data), world_size):
254
+ is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
255
+ correct[idx] = float(is_correct)
256
+ # sync results across all the processes if running distributed
257
+ if world_size > 1:
258
+ dist.barrier()
259
+ dist.all_reduce(correct, op=dist.ReduceOp.SUM)
260
+ # compute the mean
261
+ mean_correct = correct.mean().item()
262
+ return mean_correct
nanochat/dataloader.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from nanochat.common import get_dist_info
6
+ from nanochat.dataset import parquets_iter_batched
7
+ from nanochat.tokenizer import get_tokenizer
8
+
9
+ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
10
+ """Stream pretraining text from parquet files, tokenize, yield training batches."""
11
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
12
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
13
+ needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
14
+ # get the tokenizer and the bos token
15
+ tokenizer = get_tokenizer()
16
+ bos_token = tokenizer.get_bos_token_id()
17
+ # scratch buffer holds the tokens for one iteration
18
+ token_buffer = deque() # we stream tokens on the right and pop from the left
19
+ scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
20
+
21
+ # infinite iterator over document batches
22
+ def document_batches():
23
+ while True:
24
+ # batch will iterate in group size of the parquet files, usually e.g. 1024 rows
25
+ for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
26
+ # for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
27
+ for i in range(0, len(batch), tokenizer_batch_size):
28
+ yield batch[i:i+tokenizer_batch_size]
29
+ batches = document_batches()
30
+
31
+ batch_index = 0
32
+ while True:
33
+ # Accumulate enough tokens for one iteration before yielding.
34
+ while len(token_buffer) < needed_tokens:
35
+ doc_batch = next(batches)
36
+ token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
37
+ for tokens in token_lists:
38
+ token_buffer.extend(tokens)
39
+ batch_index += 1
40
+ # Move tokens from the deque into the scratch buffer
41
+ for i in range(needed_tokens):
42
+ scratch[i] = token_buffer.popleft()
43
+ # Create the inputs/targets as 1D tensors
44
+ inputs_cpu = scratch[:-1].to(dtype=torch.int32)
45
+ targets_cpu = scratch[1:]
46
+ # Reshape to 2D and move to GPU async
47
+ inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
48
+ targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
49
+ yield inputs, targets
nanochat/dataset.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The base/pretraining dataset is a set of parquet files.
3
+ This file contains utilities for:
4
+ - iterating over the parquet files and yielding documents from it
5
+ - download the files on demand if they are not on disk
6
+
7
+ For details of how the dataset was prepared, see `repackage_data_reference.py`.
8
+ """
9
+
10
+ import os
11
+ import argparse
12
+ import time
13
+ import requests
14
+ import pyarrow.parquet as pq
15
+ from multiprocessing import Pool
16
+
17
+ from nanochat.common import get_base_dir
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # The specifics of the current pretraining dataset
21
+
22
+ # The URL on the internet where the data is hosted and downloaded from on demand
23
+ BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
24
+ MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
25
+ index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames
26
+ base_dir = get_base_dir()
27
+ DATA_DIR = os.path.join(base_dir, "base_data")
28
+ os.makedirs(DATA_DIR, exist_ok=True)
29
+
30
+ # -----------------------------------------------------------------------------
31
+ # These functions are useful utilities to other modules, can/should be imported
32
+
33
+ def list_parquet_files(data_dir=None):
34
+ """ Looks into a data dir and returns full paths to all parquet files. """
35
+ data_dir = DATA_DIR if data_dir is None else data_dir
36
+ parquet_files = sorted([
37
+ f for f in os.listdir(data_dir)
38
+ if f.endswith('.parquet') and not f.endswith('.tmp')
39
+ ])
40
+ parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
41
+ return parquet_paths
42
+
43
+ def parquets_iter_batched(split, start=0, step=1):
44
+ """
45
+ Iterate through the dataset, in batches of underlying row_groups for efficiency.
46
+ - split can be "train" or "val". the last parquet file will be val.
47
+ - start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
48
+ """
49
+ assert split in ["train", "val"], "split must be 'train' or 'val'"
50
+ parquet_paths = list_parquet_files()
51
+ parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
52
+ for filepath in parquet_paths:
53
+ pf = pq.ParquetFile(filepath)
54
+ for rg_idx in range(start, pf.num_row_groups, step):
55
+ rg = pf.read_row_group(rg_idx)
56
+ texts = rg.column('text').to_pylist()
57
+ yield texts
58
+
59
+ # -----------------------------------------------------------------------------
60
+ def download_single_file(index):
61
+ """ Downloads a single file index, with some backoff """
62
+
63
+ # Construct the local filepath for this file and skip if it already exists
64
+ filename = index_to_filename(index)
65
+ filepath = os.path.join(DATA_DIR, filename)
66
+ if os.path.exists(filepath):
67
+ print(f"Skipping {filepath} (already exists)")
68
+ return True
69
+
70
+ # Construct the remote URL for this file
71
+ url = f"{BASE_URL}/{filename}"
72
+ print(f"Downloading {filename}...")
73
+
74
+ # Download with retries
75
+ max_attempts = 5
76
+ for attempt in range(1, max_attempts + 1):
77
+ try:
78
+ response = requests.get(url, stream=True, timeout=30)
79
+ response.raise_for_status()
80
+ # Write to temporary file first
81
+ temp_path = filepath + f".tmp"
82
+ with open(temp_path, 'wb') as f:
83
+ for chunk in response.iter_content(chunk_size=1024 * 1024): # 1MB chunks
84
+ if chunk:
85
+ f.write(chunk)
86
+ # Move temp file to final location
87
+ os.rename(temp_path, filepath)
88
+ print(f"Successfully downloaded {filename}")
89
+ return True
90
+
91
+ except (requests.RequestException, IOError) as e:
92
+ print(f"Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
93
+ # Clean up any partial files
94
+ for path in [filepath + f".tmp", filepath]:
95
+ if os.path.exists(path):
96
+ try:
97
+ os.remove(path)
98
+ except:
99
+ pass
100
+ # Try a few times with exponential backoff: 2^attempt seconds
101
+ if attempt < max_attempts:
102
+ wait_time = 2 ** attempt
103
+ print(f"Waiting {wait_time} seconds before retry...")
104
+ time.sleep(wait_time)
105
+ else:
106
+ print(f"Failed to download {filename} after {max_attempts} attempts")
107
+ return False
108
+
109
+ return False
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
114
+ parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
115
+ parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
116
+ args = parser.parse_args()
117
+
118
+ num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
119
+ ids_to_download = list(range(num))
120
+ print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
121
+ print(f"Target directory: {DATA_DIR}")
122
+ print()
123
+ with Pool(processes=args.num_workers) as pool:
124
+ results = pool.map(download_single_file, ids_to_download)
125
+
126
+ # Report results
127
+ successful = sum(1 for success in results if success)
128
+ print(f"Done! Downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")
nanochat/engine.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Engine for efficient inference of our models.
3
+
4
+ Everything works around token sequences:
5
+ - The user can send token sequences to the engine
6
+ - The engine returns the next token
7
+
8
+ Notes:
9
+ - The engine knows nothing about tokenization, it's purely token id sequences.
10
+
11
+ The whole thing is made as efficient as possible.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import signal
17
+ import warnings
18
+ from contextlib import contextmanager
19
+ from collections import deque
20
+ from nanochat.common import compute_init
21
+ from nanochat.checkpoint_manager import load_model
22
+
23
+ # -----------------------------------------------------------------------------
24
+ # Calculator tool helpers
25
+ @contextmanager
26
+ def timeout(duration, formula):
27
+ def timeout_handler(signum, frame):
28
+ raise Exception(f"'{formula}': timed out after {duration} seconds")
29
+
30
+ signal.signal(signal.SIGALRM, timeout_handler)
31
+ signal.alarm(duration)
32
+ yield
33
+ signal.alarm(0)
34
+
35
+ def eval_with_timeout(formula, max_time=3):
36
+ try:
37
+ with timeout(max_time, formula):
38
+ with warnings.catch_warnings():
39
+ warnings.simplefilter("ignore", SyntaxWarning)
40
+ return eval(formula)
41
+ except Exception as e:
42
+ signal.alarm(0)
43
+ # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
44
+ return None
45
+
46
+ def use_calculator(expr):
47
+ """Evaluate a math expression safely."""
48
+ expr = expr.replace(",", "")
49
+ if any([x not in "0123456789*+-/.() " for x in expr]): # for now disallow non-numeric chars
50
+ return None
51
+ if "**" in expr: # for now disallow power operator, could be very expensive
52
+ return None
53
+ return eval_with_timeout(expr)
54
+
55
+ # -----------------------------------------------------------------------------
56
+ class KVCache:
57
+ """
58
+ Works hand-in-hand with the GPT model to maintain the KV cache.
59
+ Note that the .pos advances automatically after the last layer of the Transformer inserts.
60
+ """
61
+
62
+ def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
63
+ # Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
64
+ self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
65
+ self.kv_cache = None
66
+ self.pos = 0 # current position in time in the cache
67
+
68
+ def reset(self):
69
+ self.pos = 0
70
+
71
+ def get_pos(self):
72
+ return self.pos
73
+
74
+ def prefill(self, other):
75
+ """
76
+ Prefill given another KV cache. Optionally expand along batch dim.
77
+ This is used when we do batch 1 prefill and then want to generate
78
+ multiple samples in parallel from there.
79
+ """
80
+ # 1) validate the shapes
81
+ assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
82
+ assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
83
+ for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
84
+ if ix in [0, 1, 3, 5]:
85
+ # num_layers, batch_size, num_heads, head_dim must match
86
+ assert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}"
87
+ elif ix == 2:
88
+ # batch_size can be expanded
89
+ assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
90
+ elif ix == 4:
91
+ # seq_len: self must be longer than other
92
+ assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
93
+ # 2) initialize the cache
94
+ dtype, device = other.kv_cache.dtype, other.kv_cache.device
95
+ self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
96
+ # 3) copy the data over
97
+ self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
98
+ # 4) update the pos
99
+ self.pos = other.pos
100
+
101
+ def insert_kv(self, layer_idx, k, v):
102
+ # Lazy initialize the cache here because we need to know the dtype/device
103
+ if self.kv_cache is None:
104
+ self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
105
+ # Insert new keys/values to the cache and return the full cache so far
106
+ B, H, T_add, D = k.size()
107
+ t0, t1 = self.pos, self.pos + T_add
108
+ # Dynamically grow the cache if needed
109
+ if t1 > self.kv_cache.size(4):
110
+ t_needed = t1 + 1024 # as much as we need plus buffer of 1024
111
+ t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
112
+ current_shape = list(self.kv_cache.shape)
113
+ current_shape[4] = t_needed
114
+ self.kv_cache.resize_(current_shape)
115
+ # Insert k, v into the cache
116
+ self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
117
+ self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
118
+ # Return the full cached keys/values up to current position (as a view)
119
+ key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
120
+ value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
121
+ # Increment pos after the last layer of the Transformer processes
122
+ if layer_idx == self.kv_cache.size(0) - 1:
123
+ self.pos = t1
124
+ return key_view, value_view
125
+
126
+
127
+ # -----------------------------------------------------------------------------
128
+ @torch.inference_mode()
129
+ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
130
+ """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
131
+ assert temperature >= 0.0, "temperature must be non-negative"
132
+ if temperature == 0.0:
133
+ return torch.argmax(logits, dim=-1, keepdim=True)
134
+ if top_k is not None:
135
+ k = min(top_k, logits.size(-1))
136
+ vals, idx = torch.topk(logits, k, dim=-1)
137
+ vals = vals / temperature
138
+ probs = F.softmax(vals, dim=-1)
139
+ choice = torch.multinomial(probs, num_samples=1, generator=rng)
140
+ return idx.gather(1, choice)
141
+ else:
142
+ logits = logits / temperature
143
+ probs = F.softmax(logits, dim=-1)
144
+ return torch.multinomial(probs, num_samples=1, generator=rng)
145
+
146
+ # -----------------------------------------------------------------------------
147
+
148
+ class RowState:
149
+ # Per-row state tracking during generation
150
+ def __init__(self, current_tokens=None):
151
+ self.current_tokens = current_tokens or [] # Current token sequence for this row
152
+ self.forced_tokens = deque() # Queue of tokens to force inject
153
+ self.in_python_block = False # Whether we are inside a python block
154
+ self.python_expr_tokens = [] # Tokens of the current python expression
155
+ self.completed = False # Whether this row has completed generation
156
+
157
+ class Engine:
158
+
159
+ def __init__(self, model, tokenizer):
160
+ self.model = model
161
+ self.tokenizer = tokenizer # needed for tool use
162
+
163
+ @torch.inference_mode()
164
+ def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
165
+ """Same as generate, but does single prefill and then clones the KV cache."""
166
+ assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
167
+ device = self.model.get_device()
168
+ rng = torch.Generator(device=device)
169
+ rng.manual_seed(seed)
170
+
171
+ # Get the special tokens we need to coordinate the tool use state machine
172
+ get_special = lambda s: self.tokenizer.encode_special(s)
173
+ python_start = get_special("<|python_start|>")
174
+ python_end = get_special("<|python_end|>")
175
+ output_start = get_special("<|output_start|>")
176
+ output_end = get_special("<|output_end|>")
177
+ assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
178
+ bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
179
+
180
+ # 1) Run a batch 1 prefill of the prompt tokens
181
+ m = self.model.config
182
+ kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
183
+ kv_cache_prefill = KVCache(
184
+ batch_size=1,
185
+ seq_len=len(tokens),
186
+ **kv_model_kwargs,
187
+ )
188
+ ids = torch.tensor([tokens], dtype=torch.long, device=device)
189
+ logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
190
+ logits = logits[:, -1, :]
191
+ next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
192
+ sampled_tokens = next_ids[:, 0].tolist()
193
+
194
+ # 2) Replicate the KV cache for each sample/row
195
+ kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
196
+ kv_cache_decode = KVCache(
197
+ batch_size=num_samples,
198
+ seq_len=kv_length_hint,
199
+ **kv_model_kwargs,
200
+ )
201
+ kv_cache_decode.prefill(kv_cache_prefill)
202
+ del kv_cache_prefill # no need to keep this memory around
203
+
204
+ # 3) Initialize states for each sample
205
+ row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
206
+
207
+ # 4) Main generation loop
208
+ num_generated = 0
209
+ first_iteration = True
210
+ while True:
211
+ # Stop condition: we've reached max tokens
212
+ if max_tokens is not None and num_generated >= max_tokens:
213
+ break
214
+ # Stop condition: all rows are completed
215
+ if all(state.completed for state in row_states):
216
+ break
217
+
218
+ # Get sampled tokens - either from prefill or from forward pass
219
+ if first_iteration:
220
+ # Use the tokens we already sampled from prefill
221
+ sampled_tokens = [sampled_tokens[0]] * num_samples # Broadcast first token to all rows
222
+ # TODO: we should sample a token for each row instead of broadcasting
223
+ first_iteration = False
224
+ else:
225
+ # Forward the model and get the next token for each row
226
+ logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
227
+ logits = logits[:, -1, :] # (B, vocab_size) at last time step
228
+ next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
229
+ sampled_tokens = next_ids[:, 0].tolist()
230
+
231
+ # Process each row: choose the next token, update state, optional tool use
232
+ token_column = [] # contains the next token id along each row
233
+ token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
234
+ for i, state in enumerate(row_states):
235
+ # Select the next token in this row
236
+ is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
237
+ token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
238
+ next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
239
+ token_column.append(next_token)
240
+ # Update the state of this row to include the next token
241
+ state.current_tokens.append(next_token)
242
+ # On <|assistant_end|> or <|bos|>, mark the row as completed
243
+ if next_token == assistant_end or next_token == bos:
244
+ state.completed = True
245
+ # Handle tool logic
246
+ if next_token == python_start:
247
+ state.in_python_block = True
248
+ state.python_expr_tokens = []
249
+ elif next_token == python_end and state.in_python_block:
250
+ state.in_python_block = False
251
+ if state.python_expr_tokens:
252
+ expr = self.tokenizer.decode(state.python_expr_tokens)
253
+ result = use_calculator(expr)
254
+ if result is not None:
255
+ result_tokens = self.tokenizer.encode(str(result))
256
+ state.forced_tokens.append(output_start)
257
+ state.forced_tokens.extend(result_tokens)
258
+ state.forced_tokens.append(output_end)
259
+ state.python_expr_tokens = []
260
+ elif state.in_python_block:
261
+ state.python_expr_tokens.append(next_token)
262
+
263
+ # Yield the token column
264
+ yield token_column, token_masks
265
+ num_generated += 1
266
+ # Prepare ids for next iteration
267
+ ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
268
+
269
+ def generate_batch(self, tokens, num_samples=1, **kwargs):
270
+ """
271
+ Non-streaming batch generation that just returns the final token sequences.
272
+ Returns a list of token sequences (list of lists of ints).
273
+ Terminal tokens (assistant_end, bos) are not included in the results.
274
+ """
275
+ assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
276
+ bos = self.tokenizer.get_bos_token_id()
277
+ results = [tokens.copy() for _ in range(num_samples)]
278
+ masks = [[0] * len(tokens) for _ in range(num_samples)]
279
+ completed = [False] * num_samples
280
+ for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
281
+ for i, (token, mask) in enumerate(zip(token_column, token_masks)):
282
+ if not completed[i]:
283
+ if token == assistant_end or token == bos:
284
+ completed[i] = True
285
+ else:
286
+ results[i].append(token)
287
+ masks[i].append(mask)
288
+ # Stop if all rows are completed
289
+ if all(completed):
290
+ break
291
+ return results, masks
292
+
293
+
294
+ if __name__ == "__main__":
295
+ """
296
+ Quick inline test to make sure that the naive/slow model.generate function
297
+ is equivalent to the faster Engine.generate function here.
298
+ """
299
+ import time
300
+ # init compute
301
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
302
+ # load the model and tokenizer
303
+ model, tokenizer, meta = load_model("base", device, phase="eval")
304
+ bos_token_id = tokenizer.get_bos_token_id()
305
+ # common hyperparameters
306
+ kwargs = dict(max_tokens=64, temperature=0.0)
307
+ # set the starting prompt
308
+ prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
309
+ # generate the reference sequence using the model.generate() function
310
+ generated_tokens = []
311
+ torch.cuda.synchronize()
312
+ t0 = time.time()
313
+ stream = model.generate(prompt_tokens, **kwargs)
314
+ for token in stream:
315
+ generated_tokens.append(token)
316
+ chunk = tokenizer.decode([token])
317
+ print(chunk, end="", flush=True)
318
+ print()
319
+ torch.cuda.synchronize()
320
+ t1 = time.time()
321
+ print(f"Reference time: {t1 - t0:.2f}s")
322
+ reference_ids = generated_tokens
323
+ # generate tokens with Engine
324
+ generated_tokens = []
325
+ engine = Engine(model, tokenizer)
326
+ stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
327
+ torch.cuda.synchronize()
328
+ t0 = time.time()
329
+ for token_column, token_masks in stream:
330
+ token = token_column[0] # only print out the first row
331
+ generated_tokens.append(token)
332
+ chunk = tokenizer.decode([token])
333
+ print(chunk, end="", flush=True)
334
+ print()
335
+ torch.cuda.synchronize()
336
+ t1 = time.time()
337
+ print(f"Engine time: {t1 - t0:.2f}s")
338
+ # compare the two sequences
339
+ for i in range(len(reference_ids)):
340
+ if reference_ids[i] != generated_tokens[i]:
341
+ print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
342
+ break
343
+ print(f"Match: {reference_ids == generated_tokens}")
nanochat/execution.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sandboxed execution utilities for running Python code that comes out of an LLM.
3
+ Adapted from OpenAI HumanEval code:
4
+ https://github.com/openai/human-eval/blob/master/human_eval/execution.py
5
+
6
+ What is covered:
7
+ - Each execution runs in its own process (can be killed if it hangs or crashes)
8
+ - Execution is limited by a timeout to stop infinite loops
9
+ - Memory limits are enforced by default (256MB)
10
+ - stdout and stderr are captured and returned
11
+ - Code runs in a temporary directory that is deleted afterwards
12
+ - Dangerous functions are disabled (examples: os.system, os.kill, shutil.rmtree, subprocess.Popen)
13
+
14
+ What is not covered:
15
+ - Not a true security sandbox
16
+ - Network access is not blocked (e.g. sockets could be opened)
17
+ - Python's dynamic features (e.g. ctypes) could bypass restrictions
18
+ - No kernel-level isolation (no seccomp, no containers, no virtualization)
19
+
20
+ Overall this sandbox is good for evaluation of generated code and protects against
21
+ accidental destructive behavior, but it is not safe against malicious adversarial code.
22
+ """
23
+
24
+ import contextlib
25
+ import faulthandler
26
+ import io
27
+ import multiprocessing
28
+ import os
29
+ import platform
30
+ import signal
31
+ import tempfile
32
+ from dataclasses import dataclass
33
+ from typing import Optional
34
+
35
+ # -----------------------------------------------------------------------------
36
+
37
+ @dataclass
38
+ class ExecutionResult:
39
+ """Result of executing Python code in a sandbox."""
40
+ success: bool
41
+ stdout: str
42
+ stderr: str
43
+ error: Optional[str] = None
44
+ timeout: bool = False
45
+ memory_exceeded: bool = False
46
+
47
+ def __repr__(self):
48
+ parts = []
49
+ parts.append(f"ExecutionResult(success={self.success}")
50
+ if self.timeout:
51
+ parts.append(", timeout=True")
52
+ if self.memory_exceeded:
53
+ parts.append(", memory_exceeded=True")
54
+ if self.error:
55
+ parts.append(f", error={self.error!r}")
56
+ if self.stdout:
57
+ parts.append(f", stdout={self.stdout!r}")
58
+ if self.stderr:
59
+ parts.append(f", stderr={self.stderr!r}")
60
+ parts.append(")")
61
+ return "".join(parts)
62
+
63
+
64
+ @contextlib.contextmanager
65
+ def time_limit(seconds: float):
66
+ def signal_handler(signum, frame):
67
+ raise TimeoutException("Timed out!")
68
+
69
+ signal.setitimer(signal.ITIMER_REAL, seconds)
70
+ signal.signal(signal.SIGALRM, signal_handler)
71
+ try:
72
+ yield
73
+ finally:
74
+ signal.setitimer(signal.ITIMER_REAL, 0)
75
+
76
+
77
+ @contextlib.contextmanager
78
+ def capture_io():
79
+ """Capture stdout and stderr, and disable stdin."""
80
+ stdout_capture = io.StringIO()
81
+ stderr_capture = io.StringIO()
82
+ stdin_block = WriteOnlyStringIO()
83
+ with contextlib.redirect_stdout(stdout_capture):
84
+ with contextlib.redirect_stderr(stderr_capture):
85
+ with redirect_stdin(stdin_block):
86
+ yield stdout_capture, stderr_capture
87
+
88
+
89
+ @contextlib.contextmanager
90
+ def create_tempdir():
91
+ with tempfile.TemporaryDirectory() as dirname:
92
+ with chdir(dirname):
93
+ yield dirname
94
+
95
+
96
+ class TimeoutException(Exception):
97
+ pass
98
+
99
+
100
+ class WriteOnlyStringIO(io.StringIO):
101
+ """StringIO that throws an exception when it's read from"""
102
+
103
+ def read(self, *args, **kwargs):
104
+ raise IOError
105
+
106
+ def readline(self, *args, **kwargs):
107
+ raise IOError
108
+
109
+ def readlines(self, *args, **kwargs):
110
+ raise IOError
111
+
112
+ def readable(self, *args, **kwargs):
113
+ """Returns True if the IO object can be read."""
114
+ return False
115
+
116
+
117
+ class redirect_stdin(contextlib._RedirectStream): # type: ignore
118
+ _stream = "stdin"
119
+
120
+
121
+ @contextlib.contextmanager
122
+ def chdir(root):
123
+ if root == ".":
124
+ yield
125
+ return
126
+ cwd = os.getcwd()
127
+ os.chdir(root)
128
+ try:
129
+ yield
130
+ except BaseException as exc:
131
+ raise exc
132
+ finally:
133
+ os.chdir(cwd)
134
+
135
+
136
+ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
137
+ """
138
+ This disables various destructive functions and prevents the generated code
139
+ from interfering with the test (e.g. fork bomb, killing other processes,
140
+ removing filesystem files, etc.)
141
+
142
+ WARNING
143
+ This function is NOT a security sandbox. Untrusted code, including, model-
144
+ generated code, should not be blindly executed outside of one. See the
145
+ Codex paper for more information about OpenAI's code sandbox, and proceed
146
+ with caution.
147
+ """
148
+
149
+ if maximum_memory_bytes is not None:
150
+ import resource
151
+
152
+ resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
153
+ resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
154
+ if not platform.uname().system == "Darwin":
155
+ resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
156
+
157
+ faulthandler.disable()
158
+
159
+ import builtins
160
+
161
+ builtins.exit = None
162
+ builtins.quit = None
163
+
164
+ import os
165
+
166
+ os.environ["OMP_NUM_THREADS"] = "1"
167
+
168
+ os.kill = None
169
+ os.system = None
170
+ os.putenv = None
171
+ os.remove = None
172
+ os.removedirs = None
173
+ os.rmdir = None
174
+ os.fchdir = None
175
+ os.setuid = None
176
+ os.fork = None
177
+ os.forkpty = None
178
+ os.killpg = None
179
+ os.rename = None
180
+ os.renames = None
181
+ os.truncate = None
182
+ os.replace = None
183
+ os.unlink = None
184
+ os.fchmod = None
185
+ os.fchown = None
186
+ os.chmod = None
187
+ os.chown = None
188
+ os.chroot = None
189
+ os.fchdir = None
190
+ os.lchflags = None
191
+ os.lchmod = None
192
+ os.lchown = None
193
+ os.getcwd = None
194
+ os.chdir = None
195
+
196
+ import shutil
197
+
198
+ shutil.rmtree = None
199
+ shutil.move = None
200
+ shutil.chown = None
201
+
202
+ import subprocess
203
+
204
+ subprocess.Popen = None # type: ignore
205
+
206
+ __builtins__["help"] = None
207
+
208
+ import sys
209
+
210
+ sys.modules["ipdb"] = None
211
+ sys.modules["joblib"] = None
212
+ sys.modules["resource"] = None
213
+ sys.modules["psutil"] = None
214
+ sys.modules["tkinter"] = None
215
+
216
+
217
+ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
218
+ """Execute code in a subprocess with safety guards. Results are written to result_dict."""
219
+ with create_tempdir():
220
+
221
+ # These system calls are needed when cleaning up tempdir.
222
+ import os
223
+ import shutil
224
+
225
+ rmtree = shutil.rmtree
226
+ rmdir = os.rmdir
227
+ chdir = os.chdir
228
+
229
+ # Disable functionalities that can make destructive changes to the test.
230
+ reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
231
+
232
+ # Default to failure
233
+ result_dict.update({
234
+ "success": False,
235
+ "stdout": "",
236
+ "stderr": "",
237
+ "timeout": False,
238
+ "memory_exceeded": False,
239
+ "error": None,
240
+ })
241
+
242
+ try:
243
+ exec_globals = {}
244
+ with capture_io() as (stdout_capture, stderr_capture):
245
+ with time_limit(timeout):
246
+ # WARNING
247
+ # This program exists to execute untrusted model-generated code. Although
248
+ # it is highly unlikely that model-generated code will do something overtly
249
+ # malicious in response to this test suite, model-generated code may act
250
+ # destructively due to a lack of model capability or alignment.
251
+ # Users are strongly encouraged to sandbox this evaluation suite so that it
252
+ # does not perform destructive actions on their host or network. For more
253
+ # information on how OpenAI sandboxes its code, see the accompanying paper.
254
+ # Once you have read this disclaimer and taken appropriate precautions,
255
+ # uncomment the following line and proceed at your own risk:
256
+ exec(code, exec_globals)
257
+
258
+ result_dict.update({
259
+ "success": True,
260
+ "stdout": stdout_capture.getvalue(),
261
+ "stderr": stderr_capture.getvalue(),
262
+ })
263
+
264
+ except TimeoutException:
265
+ result_dict.update({
266
+ "timeout": True,
267
+ "error": "Execution timed out",
268
+ })
269
+
270
+ except MemoryError as e:
271
+ result_dict.update({
272
+ "memory_exceeded": True,
273
+ "error": f"Memory limit exceeded: {e}",
274
+ })
275
+
276
+ except BaseException as e:
277
+ result_dict.update({
278
+ "error": f"{type(e).__name__}: {e}",
279
+ })
280
+
281
+ # Needed for cleaning up.
282
+ shutil.rmtree = rmtree
283
+ os.rmdir = rmdir
284
+ os.chdir = chdir
285
+
286
+
287
+ def execute_code(
288
+ code: str,
289
+ timeout: float = 5.0, # 5 seconds default
290
+ maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
291
+ ) -> ExecutionResult:
292
+ """
293
+ Execute Python code in a sandboxed environment.
294
+
295
+ Args:
296
+ code: Python code to execute as a string
297
+ timeout: Maximum execution time in seconds (default: 5.0)
298
+ maximum_memory_bytes: Memory limit in bytes (default: 256MB, None to disable)
299
+
300
+ Returns:
301
+ ExecutionResult with success status, stdout/stderr, and error information
302
+
303
+ Example:
304
+ >>> result = execute_code("print('hello world')")
305
+ >>> result.success
306
+ True
307
+ >>> result.stdout
308
+ 'hello world\\n'
309
+ """
310
+
311
+ manager = multiprocessing.Manager()
312
+ result_dict = manager.dict()
313
+
314
+ p = multiprocessing.Process(
315
+ target=_unsafe_execute,
316
+ args=(code, timeout, maximum_memory_bytes, result_dict)
317
+ )
318
+ p.start()
319
+ p.join(timeout=timeout + 1)
320
+
321
+ if p.is_alive():
322
+ p.kill()
323
+ return ExecutionResult(
324
+ success=False,
325
+ stdout="",
326
+ stderr="",
327
+ error="Execution timed out (process killed)",
328
+ timeout=True,
329
+ memory_exceeded=False,
330
+ )
331
+
332
+ if not result_dict:
333
+ return ExecutionResult(
334
+ success=False,
335
+ stdout="",
336
+ stderr="",
337
+ error="Execution failed (no result returned)",
338
+ timeout=True,
339
+ memory_exceeded=False,
340
+ )
341
+
342
+ return ExecutionResult(
343
+ success=result_dict["success"],
344
+ stdout=result_dict["stdout"],
345
+ stderr=result_dict["stderr"],
346
+ error=result_dict["error"],
347
+ timeout=result_dict["timeout"],
348
+ memory_exceeded=result_dict["memory_exceeded"],
349
+ )
350
+
nanochat/gpt.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT model (rewrite, a lot simpler)
3
+ Notable features:
4
+ - rotary embeddings (and no positional embeddings)
5
+ - QK norm
6
+ - untied weights for token embedding and lm_head
7
+ - relu^2 activation in MLP
8
+ - norm after token embedding
9
+ - no learnable params in rmsnorm
10
+ - no bias in linear layers
11
+ - Multi-Query Attention (MQA) support for more efficient inference
12
+ """
13
+
14
+ import math
15
+ from functools import partial
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from nanochat.common import get_dist_info, print0
23
+ from nanochat.muon import Muon, DistMuon
24
+ from nanochat.adamw import DistAdamW
25
+
26
+ @dataclass
27
+ class GPTConfig:
28
+ sequence_len: int = 1024
29
+ vocab_size: int = 50304
30
+ n_layer: int = 12
31
+ n_head: int = 6 # number of query heads
32
+ n_kv_head: int = 6 # number of key/value heads (MQA)
33
+ n_embd: int = 768
34
+
35
+
36
+ def norm(x):
37
+ # Purely functional rmsnorm with no learnable params
38
+ return F.rms_norm(x, (x.size(-1),))
39
+
40
+
41
+ def apply_rotary_emb(x, cos, sin):
42
+ assert x.ndim == 4 # multihead attention
43
+ d = x.shape[3] // 2
44
+ x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
45
+ y1 = x1 * cos + x2 * sin # rotate pairs of dims
46
+ y2 = x1 * (-sin) + x2 * cos
47
+ out = torch.cat([y1, y2], 3) # re-assemble
48
+ out = out.to(x.dtype) # ensure input/output dtypes match
49
+ return out
50
+
51
+
52
+ def repeat_kv(x, n_rep):
53
+ """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
54
+ if n_rep == 1:
55
+ return x
56
+ bs, n_kv_heads, slen, head_dim = x.shape
57
+ return (
58
+ x[:, :, None, :, :]
59
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
60
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
61
+ )
62
+
63
+
64
+ class CausalSelfAttention(nn.Module):
65
+ def __init__(self, config, layer_idx):
66
+ super().__init__()
67
+ self.layer_idx = layer_idx
68
+ self.n_head = config.n_head
69
+ self.n_kv_head = config.n_kv_head
70
+ self.n_embd = config.n_embd
71
+ self.head_dim = self.n_embd // self.n_head
72
+ assert self.n_embd % self.n_head == 0
73
+ assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
74
+ self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
75
+ self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
76
+ self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
77
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
78
+
79
+ def forward(self, x, cos_sin, kv_cache):
80
+ B, T, C = x.size()
81
+
82
+ # Project the input to get queries, keys, and values
83
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
84
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
85
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
86
+
87
+ # Apply Rotary Embeddings to queries and keys to get relative positional encoding
88
+ cos, sin = cos_sin
89
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
90
+ q, k = norm(q), norm(k) # QK norm
91
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
92
+
93
+ # Apply KV cache: insert current k,v into cache, get the full view so far
94
+ if kv_cache is not None:
95
+ k, v = kv_cache.insert_kv(self.layer_idx, k, v)
96
+ Tq = q.size(2) # number of queries in this forward pass
97
+ Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
98
+
99
+ # Apply MQA: replicate the key/value heads for each query head
100
+ nrep = self.n_head // self.n_kv_head
101
+ k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
102
+
103
+ # Attention: queries attend to keys/values autoregressively. A few cases to handle:
104
+ if kv_cache is None or Tq == Tk:
105
+ # During training (no KV cache), attend as usual with causal attention
106
+ # And even if there is KV cache, we can still use this simple version when Tq == Tk
107
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
108
+ elif Tq == 1:
109
+ # During inference but with a single query in this forward pass:
110
+ # The query has to attend to all the keys/values in the cache
111
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
112
+ else:
113
+ # During inference AND we have a chunk of queries in this forward pass:
114
+ # First, each query attends to all the cached keys/values (i.e. full prefix)
115
+ attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
116
+ prefix_len = Tk - Tq
117
+ if prefix_len > 0: # can't be negative but could be zero
118
+ attn_mask[:, :prefix_len] = True
119
+ # Then, causal attention within this chunk
120
+ attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
121
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
122
+
123
+ # Re-assemble the heads side by side and project back to residual stream
124
+ y = y.transpose(1, 2).contiguous().view(B, T, -1)
125
+ y = self.c_proj(y)
126
+ return y
127
+
128
+
129
+ class MLP(nn.Module):
130
+ def __init__(self, config):
131
+ super().__init__()
132
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
133
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
134
+
135
+ def forward(self, x):
136
+ x = self.c_fc(x)
137
+ x = F.relu(x).square()
138
+ x = self.c_proj(x)
139
+ return x
140
+
141
+
142
+ class Block(nn.Module):
143
+ def __init__(self, config, layer_idx):
144
+ super().__init__()
145
+ self.attn = CausalSelfAttention(config, layer_idx)
146
+ self.mlp = MLP(config)
147
+
148
+ def forward(self, x, cos_sin, kv_cache):
149
+ x = x + self.attn(norm(x), cos_sin, kv_cache)
150
+ x = x + self.mlp(norm(x))
151
+ return x
152
+
153
+
154
+ class GPT(nn.Module):
155
+ def __init__(self, config):
156
+ super().__init__()
157
+ self.config = config
158
+ self.transformer = nn.ModuleDict({
159
+ "wte": nn.Embedding(config.vocab_size, config.n_embd),
160
+ "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
161
+ })
162
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
163
+ # To support meta device initialization, we init the rotary embeddings here, but it's fake
164
+ # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
165
+ # so let's just over-compute them, but assert fail if we ever reach that amount.
166
+ # In the future we can dynamically grow the cache, for now it's fine.
167
+ self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
168
+ head_dim = config.n_embd // config.n_head
169
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
170
+ self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
171
+ self.register_buffer("sin", sin, persistent=False)
172
+ # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
173
+ self.transformer.wte.to(dtype=torch.bfloat16)
174
+
175
+ def init_weights(self):
176
+ self.apply(self._init_weights)
177
+ # zero out classifier weights
178
+ torch.nn.init.zeros_(self.lm_head.weight)
179
+ # zero out c_proj weights in all blocks
180
+ for block in self.transformer.h:
181
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
182
+ torch.nn.init.zeros_(block.attn.c_proj.weight)
183
+ # init the rotary embeddings
184
+ head_dim = self.config.n_embd // self.config.n_head
185
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
186
+ self.cos, self.sin = cos, sin
187
+
188
+ def _init_weights(self, module):
189
+ if isinstance(module, nn.Linear):
190
+ # https://arxiv.org/pdf/2310.17813
191
+ fan_out = module.weight.size(0)
192
+ fan_in = module.weight.size(1)
193
+ std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
194
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
195
+ if module.bias is not None:
196
+ torch.nn.init.zeros_(module.bias)
197
+ elif isinstance(module, nn.Embedding):
198
+ torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
199
+
200
+ # TODO: bump base theta more, e.g. 100K is more common more recently
201
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
202
+ # autodetect the device from model embeddings
203
+ if device is None:
204
+ device = self.transformer.wte.weight.device
205
+ # stride the channels
206
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
207
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
208
+ # stride the time steps
209
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
210
+ # calculate the rotation frequencies at each (time, channel) pair
211
+ freqs = torch.outer(t, inv_freq)
212
+ cos, sin = freqs.cos(), freqs.sin()
213
+ cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
214
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
215
+ return cos, sin
216
+
217
+ def get_device(self):
218
+ return self.transformer.wte.weight.device
219
+
220
+ def estimate_flops(self):
221
+ """ Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
222
+ nparams = sum(p.numel() for p in self.parameters())
223
+ nparams_embedding = self.transformer.wte.weight.numel()
224
+ l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
225
+ num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
226
+ return num_flops_per_token
227
+
228
+ def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
229
+ model_dim = self.config.n_embd
230
+ ddp, rank, local_rank, world_size = get_dist_info()
231
+ # Separate out all parameters into 3 groups (matrix, embedding, lm_head)
232
+ matrix_params = list(self.transformer.h.parameters())
233
+ embedding_params = list(self.transformer.wte.parameters())
234
+ lm_head_params = list(self.lm_head.parameters())
235
+ assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
236
+ # Create the AdamW optimizer for the embedding and lm_head
237
+ # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
238
+ dmodel_lr_scale = (model_dim / 768) ** -0.5
239
+ if rank == 0:
240
+ print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
241
+ adam_groups = [
242
+ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
243
+ dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
244
+ ]
245
+ adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
246
+ AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
247
+ adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
248
+ # Create the Muon optimizer for the linear layers
249
+ muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
250
+ MuonFactory = DistMuon if ddp else Muon
251
+ muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
252
+ # Combine them the two optimizers into one list
253
+ optimizers = [adamw_optimizer, muon_optimizer]
254
+ for opt in optimizers:
255
+ for group in opt.param_groups:
256
+ group["initial_lr"] = group["lr"]
257
+ return optimizers
258
+
259
+ def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
260
+ B, T = idx.size()
261
+
262
+ # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
263
+ assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
264
+ assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
265
+ assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
266
+ # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
267
+ T0 = 0 if kv_cache is None else kv_cache.get_pos()
268
+ cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
269
+
270
+ # Forward the trunk of the Transformer
271
+ x = self.transformer.wte(idx)
272
+ x = norm(x)
273
+ for block in self.transformer.h:
274
+ x = block(x, cos_sin, kv_cache)
275
+ x = norm(x)
276
+
277
+ # Forward the lm_head (compute logits)
278
+ softcap = 15
279
+ if targets is not None:
280
+ # training mode: compute and return the loss
281
+ # TODO: experiment with Liger Kernels / chunked cross-entropy etc.
282
+ logits = self.lm_head(x)
283
+ logits = softcap * torch.tanh(logits / softcap) # logits softcap
284
+ logits = logits.float() # use tf32/fp32 for logits
285
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
286
+ return loss
287
+ else:
288
+ # inference mode: compute and return the logits
289
+ logits = self.lm_head(x)
290
+ logits = softcap * torch.tanh(logits / softcap) # logits softcap
291
+ return logits
292
+
293
+ @torch.inference_mode()
294
+ def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
295
+ """
296
+ Naive autoregressive streaming inference.
297
+ To make it super simple, let's assume:
298
+ - batch size is 1
299
+ - ids and the yielded tokens are simple Python lists and ints
300
+ """
301
+ assert isinstance(tokens, list)
302
+ device = self.get_device()
303
+ rng = None
304
+ if temperature > 0:
305
+ rng = torch.Generator(device=device)
306
+ rng.manual_seed(seed)
307
+ ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
308
+ for _ in range(max_tokens):
309
+ logits = self.forward(ids) # (B, T, vocab_size)
310
+ logits = logits[:, -1, :] # (B, vocab_size)
311
+ if top_k is not None:
312
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
313
+ logits[logits < v[:, [-1]]] = -float('Inf')
314
+ if temperature > 0:
315
+ logits = logits / temperature
316
+ probs = F.softmax(logits, dim=-1)
317
+ next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
318
+ else:
319
+ next_ids = torch.argmax(logits, dim=-1, keepdim=True)
320
+ ids = torch.cat((ids, next_ids), dim=1)
321
+ token = next_ids.item()
322
+ yield token
nanochat/logo.svg ADDED
nanochat/loss_eval.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A number of functions that help with evaluating a base model.
3
+ """
4
+ import math
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ @torch.no_grad()
9
+ def evaluate_bpb(model, batches, steps, token_bytes):
10
+ """
11
+ Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
12
+ which is a tokenization vocab size-indepedent metric, meaning you are still comparing
13
+ apples:apples if you change the vocab size. The way this works is that instead of just
14
+ calculating the average loss as usual, you calculate the sum loss, and indepependently
15
+ also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
16
+ the number of bytes that the target tokens represent.
17
+
18
+ The added complexity is so that:
19
+ 1) All "normal" tokens are normalized by the length of the token in bytes
20
+ 2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
21
+ 3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
22
+
23
+ In addition to evaluate_loss, we need the token_bytes tensor:
24
+ It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
25
+ each token id, or 0 if the token is to not be counted (e.g. special tokens).
26
+ """
27
+ # record the losses
28
+ total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
29
+ total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
30
+ batch_iter = iter(batches)
31
+ for _ in range(steps):
32
+ x, y = next(batch_iter)
33
+ loss2d = model(x, y, loss_reduction='none') # (B, T)
34
+ loss2d = loss2d.view(-1) # flatten
35
+ y = y.view(-1) # flatten
36
+ if (y < 0).any():
37
+ # slightly more complex code path if some target tokens are ignore_index (e.g. -1)
38
+ # any target token < 0 is to be ignored: do NOT index token_bytes with negatives
39
+ valid = y >= 0
40
+ y_safe = torch.where(valid, y, torch.zeros_like(y))
41
+ # map valid targets to their byte length; ignored targets contribute 0 bytes
42
+ num_bytes2d = torch.where(
43
+ valid,
44
+ token_bytes[y_safe],
45
+ torch.zeros_like(y, dtype=token_bytes.dtype)
46
+ )
47
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
48
+ total_bytes += num_bytes2d.sum()
49
+ else:
50
+ # fast path: no ignored targets, safe to index directly
51
+ num_bytes2d = token_bytes[y]
52
+ total_nats += (loss2d * (num_bytes2d > 0)).sum()
53
+ total_bytes += num_bytes2d.sum()
54
+ # sum reduce across all ranks
55
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
56
+ if world_size > 1:
57
+ dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
58
+ dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
59
+ # move both to cpu, calculate bpb and return
60
+ total_nats = total_nats.item()
61
+ total_bytes = total_bytes.item()
62
+ bpb = total_nats / (math.log(2) * total_bytes)
63
+ return bpb
nanochat/muon.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Muon optimizer from Keller et al.
3
+ Also a lot of borrowing of ideas from modded-nanogpt.
4
+ """
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.distributed as dist
8
+
9
+ @torch.compile
10
+ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
11
+ """
12
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
13
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
14
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
15
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
16
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
17
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
18
+ performance at all relative to UV^T, where USV^T = G is the SVD.
19
+ """
20
+ assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
21
+ a, b, c = (3.4445, -4.7750, 2.0315)
22
+ X = G.bfloat16()
23
+ if G.size(-2) > G.size(-1):
24
+ X = X.mT
25
+
26
+ # Ensure spectral norm is at most 1
27
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
28
+ # Perform the NS iterations
29
+ for _ in range(steps):
30
+ A = X @ X.mT
31
+ B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
32
+ X = a * X + B @ X
33
+
34
+ if G.size(-2) > G.size(-1):
35
+ X = X.mT
36
+ return X
37
+
38
+ class Muon(torch.optim.Optimizer):
39
+ """
40
+ Muon - MomentUm Orthogonalized by Newton-schulz
41
+
42
+ https://kellerjordan.github.io/posts/muon/
43
+
44
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
45
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
46
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
47
+ the advantage that it can be stably run in bfloat16 on the GPU.
48
+
49
+ Some warnings:
50
+ - This optimizer should not be used for the embedding layer, the final fully connected layer,
51
+ or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
52
+ - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
53
+
54
+ Arguments:
55
+ lr: The learning rate used by the internal SGD.
56
+ momentum: The momentum used by the internal SGD.
57
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
58
+ ns_steps: The number of Newton-Schulz iteration steps to use.
59
+ """
60
+ def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
61
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
62
+ params: list[Tensor] = [*params]
63
+ param_groups = []
64
+ for size in {p.numel() for p in params}:
65
+ group = dict(params=[p for p in params if p.numel() == size])
66
+ param_groups.append(group)
67
+ super().__init__(param_groups, defaults)
68
+
69
+ @torch.no_grad()
70
+ def step(self):
71
+ for group in self.param_groups:
72
+ params: list[Tensor] = group["params"]
73
+ for p in params:
74
+ g = p.grad
75
+ assert g is not None
76
+ state = self.state[p]
77
+ if "momentum_buffer" not in state:
78
+ state["momentum_buffer"] = torch.zeros_like(g)
79
+ buf: Tensor = state["momentum_buffer"]
80
+ buf.lerp_(g, 1 - group["momentum"])
81
+ g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
82
+ g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
83
+ p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
84
+
85
+
86
+ class DistMuon(torch.optim.Optimizer):
87
+ """
88
+ Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz,
89
+ finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
90
+ - reduce_scatter(AVG) for gradient averaging
91
+ - all_gather to replicate updated weights
92
+
93
+ Notes:
94
+ * Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
95
+ params like embeddings or scalars.
96
+ * Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
97
+ by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
98
+ consolidate states beforehand.
99
+
100
+ Args:
101
+ params: iterable of Tensors
102
+ lr: learning rate
103
+ momentum: momentum coefficient in [0,1)
104
+ nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
105
+ ns_steps: number of Newton–Schulz iterations for the orthogonalization
106
+ """
107
+ def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
108
+ nesterov: bool = True, ns_steps: int = 5):
109
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
110
+ params = list(params)
111
+ assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
112
+ rank = dist.get_rank()
113
+ # Group all parameters by their shape
114
+ shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
115
+ param_groups = []
116
+ for shape in shapes:
117
+ group_params = [p for p in params if p.shape == shape]
118
+ device, dtype = group_params[0].device, group_params[0].dtype
119
+ assert all(p.device == device for p in group_params)
120
+ assert all(p.dtype == dtype for p in group_params)
121
+ if rank == 0:
122
+ print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
123
+ param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
124
+ super().__init__(param_groups, defaults)
125
+
126
+ @torch.no_grad()
127
+ def step(self):
128
+ rank = dist.get_rank()
129
+ world_size = dist.get_world_size()
130
+
131
+ # Ensure all grads exist
132
+ assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
133
+
134
+ # Kick off all the reduce scatter operations to average up the gradients across all ranks
135
+ all_reduce_futures = []
136
+ for group in self.param_groups:
137
+ params = group["params"]
138
+ zero_buffer = group["zero_buffer"]
139
+ # Go through params in groups of world_size.
140
+ for base_i in range(0, len(params), world_size):
141
+ # The compute owner of each param is rank i % world_size
142
+ owner_idx = base_i + rank
143
+ # each rank stacks up its chunk of world_size params into a list
144
+ rs_input = [p.grad for p in params[base_i:base_i + world_size]]
145
+ # pad rs_input with the zero buffer to complete the group
146
+ rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
147
+ # the output buffer gets strided across the group based on the rank
148
+ rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
149
+ # reduce scatter the gradients within this group of world_size params
150
+ work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
151
+ all_reduce_futures.append(work)
152
+
153
+ # Now each rank computes the update and gathers
154
+ future_idx = 0
155
+ all_gather_futures = []
156
+ for group in self.param_groups:
157
+ params = group["params"]
158
+ zero_buffer = group["zero_buffer"]
159
+ # Go through params in groups of world_size.
160
+ for base_i in range(0, len(params), world_size):
161
+ # The compute owner of each param is rank i % world_size
162
+ owner_idx = base_i + rank # calculate the index of the param that this rank owns
163
+ # Wait for the reduce scatter to complete
164
+ all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
165
+ future_idx += 1
166
+ # Owner computes the Muon update, result is in its param
167
+ if owner_idx < len(params):
168
+ p = params[owner_idx]
169
+ g = p.grad # now averaged across ranks
170
+ state = self.state[p]
171
+ if "momentum_buffer" not in state:
172
+ state["momentum_buffer"] = torch.zeros_like(g)
173
+ buf: Tensor = state["momentum_buffer"]
174
+ buf.lerp_(g, 1.0 - group["momentum"])
175
+ g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
176
+ g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
177
+ scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
178
+ p.add_(g, alpha=-group["lr"] * scale)
179
+ # Replicate updated parameters to all ranks
180
+ ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
181
+ ag_output = params[base_i:base_i + world_size]
182
+ ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
183
+ work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
184
+ all_gather_futures.append(work)
185
+
186
+ # Wait for all work to finish
187
+ torch.futures.collect_all(all_gather_futures).wait()
nanochat/report.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for generating training report cards. More messy code than usual, will fix.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import shutil
8
+ import subprocess
9
+ import socket
10
+ import datetime
11
+ import platform
12
+ import psutil
13
+ import torch
14
+
15
+ def run_command(cmd):
16
+ """Run a shell command and return output, or None if it fails."""
17
+ try:
18
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5)
19
+ if result.returncode == 0:
20
+ return result.stdout.strip()
21
+ return None
22
+ except:
23
+ return None
24
+
25
+ def get_git_info():
26
+ """Get current git commit, branch, and dirty status."""
27
+ info = {}
28
+ info['commit'] = run_command("git rev-parse --short HEAD") or "unknown"
29
+ info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown"
30
+
31
+ # Check if repo is dirty (has uncommitted changes)
32
+ status = run_command("git status --porcelain")
33
+ info['dirty'] = bool(status) if status is not None else False
34
+
35
+ # Get commit message
36
+ info['message'] = run_command("git log -1 --pretty=%B") or ""
37
+ info['message'] = info['message'].split('\n')[0][:80] # First line, truncated
38
+
39
+ return info
40
+
41
+ def get_gpu_info():
42
+ """Get GPU information."""
43
+ if not torch.cuda.is_available():
44
+ return {"available": False}
45
+
46
+ num_devices = torch.cuda.device_count()
47
+ info = {
48
+ "available": True,
49
+ "count": num_devices,
50
+ "names": [],
51
+ "memory_gb": []
52
+ }
53
+
54
+ for i in range(num_devices):
55
+ props = torch.cuda.get_device_properties(i)
56
+ info["names"].append(props.name)
57
+ info["memory_gb"].append(props.total_memory / (1024**3))
58
+
59
+ # Get CUDA version
60
+ info["cuda_version"] = torch.version.cuda or "unknown"
61
+
62
+ return info
63
+
64
+ def get_system_info():
65
+ """Get system information."""
66
+ info = {}
67
+
68
+ # Basic system info
69
+ info['hostname'] = socket.gethostname()
70
+ info['platform'] = platform.system()
71
+ info['python_version'] = platform.python_version()
72
+ info['torch_version'] = torch.__version__
73
+
74
+ # CPU and memory
75
+ info['cpu_count'] = psutil.cpu_count(logical=False)
76
+ info['cpu_count_logical'] = psutil.cpu_count(logical=True)
77
+ info['memory_gb'] = psutil.virtual_memory().total / (1024**3)
78
+
79
+ # User and environment
80
+ info['user'] = os.environ.get('USER', 'unknown')
81
+ info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out')
82
+ info['working_dir'] = os.getcwd()
83
+
84
+ return info
85
+
86
+ def estimate_cost(gpu_info, runtime_hours=None):
87
+ """Estimate training cost based on GPU type and runtime."""
88
+
89
+ # Rough pricing, from Lambda Cloud
90
+ default_rate = 2.0
91
+ gpu_hourly_rates = {
92
+ "H100": 3.00,
93
+ "A100": 1.79,
94
+ "V100": 0.55,
95
+ }
96
+
97
+ if not gpu_info.get("available"):
98
+ return None
99
+
100
+ # Try to identify GPU type from name
101
+ hourly_rate = None
102
+ gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown"
103
+ for gpu_type, rate in gpu_hourly_rates.items():
104
+ if gpu_type in gpu_name:
105
+ hourly_rate = rate * gpu_info["count"]
106
+ break
107
+
108
+ if hourly_rate is None:
109
+ hourly_rate = default_rate * gpu_info["count"] # Default estimate
110
+
111
+ return {
112
+ "hourly_rate": hourly_rate,
113
+ "gpu_type": gpu_name,
114
+ "estimated_total": hourly_rate * runtime_hours if runtime_hours else None
115
+ }
116
+
117
+ def generate_header():
118
+ """Generate the header for a training report."""
119
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
120
+
121
+ git_info = get_git_info()
122
+ gpu_info = get_gpu_info()
123
+ sys_info = get_system_info()
124
+ cost_info = estimate_cost(gpu_info)
125
+
126
+ header = f"""# nanochat training report
127
+
128
+ Generated: {timestamp}
129
+
130
+ ## Environment
131
+
132
+ ### Git Information
133
+ - Branch: {git_info['branch']}
134
+ - Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"}
135
+ - Message: {git_info['message']}
136
+
137
+ ### Hardware
138
+ - Platform: {sys_info['platform']}
139
+ - CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical)
140
+ - Memory: {sys_info['memory_gb']:.1f} GB
141
+ """
142
+
143
+ if gpu_info.get("available"):
144
+ gpu_names = ", ".join(set(gpu_info["names"]))
145
+ total_vram = sum(gpu_info["memory_gb"])
146
+ header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
147
+ - GPU Memory: {total_vram:.1f} GB total
148
+ - CUDA Version: {gpu_info['cuda_version']}
149
+ """
150
+ else:
151
+ header += "- GPUs: None available\n"
152
+
153
+ if cost_info and cost_info["hourly_rate"] > 0:
154
+ header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n"""
155
+
156
+ header += f"""
157
+ ### Software
158
+ - Python: {sys_info['python_version']}
159
+ - PyTorch: {sys_info['torch_version']}
160
+
161
+ """
162
+
163
+ # bloat metrics: package all of the source code and assess its weight
164
+ packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml')
165
+ num_chars = len(packaged)
166
+ num_lines = len(packaged.split('\n'))
167
+ num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')])
168
+ num_tokens = num_chars // 4 # assume approximately 4 chars per token
169
+
170
+ # count dependencies via uv.lock
171
+ uv_lock_lines = 0
172
+ if os.path.exists('uv.lock'):
173
+ with open('uv.lock', 'r') as f:
174
+ uv_lock_lines = len(f.readlines())
175
+
176
+ header += f"""
177
+ ### Bloat
178
+ - Characters: {num_chars:,}
179
+ - Lines: {num_lines:,}
180
+ - Files: {num_files:,}
181
+ - Tokens (approx): {num_tokens:,}
182
+ - Dependencies (uv.lock lines): {uv_lock_lines:,}
183
+
184
+ """
185
+ return header
186
+
187
+ # -----------------------------------------------------------------------------
188
+
189
+ def slugify(text):
190
+ """Slugify a text string."""
191
+ return text.lower().replace(" ", "-")
192
+
193
+ # the expected files and their order
194
+ EXPECTED_FILES = [
195
+ "tokenizer-training.md",
196
+ "tokenizer-evaluation.md",
197
+ "base-model-training.md",
198
+ "base-model-loss.md",
199
+ "base-model-evaluation.md",
200
+ "midtraining.md",
201
+ "chat-evaluation-mid.md",
202
+ "chat-sft.md",
203
+ "chat-evaluation-sft.md",
204
+ "chat-rl.md",
205
+ "chat-evaluation-rl.md",
206
+ ]
207
+ # the metrics we're currently interested in
208
+ chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"]
209
+
210
+ def extract(section, keys):
211
+ """simple def to extract a single key from a section"""
212
+ if not isinstance(keys, list):
213
+ keys = [keys] # convenience
214
+ out = {}
215
+ for line in section.split("\n"):
216
+ for key in keys:
217
+ if key in line:
218
+ out[key] = line.split(":")[1].strip()
219
+ return out
220
+
221
+ def extract_timestamp(content, prefix):
222
+ """Extract timestamp from content with given prefix."""
223
+ for line in content.split('\n'):
224
+ if line.startswith(prefix):
225
+ time_str = line.split(":", 1)[1].strip()
226
+ try:
227
+ return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
228
+ except:
229
+ pass
230
+ return None
231
+
232
+ class Report:
233
+ """Maintains a bunch of logs, generates a final markdown report."""
234
+
235
+ def __init__(self, report_dir):
236
+ os.makedirs(report_dir, exist_ok=True)
237
+ self.report_dir = report_dir
238
+
239
+ def log(self, section, data):
240
+ """Log a section of data to the report."""
241
+ slug = slugify(section)
242
+ file_name = f"{slug}.md"
243
+ file_path = os.path.join(self.report_dir, file_name)
244
+ with open(file_path, "w") as f:
245
+ f.write(f"## {section}\n")
246
+ f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
247
+ for item in data:
248
+ if not item:
249
+ # skip falsy values like None or empty dict etc.
250
+ continue
251
+ if isinstance(item, str):
252
+ # directly write the string
253
+ f.write(item)
254
+ else:
255
+ # render a dict
256
+ for k, v in item.items():
257
+ if isinstance(v, float):
258
+ vstr = f"{v:.4f}"
259
+ elif isinstance(v, int) and v >= 10000:
260
+ vstr = f"{v:,.0f}"
261
+ else:
262
+ vstr = str(v)
263
+ f.write(f"- {k}: {vstr}\n")
264
+ f.write("\n")
265
+ return file_path
266
+
267
+ def generate(self):
268
+ """Generate the final report."""
269
+ report_dir = self.report_dir
270
+ report_file = os.path.join(report_dir, "report.md")
271
+ print(f"Generating report to {report_file}")
272
+ final_metrics = {} # the most important final metrics we'll add as table at the end
273
+ start_time = None
274
+ end_time = None
275
+ with open(report_file, "w") as out_file:
276
+ # write the header first
277
+ header_file = os.path.join(report_dir, "header.md")
278
+ if os.path.exists(header_file):
279
+ with open(header_file, "r") as f:
280
+ header_content = f.read()
281
+ out_file.write(header_content)
282
+ start_time = extract_timestamp(header_content, "Run started:")
283
+ # capture bloat data for summary later (the stuff after Bloat header and until \n\n)
284
+ bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
285
+ bloat_data = bloat_data.group(1) if bloat_data else ""
286
+ # process all the individual sections
287
+ for file_name in EXPECTED_FILES:
288
+ section_file = os.path.join(report_dir, file_name)
289
+ if not os.path.exists(section_file):
290
+ print(f"Warning: {section_file} does not exist, skipping")
291
+ continue
292
+ with open(section_file, "r") as in_file:
293
+ section = in_file.read()
294
+ # Extract timestamp from this section (the last section's timestamp will "stick" as end_time)
295
+ if "rl" not in file_name:
296
+ # Skip RL sections for end_time calculation because RL is experimental
297
+ end_time = extract_timestamp(section, "timestamp:")
298
+ # extract the most important metrics from the sections
299
+ if file_name == "base-model-evaluation.md":
300
+ final_metrics["base"] = extract(section, "CORE")
301
+ if file_name == "chat-evaluation-mid.md":
302
+ final_metrics["mid"] = extract(section, chat_metrics)
303
+ if file_name == "chat-evaluation-sft.md":
304
+ final_metrics["sft"] = extract(section, chat_metrics)
305
+ if file_name == "chat-evaluation-rl.md":
306
+ final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K
307
+ # append this section of the report
308
+ out_file.write(section)
309
+ out_file.write("\n")
310
+ # add the final metrics table
311
+ out_file.write("## Summary\n\n")
312
+ # Copy over the bloat metrics from the header
313
+ out_file.write(bloat_data)
314
+ out_file.write("\n\n")
315
+ # Collect all unique metric names
316
+ all_metrics = set()
317
+ for stage_metrics in final_metrics.values():
318
+ all_metrics.update(stage_metrics.keys())
319
+ # Custom ordering: CORE first, ChatCORE last, rest in middle
320
+ all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x))
321
+ # Fixed column widths
322
+ stages = ["base", "mid", "sft", "rl"]
323
+ metric_width = 15
324
+ value_width = 8
325
+ # Write table header
326
+ header = f"| {'Metric'.ljust(metric_width)} |"
327
+ for stage in stages:
328
+ header += f" {stage.upper().ljust(value_width)} |"
329
+ out_file.write(header + "\n")
330
+ # Write separator
331
+ separator = f"|{'-' * (metric_width + 2)}|"
332
+ for stage in stages:
333
+ separator += f"{'-' * (value_width + 2)}|"
334
+ out_file.write(separator + "\n")
335
+ # Write table rows
336
+ for metric in all_metrics:
337
+ row = f"| {metric.ljust(metric_width)} |"
338
+ for stage in stages:
339
+ value = final_metrics.get(stage, {}).get(metric, "-")
340
+ row += f" {str(value).ljust(value_width)} |"
341
+ out_file.write(row + "\n")
342
+ out_file.write("\n")
343
+ # Calculate and write total wall clock time
344
+ if start_time and end_time:
345
+ duration = end_time - start_time
346
+ total_seconds = int(duration.total_seconds())
347
+ hours = total_seconds // 3600
348
+ minutes = (total_seconds % 3600) // 60
349
+ out_file.write(f"Total wall clock time: {hours}h{minutes}m\n")
350
+ else:
351
+ out_file.write("Total wall clock time: unknown\n")
352
+ # also cp the report.md file to current directory
353
+ print(f"Copying report.md to current directory for convenience")
354
+ shutil.copy(report_file, "report.md")
355
+ return report_file
356
+
357
+ def reset(self):
358
+ """Reset the report."""
359
+ # Remove section files
360
+ for file_name in EXPECTED_FILES:
361
+ file_path = os.path.join(self.report_dir, file_name)
362
+ if os.path.exists(file_path):
363
+ os.remove(file_path)
364
+ # Remove report.md if it exists
365
+ report_file = os.path.join(self.report_dir, "report.md")
366
+ if os.path.exists(report_file):
367
+ os.remove(report_file)
368
+ # Generate and write the header section with start timestamp
369
+ header_file = os.path.join(self.report_dir, "header.md")
370
+ header = generate_header()
371
+ start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
372
+ with open(header_file, "w") as f:
373
+ f.write(header)
374
+ f.write(f"Run started: {start_time}\n\n---\n\n")
375
+ print(f"Reset report and wrote header to {header_file}")
376
+
377
+ # -----------------------------------------------------------------------------
378
+ # nanochat-specific convenience functions
379
+
380
+ class DummyReport:
381
+ def log(self, *args, **kwargs):
382
+ pass
383
+ def reset(self, *args, **kwargs):
384
+ pass
385
+
386
+ def get_report():
387
+ # just for convenience, only rank 0 logs to report
388
+ from nanochat.common import get_base_dir, get_dist_info
389
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
390
+ if ddp_rank == 0:
391
+ report_dir = os.path.join(get_base_dir(), "report")
392
+ return Report(report_dir)
393
+ else:
394
+ return DummyReport()
395
+
396
+ if __name__ == "__main__":
397
+ import argparse
398
+ parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.")
399
+ parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)")
400
+ args = parser.parse_args()
401
+ if args.command == "generate":
402
+ get_report().generate()
403
+ elif args.command == "reset":
404
+ get_report().reset()
nanochat/tokenizer.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BPE Tokenizer in the style of GPT-4.
3
+
4
+ Two implementations are available:
5
+ 1) HuggingFace Tokenizer that can do both training and inference but is really confusing
6
+ 2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
7
+ """
8
+
9
+ import os
10
+ import copy
11
+ from functools import lru_cache
12
+
13
+ SPECIAL_TOKENS = [
14
+ # every document begins with the Beginning of Sequence (BOS) token that delimits documents
15
+ "<|bos|>",
16
+ # tokens below are only used during finetuning to render Conversations into token ids
17
+ "<|user_start|>", # user messages
18
+ "<|user_end|>",
19
+ "<|assistant_start|>", # assistant messages
20
+ "<|assistant_end|>",
21
+ "<|python_start|>", # assistant invokes python REPL tool
22
+ "<|python_end|>",
23
+ "<|output_start|>", # python REPL outputs back to assistant
24
+ "<|output_end|>",
25
+ ]
26
+
27
+ # NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3}
28
+ # I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes.
29
+ # I haven't validated that this is actually a good idea, TODO.
30
+ SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # Generic GPT-4-style tokenizer based on HuggingFace Tokenizer
34
+ from tokenizers import Tokenizer as HFTokenizer
35
+ from tokenizers import pre_tokenizers, decoders, Regex
36
+ from tokenizers.models import BPE
37
+ from tokenizers.trainers import BpeTrainer
38
+
39
+ class HuggingFaceTokenizer:
40
+ """Light wrapper around HuggingFace Tokenizer for some utilities"""
41
+
42
+ def __init__(self, tokenizer):
43
+ self.tokenizer = tokenizer
44
+
45
+ @classmethod
46
+ def from_pretrained(cls, hf_path):
47
+ # init from a HuggingFace pretrained tokenizer (e.g. "gpt2")
48
+ tokenizer = HFTokenizer.from_pretrained(hf_path)
49
+ return cls(tokenizer)
50
+
51
+ @classmethod
52
+ def from_directory(cls, tokenizer_dir):
53
+ # init from a local directory on disk (e.g. "out/tokenizer")
54
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
55
+ tokenizer = HFTokenizer.from_file(tokenizer_path)
56
+ return cls(tokenizer)
57
+
58
+ @classmethod
59
+ def train_from_iterator(cls, text_iterator, vocab_size):
60
+ # train from an iterator of text
61
+ # Configure the HuggingFace Tokenizer
62
+ tokenizer = HFTokenizer(BPE(
63
+ byte_fallback=True, # needed!
64
+ unk_token=None,
65
+ fuse_unk=False,
66
+ ))
67
+ # Normalizer: None
68
+ tokenizer.normalizer = None
69
+ # Pre-tokenizer: GPT-4 style
70
+ # the regex pattern used by GPT-4 to split text into groups before BPE
71
+ # NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to
72
+ # very small models and smaller vocab sizes, because it is a little bit wasteful in the token space.
73
+ # (but I haven't validated this! TODO)
74
+ gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
75
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
76
+ pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
77
+ pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
78
+ ])
79
+ # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
80
+ tokenizer.decoder = decoders.ByteLevel()
81
+ # Post-processor: None
82
+ tokenizer.post_processor = None
83
+ # Trainer: BPE
84
+ trainer = BpeTrainer(
85
+ vocab_size=vocab_size,
86
+ show_progress=True,
87
+ min_frequency=0, # no minimum frequency
88
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
89
+ special_tokens=SPECIAL_TOKENS,
90
+ )
91
+ # Kick off the training
92
+ tokenizer.train_from_iterator(text_iterator, trainer)
93
+ return cls(tokenizer)
94
+
95
+ def get_vocab_size(self):
96
+ return self.tokenizer.get_vocab_size()
97
+
98
+ def get_special_tokens(self):
99
+ special_tokens_map = self.tokenizer.get_added_tokens_decoder()
100
+ special_tokens = [w.content for w in special_tokens_map.values()]
101
+ return special_tokens
102
+
103
+ def id_to_token(self, id):
104
+ return self.tokenizer.id_to_token(id)
105
+
106
+ def _encode_one(self, text, prepend=None, append=None):
107
+ # encode a single string
108
+ # prepend/append can be either a string of a special token or a token id directly.
109
+ assert isinstance(text, str)
110
+ ids = []
111
+ if prepend is not None:
112
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
113
+ ids.append(prepend_id)
114
+ ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
115
+ if append is not None:
116
+ append_id = append if isinstance(append, int) else self.encode_special(append)
117
+ ids.append(append_id)
118
+ return ids
119
+
120
+ def encode_special(self, text):
121
+ # encode a single special token via exact match
122
+ return self.tokenizer.token_to_id(text)
123
+
124
+ def get_bos_token_id(self):
125
+ bos = self.encode_special("<|bos|>")
126
+ return bos
127
+
128
+ def encode(self, text, *args, **kwargs):
129
+ if isinstance(text, str):
130
+ return self._encode_one(text, *args, **kwargs)
131
+ elif isinstance(text, list):
132
+ return [self._encode_one(t, *args, **kwargs) for t in text]
133
+ else:
134
+ raise ValueError(f"Invalid input type: {type(text)}")
135
+
136
+ def __call__(self, *args, **kwargs):
137
+ return self.encode(*args, **kwargs)
138
+
139
+ def decode(self, ids):
140
+ return self.tokenizer.decode(ids, skip_special_tokens=False)
141
+
142
+ def save(self, tokenizer_dir):
143
+ # save the tokenizer to disk
144
+ os.makedirs(tokenizer_dir, exist_ok=True)
145
+ tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
146
+ self.tokenizer.save(tokenizer_path)
147
+ print(f"Saved tokenizer to {tokenizer_path}")
148
+
149
+ # -----------------------------------------------------------------------------
150
+ # Tokenizer based on rustbpe + tiktoken combo
151
+ import pickle
152
+ import rustbpe
153
+ import tiktoken
154
+
155
+ class RustBPETokenizer:
156
+ """Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""
157
+
158
+ def __init__(self, enc, bos_token):
159
+ self.enc = enc
160
+ self.bos_token_id = self.encode_special(bos_token)
161
+
162
+ @classmethod
163
+ def train_from_iterator(cls, text_iterator, vocab_size):
164
+ # 1) train using rustbpe
165
+ tokenizer = rustbpe.Tokenizer()
166
+ # the special tokens are inserted later in __init__, we don't train them here
167
+ vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
168
+ assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
169
+ tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
170
+ # 2) construct the associated tiktoken encoding for inference
171
+ pattern = tokenizer.get_pattern()
172
+ mergeable_ranks_list = tokenizer.get_mergeable_ranks()
173
+ mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
174
+ tokens_offset = len(mergeable_ranks)
175
+ special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
176
+ enc = tiktoken.Encoding(
177
+ name="rustbpe",
178
+ pat_str=pattern,
179
+ mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
180
+ special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
181
+ )
182
+ return cls(enc, "<|bos|>")
183
+
184
+ @classmethod
185
+ def from_directory(cls, tokenizer_dir):
186
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
187
+ with open(pickle_path, "rb") as f:
188
+ enc = pickle.load(f)
189
+ return cls(enc, "<|bos|>")
190
+
191
+ @classmethod
192
+ def from_pretrained(cls, tiktoken_name):
193
+ # https://github.com/openai/tiktoken/blob/eedc8563/tiktoken_ext/openai_public.py
194
+ enc = tiktoken.get_encoding(tiktoken_name)
195
+ # tiktoken calls the special document delimiter token "<|endoftext|>"
196
+ # yes this is confusing because this token is almost always PREPENDED to the beginning of the document
197
+ # it most often is used to signal the start of a new sequence to the LLM during inference etc.
198
+ # so in nanoChat we always use "<|bos|>" short for "beginning of sequence", but historically it is often called "<|endoftext|>".
199
+ return cls(enc, "<|endoftext|>")
200
+
201
+ def get_vocab_size(self):
202
+ return self.enc.n_vocab
203
+
204
+ def get_special_tokens(self):
205
+ return self.enc.special_tokens_set
206
+
207
+ def id_to_token(self, id):
208
+ return self.enc.decode([id])
209
+
210
+ @lru_cache(maxsize=32)
211
+ def encode_special(self, text):
212
+ return self.enc.encode_single_token(text)
213
+
214
+ def get_bos_token_id(self):
215
+ return self.bos_token_id
216
+
217
+ def encode(self, text, prepend=None, append=None, num_threads=8):
218
+ # text can be either a string or a list of strings
219
+
220
+ if prepend is not None:
221
+ prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
222
+ if append is not None:
223
+ append_id = append if isinstance(append, int) else self.encode_special(append)
224
+
225
+ if isinstance(text, str):
226
+ ids = self.enc.encode_ordinary(text)
227
+ if prepend is not None:
228
+ ids.insert(0, prepend_id) # TODO: slightly inefficient here? :( hmm
229
+ if append is not None:
230
+ ids.append(append_id)
231
+ elif isinstance(text, list):
232
+ ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
233
+ if prepend is not None:
234
+ for ids_row in ids:
235
+ ids_row.insert(0, prepend_id) # TODO: same
236
+ if append is not None:
237
+ for ids_row in ids:
238
+ ids_row.append(append_id)
239
+ else:
240
+ raise ValueError(f"Invalid input type: {type(text)}")
241
+
242
+ return ids
243
+
244
+ def __call__(self, *args, **kwargs):
245
+ return self.encode(*args, **kwargs)
246
+
247
+ def decode(self, ids):
248
+ return self.enc.decode(ids)
249
+
250
+ def save(self, tokenizer_dir):
251
+ # save the encoding object to disk
252
+ os.makedirs(tokenizer_dir, exist_ok=True)
253
+ pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
254
+ with open(pickle_path, "wb") as f:
255
+ pickle.dump(self.enc, f)
256
+ print(f"Saved tokenizer encoding to {pickle_path}")
257
+
258
+ def render_conversation(self, conversation, max_tokens=2048):
259
+ """
260
+ Tokenize a single Chat conversation (which we call a "doc" or "document" here).
261
+ Returns:
262
+ - ids: list[int] is a list of token ids of this rendered conversation
263
+ - mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
264
+ """
265
+ # ids, masks that we will return and a helper function to help build them up.
266
+ ids, mask = [], []
267
+ def add_tokens(token_ids, mask_val):
268
+ if isinstance(token_ids, int):
269
+ token_ids = [token_ids]
270
+ ids.extend(token_ids)
271
+ mask.extend([mask_val] * len(token_ids))
272
+
273
+ # sometimes the first message is a system message...
274
+ # => just merge it with the second (user) message
275
+ if conversation["messages"][0]["role"] == "system":
276
+ # some conversation surgery is necessary here for now...
277
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
278
+ messages = conversation["messages"]
279
+ assert messages[1]["role"] == "user", "System message must be followed by a user message"
280
+ messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
281
+ messages = messages[1:]
282
+ else:
283
+ messages = conversation["messages"]
284
+ assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"
285
+
286
+ # fetch all the special tokens we need
287
+ bos = self.get_bos_token_id()
288
+ user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
289
+ assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
290
+ python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
291
+ output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")
292
+
293
+ # now we can tokenize the conversation
294
+ add_tokens(bos, 0)
295
+ for i, message in enumerate(messages):
296
+
297
+ # some sanity checking here around assumptions, to prevent footguns
298
+ must_be_from = "user" if i % 2 == 0 else "assistant"
299
+ assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"
300
+
301
+ # content can be either a simple string or a list of parts (e.g. containing tool calls)
302
+ content = message["content"]
303
+
304
+ if message["role"] == "user":
305
+ assert isinstance(content, str), "User messages are simply expected to be strings"
306
+ value_ids = self.encode(content)
307
+ add_tokens(user_start, 0)
308
+ add_tokens(value_ids, 0)
309
+ add_tokens(user_end, 0)
310
+ elif message["role"] == "assistant":
311
+ add_tokens(assistant_start, 0)
312
+ if isinstance(content, str):
313
+ # simple string => simply add the tokens
314
+ value_ids = self.encode(content)
315
+ add_tokens(value_ids, 1)
316
+ elif isinstance(content, list):
317
+ for part in content:
318
+ value_ids = self.encode(part["text"])
319
+ if part["type"] == "text":
320
+ # string part => simply add the tokens
321
+ add_tokens(value_ids, 1)
322
+ elif part["type"] == "python":
323
+ # python tool call => add the tokens inside <|python_start|> and <|python_end|>
324
+ add_tokens(python_start, 1)
325
+ add_tokens(value_ids, 1)
326
+ add_tokens(python_end, 1)
327
+ elif part["type"] == "python_output":
328
+ # python output => add the tokens inside <|output_start|> and <|output_end|>
329
+ # none of these tokens are supervised because the tokens come from Python at test time
330
+ add_tokens(output_start, 0)
331
+ add_tokens(value_ids, 0)
332
+ add_tokens(output_end, 0)
333
+ else:
334
+ raise ValueError(f"Unknown part type: {part['type']}")
335
+ else:
336
+ raise ValueError(f"Unknown content type: {type(content)}")
337
+ add_tokens(assistant_end, 1)
338
+
339
+ # truncate to max_tokens tokens MAX (helps prevent OOMs)
340
+ ids = ids[:max_tokens]
341
+ mask = mask[:max_tokens]
342
+ return ids, mask
343
+
344
+ def visualize_tokenization(self, ids, mask):
345
+ """Small helper function useful in debugging: visualize the tokenization of render_conversation"""
346
+ RED = '\033[91m'
347
+ GREEN = '\033[92m'
348
+ RESET = '\033[0m'
349
+ tokens = []
350
+ for i, (token_id, mask_val) in enumerate(zip(ids, mask)):
351
+ token_str = self.decode([token_id])
352
+ color = GREEN if mask_val == 1 else RED
353
+ tokens.append(f"{color}{token_str}{RESET}")
354
+ return '|'.join(tokens)
355
+
356
+ def render_for_completion(self, conversation):
357
+ """
358
+ Used during Reinforcement Learning. In that setting, we want to
359
+ render the conversation priming the Assistant for a completion.
360
+ Unlike the Chat SFT case, we don't need to return the mask.
361
+ """
362
+ # We have some surgery to do: we need to pop the last message (of the Assistant)
363
+ conversation = copy.deepcopy(conversation) # avoid mutating the original
364
+ messages = conversation["messages"]
365
+ assert messages[-1]["role"] == "assistant", "Last message must be from the Assistant"
366
+ messages.pop() # remove the last message (of the Assistant) inplace
367
+
368
+ # Now tokenize the conversation
369
+ ids, mask = self.render_conversation(conversation)
370
+
371
+ # Finally, to prime the Assistant for a completion, append the Assistant start token
372
+ assistant_start = self.encode_special("<|assistant_start|>")
373
+ ids.append(assistant_start)
374
+ return ids
375
+
376
+ # -----------------------------------------------------------------------------
377
+ # nanochat-specific convenience functions
378
+
379
+ def get_tokenizer():
380
+ from nanochat.common import get_base_dir
381
+ base_dir = get_base_dir()
382
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
383
+ # return HuggingFaceTokenizer.from_directory(tokenizer_dir)
384
+ return RustBPETokenizer.from_directory(tokenizer_dir)
385
+
386
+ def get_token_bytes(device="cpu"):
387
+ import torch
388
+ from nanochat.common import get_base_dir
389
+ base_dir = get_base_dir()
390
+ tokenizer_dir = os.path.join(base_dir, "tokenizer")
391
+ token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
392
+ assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
393
+ with open(token_bytes_path, "rb") as f:
394
+ token_bytes = torch.load(f, map_location=device)
395
+ return token_bytes
nanochat/ui.html ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>NanoChat</title>
7
+ <link rel="icon" type="image/svg+xml" href="/logo.svg">
8
+ <style>
9
+ :root {
10
+ color-scheme: light;
11
+ }
12
+
13
+ * {
14
+ box-sizing: border-box;
15
+ }
16
+
17
+ body {
18
+ font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
19
+ background-color: #ffffff;
20
+ color: #111827;
21
+ min-height: 100vh;
22
+ margin: 0;
23
+ display: flex;
24
+ flex-direction: column;
25
+ }
26
+
27
+ .header {
28
+ background-color: #ffffff;
29
+ padding: 1.25rem 1.5rem;
30
+ }
31
+
32
+ .header-left {
33
+ display: flex;
34
+ align-items: center;
35
+ gap: 0.75rem;
36
+ }
37
+
38
+ .header-logo {
39
+ height: 32px;
40
+ width: auto;
41
+ }
42
+
43
+ .header h1 {
44
+ font-size: 1.25rem;
45
+ font-weight: 600;
46
+ margin: 0;
47
+ color: #111827;
48
+ }
49
+
50
+ .new-conversation-btn {
51
+ width: 32px;
52
+ height: 32px;
53
+ padding: 0;
54
+ border: 1px solid #e5e7eb;
55
+ border-radius: 0.5rem;
56
+ background-color: #ffffff;
57
+ color: #6b7280;
58
+ cursor: pointer;
59
+ display: flex;
60
+ align-items: center;
61
+ justify-content: center;
62
+ transition: all 0.2s ease;
63
+ }
64
+
65
+ .new-conversation-btn:hover {
66
+ background-color: #f3f4f6;
67
+ border-color: #d1d5db;
68
+ color: #374151;
69
+ }
70
+
71
+ .chat-container {
72
+ flex: 1;
73
+ overflow-y: auto;
74
+ background-color: #ffffff;
75
+ }
76
+
77
+ .chat-wrapper {
78
+ max-width: 48rem;
79
+ margin: 0 auto;
80
+ padding: 2rem 1.5rem 3rem;
81
+ display: flex;
82
+ flex-direction: column;
83
+ gap: 0.75rem;
84
+ }
85
+
86
+ .message {
87
+ display: flex;
88
+ justify-content: flex-start;
89
+ margin-bottom: 0.5rem;
90
+ color: #0d0d0d;
91
+ }
92
+
93
+ .message.assistant {
94
+ justify-content: flex-start;
95
+ }
96
+
97
+ .message.user {
98
+ justify-content: flex-end;
99
+ }
100
+
101
+ .message-content {
102
+ white-space: pre-wrap;
103
+ line-height: 1.6;
104
+ max-width: 100%;
105
+ }
106
+
107
+ .message.assistant .message-content {
108
+ background: transparent;
109
+ border: none;
110
+ padding: 0.25rem 0;
111
+ cursor: pointer;
112
+ border-radius: 0.5rem;
113
+ padding: 0.5rem;
114
+ margin-left: -0.5rem;
115
+ transition: background-color 0.2s ease;
116
+ }
117
+
118
+ .message.assistant .message-content:hover {
119
+ background-color: #f9fafb;
120
+ }
121
+
122
+ .message.user .message-content {
123
+ background-color: #f3f4f6;
124
+ border-radius: 1.25rem;
125
+ padding: 0.8rem 1rem;
126
+ max-width: 65%;
127
+ cursor: pointer;
128
+ transition: background-color 0.2s ease;
129
+ }
130
+
131
+ .message.user .message-content:hover {
132
+ background-color: #e5e7eb;
133
+ }
134
+
135
+ .message.console .message-content {
136
+ font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
137
+ font-size: 0.875rem;
138
+ background-color: #fafafa;
139
+ padding: 0.75rem 1rem;
140
+ color: #374151;
141
+ max-width: 80%;
142
+ }
143
+
144
+ .input-container {
145
+ background-color: #ffffff;
146
+ padding: 1rem;
147
+ }
148
+
149
+ .input-wrapper {
150
+ max-width: 48rem;
151
+ margin: 0 auto;
152
+ display: flex;
153
+ gap: 0.75rem;
154
+ align-items: flex-end;
155
+ }
156
+
157
+ .chat-input {
158
+ flex: 1;
159
+ padding: 0.8rem 1rem;
160
+ border: 1px solid #d1d5db;
161
+ border-radius: 0.75rem;
162
+ background-color: #ffffff;
163
+ color: #111827;
164
+ font-size: 1rem;
165
+ line-height: 1.5;
166
+ resize: none;
167
+ outline: none;
168
+ min-height: 54px;
169
+ max-height: 200px;
170
+ transition: border-color 0.2s ease, box-shadow 0.2s ease;
171
+ }
172
+
173
+ .chat-input::placeholder {
174
+ color: #9ca3af;
175
+ }
176
+
177
+ .chat-input:focus {
178
+ border-color: #2563eb;
179
+ box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
180
+ }
181
+
182
+ .send-button {
183
+ flex-shrink: 0;
184
+ padding: 0;
185
+ width: 54px;
186
+ height: 54px;
187
+ border: 1px solid #111827;
188
+ border-radius: 0.75rem;
189
+ background-color: #111827;
190
+ color: #ffffff;
191
+ display: flex;
192
+ align-items: center;
193
+ justify-content: center;
194
+ cursor: pointer;
195
+ transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
196
+ }
197
+
198
+ .send-button:hover:not(:disabled) {
199
+ background-color: #2563eb;
200
+ border-color: #2563eb;
201
+ }
202
+
203
+ .send-button:disabled {
204
+ cursor: not-allowed;
205
+ border-color: #d1d5db;
206
+ background-color: #e5e7eb;
207
+ color: #9ca3af;
208
+ }
209
+
210
+ .typing-indicator {
211
+ display: inline-block;
212
+ color: #6b7280;
213
+ letter-spacing: 0.15em;
214
+ }
215
+
216
+ .typing-indicator::after {
217
+ content: '···';
218
+ animation: typing 1.4s infinite;
219
+ }
220
+
221
+ @keyframes typing {
222
+ 0%, 60%, 100% { opacity: 0.2; }
223
+ 30% { opacity: 1; }
224
+ }
225
+
226
+ .error-message {
227
+ background-color: #fee2e2;
228
+ border: 1px solid #fecaca;
229
+ color: #b91c1c;
230
+ padding: 0.75rem 1rem;
231
+ border-radius: 0.75rem;
232
+ margin-top: 0.5rem;
233
+ }
234
+ </style>
235
+ </head>
236
+ <body>
237
+ <div class="header">
238
+ <div class="header-left">
239
+ <button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
240
+ <svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
241
+ <path d="M12 5v14"></path>
242
+ <path d="M5 12h14"></path>
243
+ </svg>
244
+ </button>
245
+ <h1>nanochat</h1>
246
+ </div>
247
+ </div>
248
+
249
+ <div class="chat-container" id="chatContainer">
250
+ <div class="chat-wrapper" id="chatWrapper">
251
+ <!-- Messages will be added here -->
252
+ </div>
253
+ </div>
254
+
255
+ <div class="input-container">
256
+ <div class="input-wrapper">
257
+ <textarea
258
+ id="chatInput"
259
+ class="chat-input"
260
+ placeholder="Ask anything"
261
+ rows="1"
262
+ onkeydown="handleKeyDown(event)"
263
+ ></textarea>
264
+ <button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
265
+ <svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
266
+ <path d="M22 2L11 13"></path>
267
+ <path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
268
+ </svg>
269
+ </button>
270
+ </div>
271
+ </div>
272
+
273
+ <script>
274
+ const API_URL = '';
275
+ const chatContainer = document.getElementById('chatContainer');
276
+ const chatWrapper = document.getElementById('chatWrapper');
277
+ const chatInput = document.getElementById('chatInput');
278
+ const sendButton = document.getElementById('sendButton');
279
+
280
+ let messages = [];
281
+ let isGenerating = false;
282
+ let currentTemperature = 0.8;
283
+ let currentTopK = 50;
284
+
285
+ chatInput.addEventListener('input', function() {
286
+ this.style.height = 'auto';
287
+ this.style.height = Math.min(this.scrollHeight, 200) + 'px';
288
+ sendButton.disabled = !this.value.trim() || isGenerating;
289
+ });
290
+
291
+ function handleKeyDown(event) {
292
+ if (event.key === 'Enter' && !event.shiftKey) {
293
+ event.preventDefault();
294
+ sendMessage();
295
+ }
296
+ }
297
+
298
+ document.addEventListener('keydown', function(event) {
299
+ // Ctrl+Shift+N for new conversation
300
+ if (event.ctrlKey && event.shiftKey && event.key === 'N') {
301
+ event.preventDefault();
302
+ if (!isGenerating) {
303
+ newConversation();
304
+ }
305
+ }
306
+ });
307
+
308
+ function newConversation() {
309
+ messages = [];
310
+ chatWrapper.innerHTML = '';
311
+ chatInput.value = '';
312
+ chatInput.style.height = 'auto';
313
+ sendButton.disabled = false;
314
+ isGenerating = false;
315
+ chatInput.focus();
316
+ }
317
+
318
+ function addMessage(role, content, messageIndex = null) {
319
+ const messageDiv = document.createElement('div');
320
+ messageDiv.className = `message ${role}`;
321
+
322
+ const contentDiv = document.createElement('div');
323
+ contentDiv.className = 'message-content';
324
+ contentDiv.textContent = content;
325
+
326
+ // Add click handler for user messages to enable editing
327
+ if (role === 'user' && messageIndex !== null) {
328
+ contentDiv.setAttribute('data-message-index', messageIndex);
329
+ contentDiv.setAttribute('title', 'Click to edit and restart from here');
330
+ contentDiv.addEventListener('click', function() {
331
+ if (!isGenerating) {
332
+ editMessage(messageIndex);
333
+ }
334
+ });
335
+ }
336
+
337
+ // Add click handler for assistant messages to enable regeneration
338
+ if (role === 'assistant' && messageIndex !== null) {
339
+ contentDiv.setAttribute('data-message-index', messageIndex);
340
+ contentDiv.setAttribute('title', 'Click to regenerate this response');
341
+ contentDiv.addEventListener('click', function() {
342
+ if (!isGenerating) {
343
+ regenerateMessage(messageIndex);
344
+ }
345
+ });
346
+ }
347
+
348
+ messageDiv.appendChild(contentDiv);
349
+ chatWrapper.appendChild(messageDiv);
350
+
351
+ chatContainer.scrollTop = chatContainer.scrollHeight;
352
+ return contentDiv;
353
+ }
354
+
355
+ function editMessage(messageIndex) {
356
+ // Find the message in the messages array
357
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
358
+
359
+ const messageToEdit = messages[messageIndex];
360
+ if (messageToEdit.role !== 'user') return;
361
+
362
+ // Copy message content to input
363
+ chatInput.value = messageToEdit.content;
364
+ chatInput.style.height = 'auto';
365
+ chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
366
+
367
+ // Remove this message and all subsequent messages from the array
368
+ messages = messages.slice(0, messageIndex);
369
+
370
+ // Remove message elements from DOM starting from messageIndex
371
+ const allMessages = chatWrapper.querySelectorAll('.message');
372
+ for (let i = messageIndex; i < allMessages.length; i++) {
373
+ allMessages[i].remove();
374
+ }
375
+
376
+ // Enable send button and focus input
377
+ sendButton.disabled = false;
378
+ chatInput.focus();
379
+ }
380
+
381
+ async function generateAssistantResponse() {
382
+ isGenerating = true;
383
+ sendButton.disabled = true;
384
+
385
+ const assistantContent = addMessage('assistant', '');
386
+ assistantContent.innerHTML = '<span class="typing-indicator"></span>';
387
+
388
+ try {
389
+ const response = await fetch(`${API_URL}/chat/completions`, {
390
+ method: 'POST',
391
+ headers: {
392
+ 'Content-Type': 'application/json',
393
+ },
394
+ body: JSON.stringify({
395
+ messages: messages,
396
+ temperature: currentTemperature,
397
+ top_k: currentTopK,
398
+ max_tokens: 512
399
+ }),
400
+ });
401
+
402
+ if (!response.ok) {
403
+ throw new Error(`HTTP error! status: ${response.status}`);
404
+ }
405
+
406
+ const reader = response.body.getReader();
407
+ const decoder = new TextDecoder();
408
+ let fullResponse = '';
409
+ assistantContent.textContent = '';
410
+
411
+ while (true) {
412
+ const { done, value } = await reader.read();
413
+ if (done) break;
414
+
415
+ const chunk = decoder.decode(value);
416
+ const lines = chunk.split('\n');
417
+
418
+ for (const line of lines) {
419
+ if (line.startsWith('data: ')) {
420
+ try {
421
+ const data = JSON.parse(line.slice(6));
422
+ if (data.token) {
423
+ fullResponse += data.token;
424
+ assistantContent.textContent = fullResponse;
425
+ chatContainer.scrollTop = chatContainer.scrollHeight;
426
+ }
427
+ } catch (e) {
428
+ }
429
+ }
430
+ }
431
+ }
432
+
433
+ const assistantMessageIndex = messages.length;
434
+ messages.push({ role: 'assistant', content: fullResponse });
435
+
436
+ // Add click handler to regenerate this assistant message
437
+ assistantContent.setAttribute('data-message-index', assistantMessageIndex);
438
+ assistantContent.setAttribute('title', 'Click to regenerate this response');
439
+ assistantContent.addEventListener('click', function() {
440
+ if (!isGenerating) {
441
+ regenerateMessage(assistantMessageIndex);
442
+ }
443
+ });
444
+
445
+ } catch (error) {
446
+ console.error('Error:', error);
447
+ assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
448
+ } finally {
449
+ isGenerating = false;
450
+ sendButton.disabled = !chatInput.value.trim();
451
+ }
452
+ }
453
+
454
+ async function regenerateMessage(messageIndex) {
455
+ // Find the message in the messages array
456
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
457
+
458
+ const messageToRegenerate = messages[messageIndex];
459
+ if (messageToRegenerate.role !== 'assistant') return;
460
+
461
+ // Remove this message and all subsequent messages from the array
462
+ messages = messages.slice(0, messageIndex);
463
+
464
+ // Remove message elements from DOM starting from messageIndex
465
+ const allMessages = chatWrapper.querySelectorAll('.message');
466
+ for (let i = messageIndex; i < allMessages.length; i++) {
467
+ allMessages[i].remove();
468
+ }
469
+
470
+ // Regenerate the assistant response
471
+ await generateAssistantResponse();
472
+ }
473
+
474
+ function handleSlashCommand(command) {
475
+ const parts = command.trim().split(/\s+/);
476
+ const cmd = parts[0].toLowerCase();
477
+ const arg = parts[1];
478
+
479
+ if (cmd === '/temperature') {
480
+ if (arg === undefined) {
481
+ addMessage('console', `Current temperature: ${currentTemperature}`);
482
+ } else {
483
+ const temp = parseFloat(arg);
484
+ if (isNaN(temp) || temp < 0 || temp > 2) {
485
+ addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
486
+ } else {
487
+ currentTemperature = temp;
488
+ addMessage('console', `Temperature set to ${currentTemperature}`);
489
+ }
490
+ }
491
+ return true;
492
+ } else if (cmd === '/topk') {
493
+ if (arg === undefined) {
494
+ addMessage('console', `Current top-k: ${currentTopK}`);
495
+ } else {
496
+ const topk = parseInt(arg);
497
+ if (isNaN(topk) || topk < 1 || topk > 200) {
498
+ addMessage('console', 'Invalid top-k. Must be between 1 and 200');
499
+ } else {
500
+ currentTopK = topk;
501
+ addMessage('console', `Top-k set to ${currentTopK}`);
502
+ }
503
+ }
504
+ return true;
505
+ } else if (cmd === '/clear') {
506
+ newConversation();
507
+ return true;
508
+ } else if (cmd === '/help') {
509
+ addMessage('console',
510
+ 'Available commands:\n' +
511
+ '/temperature - Show current temperature\n' +
512
+ '/temperature <value> - Set temperature (0.0-2.0)\n' +
513
+ '/topk - Show current top-k\n' +
514
+ '/topk <value> - Set top-k (1-200)\n' +
515
+ '/clear - Clear conversation\n' +
516
+ '/help - Show this help message'
517
+ );
518
+ return true;
519
+ }
520
+ return false;
521
+ }
522
+
523
+ async function sendMessage() {
524
+ const message = chatInput.value.trim();
525
+ if (!message || isGenerating) return;
526
+
527
+ // Handle slash commands
528
+ if (message.startsWith('/')) {
529
+ chatInput.value = '';
530
+ chatInput.style.height = 'auto';
531
+ handleSlashCommand(message);
532
+ return;
533
+ }
534
+
535
+ chatInput.value = '';
536
+ chatInput.style.height = 'auto';
537
+
538
+ const userMessageIndex = messages.length;
539
+ messages.push({ role: 'user', content: message });
540
+ addMessage('user', message, userMessageIndex);
541
+
542
+ await generateAssistantResponse();
543
+ }
544
+
545
+ sendButton.disabled = false;
546
+
547
+ // Autofocus the chat input on page load
548
+ chatInput.focus();
549
+
550
+ fetch(`${API_URL}/health`)
551
+ .then(response => response.json())
552
+ .then(data => {
553
+ console.log('Engine status:', data);
554
+ })
555
+ .catch(error => {
556
+ console.error('Engine not available:', error);
557
+ chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
558
+ });
559
+ </script>
560
+ </body>
561
+ </html>
pyproject.toml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "nanochat"
3
+ version = "0.1.0"
4
+ description = "the minimal full-stack ChatGPT clone"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "datasets>=4.0.0",
9
+ "fastapi>=0.117.1",
10
+ "files-to-prompt>=0.6",
11
+ "numpy==1.26.4",
12
+ "psutil>=7.1.0",
13
+ "regex>=2025.9.1",
14
+ "tiktoken>=0.11.0",
15
+ "tokenizers>=0.22.0",
16
+ "torch>=2.8.0",
17
+ "uvicorn>=0.36.0",
18
+ "wandb>=0.21.3",
19
+ ]
20
+
21
+ [build-system]
22
+ requires = ["maturin>=1.7,<2.0"]
23
+ build-backend = "maturin"
24
+
25
+ # target torch to cuda 12.8
26
+ [tool.uv.sources]
27
+ torch = [
28
+ { index = "pytorch-cu128" },
29
+ ]
30
+
31
+ [[tool.uv.index]]
32
+ name = "pytorch-cu128"
33
+ url = "https://download.pytorch.org/whl/cu128"
34
+ explicit = true
35
+
36
+ [tool.maturin]
37
+ module-name = "rustbpe"
38
+ bindings = "pyo3"
39
+ python-source = "."
40
+ manifest-path = "rustbpe/Cargo.toml"
41
+
42
+ [dependency-groups]
43
+ dev = [
44
+ "maturin>=1.9.4",
45
+ "pytest>=8.0.0",
46
+ ]
47
+
48
+ [tool.pytest.ini_options]
49
+ markers = [
50
+ "slow: marks tests as slow (deselect with '-m \"not slow\"')",
51
+ ]
52
+ testpaths = ["tests"]
53
+ python_files = ["test_*.py"]
54
+ python_classes = ["Test*"]
55
+ python_functions = ["test_*"]
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets>=4.0.0
2
+ fastapi>=0.117.1
3
+ huggingface_hub>=0.20.0
4
+ numpy==1.26.4
5
+ psutil>=7.1.0
6
+ regex>=2025.9.1
7
+ tiktoken>=0.11.0
8
+ tokenizers>=0.22.0
9
+ torch>=2.8.0
10
+ uvicorn>=0.36.0
11
+ maturin>=1.7,<2.0
rustbpe/Cargo.lock ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is automatically @generated by Cargo.
2
+ # It is not intended for manual editing.
3
+ version = 4
4
+
5
+ [[package]]
6
+ name = "ahash"
7
+ version = "0.8.12"
8
+ source = "registry+https://github.com/rust-lang/crates.io-index"
9
+ checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75"
10
+ dependencies = [
11
+ "cfg-if",
12
+ "getrandom",
13
+ "once_cell",
14
+ "version_check",
15
+ "zerocopy",
16
+ ]
17
+
18
+ [[package]]
19
+ name = "aho-corasick"
20
+ version = "1.1.3"
21
+ source = "registry+https://github.com/rust-lang/crates.io-index"
22
+ checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
23
+ dependencies = [
24
+ "memchr",
25
+ ]
26
+
27
+ [[package]]
28
+ name = "arc-swap"
29
+ version = "1.7.1"
30
+ source = "registry+https://github.com/rust-lang/crates.io-index"
31
+ checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
32
+
33
+ [[package]]
34
+ name = "autocfg"
35
+ version = "1.5.0"
36
+ source = "registry+https://github.com/rust-lang/crates.io-index"
37
+ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
38
+
39
+ [[package]]
40
+ name = "bit-set"
41
+ version = "0.8.0"
42
+ source = "registry+https://github.com/rust-lang/crates.io-index"
43
+ checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
44
+ dependencies = [
45
+ "bit-vec",
46
+ ]
47
+
48
+ [[package]]
49
+ name = "bit-vec"
50
+ version = "0.8.0"
51
+ source = "registry+https://github.com/rust-lang/crates.io-index"
52
+ checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
53
+
54
+ [[package]]
55
+ name = "castaway"
56
+ version = "0.2.4"
57
+ source = "registry+https://github.com/rust-lang/crates.io-index"
58
+ checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
59
+ dependencies = [
60
+ "rustversion",
61
+ ]
62
+
63
+ [[package]]
64
+ name = "cfg-if"
65
+ version = "1.0.3"
66
+ source = "registry+https://github.com/rust-lang/crates.io-index"
67
+ checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
68
+
69
+ [[package]]
70
+ name = "compact_str"
71
+ version = "0.9.0"
72
+ source = "registry+https://github.com/rust-lang/crates.io-index"
73
+ checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a"
74
+ dependencies = [
75
+ "castaway",
76
+ "cfg-if",
77
+ "itoa",
78
+ "rustversion",
79
+ "ryu",
80
+ "static_assertions",
81
+ ]
82
+
83
+ [[package]]
84
+ name = "crossbeam-deque"
85
+ version = "0.8.6"
86
+ source = "registry+https://github.com/rust-lang/crates.io-index"
87
+ checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
88
+ dependencies = [
89
+ "crossbeam-epoch",
90
+ "crossbeam-utils",
91
+ ]
92
+
93
+ [[package]]
94
+ name = "crossbeam-epoch"
95
+ version = "0.9.18"
96
+ source = "registry+https://github.com/rust-lang/crates.io-index"
97
+ checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
98
+ dependencies = [
99
+ "crossbeam-utils",
100
+ ]
101
+
102
+ [[package]]
103
+ name = "crossbeam-utils"
104
+ version = "0.8.21"
105
+ source = "registry+https://github.com/rust-lang/crates.io-index"
106
+ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
107
+
108
+ [[package]]
109
+ name = "dary_heap"
110
+ version = "0.3.7"
111
+ source = "registry+https://github.com/rust-lang/crates.io-index"
112
+ checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728"
113
+
114
+ [[package]]
115
+ name = "either"
116
+ version = "1.15.0"
117
+ source = "registry+https://github.com/rust-lang/crates.io-index"
118
+ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
119
+
120
+ [[package]]
121
+ name = "equivalent"
122
+ version = "1.0.2"
123
+ source = "registry+https://github.com/rust-lang/crates.io-index"
124
+ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
125
+
126
+ [[package]]
127
+ name = "fancy-regex"
128
+ version = "0.16.1"
129
+ source = "registry+https://github.com/rust-lang/crates.io-index"
130
+ checksum = "bf04c5ec15464ace8355a7b440a33aece288993475556d461154d7a62ad9947c"
131
+ dependencies = [
132
+ "bit-set",
133
+ "regex-automata",
134
+ "regex-syntax",
135
+ ]
136
+
137
+ [[package]]
138
+ name = "getrandom"
139
+ version = "0.3.3"
140
+ source = "registry+https://github.com/rust-lang/crates.io-index"
141
+ checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
142
+ dependencies = [
143
+ "cfg-if",
144
+ "libc",
145
+ "r-efi",
146
+ "wasi",
147
+ ]
148
+
149
+ [[package]]
150
+ name = "hashbrown"
151
+ version = "0.15.5"
152
+ source = "registry+https://github.com/rust-lang/crates.io-index"
153
+ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
154
+
155
+ [[package]]
156
+ name = "heck"
157
+ version = "0.5.0"
158
+ source = "registry+https://github.com/rust-lang/crates.io-index"
159
+ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
160
+
161
+ [[package]]
162
+ name = "indexmap"
163
+ version = "2.11.0"
164
+ source = "registry+https://github.com/rust-lang/crates.io-index"
165
+ checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9"
166
+ dependencies = [
167
+ "equivalent",
168
+ "hashbrown",
169
+ ]
170
+
171
+ [[package]]
172
+ name = "indoc"
173
+ version = "2.0.6"
174
+ source = "registry+https://github.com/rust-lang/crates.io-index"
175
+ checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
176
+
177
+ [[package]]
178
+ name = "itoa"
179
+ version = "1.0.15"
180
+ source = "registry+https://github.com/rust-lang/crates.io-index"
181
+ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
182
+
183
+ [[package]]
184
+ name = "libc"
185
+ version = "0.2.175"
186
+ source = "registry+https://github.com/rust-lang/crates.io-index"
187
+ checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
188
+
189
+ [[package]]
190
+ name = "log"
191
+ version = "0.4.28"
192
+ source = "registry+https://github.com/rust-lang/crates.io-index"
193
+ checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432"
194
+
195
+ [[package]]
196
+ name = "memchr"
197
+ version = "2.7.5"
198
+ source = "registry+https://github.com/rust-lang/crates.io-index"
199
+ checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
200
+
201
+ [[package]]
202
+ name = "memoffset"
203
+ version = "0.9.1"
204
+ source = "registry+https://github.com/rust-lang/crates.io-index"
205
+ checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
206
+ dependencies = [
207
+ "autocfg",
208
+ ]
209
+
210
+ [[package]]
211
+ name = "once_cell"
212
+ version = "1.21.3"
213
+ source = "registry+https://github.com/rust-lang/crates.io-index"
214
+ checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
215
+
216
+ [[package]]
217
+ name = "portable-atomic"
218
+ version = "1.11.1"
219
+ source = "registry+https://github.com/rust-lang/crates.io-index"
220
+ checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483"
221
+
222
+ [[package]]
223
+ name = "proc-macro2"
224
+ version = "1.0.101"
225
+ source = "registry+https://github.com/rust-lang/crates.io-index"
226
+ checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
227
+ dependencies = [
228
+ "unicode-ident",
229
+ ]
230
+
231
+ [[package]]
232
+ name = "pyo3"
233
+ version = "0.23.5"
234
+ source = "registry+https://github.com/rust-lang/crates.io-index"
235
+ checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
236
+ dependencies = [
237
+ "cfg-if",
238
+ "indoc",
239
+ "libc",
240
+ "memoffset",
241
+ "once_cell",
242
+ "portable-atomic",
243
+ "pyo3-build-config",
244
+ "pyo3-ffi",
245
+ "pyo3-macros",
246
+ "unindent",
247
+ ]
248
+
249
+ [[package]]
250
+ name = "pyo3-build-config"
251
+ version = "0.23.5"
252
+ source = "registry+https://github.com/rust-lang/crates.io-index"
253
+ checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
254
+ dependencies = [
255
+ "once_cell",
256
+ "target-lexicon",
257
+ ]
258
+
259
+ [[package]]
260
+ name = "pyo3-ffi"
261
+ version = "0.23.5"
262
+ source = "registry+https://github.com/rust-lang/crates.io-index"
263
+ checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
264
+ dependencies = [
265
+ "libc",
266
+ "pyo3-build-config",
267
+ ]
268
+
269
+ [[package]]
270
+ name = "pyo3-log"
271
+ version = "0.12.4"
272
+ source = "registry+https://github.com/rust-lang/crates.io-index"
273
+ checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264"
274
+ dependencies = [
275
+ "arc-swap",
276
+ "log",
277
+ "pyo3",
278
+ ]
279
+
280
+ [[package]]
281
+ name = "pyo3-macros"
282
+ version = "0.23.5"
283
+ source = "registry+https://github.com/rust-lang/crates.io-index"
284
+ checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
285
+ dependencies = [
286
+ "proc-macro2",
287
+ "pyo3-macros-backend",
288
+ "quote",
289
+ "syn",
290
+ ]
291
+
292
+ [[package]]
293
+ name = "pyo3-macros-backend"
294
+ version = "0.23.5"
295
+ source = "registry+https://github.com/rust-lang/crates.io-index"
296
+ checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
297
+ dependencies = [
298
+ "heck",
299
+ "proc-macro2",
300
+ "pyo3-build-config",
301
+ "quote",
302
+ "syn",
303
+ ]
304
+
305
+ [[package]]
306
+ name = "quote"
307
+ version = "1.0.40"
308
+ source = "registry+https://github.com/rust-lang/crates.io-index"
309
+ checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
310
+ dependencies = [
311
+ "proc-macro2",
312
+ ]
313
+
314
+ [[package]]
315
+ name = "r-efi"
316
+ version = "5.3.0"
317
+ source = "registry+https://github.com/rust-lang/crates.io-index"
318
+ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
319
+
320
+ [[package]]
321
+ name = "rayon"
322
+ version = "1.11.0"
323
+ source = "registry+https://github.com/rust-lang/crates.io-index"
324
+ checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
325
+ dependencies = [
326
+ "either",
327
+ "rayon-core",
328
+ ]
329
+
330
+ [[package]]
331
+ name = "rayon-core"
332
+ version = "1.13.0"
333
+ source = "registry+https://github.com/rust-lang/crates.io-index"
334
+ checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
335
+ dependencies = [
336
+ "crossbeam-deque",
337
+ "crossbeam-utils",
338
+ ]
339
+
340
+ [[package]]
341
+ name = "regex-automata"
342
+ version = "0.4.10"
343
+ source = "registry+https://github.com/rust-lang/crates.io-index"
344
+ checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6"
345
+ dependencies = [
346
+ "aho-corasick",
347
+ "memchr",
348
+ "regex-syntax",
349
+ ]
350
+
351
+ [[package]]
352
+ name = "regex-syntax"
353
+ version = "0.8.6"
354
+ source = "registry+https://github.com/rust-lang/crates.io-index"
355
+ checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001"
356
+
357
+ [[package]]
358
+ name = "rustbpe"
359
+ version = "0.1.0"
360
+ dependencies = [
361
+ "ahash",
362
+ "compact_str",
363
+ "dary_heap",
364
+ "fancy-regex",
365
+ "indexmap",
366
+ "log",
367
+ "pyo3",
368
+ "pyo3-log",
369
+ "rayon",
370
+ ]
371
+
372
+ [[package]]
373
+ name = "rustversion"
374
+ version = "1.0.22"
375
+ source = "registry+https://github.com/rust-lang/crates.io-index"
376
+ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
377
+
378
+ [[package]]
379
+ name = "ryu"
380
+ version = "1.0.20"
381
+ source = "registry+https://github.com/rust-lang/crates.io-index"
382
+ checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
383
+
384
+ [[package]]
385
+ name = "static_assertions"
386
+ version = "1.1.0"
387
+ source = "registry+https://github.com/rust-lang/crates.io-index"
388
+ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
389
+
390
+ [[package]]
391
+ name = "syn"
392
+ version = "2.0.106"
393
+ source = "registry+https://github.com/rust-lang/crates.io-index"
394
+ checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6"
395
+ dependencies = [
396
+ "proc-macro2",
397
+ "quote",
398
+ "unicode-ident",
399
+ ]
400
+
401
+ [[package]]
402
+ name = "target-lexicon"
403
+ version = "0.12.16"
404
+ source = "registry+https://github.com/rust-lang/crates.io-index"
405
+ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
406
+
407
+ [[package]]
408
+ name = "unicode-ident"
409
+ version = "1.0.18"
410
+ source = "registry+https://github.com/rust-lang/crates.io-index"
411
+ checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
412
+
413
+ [[package]]
414
+ name = "unindent"
415
+ version = "0.2.4"
416
+ source = "registry+https://github.com/rust-lang/crates.io-index"
417
+ checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
418
+
419
+ [[package]]
420
+ name = "version_check"
421
+ version = "0.9.5"
422
+ source = "registry+https://github.com/rust-lang/crates.io-index"
423
+ checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
424
+
425
+ [[package]]
426
+ name = "wasi"
427
+ version = "0.14.4+wasi-0.2.4"
428
+ source = "registry+https://github.com/rust-lang/crates.io-index"
429
+ checksum = "88a5f4a424faf49c3c2c344f166f0662341d470ea185e939657aaff130f0ec4a"
430
+ dependencies = [
431
+ "wit-bindgen",
432
+ ]
433
+
434
+ [[package]]
435
+ name = "wit-bindgen"
436
+ version = "0.45.1"
437
+ source = "registry+https://github.com/rust-lang/crates.io-index"
438
+ checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36"
439
+
440
+ [[package]]
441
+ name = "zerocopy"
442
+ version = "0.8.26"
443
+ source = "registry+https://github.com/rust-lang/crates.io-index"
444
+ checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f"
445
+ dependencies = [
446
+ "zerocopy-derive",
447
+ ]
448
+
449
+ [[package]]
450
+ name = "zerocopy-derive"
451
+ version = "0.8.26"
452
+ source = "registry+https://github.com/rust-lang/crates.io-index"
453
+ checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
454
+ dependencies = [
455
+ "proc-macro2",
456
+ "quote",
457
+ "syn",
458
+ ]
rustbpe/Cargo.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [package]
2
+ name = "rustbpe"
3
+ version = "0.1.0"
4
+ edition = "2024"
5
+
6
+ [dependencies]
7
+ dary_heap = "0.3"
8
+ indexmap = "2.2"
9
+ fancy-regex = "0.16.1"
10
+ log = "0.4.28"
11
+ pyo3 = { version = "0.23.3", features = ["extension-module"] }
12
+ pyo3-log = "0.12.4"
13
+ ahash = "0.8.12"
14
+ rayon = "1.11.0"
15
+ compact_str = "0.9.0"
rustbpe/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # rustbpe
2
+
3
+ > The missing tiktoken training code
4
+
5
+ A very lightweight Rust library for training a GPT tokenizer. The issue is that the inference library [tiktoken](https://github.com/openai/tiktoken) is great, but only does inference. Separately, the huggingface [tokenizers](https://github.com/huggingface/tokenizers) library does training, but it is rather bloated and really hard to navigate because it has to support all the different historical baggage of how people dealt with tokenizers over the years. More recently, I also wrote the [minbpe](https://github.com/karpathy/minbpe) library which does both training and inference, but only in inefficient Python. Basically what I really want is a non-fancy, super simple, but still relatively efficient training code for GPT tokenizer (more efficient than minbpe, much cleaner/simpler than tokenizers), and then export the trained vocab for inference with tiktoken. Does that make sense? So here we are. There are more opportunities for optimization here, I just stopped a bit early because unlike minbpe before it, rustbpe is now simple and fast enough, and not a significant bottleneck for nanochat.
rustbpe/src/lib.rs ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::cmp::Ordering;
2
+ use std::collections::HashMap as StdHashMap;
3
+
4
+ use dary_heap::OctonaryHeap;
5
+ use fancy_regex::Regex;
6
+ use pyo3::prelude::*;
7
+
8
+ use ahash::{AHashMap, AHashSet};
9
+ use compact_str::CompactString;
10
+ use rayon::prelude::*;
11
+
12
+ // Default GPT-4 style regex pattern for splitting text
13
+ const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
14
+
15
+ type Pair = (u32, u32);
16
+
17
+ /// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
18
+ #[pyclass]
19
+ pub struct Tokenizer {
20
+ /// Maps pairs of token IDs to their merged token ID
21
+ pub merges: StdHashMap<Pair, u32>,
22
+ /// The regex pattern used for text splitting
23
+ pub pattern: String,
24
+ /// Compiled regex for efficiency
25
+ compiled_pattern: Regex,
26
+ }
27
+
28
+ // ------------------------ internal helpers ------------------------
29
+
30
+ #[derive(Clone, Debug)]
31
+ struct Word {
32
+ ids: Vec<u32>,
33
+ }
34
+
35
+ impl Word {
36
+ #[inline]
37
+ fn new(ids: Vec<u32>) -> Self {
38
+ Self { ids }
39
+ }
40
+
41
+ #[inline]
42
+ fn pairs<'a>(&'a self) -> impl Iterator<Item = Pair> + 'a {
43
+ self.ids.windows(2).map(|w| (w[0], w[1]))
44
+ }
45
+
46
+ /// Merge all non-overlapping occurrences of pair -> new_id.
47
+ /// Returns a small Vec of local pair-count deltas for THIS word only:
48
+ /// -1 for removed pairs, +1 for newly created pairs.
49
+ ///
50
+ /// NOTE: this version deliberately avoids a HashMap in the hot loop.
51
+ fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> {
52
+ let (a, b) = pair;
53
+ let n = self.ids.len();
54
+ if n < 2 {
55
+ return Vec::new();
56
+ }
57
+
58
+ let mut out: Vec<u32> = Vec::with_capacity(n);
59
+ let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6);
60
+
61
+ let mut i = 0;
62
+ while i < n {
63
+ if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b {
64
+ let left = out.last().copied();
65
+ let right = if i + 2 < n { Some(self.ids[i + 2]) } else { None };
66
+
67
+ // remove old pairs
68
+ if let Some(x) = left {
69
+ deltas.push(((x, a), -1));
70
+ deltas.push(((x, new_id), 1));
71
+ }
72
+ deltas.push(((a, b), -1));
73
+ if let Some(y) = right {
74
+ deltas.push(((b, y), -1));
75
+ deltas.push(((new_id, y), 1));
76
+ }
77
+
78
+ // write merged token
79
+ out.push(new_id);
80
+ i += 2; // skip 'a' and 'b'
81
+ } else {
82
+ out.push(self.ids[i]);
83
+ i += 1;
84
+ }
85
+ }
86
+
87
+ self.ids = out;
88
+ deltas
89
+ }
90
+ }
91
+
92
+ #[derive(Debug, Eq)]
93
+ struct MergeJob {
94
+ pair: Pair,
95
+ count: u64,
96
+ /// set of word indices where this pair may occur and needs processing
97
+ pos: AHashSet<usize>,
98
+ }
99
+
100
+ impl PartialEq for MergeJob {
101
+ fn eq(&self, other: &Self) -> bool {
102
+ self.count == other.count && self.pair == other.pair
103
+ }
104
+ }
105
+
106
+ impl PartialOrd for MergeJob {
107
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108
+ Some(self.cmp(other))
109
+ }
110
+ }
111
+
112
+ impl Ord for MergeJob {
113
+ fn cmp(&self, other: &Self) -> Ordering {
114
+ // Max-heap by count; tie-break to ascending pair order (deterministic)
115
+ if self.count != other.count {
116
+ self.count.cmp(&other.count)
117
+ } else {
118
+ // ascending order on the pair when counts tie
119
+ other.pair.cmp(&self.pair)
120
+ }
121
+ }
122
+ }
123
+
124
+ #[inline]
125
+ fn count_pairs_parallel(
126
+ words: &[Word],
127
+ counts: &[i32],
128
+ ) -> (AHashMap<Pair, i32>, AHashMap<Pair, AHashSet<usize>>) {
129
+ words
130
+ .par_iter()
131
+ .enumerate()
132
+ .map(|(i, w)| {
133
+ let mut local_pc: AHashMap<Pair, i32> = AHashMap::new();
134
+ let mut local_wtu: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
135
+ if w.ids.len() >= 2 && counts[i] != 0 {
136
+ for (a, b) in w.pairs() {
137
+ *local_pc.entry((a, b)).or_default() += counts[i];
138
+ local_wtu.entry((a, b)).or_default().insert(i);
139
+ }
140
+ }
141
+ (local_pc, local_wtu)
142
+ })
143
+ .reduce(
144
+ || (AHashMap::new(), AHashMap::new()),
145
+ |(mut acc_pc, mut acc_wtu), (pc, wtu)| {
146
+ for (k, v) in pc {
147
+ *acc_pc.entry(k).or_default() += v;
148
+ }
149
+ for (k, s) in wtu {
150
+ acc_wtu.entry(k).or_default().extend(s);
151
+ }
152
+ (acc_pc, acc_wtu)
153
+ },
154
+ )
155
+ }
156
+
157
+ // ------------------------ END helpers ------------------------
158
+
159
+ impl Tokenizer {
160
+
161
+ /// Core incremental BPE training given unique words and their counts.
162
+ /// `words`: one entry per unique chunk (Vec<u32> of token-ids/bytes).
163
+ /// `counts`: same length as `words`, count per chunk.
164
+ fn train_core_incremental(&mut self, mut words: Vec<Word>, counts: Vec<i32>, vocab_size: u32) {
165
+ assert!(vocab_size >= 256, "vocab_size must be at least 256");
166
+ let num_merges = vocab_size - 256;
167
+ log::info!("Starting BPE training: {} merges to compute", num_merges);
168
+ self.merges.clear();
169
+
170
+ // ---- Initial pair_counts and where_to_update (parallel) ----
171
+ log::info!("Computing initial pair counts from {} unique sequences", words.len());
172
+ let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts);
173
+
174
+ // ---- Build heap ----
175
+ log::info!("Building heap with {} unique pairs", pair_counts.len());
176
+ let mut heap = OctonaryHeap::with_capacity(pair_counts.len());
177
+ for (pair, pos) in where_to_update.drain() {
178
+ let c = *pair_counts.get(&pair).unwrap_or(&0);
179
+ if c > 0 {
180
+ heap.push(MergeJob {
181
+ pair,
182
+ count: c as u64,
183
+ pos,
184
+ });
185
+ }
186
+ }
187
+
188
+ // ---- Merge loop ----
189
+ log::info!("Starting merge loop");
190
+ let mut merges_done = 0u32;
191
+ let mut last_log_percent = 0u32;
192
+
193
+ while merges_done < num_merges {
194
+ let Some(mut top) = heap.pop() else { break; };
195
+
196
+ // Lazy refresh
197
+ let current = *pair_counts.get(&top.pair).unwrap_or(&0);
198
+ if top.count != current as u64 {
199
+ top.count = current as u64;
200
+ if top.count > 0 {
201
+ heap.push(top);
202
+ }
203
+ continue;
204
+ }
205
+ if top.count == 0 {
206
+ break;
207
+ }
208
+
209
+ // Record merge
210
+ let new_id = 256 + merges_done;
211
+ self.merges.insert(top.pair, new_id);
212
+
213
+ // Merge this pair in all words where it occurs
214
+ let mut local_pos_updates: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
215
+ for &word_idx in &top.pos {
216
+ // Apply merge to this word and collect pair-count deltas
217
+ let changes = words[word_idx].merge_pair(top.pair, new_id);
218
+ // Update global pair counts based on this word's count
219
+ for (pair, delta) in changes {
220
+ let delta_total = delta * counts[word_idx];
221
+ if delta_total != 0 {
222
+ *pair_counts.entry(pair).or_default() += delta_total;
223
+ if delta > 0 {
224
+ local_pos_updates.entry(pair).or_default().insert(word_idx);
225
+ }
226
+ }
227
+ }
228
+ }
229
+
230
+ // Add the updated pair counts back to the heap
231
+ for (pair, pos) in local_pos_updates {
232
+ let cnt = *pair_counts.get(&pair).unwrap_or(&0);
233
+ if cnt > 0 {
234
+ heap.push(MergeJob {
235
+ pair,
236
+ count: cnt as u64,
237
+ pos,
238
+ });
239
+ }
240
+ }
241
+
242
+ merges_done += 1;
243
+
244
+ // Log progress every 1%
245
+ let current_percent = (merges_done * 100) / num_merges;
246
+ if current_percent > last_log_percent {
247
+ log::info!(
248
+ "Progress: {}% ({}/{} merges) - Last merge: {:?} -> {} (frequency: {})",
249
+ current_percent, merges_done, num_merges, top.pair, new_id, top.count
250
+ );
251
+ last_log_percent = current_percent;
252
+ }
253
+ }
254
+
255
+ log::info!("Finished training: {} merges completed", merges_done);
256
+ }
257
+ }
258
+
259
+ /// Public methods for the Tokenizer class that will be exposed to Python.
260
+ #[pymethods]
261
+ impl Tokenizer {
262
+ /// Create a new Tokenizer
263
+ #[new]
264
+ pub fn new() -> Self {
265
+ Self {
266
+ merges: StdHashMap::new(),
267
+ pattern: String::new(),
268
+ compiled_pattern: Regex::new("").expect("Empty regex should be valid"),
269
+ }
270
+ }
271
+
272
+ /// Train from a streaming iterator (parallel ingestion).
273
+ /// We refill a Rust Vec<String> buffer under the GIL, then release the GIL
274
+ /// to do the heavy splitting and counting **in parallel** with rayon.
275
+ #[pyo3(signature = (iterator, vocab_size, buffer_size=8192, pattern=None))]
276
+ #[pyo3(text_signature = "(self, iterator, vocab_size, buffer_size=8192, pattern=None)")]
277
+ pub fn train_from_iterator(
278
+ &mut self,
279
+ py: pyo3::Python<'_>,
280
+ iterator: &pyo3::Bound<'_, pyo3::PyAny>,
281
+ vocab_size: u32,
282
+ buffer_size: usize,
283
+ pattern: Option<String>,
284
+ ) -> PyResult<()> {
285
+ // Use provided pattern or default to GPT-4 pattern
286
+ let pattern_str = pattern.unwrap_or_else(|| GPT4_PATTERN.to_string());
287
+
288
+ // Update the stored pattern and compile it
289
+ self.pattern = pattern_str.clone();
290
+ self.compiled_pattern = Regex::new(&pattern_str)
291
+ .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
292
+
293
+ // Prepare a true Python iterator object
294
+ let py_iter: pyo3::Py<pyo3::PyAny> = unsafe {
295
+ pyo3::Bound::from_borrowed_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
296
+ .into()
297
+ };
298
+
299
+ // Global chunk counts
300
+ let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
301
+
302
+ // Temporary buffer we refill under the GIL
303
+ let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
304
+
305
+ log::info!("Processing sequences from iterator (buffer_size: {})", buffer_size);
306
+ let mut total_sequences = 0u64;
307
+
308
+ // Helper: refill `buf` with up to `buffer_size` strings from the Python iterator.
309
+ // Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise.
310
+ let refill = |buf: &mut Vec<String>| -> PyResult<bool> {
311
+ pyo3::Python::with_gil(|py| {
312
+ buf.clear();
313
+ let it = py_iter.bind(py);
314
+ loop {
315
+ if buf.len() >= buffer_size {
316
+ return Ok(false);
317
+ }
318
+ // next(it)
319
+ let next_obj = unsafe {
320
+ pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr()))
321
+ };
322
+ match next_obj {
323
+ Some(obj) => {
324
+ let s: String = obj.extract()?;
325
+ buf.push(s);
326
+ }
327
+ None => {
328
+ if pyo3::PyErr::occurred(py) {
329
+ return Err(pyo3::PyErr::fetch(py));
330
+ } else {
331
+ return Ok(true); // exhausted
332
+ }
333
+ }
334
+ }
335
+ }
336
+ })
337
+ };
338
+
339
+ // Stream ingestion loop: refill under GIL, process without GIL (parallel)
340
+ loop {
341
+ let exhausted = refill(&mut buf)?;
342
+ if buf.is_empty() && exhausted {
343
+ break;
344
+ }
345
+
346
+ total_sequences += buf.len() as u64;
347
+
348
+ let pattern = self.compiled_pattern.clone();
349
+ let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
350
+ buf.par_iter()
351
+ .map(|s| {
352
+ let mut m: AHashMap<CompactString, i32> = AHashMap::new();
353
+ for mat in pattern.find_iter(s) {
354
+ let piece = mat.expect("regex match failed").as_str();
355
+ *m.entry(CompactString::from(piece)).or_default() += 1;
356
+ }
357
+ m
358
+ })
359
+ .reduce(
360
+ || AHashMap::new(),
361
+ |mut a, b| {
362
+ for (k, v) in b {
363
+ *a.entry(k).or_default() += v;
364
+ }
365
+ a
366
+ },
367
+ )
368
+ });
369
+
370
+ // Merge local into global (single-threaded)
371
+ for (k, v) in local {
372
+ *counts.entry(k).or_default() += v;
373
+ }
374
+
375
+ if exhausted {
376
+ break;
377
+ }
378
+ }
379
+ log::info!("Processed {} sequences total, {} unique", total_sequences, counts.len());
380
+
381
+ // Materialize words & counts
382
+ let mut words = Vec::with_capacity(counts.len());
383
+ let mut cvec = Vec::with_capacity(counts.len());
384
+ for (chunk, c) in counts.into_iter() {
385
+ words.push(Word::new(chunk.as_bytes().iter().map(|&b| b as u32).collect()));
386
+ cvec.push(c);
387
+ }
388
+
389
+ self.train_core_incremental(words, cvec, vocab_size);
390
+ Ok(())
391
+ }
392
+
393
+ /// Return the regex pattern
394
+ pub fn get_pattern(&self) -> String {
395
+ self.pattern.clone()
396
+ }
397
+
398
+ /// Return the mergeable ranks (token bytes -> token id / rank)
399
+ pub fn get_mergeable_ranks(&self) -> Vec<(Vec<u8>, u32)> {
400
+ let mut mergeable_ranks = Vec::new();
401
+
402
+ // Build vocabulary incrementally from low to high token IDs
403
+ let mut token_bytes: Vec<Vec<u8>> = (0..256_u32).map(|i| vec![i as u8]).collect();
404
+
405
+ for (i, bytes) in token_bytes.iter().enumerate() {
406
+ mergeable_ranks.push((bytes.clone(), i as u32));
407
+ }
408
+
409
+ // Sort merges by token id (so we can reconstruct bytes progressively)
410
+ let mut sorted_merges: Vec<_> = self.merges.iter().collect();
411
+ sorted_merges.sort_by_key(|&(_, &token_id)| token_id);
412
+
413
+ for (&pair, &merged_id) in sorted_merges {
414
+ let (left, right) = pair;
415
+ let mut merged_bytes = token_bytes[left as usize].clone();
416
+ merged_bytes.extend(&token_bytes[right as usize]);
417
+
418
+ if token_bytes.len() <= merged_id as usize {
419
+ token_bytes.resize(merged_id as usize + 1, Vec::new());
420
+ }
421
+ token_bytes[merged_id as usize] = merged_bytes.clone();
422
+
423
+ mergeable_ranks.push((merged_bytes, merged_id));
424
+ }
425
+
426
+ mergeable_ranks
427
+ }
428
+
429
+ /// Encode a string into token IDs
430
+ pub fn encode(&self, text: &str) -> Vec<u32> {
431
+ let mut all_ids = Vec::new();
432
+
433
+ // Split text using the regex pattern
434
+ for m in self.compiled_pattern.find_iter(text) {
435
+ let chunk = m.expect("regex match failed").as_str();
436
+
437
+ // Convert chunk to bytes then to u32 IDs
438
+ let mut ids: Vec<u32> = chunk.bytes().map(|b| b as u32).collect();
439
+
440
+ // Apply merges iteratively
441
+ while ids.len() >= 2 {
442
+ // Find the best pair to merge
443
+ let mut best_pair: Option<(usize, Pair, u32)> = None;
444
+
445
+ for i in 0..ids.len() - 1 {
446
+ let pair: Pair = (ids[i], ids[i + 1]);
447
+ if let Some(&new_id) = self.merges.get(&pair) {
448
+ if best_pair.is_none() || new_id < best_pair.unwrap().2 {
449
+ best_pair = Some((i, pair, new_id));
450
+ }
451
+ }
452
+ }
453
+
454
+ // If we found a pair to merge, apply it
455
+ if let Some((idx, _pair, new_id)) = best_pair {
456
+ ids[idx] = new_id;
457
+ ids.remove(idx + 1);
458
+ } else {
459
+ // No more merges possible
460
+ break;
461
+ }
462
+ }
463
+
464
+ all_ids.extend(ids);
465
+ }
466
+
467
+ all_ids
468
+ }
469
+ }
470
+
471
+ #[pymodule]
472
+ fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> {
473
+ pyo3_log::init(); // forwards Rust `log` to Python's `logging`
474
+ m.add_class::<Tokenizer>()?;
475
+ Ok(())
476
+ }
scripts/__pycache__/base_eval.cpython-310.pyc ADDED
Binary file (5.36 kB). View file
 
scripts/__pycache__/base_train.cpython-310.pyc ADDED
Binary file (8.72 kB). View file
 
scripts/__pycache__/chat_web.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
scripts/__pycache__/tok_eval.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
scripts/__pycache__/tok_train.cpython-310.pyc ADDED
Binary file (3 kB). View file
 
scripts/base_eval.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evlauate the CORE metric for a given model.
3
+
4
+ Run on a single GPU:
5
+ python base_eval.py
6
+
7
+ Run with torchrun on e.g. 8 GPUs:
8
+ torchrun --nproc_per_node=8 base_eval.py
9
+
10
+ The script will print the CORE metric to the console.
11
+ """
12
+ import os
13
+ import sys
14
+ import time
15
+ import json
16
+ import random
17
+ import yaml
18
+
19
+ import pandas as pd
20
+ import torch
21
+
22
+ from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
23
+ from nanochat.tokenizer import HuggingFaceTokenizer
24
+ from nanochat.checkpoint_manager import load_model
25
+ from nanochat.core_eval import evaluate_task
26
+
27
+ # -----------------------------------------------------------------------------
28
+ # nanoChat specific function dealing with I/O etc.
29
+
30
+ def evaluate_model(model, tokenizer, device, max_per_task=-1):
31
+ """
32
+ Evaluate a base model on the CORE benchmark.
33
+ - max_per_task: crop the data to this many examples per task for testing (-1 = disable)
34
+ TODO: clean up this function, delete the need for all the files, for pandas dependency, etc.
35
+ """
36
+ # Load config and task metadata
37
+ base_dir = get_base_dir()
38
+ eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
39
+ config_path = os.path.join(eval_bundle_dir, "core.yaml")
40
+ data_base_path = os.path.join(eval_bundle_dir, "eval_data")
41
+ eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")
42
+ with open(config_path, 'r') as f:
43
+ config = yaml.safe_load(f)
44
+ tasks = config['icl_tasks']
45
+ eval_metadata = pd.read_csv(eval_meta_data)
46
+
47
+ # Evaluate each task
48
+ results = {}
49
+ centered_results = {}
50
+ for task in tasks:
51
+ start_time = time.time()
52
+ label = task['label']
53
+ task_meta = {
54
+ 'task_type': task['icl_task_type'],
55
+ 'dataset_uri': task['dataset_uri'],
56
+ 'num_fewshot': task['num_fewshot'][0],
57
+ 'continuation_delimiter': task.get('continuation_delimiter', ' ')
58
+ }
59
+ print0(f"Evaluating: {label} ({task_meta['num_fewshot']}-shot, type: {task_meta['task_type']})... ", end='')
60
+
61
+ # Load data for this task
62
+ data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
63
+ with open(data_path, 'r') as f:
64
+ data = [json.loads(line.strip()) for line in f]
65
+
66
+ # shuffle the data because in many cases it appears ordered but we want
67
+ # the abillity to only run a subset of the data for debugging purposes etc.
68
+ shuffle_rng = random.Random(1337)
69
+ shuffle_rng.shuffle(data)
70
+ if max_per_task > 0:
71
+ data = data[:max_per_task]
72
+
73
+ # run the evaluation for this task
74
+ accuracy = evaluate_task(model, tokenizer, data, device, task_meta)
75
+
76
+ results[label] = accuracy
77
+ row = eval_metadata[eval_metadata["Eval Task"] == label]
78
+ random_baseline = row["Random baseline"].values[0]
79
+ centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
80
+ centered_results[label] = centered_result
81
+ end_time = time.time()
82
+ print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
83
+
84
+ core_metric = sum(centered_results.values()) / len(centered_results)
85
+ out = {
86
+ "results": results,
87
+ "centered_results": centered_results,
88
+ "core_metric": core_metric
89
+ }
90
+ return out
91
+
92
+ # -----------------------------------------------------------------------------
93
+ # HuggingFace loading utilities and light wrappers for a model
94
+
95
+ class ModelWrapper:
96
+ """Lightweight wrapper for a HuggingFace model"""
97
+ def __init__(self, model, max_seq_len=None):
98
+ self.model = model
99
+ self.max_seq_len = max_seq_len
100
+
101
+ def __call__(self, input_ids):
102
+ outputs = self.model(input_ids)
103
+ logits = outputs.logits
104
+ return logits
105
+
106
+ def load_hf_model(hf_path: str, device):
107
+ print0(f"Loading model from: {hf_path}")
108
+ # Load the model
109
+ from transformers import AutoModelForCausalLM
110
+ model = AutoModelForCausalLM.from_pretrained(hf_path)
111
+ model.to(device)
112
+ model.eval()
113
+ max_seq_len = 1024 if "openai-community/gpt2" in hf_path else None
114
+ model = ModelWrapper(model, max_seq_len=max_seq_len)
115
+ # Load the tokenizer
116
+ tokenizer = HuggingFaceTokenizer.from_pretrained(hf_path)
117
+ return model, tokenizer
118
+
119
+ # -----------------------------------------------------------------------------
120
+ def main():
121
+ assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
122
+
123
+ # distributed / precision setup
124
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
125
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
126
+
127
+ # Load model and tokenizer from command line or from file system
128
+ if len(sys.argv) >= 2:
129
+ # atm assume that if a path is given, it's a huggingface model path
130
+ hf_path = sys.argv[1]
131
+ print0(f"Loading huggingface model from: {hf_path}")
132
+ model, tokenizer = load_hf_model(hf_path, device)
133
+ model_name = hf_path # just for logging
134
+ model_slug = hf_path.replace("/", "-") # for the output csv file
135
+ else:
136
+ # load a local model from the file system
137
+ model, tokenizer, meta = load_model("base", device, phase="eval")
138
+ model_name = f"base_model (step {meta['step']})" # just for logging
139
+ model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
140
+
141
+ # Evaluate the model
142
+ with autocast_ctx:
143
+ out = evaluate_model(model, tokenizer, device)
144
+
145
+ # Write out the results to a csv file
146
+ core_metric = None
147
+ centered_results = {}
148
+ if ddp_rank == 0:
149
+ base_dir = get_base_dir()
150
+ output_csv_path = os.path.join(base_dir, "base_eval", f"{model_slug}.csv")
151
+ os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
152
+ results = out["results"]
153
+ centered_results = out["centered_results"]
154
+ core_metric = out["core_metric"]
155
+ with open(output_csv_path, 'w') as f:
156
+ f.write(f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}\n")
157
+ for label in results:
158
+ f.write(f"{label:<35}, {results[label]:<10.6f}, {centered_results[label]:<10.6f}\n")
159
+ f.write(f"{'CORE':<35}, {'':<10}, {core_metric:<10.6f}\n")
160
+ # Print the content of the csv file to console too
161
+ print0("="*80)
162
+ print0(f"Model: {model_name}")
163
+ print0("="*80)
164
+ with open(output_csv_path, 'r') as f:
165
+ print0(f.read())
166
+
167
+ # Log to report
168
+ from nanochat.report import get_report
169
+ get_report().log(section="Base model evaluation", data=[
170
+ {
171
+ "Model": model_name,
172
+ "CORE metric": core_metric,
173
+ },
174
+ centered_results, # the full table
175
+ ])
176
+
177
+ compute_cleanup()
178
+
179
+ if __name__ == "__main__":
180
+ main()
scripts/base_loss.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loads a checkpoint, and:
3
+ - Evaluates the loss on a larger chunk of train/val splits
4
+ - Samples from the model
5
+
6
+ Example run as:
7
+ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
8
+ """
9
+ import os
10
+ import torch
11
+ from nanochat.checkpoint_manager import load_model
12
+ from nanochat.common import compute_init, print0, compute_cleanup
13
+ from nanochat.dataloader import tokenizing_distributed_data_loader
14
+ from nanochat.tokenizer import get_token_bytes
15
+ from nanochat.loss_eval import evaluate_bpb
16
+ from nanochat.engine import Engine
17
+
18
+ # Configuration
19
+ device_batch_size = 32
20
+ split_tokens = 20*524288 # number of tokens to evaluate per split
21
+ model_tag = None # optional model tag for the output directory name
22
+ model_step = None # optional model step for the output directory name
23
+ exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
24
+
25
+ # Load the base model and the tokenizer
26
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
27
+ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
28
+ sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
29
+
30
+ # Set up the precision we'll run with
31
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
32
+
33
+ # Evaluate the loss on each split
34
+ tokens_per_step = device_batch_size * sequence_len * ddp_world_size
35
+ assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
36
+ steps = split_tokens // tokens_per_step
37
+ token_bytes = get_token_bytes(device=device)
38
+ bpb_results = {}
39
+ for split_name in ["train", "val"]:
40
+ loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name)
41
+ with autocast_ctx:
42
+ bpb = evaluate_bpb(model, loader, steps, token_bytes)
43
+ print0(f"{split_name} bpb: {bpb:.4f}")
44
+ bpb_results[split_name] = bpb
45
+
46
+ # Master process also samples from the model
47
+ samples = []
48
+ if ddp_rank == 0:
49
+ prompts = [
50
+ "The capital of France is",
51
+ "The chemical symbol of gold is",
52
+ "If yesterday was Friday, then tomorrow will be",
53
+ "The opposite of hot is",
54
+ "The planets of the solar system are:",
55
+ "My favorite color is",
56
+ "If 5*x + 3 = 13, then x is",
57
+ ]
58
+ engine = Engine(model, tokenizer)
59
+ for prompt in prompts:
60
+ tokens = tokenizer(prompt, prepend="<|bos|>")
61
+ with autocast_ctx:
62
+ sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
63
+ sample_str = tokenizer.decode(sample[0])
64
+ print0(sample_str)
65
+ samples.append(sample_str)
66
+
67
+ # Log to report
68
+ from nanochat.report import get_report
69
+ get_report().log(section="Base model loss", data=[
70
+ {
71
+ "train bpb": bpb_results["train"],
72
+ "val bpb": bpb_results["val"],
73
+ },
74
+ {f"sample {i}": sample for i, sample in enumerate(samples)},
75
+ ])
76
+
77
+ # Cleanup
78
+ compute_cleanup()
scripts/base_train.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train model. Run as:
3
+
4
+ python base_train.py
5
+
6
+ or distributed as:
7
+
8
+ torchrun --nproc_per_node=8 base_train.py
9
+ """
10
+
11
+ import os
12
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
13
+ import time
14
+ import wandb
15
+ import torch
16
+
17
+ from nanochat.gpt import GPT, GPTConfig
18
+ from nanochat.dataloader import tokenizing_distributed_data_loader
19
+ from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
20
+ from nanochat.tokenizer import get_tokenizer, get_token_bytes
21
+ from nanochat.checkpoint_manager import save_checkpoint
22
+ from nanochat.loss_eval import evaluate_bpb
23
+ from nanochat.engine import Engine
24
+ from scripts.base_eval import evaluate_model
25
+ print_banner()
26
+
27
+ # -----------------------------------------------------------------------------
28
+ # User settings
29
+ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
30
+ # Model architecture
31
+ depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
32
+ max_seq_len = 2048 # max context length
33
+ # Training horizon. Only one of these 3 will be used, in this order of precedence.
34
+ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
35
+ target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
36
+ target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
37
+ # Optimization
38
+ device_batch_size = 32 # per-device batch size (set to not OOM)
39
+ total_batch_size = 524288 # total desired batch size, in #tokens
40
+ embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
41
+ unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
42
+ weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
43
+ matrix_lr = 0.02 # learning rate for the matrix parameters (Muon)
44
+ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
45
+ # Evaluation
46
+ eval_every = 250 # every how many steps to evaluate the model for val bpb
47
+ eval_tokens = 20*524288 # number of tokens to evaluate val loss on
48
+ core_metric_every = 2000 # every how many steps to evaluate the core metric
49
+ core_metric_max_per_task = 500 # examples per task in estimating the core metric
50
+ sample_every = 2000 # every how many steps to sample from the model
51
+ # Output
52
+ model_tag = "" # optionally override the model tag for the output checkpoint directory name
53
+ # now allow CLI to override the settings via the configurator lol
54
+ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
55
+ exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
56
+ user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
57
+ # -----------------------------------------------------------------------------
58
+
59
+ # Compute init
60
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
61
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
62
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
63
+
64
+ # wandb logging init
65
+ use_dummy_wandb = run == "dummy" or not master_process
66
+ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
67
+
68
+ # Tokenizer will be useful for evaluation, also we need the vocab size
69
+ tokenizer = get_tokenizer()
70
+ token_bytes = get_token_bytes(device=device)
71
+ vocab_size = tokenizer.get_vocab_size()
72
+ print0(f"Vocab size: {vocab_size:,}")
73
+
74
+ # Model kwargs are derived from the desired depth of the model
75
+ num_layers = depth
76
+ model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
77
+ num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
78
+ num_kv_heads = num_heads # 1:1 MQA ratio
79
+ print0(f"num_layers: {num_layers}")
80
+ print0(f"model_dim: {model_dim}")
81
+ print0(f"num_heads: {num_heads}")
82
+ print0(f"num_kv_heads: {num_kv_heads}")
83
+
84
+ # Optimizer / data / training length related hyperparameters
85
+ # figure out the needed gradient accumulation to reach the desired total batch size
86
+ tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
87
+ world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
88
+ assert total_batch_size % world_tokens_per_fwdbwd == 0
89
+ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
90
+ print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
91
+ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
92
+ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
93
+ # -----------------------------------------------------------------------------
94
+ # Initialize the Model
95
+ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
96
+ with torch.device("meta"):
97
+ model_config = GPTConfig(**model_config_kwargs)
98
+ model = GPT(model_config)
99
+ model.to_empty(device="cuda")
100
+ model.init_weights()
101
+ orig_model = model # original, uncompiled model, for saving raw model state_dict
102
+ model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
103
+ num_params = sum(p.numel() for p in model.parameters())
104
+ print0(f"Number of parameters: {num_params:,}")
105
+ num_flops_per_token = model.estimate_flops()
106
+ print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
107
+
108
+ # Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
109
+ assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
110
+ if num_iterations > 0:
111
+ print0(f"Using user-provided number of iterations: {num_iterations:,}")
112
+ elif target_flops > 0:
113
+ # calculate the number of iterations from the target flops
114
+ num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
115
+ print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
116
+ elif target_param_data_ratio > 0:
117
+ # calculate the number of iterations from the target param data ratio
118
+ target_tokens = target_param_data_ratio * num_params
119
+ num_iterations = target_tokens // total_batch_size
120
+ print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
121
+ else:
122
+ raise ValueError("No training horizon specified")
123
+ total_tokens = total_batch_size * num_iterations
124
+ print0(f"Total number of training tokens: {total_tokens:,}")
125
+ print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
126
+ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
127
+
128
+ # -----------------------------------------------------------------------------
129
+ # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
130
+ optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
131
+ adamw_optimizer, muon_optimizer = optimizers
132
+
133
+ # Initialize the DataLoaders for train/val
134
+ base_dir = get_base_dir()
135
+ tokens_dir = os.path.join(base_dir, "tokenized_data")
136
+ train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
137
+ build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
138
+ x, y = next(train_loader) # kick off load of the very first batch of data
139
+
140
+ # -----------------------------------------------------------------------------
141
+ # Set up hyperparameter schedulers
142
+
143
+ # Learning rate scheduler
144
+ # TODO: experiment with a short warmup for the AdamW params (expecting slight improvement)
145
+ warmup_ratio = 0.0 # ratio of iterations for LR warmup
146
+ warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
147
+ final_lr_frac = 0.0 # final LR is this fraction of the initial LR
148
+ def get_lr_multiplier(it):
149
+ warmup_iters = round(warmup_ratio * num_iterations)
150
+ warmdown_iters = round(warmdown_ratio * num_iterations)
151
+ if it < warmup_iters:
152
+ return (it + 1) / warmup_iters
153
+ elif it <= num_iterations - warmdown_iters:
154
+ return 1.0
155
+ else:
156
+ progress = (num_iterations - it) / warmdown_iters
157
+ return progress * 1.0 + (1 - progress) * final_lr_frac
158
+
159
+ # Momentum scheduler for Muon optimizer
160
+ def get_muon_momentum(it):
161
+ frac = min(it / 300, 1)
162
+ momentum = (1 - frac) * 0.85 + frac * 0.95
163
+ return momentum
164
+
165
+ # -----------------------------------------------------------------------------
166
+ # Training loop
167
+ min_val_bpb = float("inf")
168
+ smooth_train_loss = 0 # EMA of training loss
169
+ ema_beta = 0.9 # EMA decay factor
170
+ total_training_time = 0 # total wall-clock time of training
171
+ # note that we run +1 steps only so that we can eval and save at the end
172
+ for step in range(num_iterations + 1):
173
+ last_step = step == num_iterations
174
+ flops_so_far = num_flops_per_token * total_batch_size * step
175
+
176
+ # once in a while: evaluate the val bpb (all ranks participate)
177
+ if last_step or step % eval_every == 0:
178
+ model.eval()
179
+ val_loader = build_val_loader()
180
+ eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
181
+ with autocast_ctx:
182
+ val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
183
+ print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
184
+ if val_bpb < min_val_bpb:
185
+ min_val_bpb = val_bpb
186
+ wandb_run.log({
187
+ "step": step,
188
+ "total_training_flops": flops_so_far,
189
+ "total_training_time": total_training_time,
190
+ "val/bpb": val_bpb,
191
+ })
192
+ model.train()
193
+
194
+ # once in a while: estimate the CORE metric (all ranks participate)
195
+ # use the original uncompiled model because the inputs keep changing shape
196
+ if last_step or (step > 0 and step % core_metric_every == 0):
197
+ model.eval()
198
+ with autocast_ctx:
199
+ results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
200
+ print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
201
+ wandb_run.log({
202
+ "step": step,
203
+ "total_training_flops": flops_so_far,
204
+ "core_metric": results["core_metric"],
205
+ "centered_results": results["centered_results"],
206
+ })
207
+ model.train()
208
+
209
+ # once in a while: sample from the model (only on master process)
210
+ # use the original uncompiled model because the inputs keep changing shape
211
+ if master_process and (last_step or (step > 0 and step % sample_every == 0)):
212
+ model.eval()
213
+ prompts = [
214
+ "The capital of France is",
215
+ "The chemical symbol of gold is",
216
+ "If yesterday was Friday, then tomorrow will be",
217
+ "The opposite of hot is",
218
+ "The planets of the solar system are:",
219
+ "My favorite color is",
220
+ "If 5*x + 3 = 13, then x is",
221
+ ]
222
+ engine = Engine(model, tokenizer)
223
+ for prompt in prompts:
224
+ tokens = tokenizer(prompt, prepend="<|bos|>")
225
+ with autocast_ctx:
226
+ sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
227
+ print0(tokenizer.decode(sample[0]))
228
+ model.train()
229
+
230
+ # save checkpoint at the end of the run (only on master process)
231
+ if master_process and last_step:
232
+ output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
233
+ checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
234
+ save_checkpoint(
235
+ checkpoint_dir,
236
+ step,
237
+ orig_model.state_dict(),
238
+ [opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
239
+ {
240
+ "step": step,
241
+ "val_bpb": val_bpb, # loss at last step
242
+ "model_config": model_config_kwargs,
243
+ "user_config": user_config, # inputs to the training script
244
+ "device_batch_size": device_batch_size,
245
+ "max_seq_len": max_seq_len,
246
+ }
247
+ )
248
+
249
+ if last_step:
250
+ break
251
+
252
+ # -------------------------------------------------------------------------
253
+ # single training step
254
+ # evaluate the gradient
255
+ torch.cuda.synchronize()
256
+ t0 = time.time()
257
+ for micro_step in range(grad_accum_steps):
258
+ with autocast_ctx:
259
+ loss = model(x, y)
260
+ train_loss = loss.detach() # for logging
261
+ loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
262
+ loss.backward()
263
+ x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
264
+ # gradient clipping (TODO possibly expertiment with)
265
+ if grad_clip > 0.0:
266
+ torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
267
+ # step the optimizers
268
+ lrm = get_lr_multiplier(step)
269
+ for opt in optimizers:
270
+ for group in opt.param_groups:
271
+ group["lr"] = group["initial_lr"] * lrm
272
+ muon_momentum = get_muon_momentum(step)
273
+ for group in muon_optimizer.param_groups:
274
+ group["momentum"] = muon_momentum
275
+ for opt in optimizers:
276
+ opt.step()
277
+ model.zero_grad(set_to_none=True)
278
+ torch.cuda.synchronize()
279
+ t1 = time.time()
280
+ dt = t1 - t0
281
+ # -------------------------------------------------------------------------
282
+
283
+ # logging
284
+ smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
285
+ debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
286
+ pct_done = 100 * step / num_iterations
287
+ tok_per_sec = int(world_tokens_per_fwdbwd / dt)
288
+ flops_per_sec = num_flops_per_token * total_batch_size / dt
289
+ promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
290
+ mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
291
+ if step > 10:
292
+ total_training_time += dt # only count the time after the first 10 steps
293
+ print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
294
+ if step % 100 == 0:
295
+ wandb_run.log({
296
+ "step": step,
297
+ "total_training_flops": flops_so_far,
298
+ "total_training_time": total_training_time,
299
+ "train/loss": debiased_smooth_loss,
300
+ "train/lrm": lrm,
301
+ "train/dt": dt,
302
+ "train/tok_per_sec": tok_per_sec,
303
+ "train/mfu": mfu,
304
+ })
305
+
306
+ # print a few more stats
307
+ print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
308
+ print0(f"Total training time: {total_training_time/60:.2f}m")
309
+ print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
310
+
311
+ # Log to report
312
+ from nanochat.report import get_report
313
+ get_report().log(section="Base model training", data=[
314
+ user_config, # CLI args
315
+ { # stats about the training setup
316
+ "Number of parameters": num_params,
317
+ "Number of FLOPs per token": f"{num_flops_per_token:e}",
318
+ "Calculated number of iterations": num_iterations,
319
+ "Number of training tokens": total_tokens,
320
+ "Tokens : Params ratio": total_batch_size * num_iterations / num_params,
321
+ "DDP world size": ddp_world_size,
322
+ "warmup_ratio": warmup_ratio,
323
+ "warmdown_ratio": warmdown_ratio,
324
+ "final_lr_frac": final_lr_frac,
325
+ },
326
+ { # stats about training outcomes
327
+ "Minimum validation bpb": min_val_bpb,
328
+ "Final validation bpb": val_bpb,
329
+ "CORE metric estimate": results["core_metric"],
330
+ "MFU %": f"{mfu:.2f}%",
331
+ "Total training flops": f"{flops_so_far:e}",
332
+ "Total training time": f"{total_training_time/60:.2f}m",
333
+ "Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
334
+ }
335
+ ])
336
+
337
+ # cleanup
338
+ wandb_run.finish() # wandb run finish
339
+ compute_cleanup()
scripts/chat_cli.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ New and upgraded chat mode because a lot of the code has changed since the last one.
3
+
4
+ Intended to be run single GPU only atm:
5
+ python -m scripts.chat_cli -i mid
6
+ """
7
+ import argparse
8
+ import torch
9
+ from nanochat.common import compute_init
10
+ from nanochat.engine import Engine
11
+ from nanochat.checkpoint_manager import load_model
12
+
13
+ parser = argparse.ArgumentParser(description='Chat with the model')
14
+ parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
15
+ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
16
+ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
17
+ parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
18
+ parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
19
+ parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
20
+ args = parser.parse_args()
21
+
22
+ # Init the model and tokenizer
23
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
24
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
25
+ model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
26
+
27
+ # Special tokens for the chat state machine
28
+ bos = tokenizer.get_bos_token_id()
29
+ user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
30
+ assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>")
31
+
32
+ # Create Engine for efficient generation
33
+ engine = Engine(model, tokenizer)
34
+
35
+ print("\nNanoChat Interactive Mode")
36
+ print("-" * 50)
37
+ print("Type 'quit' or 'exit' to end the conversation")
38
+ print("Type 'clear' to start a new conversation")
39
+ print("-" * 50)
40
+
41
+ conversation_tokens = [bos]
42
+
43
+ while True:
44
+
45
+ if args.prompt:
46
+ # Get the prompt from the launch command
47
+ user_input = args.prompt
48
+ else:
49
+ # Get the prompt interactively from the console
50
+ try:
51
+ user_input = input("\nUser: ").strip()
52
+ except (EOFError, KeyboardInterrupt):
53
+ print("\nGoodbye!")
54
+ break
55
+
56
+ # Handle special commands
57
+ if user_input.lower() in ['quit', 'exit']:
58
+ print("Goodbye!")
59
+ break
60
+
61
+ if user_input.lower() == 'clear':
62
+ conversation_tokens = [bos]
63
+ print("Conversation cleared.")
64
+ continue
65
+
66
+ if not user_input:
67
+ continue
68
+
69
+ # Add User message to the conversation
70
+ conversation_tokens.append(user_start)
71
+ conversation_tokens.extend(tokenizer.encode(user_input))
72
+ conversation_tokens.append(user_end)
73
+
74
+ # Kick off the assistant
75
+ conversation_tokens.append(assistant_start)
76
+ generate_kwargs = {
77
+ "num_samples": 1,
78
+ "max_tokens": 256,
79
+ "temperature": args.temperature,
80
+ "top_k": args.top_k,
81
+ }
82
+ response_tokens = []
83
+ print("\nAssistant: ", end="", flush=True)
84
+ with autocast_ctx:
85
+ for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
86
+ token = token_column[0] # pop the batch dimension (num_samples=1)
87
+ response_tokens.append(token)
88
+ token_text = tokenizer.decode([token])
89
+ print(token_text, end="", flush=True)
90
+ print()
91
+ # we have to ensure that the assistant end token is the last token
92
+ # so even if generation ends due to max tokens, we have to append it to the end
93
+ if response_tokens[-1] != assistant_end:
94
+ response_tokens.append(assistant_end)
95
+ conversation_tokens.extend(response_tokens)
96
+
97
+ # In the prompt mode, we only want a single response and exit
98
+ if args.prompt:
99
+ break