drbh
commited on
Commit
·
733f7f4
1
Parent(s):
0daa7ef
feat: impl backward experts
Browse files- .gitignore +5 -1
- build.toml +2 -1
- compare_example.py +147 -46
- csrc/experts_backward.cu +341 -0
- readme_example.py +4 -3
- torch-ext/torch_binding.cpp +14 -0
- torch-ext/torch_binding.h +16 -0
- torch-ext/yamoe/__init__.py +13 -2
- torch-ext/yamoe/layers.py +104 -0
- torch-ext/yamoe/{reference.py → vendored/gpt_oss_mlp.py} +111 -25
- torch-ext/yamoe/vendored/yamoe_ref.py +82 -0
.gitignore
CHANGED
|
@@ -12,4 +12,8 @@ tests
|
|
| 12 |
torch-ext/registration.h
|
| 13 |
torch-ext/yamoe/_ops.py
|
| 14 |
csrc/batch_mm.cu
|
| 15 |
-
torch-ext/yamoe/*.abi3.so
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
torch-ext/registration.h
|
| 13 |
torch-ext/yamoe/_ops.py
|
| 14 |
csrc/batch_mm.cu
|
| 15 |
+
torch-ext/yamoe/*.abi3.so
|
| 16 |
+
|
| 17 |
+
build-ext
|
| 18 |
+
build
|
| 19 |
+
exploration
|
build.toml
CHANGED
|
@@ -32,5 +32,6 @@ src = [
|
|
| 32 |
"csrc/sort.cu",
|
| 33 |
"csrc/bincount_cumsum.cu",
|
| 34 |
"csrc/batch_mm.cu",
|
| 35 |
-
"csrc/moe.cpp"
|
|
|
|
| 36 |
]
|
|
|
|
| 32 |
"csrc/sort.cu",
|
| 33 |
"csrc/bincount_cumsum.cu",
|
| 34 |
"csrc/batch_mm.cu",
|
| 35 |
+
"csrc/moe.cpp",
|
| 36 |
+
"csrc/experts_backward.cu"
|
| 37 |
]
|
compare_example.py
CHANGED
|
@@ -7,10 +7,11 @@
|
|
| 7 |
|
| 8 |
import time
|
| 9 |
import torch
|
| 10 |
-
from kernels import get_local_kernel
|
| 11 |
-
from kernels import get_kernel
|
| 12 |
from pathlib import Path
|
| 13 |
from torch.nn import functional as F
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Set seeds and deterministic flags for reproducibility
|
| 16 |
torch.manual_seed(42)
|
|
@@ -19,11 +20,23 @@ torch.cuda.manual_seed_all(42)
|
|
| 19 |
torch.backends.cudnn.deterministic = True
|
| 20 |
torch.backends.cudnn.benchmark = False
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# Configuration
|
| 25 |
batch_size, seq_len, hidden_dim = 4, 1024, 2880
|
| 26 |
-
# batch_size, seq_len, hidden_dim = 4, 32, 1024
|
| 27 |
num_experts, top_k = 8, 2
|
| 28 |
|
| 29 |
# Create routing weights
|
|
@@ -52,6 +65,7 @@ torch.nn.init.trunc_normal_(gate_up_proj, std=0.02)
|
|
| 52 |
torch.nn.init.trunc_normal_(down_proj, std=0.02)
|
| 53 |
|
| 54 |
routing_weights = routing_weights.to(dtype=torch.float32, device="cuda")
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
# Warmup
|
|
@@ -64,7 +78,7 @@ for _ in range(5):
|
|
| 64 |
gate_up_proj_bias,
|
| 65 |
down_proj,
|
| 66 |
down_proj_bias,
|
| 67 |
-
|
| 68 |
num_experts,
|
| 69 |
top_k,
|
| 70 |
)
|
|
@@ -74,6 +88,7 @@ torch.cuda.synchronize()
|
|
| 74 |
torch.cuda.reset_peak_memory_stats()
|
| 75 |
start = time.perf_counter()
|
| 76 |
|
|
|
|
| 77 |
with torch.no_grad():
|
| 78 |
output = yamoe.experts(
|
| 79 |
hidden_states.view(-1, hidden_dim),
|
|
@@ -83,7 +98,7 @@ with torch.no_grad():
|
|
| 83 |
gate_up_proj_bias,
|
| 84 |
down_proj,
|
| 85 |
down_proj_bias,
|
| 86 |
-
|
| 87 |
num_experts,
|
| 88 |
top_k,
|
| 89 |
)
|
|
@@ -104,7 +119,7 @@ config.hidden_size = hidden_dim
|
|
| 104 |
config.intermediate_size = 4 * hidden_dim
|
| 105 |
config.num_local_experts = num_experts
|
| 106 |
|
| 107 |
-
model =
|
| 108 |
|
| 109 |
# set the weights and biases from above to the reference model
|
| 110 |
model.gate_up_proj.data = gate_up_proj
|
|
@@ -133,79 +148,165 @@ ref_memory = peak_mem_mb
|
|
| 133 |
# Reshape reference output to match kernel output
|
| 134 |
ref_output_reshaped = ref_output.view(kernel_output.shape)
|
| 135 |
|
| 136 |
-
#
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
# Cosine similarity
|
| 141 |
kernel_flat = kernel_output.view(-1)
|
| 142 |
ref_flat = ref_output_reshaped.view(-1)
|
| 143 |
-
|
|
|
|
| 144 |
kernel_flat.unsqueeze(0), ref_flat.unsqueeze(0)
|
| 145 |
).item()
|
| 146 |
|
| 147 |
# Relative error (L2 norm of difference / L2 norm of reference)
|
| 148 |
-
|
| 149 |
ref_norm = torch.norm(ref_output_reshaped).item()
|
| 150 |
-
|
| 151 |
|
| 152 |
# Max absolute difference
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
# Print comparison table
|
| 156 |
-
print("\n" + "=" *
|
| 157 |
-
print(
|
| 158 |
-
|
|
|
|
|
|
|
| 159 |
|
| 160 |
print(
|
| 161 |
-
f"{'Sum':<20} {kernel_output.sum().item():<15.4f} {ref_output_reshaped.sum().item():<15.4f} {'N/A':<15}"
|
| 162 |
)
|
| 163 |
print(
|
| 164 |
-
f"{'Min':<20} {kernel_output.min().item():<15.4f} {ref_output_reshaped.min().item():<15.4f} {'N/A':<15}"
|
| 165 |
)
|
| 166 |
print(
|
| 167 |
-
f"{'Max':<20} {kernel_output.max().item():<15.4f} {ref_output_reshaped.max().item():<15.4f} {'N/A':<15}"
|
| 168 |
)
|
| 169 |
print(
|
| 170 |
-
f"{'Norm (L2)':<20} {kernel_output.norm().item():<15.4f} {ref_output_reshaped.norm().item():<15.4f} {'N/A':<15}"
|
| 171 |
)
|
| 172 |
print(
|
| 173 |
-
f"{'Std':<20} {kernel_output.std().item():<15.4f} {ref_output_reshaped.std().item():<15.4f} {'N/A':<15}"
|
| 174 |
)
|
| 175 |
|
| 176 |
-
print("-" *
|
| 177 |
print(
|
| 178 |
-
f"{'Time (ms)':<20} {kernel_time:<15.3f} {ref_time:<15.3f} {
|
| 179 |
)
|
| 180 |
print(
|
| 181 |
-
f"{'Memory (MB)':<20} {kernel_memory:<15.2f} {ref_memory:<15.2f} {
|
| 182 |
)
|
| 183 |
|
| 184 |
-
print("-" *
|
| 185 |
-
print("SIMILARITY METRICS")
|
| 186 |
-
print("-" *
|
| 187 |
-
print(
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
print(
|
| 191 |
-
print(
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
-
print("-" *
|
| 196 |
print("FIRST 10 ELEMENTS COMPARISON")
|
| 197 |
-
print("-" *
|
| 198 |
|
| 199 |
-
# Get first 10 elements as numpy arrays for nice display
|
| 200 |
-
kernel_first_10 = kernel_flat[:10].cpu().numpy()
|
| 201 |
-
ref_first_10 = ref_flat[:10].cpu().numpy()
|
| 202 |
-
diff_first_10 = kernel_first_10 - ref_first_10
|
| 203 |
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
print(
|
| 208 |
-
f"{i:<5} {kernel_first_10[i]:<12.6f} {ref_first_10[i]:<12.6f} {
|
| 209 |
)
|
| 210 |
|
| 211 |
-
print("=" *
|
|
|
|
| 7 |
|
| 8 |
import time
|
| 9 |
import torch
|
| 10 |
+
from kernels import get_kernel, get_local_kernel
|
|
|
|
| 11 |
from pathlib import Path
|
| 12 |
from torch.nn import functional as F
|
| 13 |
+
import numpy as np
|
| 14 |
+
import sys
|
| 15 |
|
| 16 |
# Set seeds and deterministic flags for reproducibility
|
| 17 |
torch.manual_seed(42)
|
|
|
|
| 20 |
torch.backends.cudnn.deterministic = True
|
| 21 |
torch.backends.cudnn.benchmark = False
|
| 22 |
|
| 23 |
+
np.set_printoptions(precision=4)
|
| 24 |
+
|
| 25 |
+
load_method = 2 # 1: sym, 2: local, 3: hf
|
| 26 |
+
|
| 27 |
+
if load_method == 1:
|
| 28 |
+
sys.path.insert(0, "./torch-ext")
|
| 29 |
+
import yamoe
|
| 30 |
+
elif load_method == 2:
|
| 31 |
+
yamoe = get_local_kernel(Path("result"), "yamoe")
|
| 32 |
+
elif load_method == 3:
|
| 33 |
+
yamoe = get_kernel("drbh/yamoe", revision="v0.1.0")
|
| 34 |
+
|
| 35 |
+
binned_experts_ref = yamoe.vendored.yamoe_ref.binned_experts_ref
|
| 36 |
+
GptOssExperts = yamoe.vendored.gpt_oss_mlp.GptOssExperts
|
| 37 |
|
| 38 |
# Configuration
|
| 39 |
batch_size, seq_len, hidden_dim = 4, 1024, 2880
|
|
|
|
| 40 |
num_experts, top_k = 8, 2
|
| 41 |
|
| 42 |
# Create routing weights
|
|
|
|
| 65 |
torch.nn.init.trunc_normal_(down_proj, std=0.02)
|
| 66 |
|
| 67 |
routing_weights = routing_weights.to(dtype=torch.float32, device="cuda")
|
| 68 |
+
expert_capacity = batch_seq * top_k // num_experts * 2
|
| 69 |
|
| 70 |
|
| 71 |
# Warmup
|
|
|
|
| 78 |
gate_up_proj_bias,
|
| 79 |
down_proj,
|
| 80 |
down_proj_bias,
|
| 81 |
+
expert_capacity,
|
| 82 |
num_experts,
|
| 83 |
top_k,
|
| 84 |
)
|
|
|
|
| 88 |
torch.cuda.reset_peak_memory_stats()
|
| 89 |
start = time.perf_counter()
|
| 90 |
|
| 91 |
+
|
| 92 |
with torch.no_grad():
|
| 93 |
output = yamoe.experts(
|
| 94 |
hidden_states.view(-1, hidden_dim),
|
|
|
|
| 98 |
gate_up_proj_bias,
|
| 99 |
down_proj,
|
| 100 |
down_proj_bias,
|
| 101 |
+
expert_capacity,
|
| 102 |
num_experts,
|
| 103 |
top_k,
|
| 104 |
)
|
|
|
|
| 119 |
config.intermediate_size = 4 * hidden_dim
|
| 120 |
config.num_local_experts = num_experts
|
| 121 |
|
| 122 |
+
model = GptOssExperts(config)
|
| 123 |
|
| 124 |
# set the weights and biases from above to the reference model
|
| 125 |
model.gate_up_proj.data = gate_up_proj
|
|
|
|
| 148 |
# Reshape reference output to match kernel output
|
| 149 |
ref_output_reshaped = ref_output.view(kernel_output.shape)
|
| 150 |
|
| 151 |
+
# Test yamoe_ref implementation
|
| 152 |
+
expert_capacity = batch_seq * top_k // num_experts * 2 # Generous capacity
|
| 153 |
+
|
| 154 |
+
torch.cuda.synchronize()
|
| 155 |
+
torch.cuda.reset_peak_memory_stats()
|
| 156 |
+
start = time.perf_counter()
|
| 157 |
+
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
yamoe_ref_output = binned_experts_ref(
|
| 160 |
+
hidden_states,
|
| 161 |
+
router_indices,
|
| 162 |
+
routing_weights,
|
| 163 |
+
gate_up_proj,
|
| 164 |
+
gate_up_proj_bias,
|
| 165 |
+
down_proj,
|
| 166 |
+
down_proj_bias,
|
| 167 |
+
expert_capacity,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
torch.cuda.synchronize()
|
| 171 |
+
yamoe_ref_time = (time.perf_counter() - start) * 1e3
|
| 172 |
+
yamoe_ref_memory = torch.cuda.max_memory_allocated() / (1024 * 1024)
|
| 173 |
+
|
| 174 |
+
# Reshape yamoe_ref output to match kernel output
|
| 175 |
+
yamoe_ref_output_reshaped = yamoe_ref_output.view(kernel_output.shape)
|
| 176 |
+
|
| 177 |
+
# Calculate similarity metrics between kernel and reference
|
| 178 |
+
mse_kernel_ref = torch.nn.functional.mse_loss(kernel_output, ref_output_reshaped).item()
|
| 179 |
+
mae_kernel_ref = torch.nn.functional.l1_loss(kernel_output, ref_output_reshaped).item()
|
| 180 |
|
| 181 |
# Cosine similarity
|
| 182 |
kernel_flat = kernel_output.view(-1)
|
| 183 |
ref_flat = ref_output_reshaped.view(-1)
|
| 184 |
+
yamoe_ref_flat = yamoe_ref_output_reshaped.view(-1)
|
| 185 |
+
cosine_sim_kernel_ref = torch.nn.functional.cosine_similarity(
|
| 186 |
kernel_flat.unsqueeze(0), ref_flat.unsqueeze(0)
|
| 187 |
).item()
|
| 188 |
|
| 189 |
# Relative error (L2 norm of difference / L2 norm of reference)
|
| 190 |
+
diff_norm_kernel_ref = torch.norm(kernel_output - ref_output_reshaped).item()
|
| 191 |
ref_norm = torch.norm(ref_output_reshaped).item()
|
| 192 |
+
rel_error_kernel_ref = diff_norm_kernel_ref / ref_norm if ref_norm > 0 else float("inf")
|
| 193 |
|
| 194 |
# Max absolute difference
|
| 195 |
+
max_abs_diff_kernel_ref = torch.max(
|
| 196 |
+
torch.abs(kernel_output - ref_output_reshaped)
|
| 197 |
+
).item()
|
| 198 |
+
|
| 199 |
+
# Calculate similarity metrics between kernel and yamoe_ref
|
| 200 |
+
mse_kernel_yamoe = torch.nn.functional.mse_loss(
|
| 201 |
+
kernel_output, yamoe_ref_output_reshaped
|
| 202 |
+
).item()
|
| 203 |
+
mae_kernel_yamoe = torch.nn.functional.l1_loss(
|
| 204 |
+
kernel_output, yamoe_ref_output_reshaped
|
| 205 |
+
).item()
|
| 206 |
+
cosine_sim_kernel_yamoe = torch.nn.functional.cosine_similarity(
|
| 207 |
+
kernel_flat.unsqueeze(0), yamoe_ref_flat.unsqueeze(0)
|
| 208 |
+
).item()
|
| 209 |
+
diff_norm_kernel_yamoe = torch.norm(kernel_output - yamoe_ref_output_reshaped).item()
|
| 210 |
+
yamoe_ref_norm = torch.norm(yamoe_ref_output_reshaped).item()
|
| 211 |
+
rel_error_kernel_yamoe = (
|
| 212 |
+
diff_norm_kernel_yamoe / yamoe_ref_norm if yamoe_ref_norm > 0 else float("inf")
|
| 213 |
+
)
|
| 214 |
+
max_abs_diff_kernel_yamoe = torch.max(
|
| 215 |
+
torch.abs(kernel_output - yamoe_ref_output_reshaped)
|
| 216 |
+
).item()
|
| 217 |
+
|
| 218 |
+
# Calculate similarity metrics between reference and yamoe_ref
|
| 219 |
+
mse_ref_yamoe = torch.nn.functional.mse_loss(
|
| 220 |
+
ref_output_reshaped, yamoe_ref_output_reshaped
|
| 221 |
+
).item()
|
| 222 |
+
mae_ref_yamoe = torch.nn.functional.l1_loss(
|
| 223 |
+
ref_output_reshaped, yamoe_ref_output_reshaped
|
| 224 |
+
).item()
|
| 225 |
+
cosine_sim_ref_yamoe = torch.nn.functional.cosine_similarity(
|
| 226 |
+
ref_flat.unsqueeze(0), yamoe_ref_flat.unsqueeze(0)
|
| 227 |
+
).item()
|
| 228 |
+
diff_norm_ref_yamoe = torch.norm(ref_output_reshaped - yamoe_ref_output_reshaped).item()
|
| 229 |
+
rel_error_ref_yamoe = (
|
| 230 |
+
diff_norm_ref_yamoe / yamoe_ref_norm if yamoe_ref_norm > 0 else float("inf")
|
| 231 |
+
)
|
| 232 |
+
max_abs_diff_ref_yamoe = torch.max(
|
| 233 |
+
torch.abs(ref_output_reshaped - yamoe_ref_output_reshaped)
|
| 234 |
+
).item()
|
| 235 |
|
| 236 |
# Print comparison table
|
| 237 |
+
print("\n" + "=" * 110)
|
| 238 |
+
print(
|
| 239 |
+
f"{'METRIC':<20} {'KERNEL':<15} {'REFERENCE':<15} {'YAMOE_REF':<15} {'KERNEL SPEEDUP':<20} {'REF SPEEDUP':<15}"
|
| 240 |
+
)
|
| 241 |
+
print("=" * 110)
|
| 242 |
|
| 243 |
print(
|
| 244 |
+
f"{'Sum':<20} {kernel_output.sum().item():<15.4f} {ref_output_reshaped.sum().item():<15.4f} {yamoe_ref_output_reshaped.sum().item():<15.4f} {'N/A':<20} {'N/A':<15}"
|
| 245 |
)
|
| 246 |
print(
|
| 247 |
+
f"{'Min':<20} {kernel_output.min().item():<15.4f} {ref_output_reshaped.min().item():<15.4f} {yamoe_ref_output_reshaped.min().item():<15.4f} {'N/A':<20} {'N/A':<15}"
|
| 248 |
)
|
| 249 |
print(
|
| 250 |
+
f"{'Max':<20} {kernel_output.max().item():<15.4f} {ref_output_reshaped.max().item():<15.4f} {yamoe_ref_output_reshaped.max().item():<15.4f} {'N/A':<20} {'N/A':<15}"
|
| 251 |
)
|
| 252 |
print(
|
| 253 |
+
f"{'Norm (L2)':<20} {kernel_output.norm().item():<15.4f} {ref_output_reshaped.norm().item():<15.4f} {yamoe_ref_output_reshaped.norm().item():<15.4f} {'N/A':<20} {'N/A':<15}"
|
| 254 |
)
|
| 255 |
print(
|
| 256 |
+
f"{'Std':<20} {kernel_output.std().item():<15.4f} {ref_output_reshaped.std().item():<15.4f} {yamoe_ref_output_reshaped.std().item():<15.4f} {'N/A':<20} {'N/A':<15}"
|
| 257 |
)
|
| 258 |
|
| 259 |
+
print("-" * 110)
|
| 260 |
print(
|
| 261 |
+
f"{'Time (ms)':<20} {kernel_time:<15.3f} {ref_time:<15.3f} {yamoe_ref_time:<15.3f} {yamoe_ref_time / kernel_time:<20.2f}x {yamoe_ref_time / ref_time:<15.2f}x"
|
| 262 |
)
|
| 263 |
print(
|
| 264 |
+
f"{'Memory (MB)':<20} {kernel_memory:<15.2f} {ref_memory:<15.2f} {yamoe_ref_memory:<15.2f} {yamoe_ref_memory / kernel_memory:<20.2f}x {yamoe_ref_memory / ref_memory:<15.2f}x"
|
| 265 |
)
|
| 266 |
|
| 267 |
+
print("-" * 110)
|
| 268 |
+
print("SIMILARITY METRICS (vs KERNEL)")
|
| 269 |
+
print("-" * 110)
|
| 270 |
+
print(
|
| 271 |
+
f"{'METRIC':<20} {'KERNEL vs REF':<20} {'KERNEL vs YAMOE_REF':<20} {'REF vs YAMOE_REF':<20}"
|
| 272 |
+
)
|
| 273 |
+
print("-" * 110)
|
| 274 |
+
print(
|
| 275 |
+
f"{'MSE':<20} {mse_kernel_ref:<20.6e} {mse_kernel_yamoe:<20.6e} {mse_ref_yamoe:<20.6e}"
|
| 276 |
+
)
|
| 277 |
+
print(
|
| 278 |
+
f"{'MAE':<20} {mae_kernel_ref:<20.6e} {mae_kernel_yamoe:<20.6e} {mae_ref_yamoe:<20.6e}"
|
| 279 |
+
)
|
| 280 |
+
print(
|
| 281 |
+
f"{'Cosine Similarity':<20} {cosine_sim_kernel_ref:<20.6f} {cosine_sim_kernel_yamoe:<20.6f} {cosine_sim_ref_yamoe:<20.6f}"
|
| 282 |
+
)
|
| 283 |
+
print(
|
| 284 |
+
f"{'Relative Error':<20} {rel_error_kernel_ref:<20.6e} {rel_error_kernel_yamoe:<20.6e} {rel_error_ref_yamoe:<20.6e}"
|
| 285 |
+
)
|
| 286 |
+
print(
|
| 287 |
+
f"{'Max Abs Diff':<20} {max_abs_diff_kernel_ref:<20.6e} {max_abs_diff_kernel_yamoe:<20.6e} {max_abs_diff_ref_yamoe:<20.6e}"
|
| 288 |
+
)
|
| 289 |
|
| 290 |
+
print("-" * 110)
|
| 291 |
print("FIRST 10 ELEMENTS COMPARISON")
|
| 292 |
+
print("-" * 110)
|
| 293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
+
# Get first N elements as numpy arrays for nice display
|
| 296 |
+
N = 10
|
| 297 |
+
kernel_first_10 = kernel_flat[:N].cpu().numpy()
|
| 298 |
+
ref_first_10 = ref_flat[:N].cpu().numpy()
|
| 299 |
+
yamoe_ref_first_10 = yamoe_ref_flat[:N].cpu().numpy()
|
| 300 |
+
diff_kernel_ref = kernel_first_10 - ref_first_10
|
| 301 |
+
diff_kernel_yamoe = kernel_first_10 - yamoe_ref_first_10
|
| 302 |
+
|
| 303 |
+
print(
|
| 304 |
+
f"{'INDEX':<5} {'KERNEL':<12} {'REFERENCE':<12} {'YAMOE_REF':<12} {'K-R DIFF':<12} {'K-Y DIFF':<12}"
|
| 305 |
+
)
|
| 306 |
+
print("-" * 70)
|
| 307 |
+
for i in range(N):
|
| 308 |
print(
|
| 309 |
+
f"{i:<5} {kernel_first_10[i]:<12.6f} {ref_first_10[i]:<12.6f} {yamoe_ref_first_10[i]:<12.6f} {diff_kernel_ref[i]:<12.6f} {diff_kernel_yamoe[i]:<12.6f}"
|
| 310 |
)
|
| 311 |
|
| 312 |
+
print("=" * 110)
|
csrc/experts_backward.cu
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Backward pass for MoE experts
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/Atomic.cuh>
|
| 4 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 5 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 6 |
+
#include <cuda_runtime.h>
|
| 7 |
+
#include <torch/torch.h>
|
| 8 |
+
|
| 9 |
+
void sort_cuda(torch::Tensor x,
|
| 10 |
+
int64_t end_bit,
|
| 11 |
+
torch::Tensor x_out,
|
| 12 |
+
torch::Tensor iota_out);
|
| 13 |
+
void bincount_cumsum_cuda(torch::Tensor input,
|
| 14 |
+
torch::Tensor &output,
|
| 15 |
+
int64_t minlength);
|
| 16 |
+
void gather_cuda(const torch::Tensor &x,
|
| 17 |
+
const torch::Tensor &indices,
|
| 18 |
+
const torch::Tensor &bins,
|
| 19 |
+
torch::Tensor &output,
|
| 20 |
+
int64_t E,
|
| 21 |
+
int64_t C,
|
| 22 |
+
int64_t top_k);
|
| 23 |
+
void scatter_cuda(const torch::Tensor &src,
|
| 24 |
+
const torch::Tensor &indices,
|
| 25 |
+
const torch::Tensor &bins,
|
| 26 |
+
const torch::Tensor &weights,
|
| 27 |
+
torch::Tensor &y,
|
| 28 |
+
int64_t T,
|
| 29 |
+
int64_t E,
|
| 30 |
+
int64_t C,
|
| 31 |
+
int64_t top_k);
|
| 32 |
+
torch::Tensor index_select_out_cuda(torch::Tensor out,
|
| 33 |
+
torch::Tensor in,
|
| 34 |
+
torch::Tensor idx_int32);
|
| 35 |
+
|
| 36 |
+
// scatter gradients back to expert outputs and routing weights
|
| 37 |
+
template <typename scalar_t>
|
| 38 |
+
__global__ void binned_scatter_backward_kernel(
|
| 39 |
+
const scalar_t *__restrict__ grad_y, // [T, H]
|
| 40 |
+
const int *__restrict__ indices, // [S]
|
| 41 |
+
const int *__restrict__ bins, // [E+1]
|
| 42 |
+
const scalar_t *__restrict__ selected_weights, // [S]
|
| 43 |
+
const scalar_t *__restrict__ expert_output, // [E, C, H]
|
| 44 |
+
scalar_t *__restrict__ grad_expert_output, // [E, C, H]
|
| 45 |
+
scalar_t *__restrict__ grad_selected_weights, // [S]
|
| 46 |
+
int T,
|
| 47 |
+
int K,
|
| 48 |
+
int H,
|
| 49 |
+
int E,
|
| 50 |
+
int C) {
|
| 51 |
+
|
| 52 |
+
int e = blockIdx.x;
|
| 53 |
+
int i = blockIdx.y;
|
| 54 |
+
if (e >= E || i >= C)
|
| 55 |
+
return;
|
| 56 |
+
|
| 57 |
+
const int start = (e == 0) ? 0 : bins[e - 1];
|
| 58 |
+
const int end = bins[e];
|
| 59 |
+
const int n_all = end - start;
|
| 60 |
+
const int take = (n_all > 0) ? min(n_all, C) : 0;
|
| 61 |
+
|
| 62 |
+
if (take == 0 || i >= take) {
|
| 63 |
+
scalar_t *dst = grad_expert_output + ((size_t)e * C + i) * H;
|
| 64 |
+
for (int h = threadIdx.x; h < H; h += blockDim.x)
|
| 65 |
+
dst[h] = scalar_t(0);
|
| 66 |
+
return;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
const int sorted_pos = start + i;
|
| 70 |
+
const int flat_pos = indices[sorted_pos];
|
| 71 |
+
const int tok = flat_pos / K;
|
| 72 |
+
|
| 73 |
+
const scalar_t scale = selected_weights[sorted_pos];
|
| 74 |
+
|
| 75 |
+
const scalar_t *grad_y_ptr = grad_y + (size_t)tok * H;
|
| 76 |
+
scalar_t *grad_exp_ptr = grad_expert_output + ((size_t)e * C + i) * H;
|
| 77 |
+
const scalar_t *expert_ptr = expert_output + ((size_t)e * C + i) * H;
|
| 78 |
+
|
| 79 |
+
for (int h = threadIdx.x; h < H; h += blockDim.x) {
|
| 80 |
+
grad_exp_ptr[h] += grad_y_ptr[h] * scale;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
if (threadIdx.x == 0) {
|
| 84 |
+
scalar_t sum = scalar_t(0);
|
| 85 |
+
for (int h = 0; h < H; ++h)
|
| 86 |
+
sum += grad_y_ptr[h] * expert_ptr[h];
|
| 87 |
+
gpuAtomicAdd(&grad_selected_weights[flat_pos], sum);
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
// gather gradients back to hidden states
|
| 92 |
+
template <typename scalar_t>
|
| 93 |
+
__global__ void binned_gather_backward_kernel(
|
| 94 |
+
const scalar_t *__restrict__ grad_x, // [E, C, H]
|
| 95 |
+
const int *__restrict__ indices, // [S]
|
| 96 |
+
const int *__restrict__ bins, // [E+1]
|
| 97 |
+
scalar_t *__restrict__ grad_hidden, // [T, H]
|
| 98 |
+
int T,
|
| 99 |
+
int K,
|
| 100 |
+
int H,
|
| 101 |
+
int E,
|
| 102 |
+
int C) {
|
| 103 |
+
|
| 104 |
+
int e = blockIdx.x;
|
| 105 |
+
int i = blockIdx.y;
|
| 106 |
+
if (e >= E || i >= C)
|
| 107 |
+
return;
|
| 108 |
+
|
| 109 |
+
const int start = (e == 0) ? 0 : bins[e - 1];
|
| 110 |
+
const int end = bins[e];
|
| 111 |
+
const int n = min(max(end - start, 0), C);
|
| 112 |
+
if (i >= n)
|
| 113 |
+
return;
|
| 114 |
+
|
| 115 |
+
const int flat_pos = indices[start + i];
|
| 116 |
+
const int tok = flat_pos / K;
|
| 117 |
+
|
| 118 |
+
const scalar_t *gx = grad_x + ((size_t)e * C + i) * H;
|
| 119 |
+
scalar_t *gh = grad_hidden + (size_t)tok * H;
|
| 120 |
+
|
| 121 |
+
for (int h = threadIdx.x; h < H; h += blockDim.x) {
|
| 122 |
+
gpuAtomicAdd(&gh[h], gx[h]);
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
std::vector<torch::Tensor> experts_backward_cuda(
|
| 127 |
+
const torch::Tensor &grad_out,
|
| 128 |
+
const torch::Tensor &hidden_states,
|
| 129 |
+
const torch::Tensor &router_indices,
|
| 130 |
+
const torch::Tensor &routing_weights,
|
| 131 |
+
const torch::Tensor &gate_up_proj,
|
| 132 |
+
const torch::Tensor &gate_up_proj_bias,
|
| 133 |
+
const torch::Tensor &down_proj,
|
| 134 |
+
const torch::Tensor &down_proj_bias,
|
| 135 |
+
int64_t expert_capacity,
|
| 136 |
+
int64_t num_experts,
|
| 137 |
+
int64_t top_k) {
|
| 138 |
+
TORCH_CHECK(grad_out.is_cuda(), "grad_out must be CUDA");
|
| 139 |
+
TORCH_CHECK(hidden_states.is_cuda(), "hidden_states must be CUDA");
|
| 140 |
+
TORCH_CHECK(router_indices.is_cuda(), "router_indices must be CUDA");
|
| 141 |
+
TORCH_CHECK(routing_weights.is_cuda(), "routing_weights must be CUDA");
|
| 142 |
+
TORCH_CHECK(gate_up_proj.is_cuda() && down_proj.is_cuda(),
|
| 143 |
+
"weights must be CUDA");
|
| 144 |
+
TORCH_CHECK(gate_up_proj_bias.is_cuda() && down_proj_bias.is_cuda(),
|
| 145 |
+
"biases must be CUDA");
|
| 146 |
+
|
| 147 |
+
const at::cuda::OptionalCUDAGuard device_guard(grad_out.device());
|
| 148 |
+
|
| 149 |
+
const int64_t T = hidden_states.size(0);
|
| 150 |
+
const int64_t H = hidden_states.size(1);
|
| 151 |
+
const int64_t E = num_experts;
|
| 152 |
+
const int64_t C = expert_capacity;
|
| 153 |
+
const int64_t K = top_k;
|
| 154 |
+
|
| 155 |
+
TORCH_CHECK(router_indices.dim() == 2 && router_indices.size(0) == T &&
|
| 156 |
+
router_indices.size(1) == K,
|
| 157 |
+
"router_indices must be [T, K]");
|
| 158 |
+
|
| 159 |
+
auto float_opts = hidden_states.options();
|
| 160 |
+
auto i32_opts = torch::TensorOptions()
|
| 161 |
+
.device(hidden_states.device())
|
| 162 |
+
.dtype(torch::kInt32);
|
| 163 |
+
|
| 164 |
+
// Sort tokens by expert ID
|
| 165 |
+
torch::Tensor flat_indices =
|
| 166 |
+
router_indices.contiguous().view({-1}).to(torch::kInt32);
|
| 167 |
+
torch::Tensor sorted_values = torch::empty_like(flat_indices);
|
| 168 |
+
torch::Tensor sorted_indices = torch::empty_like(flat_indices);
|
| 169 |
+
sort_cuda(flat_indices, 32, sorted_values, sorted_indices);
|
| 170 |
+
|
| 171 |
+
// Compute expert boundaries
|
| 172 |
+
torch::Tensor bins = torch::empty({E + 1}, i32_opts);
|
| 173 |
+
bincount_cumsum_cuda(sorted_values, bins, E);
|
| 174 |
+
cudaDeviceSynchronize();
|
| 175 |
+
|
| 176 |
+
// Gather tokens for each expert
|
| 177 |
+
torch::Tensor x = torch::empty({E, C, H}, float_opts);
|
| 178 |
+
gather_cuda(hidden_states.contiguous(), sorted_indices, bins, x, E, C, K);
|
| 179 |
+
|
| 180 |
+
// Gate-up projection
|
| 181 |
+
torch::Tensor gate_up = at::bmm(x.contiguous(), gate_up_proj.contiguous());
|
| 182 |
+
gate_up.add_(gate_up_proj_bias.unsqueeze(1));
|
| 183 |
+
|
| 184 |
+
// GLU activation (recompute forward)
|
| 185 |
+
auto gu_pair = gate_up.view({E, C, H, 2});
|
| 186 |
+
torch::Tensor pre_gate = gu_pair.select(3, 0);
|
| 187 |
+
torch::Tensor pre_up = gu_pair.select(3, 1);
|
| 188 |
+
|
| 189 |
+
const double limit = 7.0;
|
| 190 |
+
const double alpha = 1.702;
|
| 191 |
+
torch::Tensor gate_clamped = at::clamp_max(pre_gate, limit);
|
| 192 |
+
torch::Tensor up_clamped = at::clamp(pre_up, -limit, limit);
|
| 193 |
+
torch::Tensor s = at::sigmoid(gate_clamped * alpha);
|
| 194 |
+
torch::Tensor gate_act = gate_clamped * s;
|
| 195 |
+
torch::Tensor up_out = (1 + up_clamped) * gate_act;
|
| 196 |
+
|
| 197 |
+
// Down projection
|
| 198 |
+
torch::Tensor y_expert = at::bmm(up_out.contiguous(), down_proj.contiguous());
|
| 199 |
+
y_expert.add_(down_proj_bias.unsqueeze(1));
|
| 200 |
+
|
| 201 |
+
// Get routing weights in sorted order
|
| 202 |
+
torch::Tensor flat_router = router_indices.view({T, K});
|
| 203 |
+
torch::Tensor selected_2d;
|
| 204 |
+
if (routing_weights.size(1) == K) {
|
| 205 |
+
selected_2d = routing_weights.contiguous();
|
| 206 |
+
} else {
|
| 207 |
+
TORCH_CHECK(routing_weights.size(1) == E,
|
| 208 |
+
"routing_weights must be [T,K] or [T,E]");
|
| 209 |
+
selected_2d = at::gather(routing_weights, 1, flat_router.to(torch::kLong));
|
| 210 |
+
}
|
| 211 |
+
torch::Tensor selected_flat = selected_2d.contiguous().view({T * K});
|
| 212 |
+
torch::Tensor weights_sorted = torch::empty_like(selected_flat);
|
| 213 |
+
index_select_out_cuda(weights_sorted, selected_flat, sorted_indices);
|
| 214 |
+
|
| 215 |
+
// Initialize gradients
|
| 216 |
+
torch::Tensor dHidden = torch::zeros_like(hidden_states);
|
| 217 |
+
torch::Tensor dRouting;
|
| 218 |
+
torch::Tensor dWgu = torch::zeros_like(gate_up_proj);
|
| 219 |
+
torch::Tensor dbgu = torch::zeros_like(gate_up_proj_bias);
|
| 220 |
+
torch::Tensor dWd = torch::zeros_like(down_proj);
|
| 221 |
+
torch::Tensor dbd = torch::zeros_like(down_proj_bias);
|
| 222 |
+
|
| 223 |
+
// Reshape grad_out to [T,H]
|
| 224 |
+
TORCH_CHECK(grad_out.numel() == T * H || grad_out.numel() == T * K * H,
|
| 225 |
+
"grad_out numel must be T*H or T*K*H");
|
| 226 |
+
torch::Tensor grad_y = (grad_out.numel() == T * H)
|
| 227 |
+
? grad_out.contiguous().view({T, H})
|
| 228 |
+
: grad_out.contiguous().view({T, K, H}).sum(1);
|
| 229 |
+
|
| 230 |
+
// Backward through scatter
|
| 231 |
+
torch::Tensor grad_expert_output = torch::zeros({E, C, H}, float_opts);
|
| 232 |
+
torch::Tensor grad_selected_weights = torch::zeros({T * K}, float_opts);
|
| 233 |
+
{
|
| 234 |
+
dim3 grid(E, C);
|
| 235 |
+
int threads = 256;
|
| 236 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(
|
| 237 |
+
at::kHalf,
|
| 238 |
+
at::kBFloat16,
|
| 239 |
+
hidden_states.scalar_type(),
|
| 240 |
+
"binned_scatter_backward",
|
| 241 |
+
[&] {
|
| 242 |
+
using st = scalar_t;
|
| 243 |
+
binned_scatter_backward_kernel<st>
|
| 244 |
+
<<<grid, threads>>>(grad_y.data_ptr<st>(),
|
| 245 |
+
sorted_indices.data_ptr<int>(),
|
| 246 |
+
bins.data_ptr<int>(),
|
| 247 |
+
weights_sorted.data_ptr<st>(),
|
| 248 |
+
y_expert.data_ptr<st>(),
|
| 249 |
+
grad_expert_output.data_ptr<st>(),
|
| 250 |
+
grad_selected_weights.data_ptr<st>(),
|
| 251 |
+
(int)T,
|
| 252 |
+
(int)K,
|
| 253 |
+
(int)H,
|
| 254 |
+
(int)E,
|
| 255 |
+
(int)C);
|
| 256 |
+
});
|
| 257 |
+
cudaDeviceSynchronize();
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
// Route weight gradients
|
| 261 |
+
torch::Tensor grad_selected_flat = torch::zeros({T * K}, float_opts);
|
| 262 |
+
grad_selected_flat.index_add_(0,
|
| 263 |
+
sorted_indices.to(torch::kLong),
|
| 264 |
+
grad_selected_weights);
|
| 265 |
+
|
| 266 |
+
if (routing_weights.size(1) == E) {
|
| 267 |
+
torch::Tensor flat_grad_routing = torch::zeros_like(routing_weights);
|
| 268 |
+
flat_grad_routing.scatter_add_(1,
|
| 269 |
+
flat_router.to(torch::kLong),
|
| 270 |
+
grad_selected_flat.view({T, K}));
|
| 271 |
+
dRouting = flat_grad_routing;
|
| 272 |
+
} else {
|
| 273 |
+
dRouting = grad_selected_flat.view({T, K});
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
// Backward through down projection
|
| 277 |
+
dbd = grad_expert_output.sum(1);
|
| 278 |
+
torch::Tensor grad_intermediate =
|
| 279 |
+
torch::bmm(grad_expert_output.contiguous(),
|
| 280 |
+
down_proj.transpose(1, 2).contiguous());
|
| 281 |
+
dWd = torch::bmm(up_out.transpose(1, 2).contiguous(),
|
| 282 |
+
grad_expert_output.contiguous());
|
| 283 |
+
|
| 284 |
+
// Backward through GLU
|
| 285 |
+
torch::Tensor grad_up_plus_1 = grad_intermediate * gate_act;
|
| 286 |
+
torch::Tensor grad_glu = grad_intermediate * (up_clamped + 1);
|
| 287 |
+
torch::Tensor grad_up_clamped = grad_up_plus_1;
|
| 288 |
+
|
| 289 |
+
torch::Tensor sigmoid_gate = torch::sigmoid(gate_clamped * alpha);
|
| 290 |
+
torch::Tensor grad_gate_clamped =
|
| 291 |
+
grad_glu *
|
| 292 |
+
(sigmoid_gate + gate_clamped * sigmoid_gate * (1 - sigmoid_gate) * alpha);
|
| 293 |
+
|
| 294 |
+
// Unclamp gradients
|
| 295 |
+
torch::Tensor grad_gate = grad_gate_clamped.clone();
|
| 296 |
+
grad_gate.masked_fill_(pre_gate > limit, 0);
|
| 297 |
+
torch::Tensor grad_up = grad_up_clamped.clone();
|
| 298 |
+
grad_up.masked_fill_(pre_up > limit, 0);
|
| 299 |
+
grad_up.masked_fill_(pre_up < -limit, 0);
|
| 300 |
+
|
| 301 |
+
// Merge gate/up gradients
|
| 302 |
+
torch::Tensor grad_gate_up_pair = torch::zeros({E, C, H, 2}, float_opts);
|
| 303 |
+
grad_gate_up_pair.select(3, 0).copy_(grad_gate);
|
| 304 |
+
grad_gate_up_pair.select(3, 1).copy_(grad_up);
|
| 305 |
+
torch::Tensor grad_gate_up = grad_gate_up_pair.view({E, C, 2 * H});
|
| 306 |
+
|
| 307 |
+
// Backward through gate-up projection
|
| 308 |
+
dbgu = grad_gate_up.sum(1);
|
| 309 |
+
torch::Tensor grad_x = torch::bmm(grad_gate_up.contiguous(),
|
| 310 |
+
gate_up_proj.transpose(1, 2).contiguous());
|
| 311 |
+
dWgu = torch::bmm(x.transpose(1, 2).contiguous(), grad_gate_up.contiguous());
|
| 312 |
+
|
| 313 |
+
// Backward through gather
|
| 314 |
+
torch::Tensor grad_hidden = torch::zeros({T, H}, float_opts);
|
| 315 |
+
{
|
| 316 |
+
dim3 grid(E, C);
|
| 317 |
+
int threads = 256;
|
| 318 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf,
|
| 319 |
+
at::kBFloat16,
|
| 320 |
+
hidden_states.scalar_type(),
|
| 321 |
+
"binned_gather_backward",
|
| 322 |
+
[&] {
|
| 323 |
+
using st = scalar_t;
|
| 324 |
+
binned_gather_backward_kernel<st>
|
| 325 |
+
<<<grid, threads>>>(
|
| 326 |
+
grad_x.data_ptr<st>(),
|
| 327 |
+
sorted_indices.data_ptr<int>(),
|
| 328 |
+
bins.data_ptr<int>(),
|
| 329 |
+
grad_hidden.data_ptr<st>(),
|
| 330 |
+
(int)T,
|
| 331 |
+
(int)K,
|
| 332 |
+
(int)H,
|
| 333 |
+
(int)E,
|
| 334 |
+
(int)C);
|
| 335 |
+
});
|
| 336 |
+
cudaDeviceSynchronize();
|
| 337 |
+
}
|
| 338 |
+
dHidden += grad_hidden;
|
| 339 |
+
|
| 340 |
+
return {dHidden, dRouting, dWgu, dbgu, dWd, dbd};
|
| 341 |
+
}
|
readme_example.py
CHANGED
|
@@ -7,8 +7,7 @@
|
|
| 7 |
|
| 8 |
import time
|
| 9 |
import torch
|
| 10 |
-
from kernels import get_local_kernel
|
| 11 |
-
from kernels import get_kernel
|
| 12 |
from pathlib import Path
|
| 13 |
from torch.nn import functional as F
|
| 14 |
|
|
@@ -83,6 +82,8 @@ torch.cuda.synchronize()
|
|
| 83 |
elapsed_ms = (time.perf_counter() - start) * 1e3
|
| 84 |
peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
|
| 85 |
|
| 86 |
-
print(
|
|
|
|
|
|
|
| 87 |
print(f"First 3: {output.view(-1)[:3].tolist()}")
|
| 88 |
print(f"Time: {elapsed_ms:.1f}ms, Memory: {peak_mem_mb:.0f}MB")
|
|
|
|
| 7 |
|
| 8 |
import time
|
| 9 |
import torch
|
| 10 |
+
from kernels import get_kernel, get_local_kernel
|
|
|
|
| 11 |
from pathlib import Path
|
| 12 |
from torch.nn import functional as F
|
| 13 |
|
|
|
|
| 82 |
elapsed_ms = (time.perf_counter() - start) * 1e3
|
| 83 |
peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
|
| 84 |
|
| 85 |
+
print(
|
| 86 |
+
f"Output: sum={output.sum().item():.1f}, min={output.min().item():.1f}, max={output.max().item():.1f}"
|
| 87 |
+
)
|
| 88 |
print(f"First 3: {output.view(-1)[:3].tolist()}")
|
| 89 |
print(f"Time: {elapsed_ms:.1f}ms, Memory: {peak_mem_mb:.0f}MB")
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -67,6 +67,20 @@ TORCH_LIBRARY_EXPAND(
|
|
| 67 |
"int num_experts, "
|
| 68 |
"int top_k) -> Tensor");
|
| 69 |
ops.impl("experts", torch::kCUDA, &experts_cuda);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
}
|
| 71 |
|
| 72 |
REGISTER_EXTENSION(
|
|
|
|
| 67 |
"int num_experts, "
|
| 68 |
"int top_k) -> Tensor");
|
| 69 |
ops.impl("experts", torch::kCUDA, &experts_cuda);
|
| 70 |
+
|
| 71 |
+
ops.def("experts_backward("
|
| 72 |
+
"Tensor grad_out, "
|
| 73 |
+
"Tensor hidden_states, "
|
| 74 |
+
"Tensor router_indices, "
|
| 75 |
+
"Tensor routing_weights, "
|
| 76 |
+
"Tensor gate_up_proj, "
|
| 77 |
+
"Tensor gate_up_proj_bias, "
|
| 78 |
+
"Tensor down_proj, "
|
| 79 |
+
"Tensor down_proj_bias, "
|
| 80 |
+
"int expert_capacity, "
|
| 81 |
+
"int num_experts, "
|
| 82 |
+
"int top_k) -> Tensor[]");
|
| 83 |
+
ops.impl("experts_backward", torch::kCUDA, &experts_backward_cuda);
|
| 84 |
}
|
| 85 |
|
| 86 |
REGISTER_EXTENSION(
|
torch-ext/torch_binding.h
CHANGED
|
@@ -53,3 +53,19 @@ torch::Tensor experts_cuda(
|
|
| 53 |
int64_t num_experts, // E - number of experts
|
| 54 |
int64_t top_k // K - top-k routing
|
| 55 |
);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
int64_t num_experts, // E - number of experts
|
| 54 |
int64_t top_k // K - top-k routing
|
| 55 |
);
|
| 56 |
+
|
| 57 |
+
std::vector<torch::Tensor> experts_backward_cuda(
|
| 58 |
+
const torch::Tensor &grad_out, // [T, H] - gradient from output
|
| 59 |
+
const torch::Tensor &hidden_states, // [T, H] - original input
|
| 60 |
+
const torch::Tensor &router_indices, // [T, K] - expert indices per token
|
| 61 |
+
const torch::Tensor &routing_weights, // [T, K] or [T, E] - routing weights
|
| 62 |
+
const torch::Tensor
|
| 63 |
+
&gate_up_proj, // [E, H, 2*H] - gate/up projection weights
|
| 64 |
+
const torch::Tensor
|
| 65 |
+
&gate_up_proj_bias, // [E, 2*H] - gate/up projection bias
|
| 66 |
+
const torch::Tensor &down_proj, // [E, H, H] - down projection weights
|
| 67 |
+
const torch::Tensor &down_proj_bias, // [E, H] - down projection bias
|
| 68 |
+
int64_t expert_capacity, // C - capacity per expert
|
| 69 |
+
int64_t num_experts, // E - number of experts
|
| 70 |
+
int64_t top_k // K - top-k routing
|
| 71 |
+
);
|
torch-ext/yamoe/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
from ._ops import ops
|
| 2 |
-
from . import
|
|
|
|
|
|
|
| 3 |
|
| 4 |
gather = ops.gather
|
| 5 |
scatter = ops.scatter
|
|
@@ -7,8 +9,14 @@ sort = ops.sort
|
|
| 7 |
bincount_cumsum = ops.bincount_cumsum
|
| 8 |
batch_mm = ops.batch_mm
|
| 9 |
experts = ops.experts
|
|
|
|
| 10 |
|
| 11 |
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"shuffle",
|
| 13 |
"gather",
|
| 14 |
"scatter",
|
|
@@ -16,6 +24,9 @@ __all__ = [
|
|
| 16 |
"bincount_cumsum",
|
| 17 |
"batch_mm",
|
| 18 |
"experts",
|
| 19 |
-
|
|
|
|
| 20 |
"reference",
|
|
|
|
|
|
|
| 21 |
]
|
|
|
|
| 1 |
from ._ops import ops
|
| 2 |
+
from .layers import Yamoe
|
| 3 |
+
from .vendored import yamoe_ref
|
| 4 |
+
from .vendored import gpt_oss_mlp
|
| 5 |
|
| 6 |
gather = ops.gather
|
| 7 |
scatter = ops.scatter
|
|
|
|
| 9 |
bincount_cumsum = ops.bincount_cumsum
|
| 10 |
batch_mm = ops.batch_mm
|
| 11 |
experts = ops.experts
|
| 12 |
+
experts_backward = ops.experts_backward
|
| 13 |
|
| 14 |
__all__ = [
|
| 15 |
+
# Debug
|
| 16 |
+
"ops",
|
| 17 |
+
# Layer (nn module)
|
| 18 |
+
"Yamoe",
|
| 19 |
+
# Functions
|
| 20 |
"shuffle",
|
| 21 |
"gather",
|
| 22 |
"scatter",
|
|
|
|
| 24 |
"bincount_cumsum",
|
| 25 |
"batch_mm",
|
| 26 |
"experts",
|
| 27 |
+
"experts_backward",
|
| 28 |
+
# Vendored reference implementations
|
| 29 |
"reference",
|
| 30 |
+
"yamoe_ref",
|
| 31 |
+
"gpt_oss_mlp",
|
| 32 |
]
|
torch-ext/yamoe/layers.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ._ops import ops
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Yamoe(torch.nn.Module):
|
| 6 |
+
"""Yamoe MoE layer with routing and expert computation"""
|
| 7 |
+
|
| 8 |
+
can_torch_compile: bool = True
|
| 9 |
+
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super().__init__()
|
| 12 |
+
# Pre-allocate buffers to avoid repeated allocations
|
| 13 |
+
self._routing_weights_buffer = None
|
| 14 |
+
self._batch_indices_buffer = None
|
| 15 |
+
self._last_batch_seq = None
|
| 16 |
+
self._last_num_experts = None
|
| 17 |
+
|
| 18 |
+
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 19 |
+
batch_size, seq_len, hidden_dim = hidden_states.shape
|
| 20 |
+
batch_seq = batch_size * seq_len
|
| 21 |
+
|
| 22 |
+
num_experts = getattr(self, "num_experts", 128)
|
| 23 |
+
top_k = getattr(self, "top_k", 4)
|
| 24 |
+
|
| 25 |
+
# Route tokens to experts
|
| 26 |
+
x_flat = hidden_states.view(-1, hidden_dim)
|
| 27 |
+
logits = torch.nn.functional.linear(
|
| 28 |
+
x_flat, self.router.weight, self.router.bias
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Compute top-k
|
| 32 |
+
if top_k == 1:
|
| 33 |
+
routing_weights, router_indices = logits.max(dim=-1, keepdim=True)
|
| 34 |
+
else:
|
| 35 |
+
routing_weights, router_indices = torch.topk(logits, top_k, dim=-1)
|
| 36 |
+
|
| 37 |
+
routing_weights = routing_weights.softmax(dim=-1)
|
| 38 |
+
|
| 39 |
+
# Create router scores
|
| 40 |
+
router_scores = (
|
| 41 |
+
torch.zeros_like(logits)
|
| 42 |
+
.scatter_(1, router_indices, routing_weights)
|
| 43 |
+
.transpose(0, 1)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Convert routing_weights to sparse format [batch_seq, num_experts]
|
| 47 |
+
# Reuse buffer if possible to reduce allocations
|
| 48 |
+
if (
|
| 49 |
+
self._routing_weights_buffer is None
|
| 50 |
+
or self._last_batch_seq != batch_seq
|
| 51 |
+
or self._last_num_experts != num_experts
|
| 52 |
+
or self._routing_weights_buffer.device != routing_weights.device
|
| 53 |
+
):
|
| 54 |
+
self._routing_weights_buffer = torch.zeros(
|
| 55 |
+
batch_seq,
|
| 56 |
+
num_experts,
|
| 57 |
+
device=routing_weights.device,
|
| 58 |
+
dtype=routing_weights.dtype,
|
| 59 |
+
)
|
| 60 |
+
self._batch_indices_buffer = (
|
| 61 |
+
torch.arange(batch_seq, device=routing_weights.device)
|
| 62 |
+
.unsqueeze(1)
|
| 63 |
+
.expand(-1, top_k)
|
| 64 |
+
)
|
| 65 |
+
self._last_batch_seq = batch_seq
|
| 66 |
+
self._last_num_experts = num_experts
|
| 67 |
+
else:
|
| 68 |
+
self._routing_weights_buffer.zero_()
|
| 69 |
+
|
| 70 |
+
# Fill sparse routing weights
|
| 71 |
+
flat_indices = router_indices.view(batch_seq, top_k)
|
| 72 |
+
flat_weights = routing_weights.view(batch_seq, top_k)
|
| 73 |
+
self._routing_weights_buffer[self._batch_indices_buffer, flat_indices] = (
|
| 74 |
+
flat_weights
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# FIX: Use the correct expert projections
|
| 78 |
+
gate_up = self.experts.gate_up_proj[:, :, : hidden_dim * top_k].contiguous()
|
| 79 |
+
gate_up_bias = self.experts.gate_up_proj_bias[
|
| 80 |
+
:, : hidden_dim * top_k
|
| 81 |
+
].contiguous()
|
| 82 |
+
|
| 83 |
+
down_proj = self.experts.down_proj[:, :hidden_dim, :].contiguous()
|
| 84 |
+
|
| 85 |
+
expert_capacity = batch_seq * top_k // num_experts * 2
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
# Compute expert output
|
| 89 |
+
output = ops.experts(
|
| 90 |
+
hidden_states.view(-1, hidden_dim),
|
| 91 |
+
router_indices,
|
| 92 |
+
self._routing_weights_buffer,
|
| 93 |
+
gate_up,
|
| 94 |
+
gate_up_bias,
|
| 95 |
+
down_proj,
|
| 96 |
+
self.experts.down_proj_bias,
|
| 97 |
+
expert_capacity,
|
| 98 |
+
num_experts,
|
| 99 |
+
top_k,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Reshape output back to [B, S, H]
|
| 103 |
+
output = output.view(batch_size, seq_len, hidden_dim)
|
| 104 |
+
return output, router_scores
|
torch-ext/yamoe/{reference.py → vendored/gpt_oss_mlp.py}
RENAMED
|
@@ -1,5 +1,14 @@
|
|
| 1 |
import torch
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
class GptOssExperts(nn.Module):
|
| 5 |
def __init__(self, config):
|
|
@@ -8,16 +17,26 @@ class GptOssExperts(nn.Module):
|
|
| 8 |
self.num_experts = config.num_local_experts
|
| 9 |
self.hidden_size = config.hidden_size
|
| 10 |
self.expert_dim = self.intermediate_size
|
| 11 |
-
self.gate_up_proj = nn.Parameter(
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
self.alpha = 1.702
|
| 16 |
self.limit = 7.0
|
| 17 |
|
| 18 |
-
def forward(
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
-
When training
|
| 21 |
as otherwise the memory would explode.
|
| 22 |
|
| 23 |
For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
|
|
@@ -29,45 +48,112 @@ class GptOssExperts(nn.Module):
|
|
| 29 |
Returns:
|
| 30 |
torch.Tensor
|
| 31 |
"""
|
| 32 |
-
|
| 33 |
-
# import ipdb; ipdb.set_trace()
|
| 34 |
-
|
| 35 |
batch_size = hidden_states.shape[0]
|
| 36 |
-
hidden_states = hidden_states.reshape(
|
|
|
|
|
|
|
| 37 |
num_experts = routing_weights.shape[1]
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
| 40 |
with torch.no_grad():
|
| 41 |
-
expert_mask = torch.nn.functional.one_hot(
|
|
|
|
|
|
|
| 42 |
expert_mask = expert_mask.permute(2, 1, 0)
|
| 43 |
-
# we sum on the top_k and on the sequence
|
| 44 |
# are hit this time around
|
| 45 |
-
|
| 46 |
-
for expert_idx in
|
|
|
|
|
|
|
| 47 |
with torch.no_grad():
|
| 48 |
-
_, token_idx = torch.where(expert_mask[expert_idx
|
| 49 |
current_state = hidden_states[token_idx]
|
| 50 |
-
gate_up =
|
|
|
|
|
|
|
|
|
|
| 51 |
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 52 |
gate = gate.clamp(min=None, max=self.limit)
|
| 53 |
up = up.clamp(min=-self.limit, max=self.limit)
|
| 54 |
glu = gate * torch.sigmoid(gate * self.alpha)
|
| 55 |
gated_output = (up + 1) * glu
|
| 56 |
-
out =
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
next_states = next_states.view(batch_size, -1, self.hidden_size)
|
| 60 |
else:
|
| 61 |
hidden_states = hidden_states.repeat(num_experts, 1)
|
| 62 |
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
|
| 63 |
-
gate_up =
|
|
|
|
|
|
|
|
|
|
| 64 |
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 65 |
gate = gate.clamp(min=None, max=self.limit)
|
| 66 |
up = up.clamp(min=-self.limit, max=self.limit)
|
| 67 |
glu = gate * torch.sigmoid(gate * self.alpha)
|
| 68 |
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
|
| 69 |
next_states = next_states + self.down_proj_bias[..., None, :]
|
| 70 |
-
next_states = next_states.view(
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
next_states = next_states.sum(dim=0)
|
| 73 |
return next_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def info(tensor, name):
|
| 7 |
+
print(name)
|
| 8 |
+
print(tensor.shape)
|
| 9 |
+
print(tensor.cpu())
|
| 10 |
+
print()
|
| 11 |
+
|
| 12 |
|
| 13 |
class GptOssExperts(nn.Module):
|
| 14 |
def __init__(self, config):
|
|
|
|
| 17 |
self.num_experts = config.num_local_experts
|
| 18 |
self.hidden_size = config.hidden_size
|
| 19 |
self.expert_dim = self.intermediate_size
|
| 20 |
+
self.gate_up_proj = nn.Parameter(
|
| 21 |
+
torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)
|
| 22 |
+
)
|
| 23 |
+
self.gate_up_proj_bias = nn.Parameter(
|
| 24 |
+
torch.empty(self.num_experts, 2 * self.expert_dim)
|
| 25 |
+
)
|
| 26 |
+
self.down_proj = nn.Parameter(
|
| 27 |
+
torch.empty((self.num_experts, self.expert_dim, self.hidden_size))
|
| 28 |
+
)
|
| 29 |
+
self.down_proj_bias = nn.Parameter(
|
| 30 |
+
torch.empty(self.num_experts, self.hidden_size)
|
| 31 |
+
)
|
| 32 |
self.alpha = 1.702
|
| 33 |
self.limit = 7.0
|
| 34 |
|
| 35 |
+
def forward(
|
| 36 |
+
self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None
|
| 37 |
+
) -> torch.Tensor:
|
| 38 |
"""
|
| 39 |
+
When training it is more efficient to just loop over the experts and compute the output for each expert
|
| 40 |
as otherwise the memory would explode.
|
| 41 |
|
| 42 |
For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
|
|
|
|
| 48 |
Returns:
|
| 49 |
torch.Tensor
|
| 50 |
"""
|
|
|
|
|
|
|
|
|
|
| 51 |
batch_size = hidden_states.shape[0]
|
| 52 |
+
hidden_states = hidden_states.reshape(
|
| 53 |
+
-1, self.hidden_size
|
| 54 |
+
) # (num_tokens, hidden_size)
|
| 55 |
num_experts = routing_weights.shape[1]
|
| 56 |
+
|
| 57 |
+
if hidden_states.device.type == "cpu" or self.training:
|
| 58 |
+
next_states = torch.zeros_like(
|
| 59 |
+
hidden_states, dtype=hidden_states.dtype, device=hidden_states.device
|
| 60 |
+
)
|
| 61 |
with torch.no_grad():
|
| 62 |
+
expert_mask = torch.nn.functional.one_hot(
|
| 63 |
+
router_indices, num_classes=num_experts
|
| 64 |
+
)
|
| 65 |
expert_mask = expert_mask.permute(2, 1, 0)
|
| 66 |
+
# we sum on the top_k and on the sequence length to get which experts
|
| 67 |
# are hit this time around
|
| 68 |
+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
| 69 |
+
for expert_idx in expert_hit[:]:
|
| 70 |
+
# expert_idx only have 1 element, so we can use scale for fast indexing
|
| 71 |
+
expert_idx = expert_idx[0]
|
| 72 |
with torch.no_grad():
|
| 73 |
+
_, token_idx = torch.where(expert_mask[expert_idx])
|
| 74 |
current_state = hidden_states[token_idx]
|
| 75 |
+
gate_up = (
|
| 76 |
+
current_state @ self.gate_up_proj[expert_idx]
|
| 77 |
+
+ self.gate_up_proj_bias[expert_idx]
|
| 78 |
+
)
|
| 79 |
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 80 |
gate = gate.clamp(min=None, max=self.limit)
|
| 81 |
up = up.clamp(min=-self.limit, max=self.limit)
|
| 82 |
glu = gate * torch.sigmoid(gate * self.alpha)
|
| 83 |
gated_output = (up + 1) * glu
|
| 84 |
+
out = (
|
| 85 |
+
gated_output @ self.down_proj[expert_idx]
|
| 86 |
+
+ self.down_proj_bias[expert_idx]
|
| 87 |
+
)
|
| 88 |
+
weighted_output = out * routing_weights[token_idx, expert_idx, None]
|
| 89 |
+
next_states.index_add_(
|
| 90 |
+
0, token_idx, weighted_output.to(hidden_states.dtype)
|
| 91 |
+
)
|
| 92 |
next_states = next_states.view(batch_size, -1, self.hidden_size)
|
| 93 |
else:
|
| 94 |
hidden_states = hidden_states.repeat(num_experts, 1)
|
| 95 |
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
|
| 96 |
+
gate_up = (
|
| 97 |
+
torch.bmm(hidden_states, self.gate_up_proj)
|
| 98 |
+
+ self.gate_up_proj_bias[..., None, :]
|
| 99 |
+
)
|
| 100 |
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 101 |
gate = gate.clamp(min=None, max=self.limit)
|
| 102 |
up = up.clamp(min=-self.limit, max=self.limit)
|
| 103 |
glu = gate * torch.sigmoid(gate * self.alpha)
|
| 104 |
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
|
| 105 |
next_states = next_states + self.down_proj_bias[..., None, :]
|
| 106 |
+
next_states = next_states.view(
|
| 107 |
+
num_experts, batch_size, -1, self.hidden_size
|
| 108 |
+
)
|
| 109 |
+
next_states = (
|
| 110 |
+
next_states
|
| 111 |
+
* routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[
|
| 112 |
+
..., None
|
| 113 |
+
]
|
| 114 |
+
)
|
| 115 |
next_states = next_states.sum(dim=0)
|
| 116 |
return next_states
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class GptOssTopKRouter(nn.Module):
|
| 120 |
+
def __init__(self, config):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.top_k = config.num_experts_per_tok
|
| 123 |
+
self.num_experts = config.num_local_experts
|
| 124 |
+
self.hidden_dim = config.hidden_size
|
| 125 |
+
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
|
| 126 |
+
self.bias = nn.Parameter(torch.empty(self.num_experts))
|
| 127 |
+
|
| 128 |
+
def forward(self, hidden_states):
|
| 129 |
+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
| 130 |
+
router_logits = F.linear(
|
| 131 |
+
hidden_states, self.weight, self.bias
|
| 132 |
+
) # (seq_len, num_experts)
|
| 133 |
+
router_top_value, router_indices = torch.topk(
|
| 134 |
+
router_logits, self.top_k, dim=-1
|
| 135 |
+
) # (seq_len, top_k)
|
| 136 |
+
router_top_value = torch.nn.functional.softmax(
|
| 137 |
+
router_top_value, dim=1, dtype=router_top_value.dtype
|
| 138 |
+
)
|
| 139 |
+
router_scores = torch.zeros_like(router_logits).scatter_(
|
| 140 |
+
1, router_indices, router_top_value
|
| 141 |
+
)
|
| 142 |
+
return router_scores, router_indices
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# @use_kernel_forward_from_hub("MegaBlocksMoeMLP")
|
| 146 |
+
class GptOssMLP(nn.Module):
|
| 147 |
+
def __init__(self, config):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.router = GptOssTopKRouter(config)
|
| 150 |
+
self.experts = GptOssExperts(config)
|
| 151 |
+
|
| 152 |
+
def forward(self, hidden_states):
|
| 153 |
+
router_scores, router_indices = self.router(
|
| 154 |
+
hidden_states
|
| 155 |
+
) # (num_experts, seq_len)
|
| 156 |
+
routed_out = self.experts(
|
| 157 |
+
hidden_states, router_indices=router_indices, routing_weights=router_scores
|
| 158 |
+
)
|
| 159 |
+
return routed_out, router_scores
|
torch-ext/yamoe/vendored/yamoe_ref.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def binned_gather(x, indices, bins, expert_capacity, top_k):
|
| 5 |
+
E, H = bins.shape[0], x.shape[1]
|
| 6 |
+
out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
|
| 7 |
+
for e in range(E):
|
| 8 |
+
start = 0 if e == 0 else bins[e - 1]
|
| 9 |
+
end = bins[e]
|
| 10 |
+
n = min(end - start, expert_capacity)
|
| 11 |
+
for i in range(n):
|
| 12 |
+
flat_pos = indices[start + i]
|
| 13 |
+
tok = flat_pos // top_k
|
| 14 |
+
out[e, i] = x[tok]
|
| 15 |
+
return out
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
|
| 19 |
+
E, C, H = x.shape
|
| 20 |
+
N = indices.shape[0] // top_k
|
| 21 |
+
out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
|
| 22 |
+
for e in range(E):
|
| 23 |
+
start = 0 if e == 0 else bins[e - 1]
|
| 24 |
+
end = bins[e]
|
| 25 |
+
n = end - start
|
| 26 |
+
if n == 0:
|
| 27 |
+
continue
|
| 28 |
+
take = min(n, expert_capacity)
|
| 29 |
+
for i in range(take):
|
| 30 |
+
flat_pos = indices[start + i] # flattened (token, slot)
|
| 31 |
+
tok = flat_pos // top_k
|
| 32 |
+
slot = flat_pos % top_k
|
| 33 |
+
scale = weights[flat_pos] if weights is not None else 1.0
|
| 34 |
+
out[tok, slot] = x[e, i] * scale
|
| 35 |
+
return out.sum(dim=1)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def sort_tokens_by_expert(router_indices, num_experts):
|
| 39 |
+
flat_indices = router_indices.flatten()
|
| 40 |
+
sorted_values, sorted_indices = torch.sort(flat_indices)
|
| 41 |
+
tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
|
| 42 |
+
bins = torch.cumsum(tokens_per_expert, dim=0)
|
| 43 |
+
return sorted_indices, sorted_values, bins, tokens_per_expert
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def binned_experts_ref(
|
| 47 |
+
hidden_states,
|
| 48 |
+
router_indices,
|
| 49 |
+
routing_weights,
|
| 50 |
+
gate_up_proj,
|
| 51 |
+
gate_up_proj_bias,
|
| 52 |
+
down_proj,
|
| 53 |
+
down_proj_bias,
|
| 54 |
+
expert_capacity,
|
| 55 |
+
):
|
| 56 |
+
B, S, H = hidden_states.shape
|
| 57 |
+
E, K = routing_weights.shape[1], router_indices.shape[1]
|
| 58 |
+
|
| 59 |
+
indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
|
| 60 |
+
x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
|
| 61 |
+
|
| 62 |
+
gate_up = torch.bmm(x, gate_up_proj) + gate_up_proj_bias[..., None, :]
|
| 63 |
+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 64 |
+
|
| 65 |
+
# clamp to limit
|
| 66 |
+
limit = 7.0
|
| 67 |
+
gate = gate.clamp(min=None, max=limit)
|
| 68 |
+
up = up.clamp(min=-limit, max=limit)
|
| 69 |
+
|
| 70 |
+
glu = gate * torch.sigmoid(gate * 1.702)
|
| 71 |
+
x = (up + 1) * glu
|
| 72 |
+
x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
|
| 73 |
+
|
| 74 |
+
# build routing weights aligned to (token, slot)
|
| 75 |
+
flat_dense = routing_weights.view(-1, E) # [B*S, E]
|
| 76 |
+
flat_router = router_indices.view(-1, K) # [B*S, K]
|
| 77 |
+
selected = torch.gather(flat_dense, 1, flat_router).reshape(-1) # [B*S*K]
|
| 78 |
+
|
| 79 |
+
# scatter back
|
| 80 |
+
y = binned_scatter(x, indices, selected, bins, expert_capacity, K) # [B*S, H]
|
| 81 |
+
|
| 82 |
+
return y.view(B, S, H)
|