drbh commited on
Commit
733f7f4
·
1 Parent(s): 0daa7ef

feat: impl backward experts

Browse files
.gitignore CHANGED
@@ -12,4 +12,8 @@ tests
12
  torch-ext/registration.h
13
  torch-ext/yamoe/_ops.py
14
  csrc/batch_mm.cu
15
- torch-ext/yamoe/*.abi3.so
 
 
 
 
 
12
  torch-ext/registration.h
13
  torch-ext/yamoe/_ops.py
14
  csrc/batch_mm.cu
15
+ torch-ext/yamoe/*.abi3.so
16
+
17
+ build-ext
18
+ build
19
+ exploration
build.toml CHANGED
@@ -32,5 +32,6 @@ src = [
32
  "csrc/sort.cu",
33
  "csrc/bincount_cumsum.cu",
34
  "csrc/batch_mm.cu",
35
- "csrc/moe.cpp"
 
36
  ]
 
32
  "csrc/sort.cu",
33
  "csrc/bincount_cumsum.cu",
34
  "csrc/batch_mm.cu",
35
+ "csrc/moe.cpp",
36
+ "csrc/experts_backward.cu"
37
  ]
compare_example.py CHANGED
@@ -7,10 +7,11 @@
7
 
8
  import time
9
  import torch
10
- from kernels import get_local_kernel
11
- from kernels import get_kernel
12
  from pathlib import Path
13
  from torch.nn import functional as F
 
 
14
 
15
  # Set seeds and deterministic flags for reproducibility
16
  torch.manual_seed(42)
@@ -19,11 +20,23 @@ torch.cuda.manual_seed_all(42)
19
  torch.backends.cudnn.deterministic = True
20
  torch.backends.cudnn.benchmark = False
21
 
22
- yamoe = get_kernel("drbh/yamoe", revision="v0.1.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Configuration
25
  batch_size, seq_len, hidden_dim = 4, 1024, 2880
26
- # batch_size, seq_len, hidden_dim = 4, 32, 1024
27
  num_experts, top_k = 8, 2
28
 
29
  # Create routing weights
@@ -52,6 +65,7 @@ torch.nn.init.trunc_normal_(gate_up_proj, std=0.02)
52
  torch.nn.init.trunc_normal_(down_proj, std=0.02)
53
 
54
  routing_weights = routing_weights.to(dtype=torch.float32, device="cuda")
 
55
 
56
 
57
  # Warmup
@@ -64,7 +78,7 @@ for _ in range(5):
64
  gate_up_proj_bias,
65
  down_proj,
66
  down_proj_bias,
67
- seq_len,
68
  num_experts,
69
  top_k,
70
  )
@@ -74,6 +88,7 @@ torch.cuda.synchronize()
74
  torch.cuda.reset_peak_memory_stats()
75
  start = time.perf_counter()
76
 
 
77
  with torch.no_grad():
78
  output = yamoe.experts(
79
  hidden_states.view(-1, hidden_dim),
@@ -83,7 +98,7 @@ with torch.no_grad():
83
  gate_up_proj_bias,
84
  down_proj,
85
  down_proj_bias,
86
- seq_len,
87
  num_experts,
88
  top_k,
89
  )
@@ -104,7 +119,7 @@ config.hidden_size = hidden_dim
104
  config.intermediate_size = 4 * hidden_dim
105
  config.num_local_experts = num_experts
106
 
107
- model = yamoe.reference.GptOssExperts(config)
108
 
109
  # set the weights and biases from above to the reference model
110
  model.gate_up_proj.data = gate_up_proj
@@ -133,79 +148,165 @@ ref_memory = peak_mem_mb
133
  # Reshape reference output to match kernel output
134
  ref_output_reshaped = ref_output.view(kernel_output.shape)
135
 
136
- # Calculate similarity metrics
137
- mse = torch.nn.functional.mse_loss(kernel_output, ref_output_reshaped).item()
138
- mae = torch.nn.functional.l1_loss(kernel_output, ref_output_reshaped).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  # Cosine similarity
141
  kernel_flat = kernel_output.view(-1)
142
  ref_flat = ref_output_reshaped.view(-1)
143
- cosine_sim = torch.nn.functional.cosine_similarity(
 
144
  kernel_flat.unsqueeze(0), ref_flat.unsqueeze(0)
145
  ).item()
146
 
147
  # Relative error (L2 norm of difference / L2 norm of reference)
148
- diff_norm = torch.norm(kernel_output - ref_output_reshaped).item()
149
  ref_norm = torch.norm(ref_output_reshaped).item()
150
- rel_error = diff_norm / ref_norm if ref_norm > 0 else float("inf")
151
 
152
  # Max absolute difference
153
- max_abs_diff = torch.max(torch.abs(kernel_output - ref_output_reshaped)).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  # Print comparison table
156
- print("\n" + "=" * 80)
157
- print(f"{'METRIC':<20} {'KERNEL':<15} {'REFERENCE':<15} {'SIMILARITY/SPEEDUP':<15}")
158
- print("=" * 80)
 
 
159
 
160
  print(
161
- f"{'Sum':<20} {kernel_output.sum().item():<15.4f} {ref_output_reshaped.sum().item():<15.4f} {'N/A':<15}"
162
  )
163
  print(
164
- f"{'Min':<20} {kernel_output.min().item():<15.4f} {ref_output_reshaped.min().item():<15.4f} {'N/A':<15}"
165
  )
166
  print(
167
- f"{'Max':<20} {kernel_output.max().item():<15.4f} {ref_output_reshaped.max().item():<15.4f} {'N/A':<15}"
168
  )
169
  print(
170
- f"{'Norm (L2)':<20} {kernel_output.norm().item():<15.4f} {ref_output_reshaped.norm().item():<15.4f} {'N/A':<15}"
171
  )
172
  print(
173
- f"{'Std':<20} {kernel_output.std().item():<15.4f} {ref_output_reshaped.std().item():<15.4f} {'N/A':<15}"
174
  )
175
 
176
- print("-" * 80)
177
  print(
178
- f"{'Time (ms)':<20} {kernel_time:<15.3f} {ref_time:<15.3f} {ref_time / kernel_time:<15.2f}x"
179
  )
180
  print(
181
- f"{'Memory (MB)':<20} {kernel_memory:<15.2f} {ref_memory:<15.2f} {ref_memory / kernel_memory:<15.2f}x"
182
  )
183
 
184
- print("-" * 80)
185
- print("SIMILARITY METRICS")
186
- print("-" * 80)
187
- print(f"{'METRIC':<20} {'VALUE':<15} {'DIFFERENCE':<15}")
188
- print("-" * 80)
189
- print(f"{'MSE':<20} {mse:<15.6e} {'N/A':<15}")
190
- print(f"{'MAE':<20} {mae:<15.6e} {'N/A':<15}")
191
- print(f"{'Cosine Similarity':<20} {cosine_sim:<15.6f} {abs(1.0 - cosine_sim):<15.6f}")
192
- print(f"{'Relative Error':<20} {rel_error:<15.6e} {'N/A':<15}")
193
- print(f"{'Max Abs Diff':<20} {max_abs_diff:<15.6e} {'N/A':<15}")
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- print("-" * 80)
196
  print("FIRST 10 ELEMENTS COMPARISON")
197
- print("-" * 80)
198
 
199
- # Get first 10 elements as numpy arrays for nice display
200
- kernel_first_10 = kernel_flat[:10].cpu().numpy()
201
- ref_first_10 = ref_flat[:10].cpu().numpy()
202
- diff_first_10 = kernel_first_10 - ref_first_10
203
 
204
- print(f"{'INDEX':<5} {'KERNEL':<12} {'REFERENCE':<12} {'DIFF':<12}")
205
- print("-" * 45)
206
- for i in range(10):
 
 
 
 
 
 
 
 
 
 
207
  print(
208
- f"{i:<5} {kernel_first_10[i]:<12.6f} {ref_first_10[i]:<12.6f} {diff_first_10[i]:<12.6f}"
209
  )
210
 
211
- print("=" * 80)
 
7
 
8
  import time
9
  import torch
10
+ from kernels import get_kernel, get_local_kernel
 
11
  from pathlib import Path
12
  from torch.nn import functional as F
13
+ import numpy as np
14
+ import sys
15
 
16
  # Set seeds and deterministic flags for reproducibility
17
  torch.manual_seed(42)
 
20
  torch.backends.cudnn.deterministic = True
21
  torch.backends.cudnn.benchmark = False
22
 
23
+ np.set_printoptions(precision=4)
24
+
25
+ load_method = 2 # 1: sym, 2: local, 3: hf
26
+
27
+ if load_method == 1:
28
+ sys.path.insert(0, "./torch-ext")
29
+ import yamoe
30
+ elif load_method == 2:
31
+ yamoe = get_local_kernel(Path("result"), "yamoe")
32
+ elif load_method == 3:
33
+ yamoe = get_kernel("drbh/yamoe", revision="v0.1.0")
34
+
35
+ binned_experts_ref = yamoe.vendored.yamoe_ref.binned_experts_ref
36
+ GptOssExperts = yamoe.vendored.gpt_oss_mlp.GptOssExperts
37
 
38
  # Configuration
39
  batch_size, seq_len, hidden_dim = 4, 1024, 2880
 
40
  num_experts, top_k = 8, 2
41
 
42
  # Create routing weights
 
65
  torch.nn.init.trunc_normal_(down_proj, std=0.02)
66
 
67
  routing_weights = routing_weights.to(dtype=torch.float32, device="cuda")
68
+ expert_capacity = batch_seq * top_k // num_experts * 2
69
 
70
 
71
  # Warmup
 
78
  gate_up_proj_bias,
79
  down_proj,
80
  down_proj_bias,
81
+ expert_capacity,
82
  num_experts,
83
  top_k,
84
  )
 
88
  torch.cuda.reset_peak_memory_stats()
89
  start = time.perf_counter()
90
 
91
+
92
  with torch.no_grad():
93
  output = yamoe.experts(
94
  hidden_states.view(-1, hidden_dim),
 
98
  gate_up_proj_bias,
99
  down_proj,
100
  down_proj_bias,
101
+ expert_capacity,
102
  num_experts,
103
  top_k,
104
  )
 
119
  config.intermediate_size = 4 * hidden_dim
120
  config.num_local_experts = num_experts
121
 
122
+ model = GptOssExperts(config)
123
 
124
  # set the weights and biases from above to the reference model
125
  model.gate_up_proj.data = gate_up_proj
 
148
  # Reshape reference output to match kernel output
149
  ref_output_reshaped = ref_output.view(kernel_output.shape)
150
 
151
+ # Test yamoe_ref implementation
152
+ expert_capacity = batch_seq * top_k // num_experts * 2 # Generous capacity
153
+
154
+ torch.cuda.synchronize()
155
+ torch.cuda.reset_peak_memory_stats()
156
+ start = time.perf_counter()
157
+
158
+ with torch.no_grad():
159
+ yamoe_ref_output = binned_experts_ref(
160
+ hidden_states,
161
+ router_indices,
162
+ routing_weights,
163
+ gate_up_proj,
164
+ gate_up_proj_bias,
165
+ down_proj,
166
+ down_proj_bias,
167
+ expert_capacity,
168
+ )
169
+
170
+ torch.cuda.synchronize()
171
+ yamoe_ref_time = (time.perf_counter() - start) * 1e3
172
+ yamoe_ref_memory = torch.cuda.max_memory_allocated() / (1024 * 1024)
173
+
174
+ # Reshape yamoe_ref output to match kernel output
175
+ yamoe_ref_output_reshaped = yamoe_ref_output.view(kernel_output.shape)
176
+
177
+ # Calculate similarity metrics between kernel and reference
178
+ mse_kernel_ref = torch.nn.functional.mse_loss(kernel_output, ref_output_reshaped).item()
179
+ mae_kernel_ref = torch.nn.functional.l1_loss(kernel_output, ref_output_reshaped).item()
180
 
181
  # Cosine similarity
182
  kernel_flat = kernel_output.view(-1)
183
  ref_flat = ref_output_reshaped.view(-1)
184
+ yamoe_ref_flat = yamoe_ref_output_reshaped.view(-1)
185
+ cosine_sim_kernel_ref = torch.nn.functional.cosine_similarity(
186
  kernel_flat.unsqueeze(0), ref_flat.unsqueeze(0)
187
  ).item()
188
 
189
  # Relative error (L2 norm of difference / L2 norm of reference)
190
+ diff_norm_kernel_ref = torch.norm(kernel_output - ref_output_reshaped).item()
191
  ref_norm = torch.norm(ref_output_reshaped).item()
192
+ rel_error_kernel_ref = diff_norm_kernel_ref / ref_norm if ref_norm > 0 else float("inf")
193
 
194
  # Max absolute difference
195
+ max_abs_diff_kernel_ref = torch.max(
196
+ torch.abs(kernel_output - ref_output_reshaped)
197
+ ).item()
198
+
199
+ # Calculate similarity metrics between kernel and yamoe_ref
200
+ mse_kernel_yamoe = torch.nn.functional.mse_loss(
201
+ kernel_output, yamoe_ref_output_reshaped
202
+ ).item()
203
+ mae_kernel_yamoe = torch.nn.functional.l1_loss(
204
+ kernel_output, yamoe_ref_output_reshaped
205
+ ).item()
206
+ cosine_sim_kernel_yamoe = torch.nn.functional.cosine_similarity(
207
+ kernel_flat.unsqueeze(0), yamoe_ref_flat.unsqueeze(0)
208
+ ).item()
209
+ diff_norm_kernel_yamoe = torch.norm(kernel_output - yamoe_ref_output_reshaped).item()
210
+ yamoe_ref_norm = torch.norm(yamoe_ref_output_reshaped).item()
211
+ rel_error_kernel_yamoe = (
212
+ diff_norm_kernel_yamoe / yamoe_ref_norm if yamoe_ref_norm > 0 else float("inf")
213
+ )
214
+ max_abs_diff_kernel_yamoe = torch.max(
215
+ torch.abs(kernel_output - yamoe_ref_output_reshaped)
216
+ ).item()
217
+
218
+ # Calculate similarity metrics between reference and yamoe_ref
219
+ mse_ref_yamoe = torch.nn.functional.mse_loss(
220
+ ref_output_reshaped, yamoe_ref_output_reshaped
221
+ ).item()
222
+ mae_ref_yamoe = torch.nn.functional.l1_loss(
223
+ ref_output_reshaped, yamoe_ref_output_reshaped
224
+ ).item()
225
+ cosine_sim_ref_yamoe = torch.nn.functional.cosine_similarity(
226
+ ref_flat.unsqueeze(0), yamoe_ref_flat.unsqueeze(0)
227
+ ).item()
228
+ diff_norm_ref_yamoe = torch.norm(ref_output_reshaped - yamoe_ref_output_reshaped).item()
229
+ rel_error_ref_yamoe = (
230
+ diff_norm_ref_yamoe / yamoe_ref_norm if yamoe_ref_norm > 0 else float("inf")
231
+ )
232
+ max_abs_diff_ref_yamoe = torch.max(
233
+ torch.abs(ref_output_reshaped - yamoe_ref_output_reshaped)
234
+ ).item()
235
 
236
  # Print comparison table
237
+ print("\n" + "=" * 110)
238
+ print(
239
+ f"{'METRIC':<20} {'KERNEL':<15} {'REFERENCE':<15} {'YAMOE_REF':<15} {'KERNEL SPEEDUP':<20} {'REF SPEEDUP':<15}"
240
+ )
241
+ print("=" * 110)
242
 
243
  print(
244
+ f"{'Sum':<20} {kernel_output.sum().item():<15.4f} {ref_output_reshaped.sum().item():<15.4f} {yamoe_ref_output_reshaped.sum().item():<15.4f} {'N/A':<20} {'N/A':<15}"
245
  )
246
  print(
247
+ f"{'Min':<20} {kernel_output.min().item():<15.4f} {ref_output_reshaped.min().item():<15.4f} {yamoe_ref_output_reshaped.min().item():<15.4f} {'N/A':<20} {'N/A':<15}"
248
  )
249
  print(
250
+ f"{'Max':<20} {kernel_output.max().item():<15.4f} {ref_output_reshaped.max().item():<15.4f} {yamoe_ref_output_reshaped.max().item():<15.4f} {'N/A':<20} {'N/A':<15}"
251
  )
252
  print(
253
+ f"{'Norm (L2)':<20} {kernel_output.norm().item():<15.4f} {ref_output_reshaped.norm().item():<15.4f} {yamoe_ref_output_reshaped.norm().item():<15.4f} {'N/A':<20} {'N/A':<15}"
254
  )
255
  print(
256
+ f"{'Std':<20} {kernel_output.std().item():<15.4f} {ref_output_reshaped.std().item():<15.4f} {yamoe_ref_output_reshaped.std().item():<15.4f} {'N/A':<20} {'N/A':<15}"
257
  )
258
 
259
+ print("-" * 110)
260
  print(
261
+ f"{'Time (ms)':<20} {kernel_time:<15.3f} {ref_time:<15.3f} {yamoe_ref_time:<15.3f} {yamoe_ref_time / kernel_time:<20.2f}x {yamoe_ref_time / ref_time:<15.2f}x"
262
  )
263
  print(
264
+ f"{'Memory (MB)':<20} {kernel_memory:<15.2f} {ref_memory:<15.2f} {yamoe_ref_memory:<15.2f} {yamoe_ref_memory / kernel_memory:<20.2f}x {yamoe_ref_memory / ref_memory:<15.2f}x"
265
  )
266
 
267
+ print("-" * 110)
268
+ print("SIMILARITY METRICS (vs KERNEL)")
269
+ print("-" * 110)
270
+ print(
271
+ f"{'METRIC':<20} {'KERNEL vs REF':<20} {'KERNEL vs YAMOE_REF':<20} {'REF vs YAMOE_REF':<20}"
272
+ )
273
+ print("-" * 110)
274
+ print(
275
+ f"{'MSE':<20} {mse_kernel_ref:<20.6e} {mse_kernel_yamoe:<20.6e} {mse_ref_yamoe:<20.6e}"
276
+ )
277
+ print(
278
+ f"{'MAE':<20} {mae_kernel_ref:<20.6e} {mae_kernel_yamoe:<20.6e} {mae_ref_yamoe:<20.6e}"
279
+ )
280
+ print(
281
+ f"{'Cosine Similarity':<20} {cosine_sim_kernel_ref:<20.6f} {cosine_sim_kernel_yamoe:<20.6f} {cosine_sim_ref_yamoe:<20.6f}"
282
+ )
283
+ print(
284
+ f"{'Relative Error':<20} {rel_error_kernel_ref:<20.6e} {rel_error_kernel_yamoe:<20.6e} {rel_error_ref_yamoe:<20.6e}"
285
+ )
286
+ print(
287
+ f"{'Max Abs Diff':<20} {max_abs_diff_kernel_ref:<20.6e} {max_abs_diff_kernel_yamoe:<20.6e} {max_abs_diff_ref_yamoe:<20.6e}"
288
+ )
289
 
290
+ print("-" * 110)
291
  print("FIRST 10 ELEMENTS COMPARISON")
292
+ print("-" * 110)
293
 
 
 
 
 
294
 
295
+ # Get first N elements as numpy arrays for nice display
296
+ N = 10
297
+ kernel_first_10 = kernel_flat[:N].cpu().numpy()
298
+ ref_first_10 = ref_flat[:N].cpu().numpy()
299
+ yamoe_ref_first_10 = yamoe_ref_flat[:N].cpu().numpy()
300
+ diff_kernel_ref = kernel_first_10 - ref_first_10
301
+ diff_kernel_yamoe = kernel_first_10 - yamoe_ref_first_10
302
+
303
+ print(
304
+ f"{'INDEX':<5} {'KERNEL':<12} {'REFERENCE':<12} {'YAMOE_REF':<12} {'K-R DIFF':<12} {'K-Y DIFF':<12}"
305
+ )
306
+ print("-" * 70)
307
+ for i in range(N):
308
  print(
309
+ f"{i:<5} {kernel_first_10[i]:<12.6f} {ref_first_10[i]:<12.6f} {yamoe_ref_first_10[i]:<12.6f} {diff_kernel_ref[i]:<12.6f} {diff_kernel_yamoe[i]:<12.6f}"
310
  )
311
 
312
+ print("=" * 110)
csrc/experts_backward.cu ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Backward pass for MoE experts
2
+
3
+ #include <ATen/cuda/Atomic.cuh>
4
+ #include <ATen/cuda/CUDAContext.h>
5
+ #include <c10/cuda/CUDAGuard.h>
6
+ #include <cuda_runtime.h>
7
+ #include <torch/torch.h>
8
+
9
+ void sort_cuda(torch::Tensor x,
10
+ int64_t end_bit,
11
+ torch::Tensor x_out,
12
+ torch::Tensor iota_out);
13
+ void bincount_cumsum_cuda(torch::Tensor input,
14
+ torch::Tensor &output,
15
+ int64_t minlength);
16
+ void gather_cuda(const torch::Tensor &x,
17
+ const torch::Tensor &indices,
18
+ const torch::Tensor &bins,
19
+ torch::Tensor &output,
20
+ int64_t E,
21
+ int64_t C,
22
+ int64_t top_k);
23
+ void scatter_cuda(const torch::Tensor &src,
24
+ const torch::Tensor &indices,
25
+ const torch::Tensor &bins,
26
+ const torch::Tensor &weights,
27
+ torch::Tensor &y,
28
+ int64_t T,
29
+ int64_t E,
30
+ int64_t C,
31
+ int64_t top_k);
32
+ torch::Tensor index_select_out_cuda(torch::Tensor out,
33
+ torch::Tensor in,
34
+ torch::Tensor idx_int32);
35
+
36
+ // scatter gradients back to expert outputs and routing weights
37
+ template <typename scalar_t>
38
+ __global__ void binned_scatter_backward_kernel(
39
+ const scalar_t *__restrict__ grad_y, // [T, H]
40
+ const int *__restrict__ indices, // [S]
41
+ const int *__restrict__ bins, // [E+1]
42
+ const scalar_t *__restrict__ selected_weights, // [S]
43
+ const scalar_t *__restrict__ expert_output, // [E, C, H]
44
+ scalar_t *__restrict__ grad_expert_output, // [E, C, H]
45
+ scalar_t *__restrict__ grad_selected_weights, // [S]
46
+ int T,
47
+ int K,
48
+ int H,
49
+ int E,
50
+ int C) {
51
+
52
+ int e = blockIdx.x;
53
+ int i = blockIdx.y;
54
+ if (e >= E || i >= C)
55
+ return;
56
+
57
+ const int start = (e == 0) ? 0 : bins[e - 1];
58
+ const int end = bins[e];
59
+ const int n_all = end - start;
60
+ const int take = (n_all > 0) ? min(n_all, C) : 0;
61
+
62
+ if (take == 0 || i >= take) {
63
+ scalar_t *dst = grad_expert_output + ((size_t)e * C + i) * H;
64
+ for (int h = threadIdx.x; h < H; h += blockDim.x)
65
+ dst[h] = scalar_t(0);
66
+ return;
67
+ }
68
+
69
+ const int sorted_pos = start + i;
70
+ const int flat_pos = indices[sorted_pos];
71
+ const int tok = flat_pos / K;
72
+
73
+ const scalar_t scale = selected_weights[sorted_pos];
74
+
75
+ const scalar_t *grad_y_ptr = grad_y + (size_t)tok * H;
76
+ scalar_t *grad_exp_ptr = grad_expert_output + ((size_t)e * C + i) * H;
77
+ const scalar_t *expert_ptr = expert_output + ((size_t)e * C + i) * H;
78
+
79
+ for (int h = threadIdx.x; h < H; h += blockDim.x) {
80
+ grad_exp_ptr[h] += grad_y_ptr[h] * scale;
81
+ }
82
+
83
+ if (threadIdx.x == 0) {
84
+ scalar_t sum = scalar_t(0);
85
+ for (int h = 0; h < H; ++h)
86
+ sum += grad_y_ptr[h] * expert_ptr[h];
87
+ gpuAtomicAdd(&grad_selected_weights[flat_pos], sum);
88
+ }
89
+ }
90
+
91
+ // gather gradients back to hidden states
92
+ template <typename scalar_t>
93
+ __global__ void binned_gather_backward_kernel(
94
+ const scalar_t *__restrict__ grad_x, // [E, C, H]
95
+ const int *__restrict__ indices, // [S]
96
+ const int *__restrict__ bins, // [E+1]
97
+ scalar_t *__restrict__ grad_hidden, // [T, H]
98
+ int T,
99
+ int K,
100
+ int H,
101
+ int E,
102
+ int C) {
103
+
104
+ int e = blockIdx.x;
105
+ int i = blockIdx.y;
106
+ if (e >= E || i >= C)
107
+ return;
108
+
109
+ const int start = (e == 0) ? 0 : bins[e - 1];
110
+ const int end = bins[e];
111
+ const int n = min(max(end - start, 0), C);
112
+ if (i >= n)
113
+ return;
114
+
115
+ const int flat_pos = indices[start + i];
116
+ const int tok = flat_pos / K;
117
+
118
+ const scalar_t *gx = grad_x + ((size_t)e * C + i) * H;
119
+ scalar_t *gh = grad_hidden + (size_t)tok * H;
120
+
121
+ for (int h = threadIdx.x; h < H; h += blockDim.x) {
122
+ gpuAtomicAdd(&gh[h], gx[h]);
123
+ }
124
+ }
125
+
126
+ std::vector<torch::Tensor> experts_backward_cuda(
127
+ const torch::Tensor &grad_out,
128
+ const torch::Tensor &hidden_states,
129
+ const torch::Tensor &router_indices,
130
+ const torch::Tensor &routing_weights,
131
+ const torch::Tensor &gate_up_proj,
132
+ const torch::Tensor &gate_up_proj_bias,
133
+ const torch::Tensor &down_proj,
134
+ const torch::Tensor &down_proj_bias,
135
+ int64_t expert_capacity,
136
+ int64_t num_experts,
137
+ int64_t top_k) {
138
+ TORCH_CHECK(grad_out.is_cuda(), "grad_out must be CUDA");
139
+ TORCH_CHECK(hidden_states.is_cuda(), "hidden_states must be CUDA");
140
+ TORCH_CHECK(router_indices.is_cuda(), "router_indices must be CUDA");
141
+ TORCH_CHECK(routing_weights.is_cuda(), "routing_weights must be CUDA");
142
+ TORCH_CHECK(gate_up_proj.is_cuda() && down_proj.is_cuda(),
143
+ "weights must be CUDA");
144
+ TORCH_CHECK(gate_up_proj_bias.is_cuda() && down_proj_bias.is_cuda(),
145
+ "biases must be CUDA");
146
+
147
+ const at::cuda::OptionalCUDAGuard device_guard(grad_out.device());
148
+
149
+ const int64_t T = hidden_states.size(0);
150
+ const int64_t H = hidden_states.size(1);
151
+ const int64_t E = num_experts;
152
+ const int64_t C = expert_capacity;
153
+ const int64_t K = top_k;
154
+
155
+ TORCH_CHECK(router_indices.dim() == 2 && router_indices.size(0) == T &&
156
+ router_indices.size(1) == K,
157
+ "router_indices must be [T, K]");
158
+
159
+ auto float_opts = hidden_states.options();
160
+ auto i32_opts = torch::TensorOptions()
161
+ .device(hidden_states.device())
162
+ .dtype(torch::kInt32);
163
+
164
+ // Sort tokens by expert ID
165
+ torch::Tensor flat_indices =
166
+ router_indices.contiguous().view({-1}).to(torch::kInt32);
167
+ torch::Tensor sorted_values = torch::empty_like(flat_indices);
168
+ torch::Tensor sorted_indices = torch::empty_like(flat_indices);
169
+ sort_cuda(flat_indices, 32, sorted_values, sorted_indices);
170
+
171
+ // Compute expert boundaries
172
+ torch::Tensor bins = torch::empty({E + 1}, i32_opts);
173
+ bincount_cumsum_cuda(sorted_values, bins, E);
174
+ cudaDeviceSynchronize();
175
+
176
+ // Gather tokens for each expert
177
+ torch::Tensor x = torch::empty({E, C, H}, float_opts);
178
+ gather_cuda(hidden_states.contiguous(), sorted_indices, bins, x, E, C, K);
179
+
180
+ // Gate-up projection
181
+ torch::Tensor gate_up = at::bmm(x.contiguous(), gate_up_proj.contiguous());
182
+ gate_up.add_(gate_up_proj_bias.unsqueeze(1));
183
+
184
+ // GLU activation (recompute forward)
185
+ auto gu_pair = gate_up.view({E, C, H, 2});
186
+ torch::Tensor pre_gate = gu_pair.select(3, 0);
187
+ torch::Tensor pre_up = gu_pair.select(3, 1);
188
+
189
+ const double limit = 7.0;
190
+ const double alpha = 1.702;
191
+ torch::Tensor gate_clamped = at::clamp_max(pre_gate, limit);
192
+ torch::Tensor up_clamped = at::clamp(pre_up, -limit, limit);
193
+ torch::Tensor s = at::sigmoid(gate_clamped * alpha);
194
+ torch::Tensor gate_act = gate_clamped * s;
195
+ torch::Tensor up_out = (1 + up_clamped) * gate_act;
196
+
197
+ // Down projection
198
+ torch::Tensor y_expert = at::bmm(up_out.contiguous(), down_proj.contiguous());
199
+ y_expert.add_(down_proj_bias.unsqueeze(1));
200
+
201
+ // Get routing weights in sorted order
202
+ torch::Tensor flat_router = router_indices.view({T, K});
203
+ torch::Tensor selected_2d;
204
+ if (routing_weights.size(1) == K) {
205
+ selected_2d = routing_weights.contiguous();
206
+ } else {
207
+ TORCH_CHECK(routing_weights.size(1) == E,
208
+ "routing_weights must be [T,K] or [T,E]");
209
+ selected_2d = at::gather(routing_weights, 1, flat_router.to(torch::kLong));
210
+ }
211
+ torch::Tensor selected_flat = selected_2d.contiguous().view({T * K});
212
+ torch::Tensor weights_sorted = torch::empty_like(selected_flat);
213
+ index_select_out_cuda(weights_sorted, selected_flat, sorted_indices);
214
+
215
+ // Initialize gradients
216
+ torch::Tensor dHidden = torch::zeros_like(hidden_states);
217
+ torch::Tensor dRouting;
218
+ torch::Tensor dWgu = torch::zeros_like(gate_up_proj);
219
+ torch::Tensor dbgu = torch::zeros_like(gate_up_proj_bias);
220
+ torch::Tensor dWd = torch::zeros_like(down_proj);
221
+ torch::Tensor dbd = torch::zeros_like(down_proj_bias);
222
+
223
+ // Reshape grad_out to [T,H]
224
+ TORCH_CHECK(grad_out.numel() == T * H || grad_out.numel() == T * K * H,
225
+ "grad_out numel must be T*H or T*K*H");
226
+ torch::Tensor grad_y = (grad_out.numel() == T * H)
227
+ ? grad_out.contiguous().view({T, H})
228
+ : grad_out.contiguous().view({T, K, H}).sum(1);
229
+
230
+ // Backward through scatter
231
+ torch::Tensor grad_expert_output = torch::zeros({E, C, H}, float_opts);
232
+ torch::Tensor grad_selected_weights = torch::zeros({T * K}, float_opts);
233
+ {
234
+ dim3 grid(E, C);
235
+ int threads = 256;
236
+ AT_DISPATCH_FLOATING_TYPES_AND2(
237
+ at::kHalf,
238
+ at::kBFloat16,
239
+ hidden_states.scalar_type(),
240
+ "binned_scatter_backward",
241
+ [&] {
242
+ using st = scalar_t;
243
+ binned_scatter_backward_kernel<st>
244
+ <<<grid, threads>>>(grad_y.data_ptr<st>(),
245
+ sorted_indices.data_ptr<int>(),
246
+ bins.data_ptr<int>(),
247
+ weights_sorted.data_ptr<st>(),
248
+ y_expert.data_ptr<st>(),
249
+ grad_expert_output.data_ptr<st>(),
250
+ grad_selected_weights.data_ptr<st>(),
251
+ (int)T,
252
+ (int)K,
253
+ (int)H,
254
+ (int)E,
255
+ (int)C);
256
+ });
257
+ cudaDeviceSynchronize();
258
+ }
259
+
260
+ // Route weight gradients
261
+ torch::Tensor grad_selected_flat = torch::zeros({T * K}, float_opts);
262
+ grad_selected_flat.index_add_(0,
263
+ sorted_indices.to(torch::kLong),
264
+ grad_selected_weights);
265
+
266
+ if (routing_weights.size(1) == E) {
267
+ torch::Tensor flat_grad_routing = torch::zeros_like(routing_weights);
268
+ flat_grad_routing.scatter_add_(1,
269
+ flat_router.to(torch::kLong),
270
+ grad_selected_flat.view({T, K}));
271
+ dRouting = flat_grad_routing;
272
+ } else {
273
+ dRouting = grad_selected_flat.view({T, K});
274
+ }
275
+
276
+ // Backward through down projection
277
+ dbd = grad_expert_output.sum(1);
278
+ torch::Tensor grad_intermediate =
279
+ torch::bmm(grad_expert_output.contiguous(),
280
+ down_proj.transpose(1, 2).contiguous());
281
+ dWd = torch::bmm(up_out.transpose(1, 2).contiguous(),
282
+ grad_expert_output.contiguous());
283
+
284
+ // Backward through GLU
285
+ torch::Tensor grad_up_plus_1 = grad_intermediate * gate_act;
286
+ torch::Tensor grad_glu = grad_intermediate * (up_clamped + 1);
287
+ torch::Tensor grad_up_clamped = grad_up_plus_1;
288
+
289
+ torch::Tensor sigmoid_gate = torch::sigmoid(gate_clamped * alpha);
290
+ torch::Tensor grad_gate_clamped =
291
+ grad_glu *
292
+ (sigmoid_gate + gate_clamped * sigmoid_gate * (1 - sigmoid_gate) * alpha);
293
+
294
+ // Unclamp gradients
295
+ torch::Tensor grad_gate = grad_gate_clamped.clone();
296
+ grad_gate.masked_fill_(pre_gate > limit, 0);
297
+ torch::Tensor grad_up = grad_up_clamped.clone();
298
+ grad_up.masked_fill_(pre_up > limit, 0);
299
+ grad_up.masked_fill_(pre_up < -limit, 0);
300
+
301
+ // Merge gate/up gradients
302
+ torch::Tensor grad_gate_up_pair = torch::zeros({E, C, H, 2}, float_opts);
303
+ grad_gate_up_pair.select(3, 0).copy_(grad_gate);
304
+ grad_gate_up_pair.select(3, 1).copy_(grad_up);
305
+ torch::Tensor grad_gate_up = grad_gate_up_pair.view({E, C, 2 * H});
306
+
307
+ // Backward through gate-up projection
308
+ dbgu = grad_gate_up.sum(1);
309
+ torch::Tensor grad_x = torch::bmm(grad_gate_up.contiguous(),
310
+ gate_up_proj.transpose(1, 2).contiguous());
311
+ dWgu = torch::bmm(x.transpose(1, 2).contiguous(), grad_gate_up.contiguous());
312
+
313
+ // Backward through gather
314
+ torch::Tensor grad_hidden = torch::zeros({T, H}, float_opts);
315
+ {
316
+ dim3 grid(E, C);
317
+ int threads = 256;
318
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf,
319
+ at::kBFloat16,
320
+ hidden_states.scalar_type(),
321
+ "binned_gather_backward",
322
+ [&] {
323
+ using st = scalar_t;
324
+ binned_gather_backward_kernel<st>
325
+ <<<grid, threads>>>(
326
+ grad_x.data_ptr<st>(),
327
+ sorted_indices.data_ptr<int>(),
328
+ bins.data_ptr<int>(),
329
+ grad_hidden.data_ptr<st>(),
330
+ (int)T,
331
+ (int)K,
332
+ (int)H,
333
+ (int)E,
334
+ (int)C);
335
+ });
336
+ cudaDeviceSynchronize();
337
+ }
338
+ dHidden += grad_hidden;
339
+
340
+ return {dHidden, dRouting, dWgu, dbgu, dWd, dbd};
341
+ }
readme_example.py CHANGED
@@ -7,8 +7,7 @@
7
 
8
  import time
9
  import torch
10
- from kernels import get_local_kernel
11
- from kernels import get_kernel
12
  from pathlib import Path
13
  from torch.nn import functional as F
14
 
@@ -83,6 +82,8 @@ torch.cuda.synchronize()
83
  elapsed_ms = (time.perf_counter() - start) * 1e3
84
  peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
85
 
86
- print(f"Output: sum={output.sum().item():.1f}, min={output.min().item():.1f}, max={output.max().item():.1f}")
 
 
87
  print(f"First 3: {output.view(-1)[:3].tolist()}")
88
  print(f"Time: {elapsed_ms:.1f}ms, Memory: {peak_mem_mb:.0f}MB")
 
7
 
8
  import time
9
  import torch
10
+ from kernels import get_kernel, get_local_kernel
 
11
  from pathlib import Path
12
  from torch.nn import functional as F
13
 
 
82
  elapsed_ms = (time.perf_counter() - start) * 1e3
83
  peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
84
 
85
+ print(
86
+ f"Output: sum={output.sum().item():.1f}, min={output.min().item():.1f}, max={output.max().item():.1f}"
87
+ )
88
  print(f"First 3: {output.view(-1)[:3].tolist()}")
89
  print(f"Time: {elapsed_ms:.1f}ms, Memory: {peak_mem_mb:.0f}MB")
torch-ext/torch_binding.cpp CHANGED
@@ -67,6 +67,20 @@ TORCH_LIBRARY_EXPAND(
67
  "int num_experts, "
68
  "int top_k) -> Tensor");
69
  ops.impl("experts", torch::kCUDA, &experts_cuda);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  }
71
 
72
  REGISTER_EXTENSION(
 
67
  "int num_experts, "
68
  "int top_k) -> Tensor");
69
  ops.impl("experts", torch::kCUDA, &experts_cuda);
70
+
71
+ ops.def("experts_backward("
72
+ "Tensor grad_out, "
73
+ "Tensor hidden_states, "
74
+ "Tensor router_indices, "
75
+ "Tensor routing_weights, "
76
+ "Tensor gate_up_proj, "
77
+ "Tensor gate_up_proj_bias, "
78
+ "Tensor down_proj, "
79
+ "Tensor down_proj_bias, "
80
+ "int expert_capacity, "
81
+ "int num_experts, "
82
+ "int top_k) -> Tensor[]");
83
+ ops.impl("experts_backward", torch::kCUDA, &experts_backward_cuda);
84
  }
85
 
86
  REGISTER_EXTENSION(
torch-ext/torch_binding.h CHANGED
@@ -53,3 +53,19 @@ torch::Tensor experts_cuda(
53
  int64_t num_experts, // E - number of experts
54
  int64_t top_k // K - top-k routing
55
  );
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  int64_t num_experts, // E - number of experts
54
  int64_t top_k // K - top-k routing
55
  );
56
+
57
+ std::vector<torch::Tensor> experts_backward_cuda(
58
+ const torch::Tensor &grad_out, // [T, H] - gradient from output
59
+ const torch::Tensor &hidden_states, // [T, H] - original input
60
+ const torch::Tensor &router_indices, // [T, K] - expert indices per token
61
+ const torch::Tensor &routing_weights, // [T, K] or [T, E] - routing weights
62
+ const torch::Tensor
63
+ &gate_up_proj, // [E, H, 2*H] - gate/up projection weights
64
+ const torch::Tensor
65
+ &gate_up_proj_bias, // [E, 2*H] - gate/up projection bias
66
+ const torch::Tensor &down_proj, // [E, H, H] - down projection weights
67
+ const torch::Tensor &down_proj_bias, // [E, H] - down projection bias
68
+ int64_t expert_capacity, // C - capacity per expert
69
+ int64_t num_experts, // E - number of experts
70
+ int64_t top_k // K - top-k routing
71
+ );
torch-ext/yamoe/__init__.py CHANGED
@@ -1,5 +1,7 @@
1
  from ._ops import ops
2
- from . import reference
 
 
3
 
4
  gather = ops.gather
5
  scatter = ops.scatter
@@ -7,8 +9,14 @@ sort = ops.sort
7
  bincount_cumsum = ops.bincount_cumsum
8
  batch_mm = ops.batch_mm
9
  experts = ops.experts
 
10
 
11
  __all__ = [
 
 
 
 
 
12
  "shuffle",
13
  "gather",
14
  "scatter",
@@ -16,6 +24,9 @@ __all__ = [
16
  "bincount_cumsum",
17
  "batch_mm",
18
  "experts",
19
- # Export the reference implementation
 
20
  "reference",
 
 
21
  ]
 
1
  from ._ops import ops
2
+ from .layers import Yamoe
3
+ from .vendored import yamoe_ref
4
+ from .vendored import gpt_oss_mlp
5
 
6
  gather = ops.gather
7
  scatter = ops.scatter
 
9
  bincount_cumsum = ops.bincount_cumsum
10
  batch_mm = ops.batch_mm
11
  experts = ops.experts
12
+ experts_backward = ops.experts_backward
13
 
14
  __all__ = [
15
+ # Debug
16
+ "ops",
17
+ # Layer (nn module)
18
+ "Yamoe",
19
+ # Functions
20
  "shuffle",
21
  "gather",
22
  "scatter",
 
24
  "bincount_cumsum",
25
  "batch_mm",
26
  "experts",
27
+ "experts_backward",
28
+ # Vendored reference implementations
29
  "reference",
30
+ "yamoe_ref",
31
+ "gpt_oss_mlp",
32
  ]
torch-ext/yamoe/layers.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ._ops import ops
3
+
4
+
5
+ class Yamoe(torch.nn.Module):
6
+ """Yamoe MoE layer with routing and expert computation"""
7
+
8
+ can_torch_compile: bool = True
9
+
10
+ def __init__(self):
11
+ super().__init__()
12
+ # Pre-allocate buffers to avoid repeated allocations
13
+ self._routing_weights_buffer = None
14
+ self._batch_indices_buffer = None
15
+ self._last_batch_seq = None
16
+ self._last_num_experts = None
17
+
18
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
19
+ batch_size, seq_len, hidden_dim = hidden_states.shape
20
+ batch_seq = batch_size * seq_len
21
+
22
+ num_experts = getattr(self, "num_experts", 128)
23
+ top_k = getattr(self, "top_k", 4)
24
+
25
+ # Route tokens to experts
26
+ x_flat = hidden_states.view(-1, hidden_dim)
27
+ logits = torch.nn.functional.linear(
28
+ x_flat, self.router.weight, self.router.bias
29
+ )
30
+
31
+ # Compute top-k
32
+ if top_k == 1:
33
+ routing_weights, router_indices = logits.max(dim=-1, keepdim=True)
34
+ else:
35
+ routing_weights, router_indices = torch.topk(logits, top_k, dim=-1)
36
+
37
+ routing_weights = routing_weights.softmax(dim=-1)
38
+
39
+ # Create router scores
40
+ router_scores = (
41
+ torch.zeros_like(logits)
42
+ .scatter_(1, router_indices, routing_weights)
43
+ .transpose(0, 1)
44
+ )
45
+
46
+ # Convert routing_weights to sparse format [batch_seq, num_experts]
47
+ # Reuse buffer if possible to reduce allocations
48
+ if (
49
+ self._routing_weights_buffer is None
50
+ or self._last_batch_seq != batch_seq
51
+ or self._last_num_experts != num_experts
52
+ or self._routing_weights_buffer.device != routing_weights.device
53
+ ):
54
+ self._routing_weights_buffer = torch.zeros(
55
+ batch_seq,
56
+ num_experts,
57
+ device=routing_weights.device,
58
+ dtype=routing_weights.dtype,
59
+ )
60
+ self._batch_indices_buffer = (
61
+ torch.arange(batch_seq, device=routing_weights.device)
62
+ .unsqueeze(1)
63
+ .expand(-1, top_k)
64
+ )
65
+ self._last_batch_seq = batch_seq
66
+ self._last_num_experts = num_experts
67
+ else:
68
+ self._routing_weights_buffer.zero_()
69
+
70
+ # Fill sparse routing weights
71
+ flat_indices = router_indices.view(batch_seq, top_k)
72
+ flat_weights = routing_weights.view(batch_seq, top_k)
73
+ self._routing_weights_buffer[self._batch_indices_buffer, flat_indices] = (
74
+ flat_weights
75
+ )
76
+
77
+ # FIX: Use the correct expert projections
78
+ gate_up = self.experts.gate_up_proj[:, :, : hidden_dim * top_k].contiguous()
79
+ gate_up_bias = self.experts.gate_up_proj_bias[
80
+ :, : hidden_dim * top_k
81
+ ].contiguous()
82
+
83
+ down_proj = self.experts.down_proj[:, :hidden_dim, :].contiguous()
84
+
85
+ expert_capacity = batch_seq * top_k // num_experts * 2
86
+
87
+ with torch.no_grad():
88
+ # Compute expert output
89
+ output = ops.experts(
90
+ hidden_states.view(-1, hidden_dim),
91
+ router_indices,
92
+ self._routing_weights_buffer,
93
+ gate_up,
94
+ gate_up_bias,
95
+ down_proj,
96
+ self.experts.down_proj_bias,
97
+ expert_capacity,
98
+ num_experts,
99
+ top_k,
100
+ )
101
+
102
+ # Reshape output back to [B, S, H]
103
+ output = output.view(batch_size, seq_len, hidden_dim)
104
+ return output, router_scores
torch-ext/yamoe/{reference.py → vendored/gpt_oss_mlp.py} RENAMED
@@ -1,5 +1,14 @@
1
  import torch
2
- import torch.nn as nn
 
 
 
 
 
 
 
 
 
3
 
4
  class GptOssExperts(nn.Module):
5
  def __init__(self, config):
@@ -8,16 +17,26 @@ class GptOssExperts(nn.Module):
8
  self.num_experts = config.num_local_experts
9
  self.hidden_size = config.hidden_size
10
  self.expert_dim = self.intermediate_size
11
- self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
12
- self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
13
- self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
14
- self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
 
 
 
 
 
 
 
 
15
  self.alpha = 1.702
16
  self.limit = 7.0
17
 
18
- def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
 
 
19
  """
20
- When training is is more efficient to just loop over the experts and compute the output for each expert
21
  as otherwise the memory would explode.
22
 
23
  For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
@@ -29,45 +48,112 @@ class GptOssExperts(nn.Module):
29
  Returns:
30
  torch.Tensor
31
  """
32
-
33
- # import ipdb; ipdb.set_trace()
34
-
35
  batch_size = hidden_states.shape[0]
36
- hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
 
 
37
  num_experts = routing_weights.shape[1]
38
- if self.training:
39
- next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
 
 
 
40
  with torch.no_grad():
41
- expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
 
 
42
  expert_mask = expert_mask.permute(2, 1, 0)
43
- # we sum on the top_k and on the sequence lenght to get which experts
44
  # are hit this time around
45
- expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
46
- for expert_idx in expert_hitted[:]:
 
 
47
  with torch.no_grad():
48
- _, token_idx = torch.where(expert_mask[expert_idx[0]])
49
  current_state = hidden_states[token_idx]
50
- gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
 
 
 
51
  gate, up = gate_up[..., ::2], gate_up[..., 1::2]
52
  gate = gate.clamp(min=None, max=self.limit)
53
  up = up.clamp(min=-self.limit, max=self.limit)
54
  glu = gate * torch.sigmoid(gate * self.alpha)
55
  gated_output = (up + 1) * glu
56
- out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
57
- weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
58
- next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
 
 
 
 
 
59
  next_states = next_states.view(batch_size, -1, self.hidden_size)
60
  else:
61
  hidden_states = hidden_states.repeat(num_experts, 1)
62
  hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
63
- gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
 
 
 
64
  gate, up = gate_up[..., ::2], gate_up[..., 1::2]
65
  gate = gate.clamp(min=None, max=self.limit)
66
  up = up.clamp(min=-self.limit, max=self.limit)
67
  glu = gate * torch.sigmoid(gate * self.alpha)
68
  next_states = torch.bmm(((up + 1) * glu), self.down_proj)
69
  next_states = next_states + self.down_proj_bias[..., None, :]
70
- next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
71
- next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
 
 
 
 
 
 
 
72
  next_states = next_states.sum(dim=0)
73
  return next_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def info(tensor, name):
7
+ print(name)
8
+ print(tensor.shape)
9
+ print(tensor.cpu())
10
+ print()
11
+
12
 
13
  class GptOssExperts(nn.Module):
14
  def __init__(self, config):
 
17
  self.num_experts = config.num_local_experts
18
  self.hidden_size = config.hidden_size
19
  self.expert_dim = self.intermediate_size
20
+ self.gate_up_proj = nn.Parameter(
21
+ torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)
22
+ )
23
+ self.gate_up_proj_bias = nn.Parameter(
24
+ torch.empty(self.num_experts, 2 * self.expert_dim)
25
+ )
26
+ self.down_proj = nn.Parameter(
27
+ torch.empty((self.num_experts, self.expert_dim, self.hidden_size))
28
+ )
29
+ self.down_proj_bias = nn.Parameter(
30
+ torch.empty(self.num_experts, self.hidden_size)
31
+ )
32
  self.alpha = 1.702
33
  self.limit = 7.0
34
 
35
+ def forward(
36
+ self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None
37
+ ) -> torch.Tensor:
38
  """
39
+ When training it is more efficient to just loop over the experts and compute the output for each expert
40
  as otherwise the memory would explode.
41
 
42
  For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
 
48
  Returns:
49
  torch.Tensor
50
  """
 
 
 
51
  batch_size = hidden_states.shape[0]
52
+ hidden_states = hidden_states.reshape(
53
+ -1, self.hidden_size
54
+ ) # (num_tokens, hidden_size)
55
  num_experts = routing_weights.shape[1]
56
+
57
+ if hidden_states.device.type == "cpu" or self.training:
58
+ next_states = torch.zeros_like(
59
+ hidden_states, dtype=hidden_states.dtype, device=hidden_states.device
60
+ )
61
  with torch.no_grad():
62
+ expert_mask = torch.nn.functional.one_hot(
63
+ router_indices, num_classes=num_experts
64
+ )
65
  expert_mask = expert_mask.permute(2, 1, 0)
66
+ # we sum on the top_k and on the sequence length to get which experts
67
  # are hit this time around
68
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
69
+ for expert_idx in expert_hit[:]:
70
+ # expert_idx only have 1 element, so we can use scale for fast indexing
71
+ expert_idx = expert_idx[0]
72
  with torch.no_grad():
73
+ _, token_idx = torch.where(expert_mask[expert_idx])
74
  current_state = hidden_states[token_idx]
75
+ gate_up = (
76
+ current_state @ self.gate_up_proj[expert_idx]
77
+ + self.gate_up_proj_bias[expert_idx]
78
+ )
79
  gate, up = gate_up[..., ::2], gate_up[..., 1::2]
80
  gate = gate.clamp(min=None, max=self.limit)
81
  up = up.clamp(min=-self.limit, max=self.limit)
82
  glu = gate * torch.sigmoid(gate * self.alpha)
83
  gated_output = (up + 1) * glu
84
+ out = (
85
+ gated_output @ self.down_proj[expert_idx]
86
+ + self.down_proj_bias[expert_idx]
87
+ )
88
+ weighted_output = out * routing_weights[token_idx, expert_idx, None]
89
+ next_states.index_add_(
90
+ 0, token_idx, weighted_output.to(hidden_states.dtype)
91
+ )
92
  next_states = next_states.view(batch_size, -1, self.hidden_size)
93
  else:
94
  hidden_states = hidden_states.repeat(num_experts, 1)
95
  hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
96
+ gate_up = (
97
+ torch.bmm(hidden_states, self.gate_up_proj)
98
+ + self.gate_up_proj_bias[..., None, :]
99
+ )
100
  gate, up = gate_up[..., ::2], gate_up[..., 1::2]
101
  gate = gate.clamp(min=None, max=self.limit)
102
  up = up.clamp(min=-self.limit, max=self.limit)
103
  glu = gate * torch.sigmoid(gate * self.alpha)
104
  next_states = torch.bmm(((up + 1) * glu), self.down_proj)
105
  next_states = next_states + self.down_proj_bias[..., None, :]
106
+ next_states = next_states.view(
107
+ num_experts, batch_size, -1, self.hidden_size
108
+ )
109
+ next_states = (
110
+ next_states
111
+ * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[
112
+ ..., None
113
+ ]
114
+ )
115
  next_states = next_states.sum(dim=0)
116
  return next_states
117
+
118
+
119
+ class GptOssTopKRouter(nn.Module):
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ self.top_k = config.num_experts_per_tok
123
+ self.num_experts = config.num_local_experts
124
+ self.hidden_dim = config.hidden_size
125
+ self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
126
+ self.bias = nn.Parameter(torch.empty(self.num_experts))
127
+
128
+ def forward(self, hidden_states):
129
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
130
+ router_logits = F.linear(
131
+ hidden_states, self.weight, self.bias
132
+ ) # (seq_len, num_experts)
133
+ router_top_value, router_indices = torch.topk(
134
+ router_logits, self.top_k, dim=-1
135
+ ) # (seq_len, top_k)
136
+ router_top_value = torch.nn.functional.softmax(
137
+ router_top_value, dim=1, dtype=router_top_value.dtype
138
+ )
139
+ router_scores = torch.zeros_like(router_logits).scatter_(
140
+ 1, router_indices, router_top_value
141
+ )
142
+ return router_scores, router_indices
143
+
144
+
145
+ # @use_kernel_forward_from_hub("MegaBlocksMoeMLP")
146
+ class GptOssMLP(nn.Module):
147
+ def __init__(self, config):
148
+ super().__init__()
149
+ self.router = GptOssTopKRouter(config)
150
+ self.experts = GptOssExperts(config)
151
+
152
+ def forward(self, hidden_states):
153
+ router_scores, router_indices = self.router(
154
+ hidden_states
155
+ ) # (num_experts, seq_len)
156
+ routed_out = self.experts(
157
+ hidden_states, router_indices=router_indices, routing_weights=router_scores
158
+ )
159
+ return routed_out, router_scores
torch-ext/yamoe/vendored/yamoe_ref.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def binned_gather(x, indices, bins, expert_capacity, top_k):
5
+ E, H = bins.shape[0], x.shape[1]
6
+ out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
7
+ for e in range(E):
8
+ start = 0 if e == 0 else bins[e - 1]
9
+ end = bins[e]
10
+ n = min(end - start, expert_capacity)
11
+ for i in range(n):
12
+ flat_pos = indices[start + i]
13
+ tok = flat_pos // top_k
14
+ out[e, i] = x[tok]
15
+ return out
16
+
17
+
18
+ def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
19
+ E, C, H = x.shape
20
+ N = indices.shape[0] // top_k
21
+ out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
22
+ for e in range(E):
23
+ start = 0 if e == 0 else bins[e - 1]
24
+ end = bins[e]
25
+ n = end - start
26
+ if n == 0:
27
+ continue
28
+ take = min(n, expert_capacity)
29
+ for i in range(take):
30
+ flat_pos = indices[start + i] # flattened (token, slot)
31
+ tok = flat_pos // top_k
32
+ slot = flat_pos % top_k
33
+ scale = weights[flat_pos] if weights is not None else 1.0
34
+ out[tok, slot] = x[e, i] * scale
35
+ return out.sum(dim=1)
36
+
37
+
38
+ def sort_tokens_by_expert(router_indices, num_experts):
39
+ flat_indices = router_indices.flatten()
40
+ sorted_values, sorted_indices = torch.sort(flat_indices)
41
+ tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
42
+ bins = torch.cumsum(tokens_per_expert, dim=0)
43
+ return sorted_indices, sorted_values, bins, tokens_per_expert
44
+
45
+
46
+ def binned_experts_ref(
47
+ hidden_states,
48
+ router_indices,
49
+ routing_weights,
50
+ gate_up_proj,
51
+ gate_up_proj_bias,
52
+ down_proj,
53
+ down_proj_bias,
54
+ expert_capacity,
55
+ ):
56
+ B, S, H = hidden_states.shape
57
+ E, K = routing_weights.shape[1], router_indices.shape[1]
58
+
59
+ indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
60
+ x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
61
+
62
+ gate_up = torch.bmm(x, gate_up_proj) + gate_up_proj_bias[..., None, :]
63
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
64
+
65
+ # clamp to limit
66
+ limit = 7.0
67
+ gate = gate.clamp(min=None, max=limit)
68
+ up = up.clamp(min=-limit, max=limit)
69
+
70
+ glu = gate * torch.sigmoid(gate * 1.702)
71
+ x = (up + 1) * glu
72
+ x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
73
+
74
+ # build routing weights aligned to (token, slot)
75
+ flat_dense = routing_weights.view(-1, E) # [B*S, E]
76
+ flat_router = router_indices.view(-1, K) # [B*S, K]
77
+ selected = torch.gather(flat_dense, 1, flat_router).reshape(-1) # [B*S*K]
78
+
79
+ # scatter back
80
+ y = binned_scatter(x, indices, selected, bins, expert_capacity, K) # [B*S, H]
81
+
82
+ return y.view(B, S, H)