drbh commited on
Commit
a9b8fe6
·
1 Parent(s): b992f14

feat: update readme and add benching scripts

Browse files
Files changed (4) hide show
  1. README.md +21 -18
  2. compare_example.py +211 -0
  3. perf_plot.py +536 -0
  4. readme_example.py +88 -0
README.md CHANGED
@@ -38,16 +38,23 @@ oooo ooo .oooo. ooo. .oo. .oo. .ooooo. .ooooo.
38
 
39
  import time
40
  import torch
 
41
  from kernels import get_kernel
42
  from pathlib import Path
43
  from torch.nn import functional as F
44
 
45
- yamoe = get_kernel("drbh/yamoe")
 
 
 
 
 
 
 
46
 
47
  # Configuration
48
- torch.manual_seed(0)
49
- batch_size, seq_len, hidden_dim = 128, 2048, 2880
50
- num_experts, top_k = 32, 4
51
 
52
  # Create routing weights
53
  logits = torch.randn(batch_size, seq_len, num_experts)
@@ -60,13 +67,13 @@ flat_indices, flat_weights = indices.reshape(-1, top_k), weights.reshape(-1, top
60
  batch_indices = torch.arange(batch_seq).unsqueeze(1).expand(-1, top_k)
61
  routing_weights[batch_indices, flat_indices] = flat_weights
62
 
63
- # Create model tensors (scaled to prevent overflow)
64
- hidden_states = torch.randn(batch_size, seq_len, hidden_dim).cuda().half() * 0.1
65
- gate_up_proj = torch.randn(num_experts, hidden_dim, 2 * hidden_dim).cuda().half() * 0.02
66
- gate_up_proj_bias = torch.zeros(num_experts, 2 * hidden_dim).cuda().half()
67
- down_proj = torch.randn(num_experts, hidden_dim, hidden_dim).cuda().half() * 0.02
68
- down_proj_bias = torch.zeros(num_experts, hidden_dim).cuda().half()
69
- routing_weights = routing_weights.cuda().half()
70
  router_indices = flat_indices.cuda()
71
 
72
  # Warmup
@@ -107,11 +114,7 @@ torch.cuda.synchronize()
107
  elapsed_ms = (time.perf_counter() - start) * 1e3
108
  peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
109
 
110
- print(f"Output sum: {output.sum().item():.4f}")
111
- print(f"Kernel time: {elapsed_ms:.3f} ms")
112
- print(f"Peak GPU memory: {peak_mem_mb:.2f} MB")
113
- # Output sum: 124.2500
114
- # Kernel time: 85.722 ms
115
- # Peak GPU memory: 8403.40 MB
116
-
117
  ```
 
38
 
39
  import time
40
  import torch
41
+ from kernels import get_local_kernel
42
  from kernels import get_kernel
43
  from pathlib import Path
44
  from torch.nn import functional as F
45
 
46
+ # Set seeds and deterministic flags for reproducibility
47
+ torch.manual_seed(42)
48
+ torch.cuda.manual_seed(42)
49
+ torch.cuda.manual_seed_all(42)
50
+ torch.backends.cudnn.deterministic = True
51
+ torch.backends.cudnn.benchmark = False
52
+
53
+ yamoe = get_kernel("drbh/yamoe", revision="v0.1.0")
54
 
55
  # Configuration
56
+ batch_size, seq_len, hidden_dim = 16, 256, 2880
57
+ num_experts, top_k = 8, 2
 
58
 
59
  # Create routing weights
60
  logits = torch.randn(batch_size, seq_len, num_experts)
 
67
  batch_indices = torch.arange(batch_seq).unsqueeze(1).expand(-1, top_k)
68
  routing_weights[batch_indices, flat_indices] = flat_weights
69
 
70
+ # Create model tensors
71
+ hidden_states = torch.randn(batch_size, seq_len, hidden_dim).cuda()
72
+ gate_up_proj = torch.randn(num_experts, hidden_dim, 2 * hidden_dim).cuda()
73
+ gate_up_proj_bias = torch.zeros(num_experts, 2 * hidden_dim).cuda()
74
+ down_proj = torch.randn(num_experts, hidden_dim, hidden_dim).cuda()
75
+ down_proj_bias = torch.zeros(num_experts, hidden_dim).cuda()
76
+ routing_weights = routing_weights.cuda()
77
  router_indices = flat_indices.cuda()
78
 
79
  # Warmup
 
114
  elapsed_ms = (time.perf_counter() - start) * 1e3
115
  peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
116
 
117
+ print(f"Output: sum={output.sum().item():.1f}, min={output.min().item():.1f}, max={output.max().item():.1f}")
118
+ print(f"First 3: {output.view(-1)[:3].tolist()}")
119
+ print(f"Time: {elapsed_ms:.1f}ms, Memory: {peak_mem_mb:.0f}MB")
 
 
 
 
120
  ```
compare_example.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = "==3.10"
3
+ # dependencies = ["torch==2.7.0", "triton", "numpy", "kernels"]
4
+ # [tool.uv.sources]
5
+ # kernels = { git = "https://github.com/huggingface/kernels.git" }
6
+ # ///
7
+
8
+ import time
9
+ import torch
10
+ from kernels import get_local_kernel
11
+ from kernels import get_kernel
12
+ from pathlib import Path
13
+ from torch.nn import functional as F
14
+
15
+ # Set seeds and deterministic flags for reproducibility
16
+ torch.manual_seed(42)
17
+ torch.cuda.manual_seed(42)
18
+ torch.cuda.manual_seed_all(42)
19
+ torch.backends.cudnn.deterministic = True
20
+ torch.backends.cudnn.benchmark = False
21
+
22
+ yamoe = get_kernel("drbh/yamoe", revision="v0.1.0")
23
+
24
+ # Configuration
25
+ batch_size, seq_len, hidden_dim = 4, 1024, 2880
26
+ # batch_size, seq_len, hidden_dim = 4, 32, 1024
27
+ num_experts, top_k = 8, 2
28
+
29
+ # Create routing weights
30
+ logits = torch.randn(batch_size, seq_len, num_experts)
31
+ probs = F.softmax(logits, dim=-1)
32
+ weights, indices = torch.topk(probs, top_k, dim=-1)
33
+
34
+ batch_seq = batch_size * seq_len
35
+ routing_weights = torch.zeros(batch_seq, num_experts, dtype=weights.dtype)
36
+ flat_indices, flat_weights = indices.reshape(-1, top_k), weights.reshape(-1, top_k)
37
+ batch_indices = torch.arange(batch_seq).unsqueeze(1).expand(-1, top_k)
38
+ routing_weights[batch_indices, flat_indices] = flat_weights
39
+
40
+ # Create model tensors
41
+ hidden_states = torch.randn(batch_size, seq_len, hidden_dim).cuda()
42
+ # gate_up_proj = torch.randn(num_experts, hidden_dim, 2 * hidden_dim).cuda()
43
+ gate_up_proj_bias = torch.zeros(num_experts, 2 * hidden_dim).cuda()
44
+ # down_proj = torch.randn(num_experts, hidden_dim, hidden_dim).cuda()
45
+ down_proj_bias = torch.zeros(num_experts, hidden_dim).cuda()
46
+ # routing_weights = routing_weights.cuda()
47
+ router_indices = flat_indices.cuda()
48
+
49
+ gate_up_proj = torch.empty(num_experts, hidden_dim, 2 * hidden_dim, device="cuda")
50
+ down_proj = torch.empty(num_experts, hidden_dim, hidden_dim, device="cuda")
51
+ torch.nn.init.trunc_normal_(gate_up_proj, std=0.02)
52
+ torch.nn.init.trunc_normal_(down_proj, std=0.02)
53
+
54
+ routing_weights = routing_weights.to(dtype=torch.float32, device="cuda")
55
+
56
+
57
+ # Warmup
58
+ for _ in range(5):
59
+ _ = yamoe.experts(
60
+ hidden_states.view(-1, hidden_dim),
61
+ router_indices,
62
+ routing_weights.view(-1, num_experts),
63
+ gate_up_proj,
64
+ gate_up_proj_bias,
65
+ down_proj,
66
+ down_proj_bias,
67
+ seq_len,
68
+ num_experts,
69
+ top_k,
70
+ )
71
+
72
+ # Benchmark
73
+ torch.cuda.synchronize()
74
+ torch.cuda.reset_peak_memory_stats()
75
+ start = time.perf_counter()
76
+
77
+ with torch.no_grad():
78
+ output = yamoe.experts(
79
+ hidden_states.view(-1, hidden_dim),
80
+ router_indices,
81
+ routing_weights.view(-1, num_experts),
82
+ gate_up_proj,
83
+ gate_up_proj_bias,
84
+ down_proj,
85
+ down_proj_bias,
86
+ seq_len,
87
+ num_experts,
88
+ top_k,
89
+ )
90
+
91
+ torch.cuda.synchronize()
92
+ elapsed_ms = (time.perf_counter() - start) * 1e3
93
+ peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
94
+
95
+ # Store kernel results
96
+ kernel_output = output.clone()
97
+ kernel_time = elapsed_ms
98
+ kernel_memory = peak_mem_mb
99
+
100
+ ## OPTIONAL
101
+ # Compare to reference implementation
102
+ config = type("Config", (), {})()
103
+ config.hidden_size = hidden_dim
104
+ config.intermediate_size = 4 * hidden_dim
105
+ config.num_local_experts = num_experts
106
+
107
+ model = yamoe.reference.GptOssExperts(config)
108
+
109
+ # set the weights and biases from above to the reference model
110
+ model.gate_up_proj.data = gate_up_proj
111
+ model.gate_up_proj_bias.data = gate_up_proj_bias
112
+ model.down_proj.data = down_proj
113
+ model.down_proj_bias.data = down_proj_bias
114
+
115
+ model = model.cuda()
116
+ model.eval()
117
+
118
+ torch.cuda.synchronize()
119
+ torch.cuda.reset_peak_memory_stats()
120
+ start = time.perf_counter()
121
+
122
+ with torch.no_grad():
123
+ ref_output = model(hidden_states, router_indices, routing_weights)
124
+
125
+ torch.cuda.synchronize()
126
+ elapsed_ms = (time.perf_counter() - start) * 1e3
127
+ peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
128
+
129
+ # Store reference results
130
+ ref_time = elapsed_ms
131
+ ref_memory = peak_mem_mb
132
+
133
+ # Reshape reference output to match kernel output
134
+ ref_output_reshaped = ref_output.view(kernel_output.shape)
135
+
136
+ # Calculate similarity metrics
137
+ mse = torch.nn.functional.mse_loss(kernel_output, ref_output_reshaped).item()
138
+ mae = torch.nn.functional.l1_loss(kernel_output, ref_output_reshaped).item()
139
+
140
+ # Cosine similarity
141
+ kernel_flat = kernel_output.view(-1)
142
+ ref_flat = ref_output_reshaped.view(-1)
143
+ cosine_sim = torch.nn.functional.cosine_similarity(
144
+ kernel_flat.unsqueeze(0), ref_flat.unsqueeze(0)
145
+ ).item()
146
+
147
+ # Relative error (L2 norm of difference / L2 norm of reference)
148
+ diff_norm = torch.norm(kernel_output - ref_output_reshaped).item()
149
+ ref_norm = torch.norm(ref_output_reshaped).item()
150
+ rel_error = diff_norm / ref_norm if ref_norm > 0 else float("inf")
151
+
152
+ # Max absolute difference
153
+ max_abs_diff = torch.max(torch.abs(kernel_output - ref_output_reshaped)).item()
154
+
155
+ # Print comparison table
156
+ print("\n" + "=" * 80)
157
+ print(f"{'METRIC':<20} {'KERNEL':<15} {'REFERENCE':<15} {'SIMILARITY/SPEEDUP':<15}")
158
+ print("=" * 80)
159
+
160
+ print(
161
+ f"{'Sum':<20} {kernel_output.sum().item():<15.4f} {ref_output_reshaped.sum().item():<15.4f} {'N/A':<15}"
162
+ )
163
+ print(
164
+ f"{'Min':<20} {kernel_output.min().item():<15.4f} {ref_output_reshaped.min().item():<15.4f} {'N/A':<15}"
165
+ )
166
+ print(
167
+ f"{'Max':<20} {kernel_output.max().item():<15.4f} {ref_output_reshaped.max().item():<15.4f} {'N/A':<15}"
168
+ )
169
+ print(
170
+ f"{'Norm (L2)':<20} {kernel_output.norm().item():<15.4f} {ref_output_reshaped.norm().item():<15.4f} {'N/A':<15}"
171
+ )
172
+ print(
173
+ f"{'Std':<20} {kernel_output.std().item():<15.4f} {ref_output_reshaped.std().item():<15.4f} {'N/A':<15}"
174
+ )
175
+
176
+ print("-" * 80)
177
+ print(
178
+ f"{'Time (ms)':<20} {kernel_time:<15.3f} {ref_time:<15.3f} {ref_time / kernel_time:<15.2f}x"
179
+ )
180
+ print(
181
+ f"{'Memory (MB)':<20} {kernel_memory:<15.2f} {ref_memory:<15.2f} {ref_memory / kernel_memory:<15.2f}x"
182
+ )
183
+
184
+ print("-" * 80)
185
+ print("SIMILARITY METRICS")
186
+ print("-" * 80)
187
+ print(f"{'METRIC':<20} {'VALUE':<15} {'DIFFERENCE':<15}")
188
+ print("-" * 80)
189
+ print(f"{'MSE':<20} {mse:<15.6e} {'N/A':<15}")
190
+ print(f"{'MAE':<20} {mae:<15.6e} {'N/A':<15}")
191
+ print(f"{'Cosine Similarity':<20} {cosine_sim:<15.6f} {abs(1.0 - cosine_sim):<15.6f}")
192
+ print(f"{'Relative Error':<20} {rel_error:<15.6e} {'N/A':<15}")
193
+ print(f"{'Max Abs Diff':<20} {max_abs_diff:<15.6e} {'N/A':<15}")
194
+
195
+ print("-" * 80)
196
+ print("FIRST 10 ELEMENTS COMPARISON")
197
+ print("-" * 80)
198
+
199
+ # Get first 10 elements as numpy arrays for nice display
200
+ kernel_first_10 = kernel_flat[:10].cpu().numpy()
201
+ ref_first_10 = ref_flat[:10].cpu().numpy()
202
+ diff_first_10 = kernel_first_10 - ref_first_10
203
+
204
+ print(f"{'INDEX':<5} {'KERNEL':<12} {'REFERENCE':<12} {'DIFF':<12}")
205
+ print("-" * 45)
206
+ for i in range(10):
207
+ print(
208
+ f"{i:<5} {kernel_first_10[i]:<12.6f} {ref_first_10[i]:<12.6f} {diff_first_10[i]:<12.6f}"
209
+ )
210
+
211
+ print("=" * 80)
perf_plot.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = "==3.10"
3
+ # dependencies = ["torch==2.7.0", "triton", "numpy", "kernels", "matplotlib"]
4
+ # [tool.uv.sources]
5
+ # kernels = { git = "https://github.com/huggingface/kernels.git" }
6
+ # ///
7
+
8
+ import time
9
+ import torch
10
+ from kernels import get_local_kernel, get_kernel
11
+ from pathlib import Path
12
+ from torch.nn import functional as F
13
+ import sys
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib.gridspec as gridspec
16
+ import numpy as np
17
+
18
+ # sys.path.insert(0, "./torch-ext")
19
+ # import yamoe
20
+ # import yamoe.reference as reference
21
+
22
+ yamoe = get_kernel("drbh/yamoe", revision="v0.1.0")
23
+ reference = yamoe.reference
24
+
25
+ # Setup
26
+ torch.manual_seed(0)
27
+
28
+ # Parameter combinations to test
29
+ configs = [
30
+ {"seq_len": 512, "hidden_dim": 2880, "num_experts": 32, "top_k": 4},
31
+ {"seq_len": 1024, "hidden_dim": 2880, "num_experts": 32, "top_k": 4},
32
+ {"seq_len": 512, "hidden_dim": 1024, "num_experts": 32, "top_k": 4},
33
+ {"seq_len": 512, "hidden_dim": 2880, "num_experts": 16, "top_k": 2},
34
+ {"seq_len": 2048, "hidden_dim": 1024, "num_experts": 16, "top_k": 2},
35
+ {"seq_len": 768, "hidden_dim": 2048, "num_experts": 64, "top_k": 8},
36
+ ]
37
+
38
+ # Strategic batch sizes: small (1,2), medium (4,8), large (16,32), extra large (64)
39
+ batch_sizes = [1, 2, 4, 8, 16, 32, 64]
40
+ all_results = []
41
+
42
+ # Test each configuration
43
+ for config_idx, config in enumerate(configs):
44
+ seq_len = config["seq_len"]
45
+ hidden_dim = config["hidden_dim"]
46
+ num_experts = config["num_experts"]
47
+ top_k = config["top_k"]
48
+
49
+ print(f"\n{'=' * 70}")
50
+ print(
51
+ f"Config {config_idx + 1}: seq={seq_len}, hidden={hidden_dim}, experts={num_experts}, top_k={top_k}"
52
+ )
53
+ print(f"{'=' * 70}")
54
+
55
+ yamoe_times = []
56
+ reference_times = []
57
+ yamoe_memory = []
58
+ reference_memory = []
59
+ speedups = []
60
+
61
+ # Iterate over batch sizes
62
+ for batch_size in batch_sizes:
63
+ print(f"\nBatch size = {batch_size}")
64
+
65
+ try:
66
+ # Create logits for this batch size
67
+ logits = torch.randn(batch_size, seq_len, num_experts)
68
+
69
+ # Inline routing creation
70
+ weights, indices = torch.topk(logits, top_k, dim=-1)
71
+ weights = F.softmax(weights, dim=-1)
72
+ batch_seq = batch_size * seq_len
73
+ routing_weights = torch.zeros(
74
+ batch_seq, num_experts, device=logits.device, dtype=weights.dtype
75
+ )
76
+ flat_indices, flat_weights = (
77
+ indices.reshape(-1, top_k),
78
+ weights.reshape(-1, top_k),
79
+ )
80
+ batch_indices = (
81
+ torch.arange(batch_seq, device=logits.device)
82
+ .unsqueeze(1)
83
+ .expand(-1, top_k)
84
+ )
85
+ routing_weights[batch_indices, flat_indices] = flat_weights
86
+ router_indices = flat_indices
87
+
88
+ # Create tensors and convert to CUDA half precision
89
+ hidden_states = torch.randn(batch_size, seq_len, hidden_dim).cuda().half()
90
+ gate_up_proj = (
91
+ torch.randn(num_experts, hidden_dim, 2 * hidden_dim).cuda().half()
92
+ )
93
+ gate_up_proj_bias = torch.ones(num_experts, 2 * hidden_dim).cuda().half()
94
+ down_proj = torch.randn(num_experts, hidden_dim, hidden_dim).cuda().half()
95
+ down_proj_bias = torch.ones(num_experts, hidden_dim).cuda().half()
96
+ logits, routing_weights = (
97
+ logits.cuda().half(),
98
+ routing_weights.cuda().half(),
99
+ )
100
+ router_indices = router_indices.cuda()
101
+
102
+ # Test Yamoe kernel first
103
+ yamoe_success = True
104
+ yamoe_time = None
105
+ yamoe_mem = None
106
+
107
+ try:
108
+ # Warmup runs for yamoe
109
+ for _ in range(5):
110
+ _ = yamoe.experts(
111
+ hidden_states.view(-1, hidden_dim),
112
+ router_indices,
113
+ routing_weights.view(-1, num_experts),
114
+ gate_up_proj,
115
+ gate_up_proj_bias,
116
+ down_proj,
117
+ down_proj_bias,
118
+ seq_len,
119
+ num_experts,
120
+ top_k,
121
+ )
122
+
123
+ # Time and measure memory for yamoe kernel
124
+ torch.cuda.synchronize()
125
+ torch.cuda.reset_peak_memory_stats()
126
+
127
+ yamoe_runs = []
128
+ for _ in range(10):
129
+ start = time.perf_counter()
130
+ output = yamoe.experts(
131
+ hidden_states.view(-1, hidden_dim),
132
+ router_indices,
133
+ routing_weights.view(-1, num_experts),
134
+ gate_up_proj,
135
+ gate_up_proj_bias,
136
+ down_proj,
137
+ down_proj_bias,
138
+ seq_len,
139
+ num_experts,
140
+ top_k,
141
+ )
142
+ torch.cuda.synchronize()
143
+ yamoe_runs.append((time.perf_counter() - start) * 1e3)
144
+
145
+ yamoe_time = sum(yamoe_runs) / len(yamoe_runs)
146
+ yamoe_mem = torch.cuda.max_memory_allocated() / (1024 * 1024)
147
+
148
+ except RuntimeError as e:
149
+ if "out of memory" in str(e).lower():
150
+ print(f" Yamoe: OOM - skipping this batch size")
151
+ yamoe_success = False
152
+ else:
153
+ raise e
154
+
155
+ # Test reference model
156
+ ref_success = True
157
+ ref_time = None
158
+ ref_mem = None
159
+
160
+ try:
161
+ # Setup reference model
162
+ config_obj = type("Config", (), {})()
163
+ config_obj.hidden_size = hidden_dim
164
+ config_obj.intermediate_size = 4 * hidden_dim
165
+ config_obj.num_local_experts = num_experts
166
+
167
+ model = reference.GptOssExperts(config_obj)
168
+ model.gate_up_proj.data = gate_up_proj
169
+ model.gate_up_proj_bias.data = gate_up_proj_bias
170
+ model.down_proj.data = down_proj
171
+ model.down_proj_bias.data = down_proj_bias
172
+ model = model.cuda().half()
173
+ model.eval()
174
+
175
+ # Warmup runs for reference
176
+ with torch.no_grad():
177
+ for _ in range(5):
178
+ _ = model(hidden_states, router_indices, routing_weights)
179
+
180
+ # Time and measure memory for reference model
181
+ torch.cuda.synchronize()
182
+ torch.cuda.reset_peak_memory_stats()
183
+
184
+ ref_runs = []
185
+ with torch.no_grad():
186
+ for _ in range(10):
187
+ start = time.perf_counter()
188
+ ref_output = model(
189
+ hidden_states, router_indices, routing_weights
190
+ )
191
+ torch.cuda.synchronize()
192
+ ref_runs.append((time.perf_counter() - start) * 1e3)
193
+
194
+ ref_time = sum(ref_runs) / len(ref_runs)
195
+ ref_mem = torch.cuda.max_memory_allocated() / (1024 * 1024)
196
+
197
+ except RuntimeError as e:
198
+ if "out of memory" in str(e).lower():
199
+ print(f" Reference: OOM - skipping this batch size")
200
+ ref_success = False
201
+ else:
202
+ raise e
203
+
204
+ # Report results if both succeeded
205
+ if yamoe_success and ref_success:
206
+ yamoe_times.append(yamoe_time)
207
+ yamoe_memory.append(yamoe_mem)
208
+ reference_times.append(ref_time)
209
+ reference_memory.append(ref_mem)
210
+ speedup = ref_time / yamoe_time
211
+ speedups.append(speedup)
212
+
213
+ throughput_yamoe = (
214
+ (batch_size * seq_len * hidden_dim) / (yamoe_time / 1000) / 1e9
215
+ ) # GFLOPS
216
+ throughput_ref = (
217
+ (batch_size * seq_len * hidden_dim) / (ref_time / 1000) / 1e9
218
+ ) # GFLOPS
219
+
220
+ print(
221
+ f" Yamoe: {yamoe_time:.3f} ms / {yamoe_mem:.1f} MB / {throughput_yamoe:.2f} GFLOPS"
222
+ )
223
+ print(
224
+ f" Reference: {ref_time:.3f} ms / {ref_mem:.1f} MB / {throughput_ref:.2f} GFLOPS"
225
+ )
226
+ print(
227
+ f" Speedup: {speedup:.2f}x, Memory reduction: {ref_mem / yamoe_mem:.2f}x, "
228
+ f"Efficiency gain: {throughput_yamoe / throughput_ref:.2f}x"
229
+ )
230
+ elif yamoe_success and not ref_success:
231
+ # Only Yamoe succeeded - still record its results
232
+ yamoe_times.append(yamoe_time)
233
+ yamoe_memory.append(yamoe_mem)
234
+ # Use None/placeholder values for reference
235
+ reference_times.append(None)
236
+ reference_memory.append(None)
237
+ speedups.append(None)
238
+
239
+ throughput_yamoe = (
240
+ (batch_size * seq_len * hidden_dim) / (yamoe_time / 1000) / 1e9
241
+ )
242
+ print(
243
+ f" Yamoe: {yamoe_time:.3f} ms / {yamoe_mem:.1f} MB / {throughput_yamoe:.2f} GFLOPS"
244
+ )
245
+ print(f" Reference: OOM - unable to measure")
246
+ print(f" Yamoe runs successfully while Reference OOMs")
247
+ elif not yamoe_success and ref_success:
248
+ # Only Reference succeeded
249
+ yamoe_times.append(None)
250
+ yamoe_memory.append(None)
251
+ reference_times.append(ref_time)
252
+ reference_memory.append(ref_mem)
253
+ speedups.append(None)
254
+
255
+ throughput_ref = (
256
+ (batch_size * seq_len * hidden_dim) / (ref_time / 1000) / 1e9
257
+ )
258
+ print(f" Yamoe: OOM - unable to measure")
259
+ print(
260
+ f" Reference: {ref_time:.3f} ms / {ref_mem:.1f} MB / {throughput_ref:.2f} GFLOPS"
261
+ )
262
+ print(f" Reference runs successfully while Yamoe OOMs")
263
+ else:
264
+ # Both failed
265
+ yamoe_times.append(None)
266
+ yamoe_memory.append(None)
267
+ reference_times.append(None)
268
+ reference_memory.append(None)
269
+ speedups.append(None)
270
+ print(f" Both implementations OOM at batch_size={batch_size}")
271
+
272
+ except Exception as e:
273
+ print(f" Unexpected error at batch_size={batch_size}: {str(e)}")
274
+ # Add None values to maintain list consistency
275
+ yamoe_times.append(None)
276
+ yamoe_memory.append(None)
277
+ reference_times.append(None)
278
+ reference_memory.append(None)
279
+ speedups.append(None)
280
+
281
+ # Clear GPU memory after each batch size test
282
+ torch.cuda.empty_cache()
283
+
284
+ all_results.append(
285
+ {
286
+ "config": config,
287
+ "yamoe_times": yamoe_times,
288
+ "reference_times": reference_times,
289
+ "yamoe_memory": yamoe_memory,
290
+ "reference_memory": reference_memory,
291
+ "speedups": speedups,
292
+ }
293
+ )
294
+
295
+ # Create comprehensive visualization with time and memory
296
+ fig = plt.figure(figsize=(24, 16))
297
+
298
+ # Create 3 rows: time comparison, memory comparison, combined metrics
299
+ for config_idx, result in enumerate(all_results[:6]):
300
+ # Time comparison subplot
301
+ ax1 = plt.subplot(3, 6, config_idx + 1)
302
+ x = np.arange(len(batch_sizes))
303
+ width = 0.35
304
+
305
+ # Filter out None values for plotting
306
+ yamoe_times_filtered = [t if t is not None else 0 for t in result["yamoe_times"]]
307
+ ref_times_filtered = [t if t is not None else 0 for t in result["reference_times"]]
308
+
309
+ bars1 = ax1.bar(
310
+ x - width / 2,
311
+ yamoe_times_filtered,
312
+ width,
313
+ label="Yamoe",
314
+ color="#1f77b4",
315
+ alpha=0.8,
316
+ )
317
+ bars2 = ax1.bar(
318
+ x + width / 2,
319
+ ref_times_filtered,
320
+ width,
321
+ label="Reference",
322
+ color="#ff7f0e",
323
+ alpha=0.8,
324
+ )
325
+
326
+ # Add speedup annotations (only where both values exist)
327
+ for i, (y_time, r_time) in enumerate(
328
+ zip(result["yamoe_times"], result["reference_times"])
329
+ ):
330
+ if y_time is not None and r_time is not None:
331
+ speedup = r_time / y_time
332
+ ax1.text(
333
+ i,
334
+ max(y_time, r_time) * 1.05,
335
+ f"{speedup:.1f}x",
336
+ ha="center",
337
+ va="bottom",
338
+ fontsize=7,
339
+ fontweight="bold",
340
+ color="green",
341
+ )
342
+ elif y_time is not None and r_time is None:
343
+ ax1.text(
344
+ i,
345
+ y_time * 1.05,
346
+ "Y-OK",
347
+ ha="center",
348
+ va="bottom",
349
+ fontsize=7,
350
+ fontweight="bold",
351
+ color="blue",
352
+ )
353
+ elif y_time is None and r_time is not None:
354
+ ax1.text(
355
+ i,
356
+ r_time * 1.05,
357
+ "R-OK",
358
+ ha="center",
359
+ va="bottom",
360
+ fontsize=7,
361
+ fontweight="bold",
362
+ color="orange",
363
+ )
364
+ else:
365
+ ax1.text(
366
+ i,
367
+ 0.1,
368
+ "OOM",
369
+ ha="center",
370
+ va="bottom",
371
+ fontsize=7,
372
+ fontweight="bold",
373
+ color="red",
374
+ )
375
+
376
+ ax1.set_ylabel("Time (ms)", fontsize=9)
377
+ ax1.set_yscale("log")
378
+ ax1.set_xticks(x)
379
+ ax1.set_xticklabels(batch_sizes, fontsize=8)
380
+ ax1.grid(True, alpha=0.3, axis="y")
381
+
382
+ config = result["config"]
383
+ ax1.set_title(
384
+ f"Time: seq={config['seq_len']}, h={config['hidden_dim']}, e={config['num_experts']}",
385
+ fontsize=8,
386
+ fontweight="bold",
387
+ )
388
+
389
+ if config_idx == 0:
390
+ ax1.legend(loc="upper left", fontsize=8)
391
+
392
+ # Memory comparison subplot
393
+ ax2 = plt.subplot(3, 6, config_idx + 7)
394
+
395
+ # Filter out None values for memory plotting
396
+ yamoe_mem_filtered = [m if m is not None else 0 for m in result["yamoe_memory"]]
397
+ ref_mem_filtered = [m if m is not None else 0 for m in result["reference_memory"]]
398
+
399
+ bars3 = ax2.bar(
400
+ x - width / 2,
401
+ yamoe_mem_filtered,
402
+ width,
403
+ label="Yamoe",
404
+ color="#2ca02c",
405
+ alpha=0.8,
406
+ )
407
+ bars4 = ax2.bar(
408
+ x + width / 2,
409
+ ref_mem_filtered,
410
+ width,
411
+ label="Reference",
412
+ color="#d62728",
413
+ alpha=0.8,
414
+ )
415
+
416
+ # Add memory reduction annotations (only where both values exist)
417
+ for i, (y_mem, r_mem) in enumerate(
418
+ zip(result["yamoe_memory"], result["reference_memory"])
419
+ ):
420
+ if y_mem is not None and r_mem is not None:
421
+ reduction = r_mem / y_mem
422
+ ax2.text(
423
+ i,
424
+ max(y_mem, r_mem) * 1.05,
425
+ f"{reduction:.1f}x",
426
+ ha="center",
427
+ va="bottom",
428
+ fontsize=7,
429
+ fontweight="bold",
430
+ color="purple",
431
+ )
432
+
433
+ ax2.set_ylabel("Memory (MB)", fontsize=9)
434
+ ax2.set_yscale("log")
435
+ ax2.set_xticks(x)
436
+ ax2.set_xticklabels(batch_sizes, fontsize=8)
437
+ ax2.grid(True, alpha=0.3, axis="y")
438
+ ax2.set_title(
439
+ f"Memory: seq={config['seq_len']}, h={config['hidden_dim']}, e={config['num_experts']}",
440
+ fontsize=8,
441
+ fontweight="bold",
442
+ )
443
+
444
+ if config_idx == 0:
445
+ ax2.legend(loc="upper left", fontsize=8)
446
+
447
+ # Combined speedup and memory efficiency subplot
448
+ ax3 = plt.subplot(3, 6, config_idx + 13)
449
+
450
+ # Calculate speedups and memory reductions, handling None values
451
+ valid_speedups = []
452
+ valid_mem_reductions = []
453
+ valid_batch_sizes_speedup = []
454
+ valid_batch_sizes_mem = []
455
+
456
+ for i, (r, y) in enumerate(zip(result["reference_times"], result["yamoe_times"])):
457
+ if r is not None and y is not None:
458
+ valid_speedups.append(r / y)
459
+ valid_batch_sizes_speedup.append(batch_sizes[i])
460
+
461
+ for i, (r, y) in enumerate(zip(result["reference_memory"], result["yamoe_memory"])):
462
+ if r is not None and y is not None:
463
+ valid_mem_reductions.append(r / y)
464
+ valid_batch_sizes_mem.append(batch_sizes[i])
465
+
466
+ if valid_speedups:
467
+ ax3.plot(
468
+ valid_batch_sizes_speedup,
469
+ valid_speedups,
470
+ "o-",
471
+ label="Time Speedup",
472
+ color="green",
473
+ linewidth=2,
474
+ markersize=6,
475
+ )
476
+ if valid_mem_reductions:
477
+ ax3.plot(
478
+ valid_batch_sizes_mem,
479
+ valid_mem_reductions,
480
+ "s-",
481
+ label="Memory Reduction",
482
+ color="purple",
483
+ linewidth=2,
484
+ markersize=6,
485
+ )
486
+
487
+ ax3.set_xlabel("Batch Size", fontsize=9)
488
+ ax3.set_ylabel("Improvement Factor", fontsize=9)
489
+ ax3.set_xticks(batch_sizes)
490
+ ax3.grid(True, alpha=0.3)
491
+ ax3.axhline(y=1, color="gray", linestyle="--", alpha=0.5)
492
+ ax3.set_title(
493
+ f"Improvements: seq={config['seq_len']}, h={config['hidden_dim']}",
494
+ fontsize=8,
495
+ fontweight="bold",
496
+ )
497
+
498
+ if config_idx == 0:
499
+ ax3.legend(loc="upper left", fontsize=8)
500
+
501
+ plt.suptitle(
502
+ "MoE Performance & Memory Comparison - Yamoe vs Reference",
503
+ fontsize=16,
504
+ fontweight="bold",
505
+ y=0.98,
506
+ )
507
+ plt.tight_layout()
508
+ plt.savefig("moe_performance_comparison.png", dpi=150, bbox_inches="tight")
509
+ plt.show()
510
+
511
+ # Removed heatmap section per user request
512
+
513
+ # Print detailed summary
514
+ print("\n" + "=" * 80)
515
+ print("DETAILED SUMMARY")
516
+ print("=" * 80)
517
+
518
+ for idx, result in enumerate(all_results[:6]):
519
+ config = result["config"]
520
+ print(f"\nConfiguration {idx + 1}:")
521
+ print(
522
+ f" Parameters: seq_len={config['seq_len']}, hidden_dim={config['hidden_dim']}, "
523
+ f"experts={config['num_experts']}, top_k={config['top_k']}"
524
+ )
525
+ # Handle None values in speedups
526
+ valid_speedups = [s for s in result["speedups"] if s is not None]
527
+ if valid_speedups:
528
+ print(f" Average Speedup: {sum(valid_speedups) / len(valid_speedups):.2f}x")
529
+ max_speedup = max(valid_speedups)
530
+ min_speedup = min(valid_speedups)
531
+ max_idx = result["speedups"].index(max_speedup)
532
+ min_idx = result["speedups"].index(min_speedup)
533
+ print(f" Max Speedup: {max_speedup:.2f}x at batch_size={batch_sizes[max_idx]}")
534
+ print(f" Min Speedup: {min_speedup:.2f}x at batch_size={batch_sizes[min_idx]}")
535
+ else:
536
+ print(" No valid speedup measurements (all OOM)")
readme_example.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = "==3.10"
3
+ # dependencies = ["torch==2.7.0", "triton", "numpy", "kernels"]
4
+ # [tool.uv.sources]
5
+ # kernels = { git = "https://github.com/huggingface/kernels.git" }
6
+ # ///
7
+
8
+ import time
9
+ import torch
10
+ from kernels import get_local_kernel
11
+ from kernels import get_kernel
12
+ from pathlib import Path
13
+ from torch.nn import functional as F
14
+
15
+ # Set seeds and deterministic flags for reproducibility
16
+ torch.manual_seed(42)
17
+ torch.cuda.manual_seed(42)
18
+ torch.cuda.manual_seed_all(42)
19
+ torch.backends.cudnn.deterministic = True
20
+ torch.backends.cudnn.benchmark = False
21
+
22
+ yamoe = get_kernel("drbh/yamoe", revision="v0.1.0")
23
+
24
+ # Configuration
25
+ batch_size, seq_len, hidden_dim = 16, 256, 2880
26
+ num_experts, top_k = 8, 2
27
+
28
+ # Create routing weights
29
+ logits = torch.randn(batch_size, seq_len, num_experts)
30
+ probs = F.softmax(logits, dim=-1)
31
+ weights, indices = torch.topk(probs, top_k, dim=-1)
32
+
33
+ batch_seq = batch_size * seq_len
34
+ routing_weights = torch.zeros(batch_seq, num_experts, dtype=weights.dtype)
35
+ flat_indices, flat_weights = indices.reshape(-1, top_k), weights.reshape(-1, top_k)
36
+ batch_indices = torch.arange(batch_seq).unsqueeze(1).expand(-1, top_k)
37
+ routing_weights[batch_indices, flat_indices] = flat_weights
38
+
39
+ # Create model tensors
40
+ hidden_states = torch.randn(batch_size, seq_len, hidden_dim).cuda()
41
+ gate_up_proj = torch.randn(num_experts, hidden_dim, 2 * hidden_dim).cuda()
42
+ gate_up_proj_bias = torch.zeros(num_experts, 2 * hidden_dim).cuda()
43
+ down_proj = torch.randn(num_experts, hidden_dim, hidden_dim).cuda()
44
+ down_proj_bias = torch.zeros(num_experts, hidden_dim).cuda()
45
+ routing_weights = routing_weights.cuda()
46
+ router_indices = flat_indices.cuda()
47
+
48
+ # Warmup
49
+ for _ in range(5):
50
+ _ = yamoe.experts(
51
+ hidden_states.view(-1, hidden_dim),
52
+ router_indices,
53
+ routing_weights.view(-1, num_experts),
54
+ gate_up_proj,
55
+ gate_up_proj_bias,
56
+ down_proj,
57
+ down_proj_bias,
58
+ seq_len,
59
+ num_experts,
60
+ top_k,
61
+ )
62
+
63
+ # Benchmark
64
+ torch.cuda.synchronize()
65
+ torch.cuda.reset_peak_memory_stats()
66
+ start = time.perf_counter()
67
+
68
+ with torch.no_grad():
69
+ output = yamoe.experts(
70
+ hidden_states.view(-1, hidden_dim),
71
+ router_indices,
72
+ routing_weights.view(-1, num_experts),
73
+ gate_up_proj,
74
+ gate_up_proj_bias,
75
+ down_proj,
76
+ down_proj_bias,
77
+ seq_len,
78
+ num_experts,
79
+ top_k,
80
+ )
81
+
82
+ torch.cuda.synchronize()
83
+ elapsed_ms = (time.perf_counter() - start) * 1e3
84
+ peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
85
+
86
+ print(f"Output: sum={output.sum().item():.1f}, min={output.min().item():.1f}, max={output.max().item():.1f}")
87
+ print(f"First 3: {output.view(-1)[:3].tolist()}")
88
+ print(f"Time: {elapsed_ms:.1f}ms, Memory: {peak_mem_mb:.0f}MB")