drbh commited on
Commit
281d8ba
·
0 Parent(s):

feat: yet another moe

Browse files
.clang-format ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ Language: Cpp
3
+ AccessModifierOffset: -2
4
+ AlignAfterOpenBracket: Align
5
+ AlignArrayOfStructures: None
6
+ AlignConsecutiveAssignments:
7
+ Enabled: false
8
+ AcrossEmptyLines: false
9
+ AcrossComments: false
10
+ AlignCompound: false
11
+ AlignFunctionDeclarations: false
12
+ AlignFunctionPointers: false
13
+ PadOperators: true
14
+ AlignConsecutiveBitFields:
15
+ Enabled: false
16
+ AcrossEmptyLines: false
17
+ AcrossComments: false
18
+ AlignCompound: false
19
+ AlignFunctionDeclarations: false
20
+ AlignFunctionPointers: false
21
+ PadOperators: false
22
+ AlignConsecutiveDeclarations:
23
+ Enabled: false
24
+ AcrossEmptyLines: false
25
+ AcrossComments: false
26
+ AlignCompound: false
27
+ AlignFunctionDeclarations: true
28
+ AlignFunctionPointers: false
29
+ PadOperators: false
30
+ AlignConsecutiveMacros:
31
+ Enabled: false
32
+ AcrossEmptyLines: false
33
+ AcrossComments: false
34
+ AlignCompound: false
35
+ AlignFunctionDeclarations: false
36
+ AlignFunctionPointers: false
37
+ PadOperators: false
38
+ AlignConsecutiveShortCaseStatements:
39
+ Enabled: false
40
+ AcrossEmptyLines: false
41
+ AcrossComments: false
42
+ AlignCaseArrows: false
43
+ AlignCaseColons: false
44
+ AlignConsecutiveTableGenBreakingDAGArgColons:
45
+ Enabled: false
46
+ AcrossEmptyLines: false
47
+ AcrossComments: false
48
+ AlignCompound: false
49
+ AlignFunctionDeclarations: false
50
+ AlignFunctionPointers: false
51
+ PadOperators: false
52
+ AlignConsecutiveTableGenCondOperatorColons:
53
+ Enabled: false
54
+ AcrossEmptyLines: false
55
+ AcrossComments: false
56
+ AlignCompound: false
57
+ AlignFunctionDeclarations: false
58
+ AlignFunctionPointers: false
59
+ PadOperators: false
60
+ AlignConsecutiveTableGenDefinitionColons:
61
+ Enabled: false
62
+ AcrossEmptyLines: false
63
+ AcrossComments: false
64
+ AlignCompound: false
65
+ AlignFunctionDeclarations: false
66
+ AlignFunctionPointers: false
67
+ PadOperators: false
68
+ AlignEscapedNewlines: Right
69
+ AlignOperands: Align
70
+ AlignTrailingComments:
71
+ Kind: Always
72
+ OverEmptyLines: 0
73
+ AllowAllArgumentsOnNextLine: false
74
+ AllowAllParametersOfDeclarationOnNextLine: false
75
+ AllowBreakBeforeNoexceptSpecifier: Never
76
+ AllowShortBlocksOnASingleLine: Never
77
+ AllowShortCaseExpressionOnASingleLine: true
78
+ AllowShortCaseLabelsOnASingleLine: false
79
+ AllowShortCompoundRequirementOnASingleLine: true
80
+ AllowShortEnumsOnASingleLine: true
81
+ AllowShortFunctionsOnASingleLine: All
82
+ AllowShortIfStatementsOnASingleLine: Never
83
+ AllowShortLambdasOnASingleLine: All
84
+ AllowShortLoopsOnASingleLine: false
85
+ AllowShortNamespacesOnASingleLine: false
86
+ AlwaysBreakAfterDefinitionReturnType: None
87
+ AlwaysBreakBeforeMultilineStrings: false
88
+ AttributeMacros:
89
+ - __capability
90
+ BinPackArguments: false
91
+ BinPackParameters: false
92
+ BitFieldColonSpacing: Both
93
+ BraceWrapping:
94
+ AfterCaseLabel: false
95
+ AfterClass: false
96
+ AfterControlStatement: Never
97
+ AfterEnum: false
98
+ AfterExternBlock: false
99
+ AfterFunction: false
100
+ AfterNamespace: false
101
+ AfterObjCDeclaration: false
102
+ AfterStruct: false
103
+ AfterUnion: false
104
+ BeforeCatch: false
105
+ BeforeElse: false
106
+ BeforeLambdaBody: false
107
+ BeforeWhile: false
108
+ IndentBraces: false
109
+ SplitEmptyFunction: true
110
+ SplitEmptyRecord: true
111
+ SplitEmptyNamespace: true
112
+ BreakAdjacentStringLiterals: true
113
+ BreakAfterAttributes: Leave
114
+ BreakAfterJavaFieldAnnotations: false
115
+ BreakAfterReturnType: None
116
+ BreakArrays: true
117
+ BreakBeforeBinaryOperators: None
118
+ BreakBeforeConceptDeclarations: Always
119
+ BreakBeforeBraces: Attach
120
+ BreakBeforeInlineASMColon: OnlyMultiline
121
+ BreakBeforeTernaryOperators: true
122
+ BreakBinaryOperations: Never
123
+ BreakConstructorInitializers: AfterColon
124
+ BreakFunctionDefinitionParameters: true
125
+ BreakInheritanceList: BeforeColon
126
+ BreakStringLiterals: true
127
+ BreakTemplateDeclarations: MultiLine
128
+ ColumnLimit: 80
129
+ CommentPragmas: '^ IWYU pragma:'
130
+ CompactNamespaces: false
131
+ ConstructorInitializerIndentWidth: 4
132
+ ContinuationIndentWidth: 4
133
+ Cpp11BracedListStyle: true
134
+ DerivePointerAlignment: false
135
+ DisableFormat: false
136
+ EmptyLineAfterAccessModifier: Never
137
+ EmptyLineBeforeAccessModifier: LogicalBlock
138
+ ExperimentalAutoDetectBinPacking: false
139
+ FixNamespaceComments: true
140
+ ForEachMacros:
141
+ - foreach
142
+ - Q_FOREACH
143
+ - BOOST_FOREACH
144
+ IfMacros:
145
+ - KJ_IF_MAYBE
146
+ IncludeBlocks: Preserve
147
+ IncludeCategories:
148
+ - Regex: '^"(llvm|llvm-c|clang|clang-c)/'
149
+ Priority: 2
150
+ SortPriority: 0
151
+ CaseSensitive: false
152
+ - Regex: '^(<|"(gtest|gmock|isl|json)/)'
153
+ Priority: 3
154
+ SortPriority: 0
155
+ CaseSensitive: false
156
+ - Regex: '.*'
157
+ Priority: 1
158
+ SortPriority: 0
159
+ CaseSensitive: false
160
+ IncludeIsMainRegex: '(Test)?$'
161
+ IncludeIsMainSourceRegex: ''
162
+ IndentAccessModifiers: false
163
+ IndentCaseBlocks: false
164
+ IndentCaseLabels: false
165
+ IndentExportBlock: true
166
+ IndentExternBlock: AfterExternBlock
167
+ IndentGotoLabels: true
168
+ IndentPPDirectives: None
169
+ IndentRequiresClause: true
170
+ IndentWidth: 2
171
+ IndentWrappedFunctionNames: false
172
+ InsertBraces: false
173
+ InsertNewlineAtEOF: false
174
+ InsertTrailingCommas: None
175
+ IntegerLiteralSeparator:
176
+ Binary: 0
177
+ BinaryMinDigits: 0
178
+ Decimal: 0
179
+ DecimalMinDigits: 0
180
+ Hex: 0
181
+ HexMinDigits: 0
182
+ JavaScriptQuotes: Leave
183
+ JavaScriptWrapImports: true
184
+ KeepEmptyLines:
185
+ AtEndOfFile: false
186
+ AtStartOfBlock: true
187
+ AtStartOfFile: true
188
+ KeepFormFeed: false
189
+ LambdaBodyIndentation: Signature
190
+ LineEnding: DeriveLF
191
+ MacroBlockBegin: ''
192
+ MacroBlockEnd: ''
193
+ MainIncludeChar: Quote
194
+ MaxEmptyLinesToKeep: 1
195
+ NamespaceIndentation: None
196
+ ObjCBinPackProtocolList: Auto
197
+ ObjCBlockIndentWidth: 2
198
+ ObjCBreakBeforeNestedBlockParam: true
199
+ ObjCSpaceAfterProperty: false
200
+ ObjCSpaceBeforeProtocolList: true
201
+ PackConstructorInitializers: BinPack
202
+ PenaltyBreakAssignment: 2
203
+ PenaltyBreakBeforeFirstCallParameter: 0
204
+ PenaltyBreakBeforeMemberAccess: 150
205
+ PenaltyBreakComment: 300
206
+ PenaltyBreakFirstLessLess: 120
207
+ PenaltyBreakOpenParenthesis: 0
208
+ PenaltyBreakScopeResolution: 500
209
+ PenaltyBreakString: 1000
210
+ PenaltyBreakTemplateDeclaration: 10
211
+ PenaltyExcessCharacter: 1000000
212
+ PenaltyIndentedWhitespace: 0
213
+ PenaltyReturnTypeOnItsOwnLine: 60
214
+ PointerAlignment: Right
215
+ PPIndentWidth: -1
216
+ QualifierAlignment: Leave
217
+ ReferenceAlignment: Pointer
218
+ ReflowComments: Always
219
+ RemoveBracesLLVM: false
220
+ RemoveEmptyLinesInUnwrappedLines: false
221
+ RemoveParentheses: Leave
222
+ RemoveSemicolon: false
223
+ RequiresClausePosition: OwnLine
224
+ RequiresExpressionIndentation: OuterScope
225
+ SeparateDefinitionBlocks: Leave
226
+ ShortNamespaceLines: 1
227
+ SkipMacroDefinitionBody: false
228
+ SortIncludes: CaseSensitive
229
+ SortJavaStaticImport: Before
230
+ SortUsingDeclarations: LexicographicNumeric
231
+ SpaceAfterCStyleCast: false
232
+ SpaceAfterLogicalNot: false
233
+ SpaceAfterTemplateKeyword: true
234
+ SpaceAroundPointerQualifiers: Default
235
+ SpaceBeforeAssignmentOperators: true
236
+ SpaceBeforeCaseColon: false
237
+ SpaceBeforeCpp11BracedList: false
238
+ SpaceBeforeCtorInitializerColon: true
239
+ SpaceBeforeInheritanceColon: true
240
+ SpaceBeforeJsonColon: false
241
+ SpaceBeforeParens: ControlStatements
242
+ SpaceBeforeParensOptions:
243
+ AfterControlStatements: true
244
+ AfterForeachMacros: true
245
+ AfterFunctionDefinitionName: false
246
+ AfterFunctionDeclarationName: false
247
+ AfterIfMacros: true
248
+ AfterOverloadedOperator: false
249
+ AfterPlacementOperator: true
250
+ AfterRequiresInClause: false
251
+ AfterRequiresInExpression: false
252
+ BeforeNonEmptyParentheses: false
253
+ SpaceBeforeRangeBasedForLoopColon: true
254
+ SpaceBeforeSquareBrackets: false
255
+ SpaceInEmptyBlock: false
256
+ SpacesBeforeTrailingComments: 1
257
+ SpacesInAngles: Never
258
+ SpacesInContainerLiterals: true
259
+ SpacesInLineCommentPrefix:
260
+ Minimum: 1
261
+ Maximum: -1
262
+ SpacesInParens: Never
263
+ SpacesInParensOptions:
264
+ ExceptDoubleParentheses: false
265
+ InCStyleCasts: false
266
+ InConditionalStatements: false
267
+ InEmptyParentheses: false
268
+ Other: false
269
+ SpacesInSquareBrackets: false
270
+ Standard: Latest
271
+ StatementAttributeLikeMacros:
272
+ - Q_EMIT
273
+ StatementMacros:
274
+ - Q_UNUSED
275
+ - QT_REQUIRE_VERSION
276
+ TableGenBreakInsideDAGArg: DontBreak
277
+ TabWidth: 8
278
+ UseTab: Never
279
+ VerilogBreakBetweenInstancePorts: true
280
+ WhitespaceSensitiveMacros:
281
+ - BOOST_PP_STRINGIZE
282
+ - CF_SWIFT_NAME
283
+ - NS_SWIFT_NAME
284
+ - PP_STRINGIZE
285
+ - STRINGIZE
286
+ WrapNamespaceBodyWithEmptyLines: Leave
287
+ ...
288
+
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.so filter=lfs diff=lfs merge=lfs -text
2
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .bak
2
+ .ruff_cache
3
+ .venv
4
+ cmake
5
+ result
6
+ scripts
7
+ __pycache__
8
+ CMakeLists.txt
9
+ setup.py
10
+ pyproject.toml
11
+ tests
12
+ torch-ext/registration.h
13
+ torch-ext/yamoe/_ops.py
14
+ csrc/batch_mm.cu
15
+ torch-ext/yamoe/*.abi3.so
.pre-commit-config.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/mirrors-clang-format
3
+ rev: v20.1.8
4
+ hooks:
5
+ - id: clang-format
6
+ files: ^(csrc/|torch-ext/).*\.(?:c|cc|cpp|cxx|h|hh|hpp|hxx|cu|cuh)$
7
+ args: [-i]
README.md ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - kernel
5
+ ---
6
+
7
+ ```
8
+
9
+ oooo ooo .oooo. ooo. .oo. .oo. .ooooo. .ooooo.
10
+ `88. .8' `P )88b `888P"Y88bP"Y88b d88' `88b d88' `88b
11
+ `88..8' .oP"888 888 888 888 888 888 888ooo888
12
+ `888' d8( 888 888 888 888 888 888 888 .o
13
+ .8' `Y888""8o o888o o888o o888o `Y8bod8P' `Y8bod8P'
14
+ .o..P'
15
+ `Y8P'
16
+
17
+ Yet Another Mixture of Experts
18
+ ```
19
+
20
+ `yamoe` is a no nonsense, straightforward implementation of Mixture of Experts (MoE) kernels, designed to be super easy to use and be very computationally efficient.
21
+
22
+ ### Design goals
23
+ - simplicity: easy to read and understand the code
24
+ - efficiency: optimized for high throughput and low latency
25
+ - low memory usage: optimized to handle large batch sizes
26
+ - reproducibility: easy to reproduce results, no special new `sm` requirements
27
+
28
+
29
+ ### How to use
30
+
31
+ ```python
32
+ # /// script
33
+ # requires-python = "==3.10"
34
+ # dependencies = ["torch==2.7.0", "triton", "numpy", "kernels"]
35
+ # [tool.uv.sources]
36
+ # kernels = { git = "https://github.com/huggingface/kernels.git" }
37
+ # ///
38
+
39
+ import time
40
+ import torch
41
+ from kernels import get_kernel
42
+ from pathlib import Path
43
+ from torch.nn import functional as F
44
+
45
+ yamoe = get_kernel("drbh/yamoe")
46
+
47
+ # Configuration
48
+ torch.manual_seed(0)
49
+ batch_size, seq_len, hidden_dim = 128, 2048, 2880
50
+ num_experts, top_k = 32, 4
51
+
52
+ # Create routing weights
53
+ logits = torch.randn(batch_size, seq_len, num_experts)
54
+ probs = F.softmax(logits, dim=-1)
55
+ weights, indices = torch.topk(probs, top_k, dim=-1)
56
+
57
+ batch_seq = batch_size * seq_len
58
+ routing_weights = torch.zeros(batch_seq, num_experts, dtype=weights.dtype)
59
+ flat_indices, flat_weights = indices.reshape(-1, top_k), weights.reshape(-1, top_k)
60
+ batch_indices = torch.arange(batch_seq).unsqueeze(1).expand(-1, top_k)
61
+ routing_weights[batch_indices, flat_indices] = flat_weights
62
+
63
+ # Create model tensors (scaled to prevent overflow)
64
+ hidden_states = torch.randn(batch_size, seq_len, hidden_dim).cuda().half() * 0.1
65
+ gate_up_proj = torch.randn(num_experts, hidden_dim, 2 * hidden_dim).cuda().half() * 0.02
66
+ gate_up_proj_bias = torch.zeros(num_experts, 2 * hidden_dim).cuda().half()
67
+ down_proj = torch.randn(num_experts, hidden_dim, hidden_dim).cuda().half() * 0.02
68
+ down_proj_bias = torch.zeros(num_experts, hidden_dim).cuda().half()
69
+ routing_weights = routing_weights.cuda().half()
70
+ router_indices = flat_indices.cuda()
71
+
72
+ # Warmup
73
+ for _ in range(5):
74
+ _ = yamoe.experts(
75
+ hidden_states.view(-1, hidden_dim),
76
+ router_indices,
77
+ routing_weights.view(-1, num_experts),
78
+ gate_up_proj,
79
+ gate_up_proj_bias,
80
+ down_proj,
81
+ down_proj_bias,
82
+ seq_len,
83
+ num_experts,
84
+ top_k,
85
+ )
86
+
87
+ # Benchmark
88
+ torch.cuda.synchronize()
89
+ torch.cuda.reset_peak_memory_stats()
90
+ start = time.perf_counter()
91
+
92
+ with torch.no_grad():
93
+ output = yamoe.experts(
94
+ hidden_states.view(-1, hidden_dim),
95
+ router_indices,
96
+ routing_weights.view(-1, num_experts),
97
+ gate_up_proj,
98
+ gate_up_proj_bias,
99
+ down_proj,
100
+ down_proj_bias,
101
+ seq_len,
102
+ num_experts,
103
+ top_k,
104
+ )
105
+
106
+ torch.cuda.synchronize()
107
+ elapsed_ms = (time.perf_counter() - start) * 1e3
108
+ peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
109
+
110
+ print(f"Output sum: {output.sum().item():.4f}")
111
+ print(f"Kernel time: {elapsed_ms:.3f} ms")
112
+ print(f"Peak GPU memory: {peak_mem_mb:.2f} MB")
113
+ # Output sum: 124.2500
114
+ # Kernel time: 85.722 ms
115
+ # Peak GPU memory: 8403.40 MB
116
+
117
+ ```
build.toml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "yamoe"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h"
9
+ ]
10
+
11
+ [kernel.yamoe]
12
+ backend = "cuda"
13
+ cuda-capabilities = [
14
+ "7.0",
15
+ "7.2",
16
+ "7.5",
17
+ "8.0",
18
+ "8.6",
19
+ "8.7",
20
+ "8.9",
21
+ "9.0",
22
+ "10.0",
23
+ "10.1",
24
+ "11.8",
25
+ "12.0"
26
+ ]
27
+ depends = ["torch", "cutlass_3_8"]
28
+ src = [
29
+ "csrc/index_select.cu",
30
+ "csrc/gather.cu",
31
+ "csrc/scatter.cu",
32
+ "csrc/sort.cu",
33
+ "csrc/bincount_cumsum.cu",
34
+ "csrc/batch_mm.cu",
35
+ "csrc/moe.cpp"
36
+ ]
csrc/batch_mm.cu ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // csrc/batch_mm.cu
2
+
3
+ #include <torch/torch.h>
4
+
5
+ // Simply use a standard bmm for now but this can be adapted for
6
+ // faster batched expert matrix multiply if needed
7
+ torch::Tensor batch_mm(
8
+ torch::Tensor x,
9
+ torch::Tensor weights,
10
+ torch::Tensor batch_sizes,
11
+ torch::Tensor output,
12
+ bool trans_b) {
13
+ // Validate inputs
14
+ TORCH_CHECK(x.is_cuda(), "x must be on CUDA");
15
+ TORCH_CHECK(weights.is_cuda(), "weights must be on CUDA");
16
+ TORCH_CHECK(batch_sizes.is_cuda(), "batch_sizes must be on CUDA");
17
+
18
+ TORCH_CHECK(x.ndimension() == 3, "x must be 3D tensor"); // [E, C, H]
19
+ TORCH_CHECK(weights.ndimension() == 3,
20
+ "weights must be 3D tensor"); // [E, H, H_out]
21
+ TORCH_CHECK(batch_sizes.ndimension() == 1,
22
+ "batch_sizes must be 1D tensor"); // [E]
23
+
24
+ TORCH_CHECK(x.size(0) == weights.size(0) && x.size(0) == batch_sizes.size(0));
25
+ TORCH_CHECK(x.size(2) == weights.size(1)); // H dimension match
26
+
27
+ // For now, just fall back to bmm to test the binding
28
+ // torch::bmm(x, weights, output);
29
+ torch::bmm_out(output, x, weights);
30
+ return output;
31
+ }
csrc/bincount_cumsum.cu ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // csrc/bincount_cumsum.cu
2
+
3
+ #include <cuda.h>
4
+ #include <cuda_runtime.h>
5
+ #include <torch/torch.h>
6
+
7
+ template <typename scalar_t>
8
+ __global__ void bincount_cumsum_kernel(
9
+ const scalar_t *__restrict__ input,
10
+ int32_t *__restrict__ bins_out,
11
+ const int n_input,
12
+ const int n_bins) {
13
+ // Shared memory for local bincount
14
+ extern __shared__ int shared_counts[];
15
+
16
+ int tid = threadIdx.x;
17
+ int bid = blockIdx.x;
18
+ int threads_per_block = blockDim.x;
19
+
20
+ // Initialize shared memory
21
+ for (int i = tid; i < n_bins; i += threads_per_block) {
22
+ shared_counts[i] = 0;
23
+ }
24
+ __syncthreads();
25
+
26
+ // Each block processes a chunk of input
27
+ int start = bid * threads_per_block;
28
+ int end = min(start + threads_per_block, n_input);
29
+
30
+ // Bincount phase - each thread processes its elements
31
+ for (int i = start + tid; i < end; i += threads_per_block) {
32
+ if (i < n_input) {
33
+ int bin = static_cast<int>(input[i]);
34
+ if (bin >= 0 && bin < n_bins) {
35
+ atomicAdd(&shared_counts[bin], 1);
36
+ }
37
+ }
38
+ }
39
+ __syncthreads();
40
+
41
+ // Write block results to global memory
42
+ for (int i = tid; i < n_bins; i += threads_per_block) {
43
+ atomicAdd(&bins_out[i], shared_counts[i]);
44
+ }
45
+ __syncthreads();
46
+
47
+ // Only first block does the cumsum
48
+ if (bid == 0) {
49
+ // Simple cumsum on first block
50
+ if (tid == 0) {
51
+ for (int i = 1; i < n_bins; i++) {
52
+ bins_out[i] += bins_out[i - 1];
53
+ }
54
+ }
55
+ }
56
+ }
57
+
58
+ void bincount_cumsum_cuda(
59
+ torch::Tensor input,
60
+ torch::Tensor &bins_out,
61
+ int64_t minlength) {
62
+ TORCH_CHECK(input.is_cuda(), "Input must be CUDA tensor");
63
+ TORCH_CHECK(input.dtype() == torch::kInt32, "Input must be int32");
64
+ TORCH_CHECK(bins_out.is_cuda(), "Output must be CUDA tensor");
65
+
66
+ const auto n_input = input.numel();
67
+ const auto n_bins = static_cast<int>(minlength);
68
+
69
+ // Validate output tensor dimensions and clear it
70
+ TORCH_CHECK(bins_out.numel() >= n_bins,
71
+ "Output tensor must have at least minlength elements");
72
+ bins_out.zero_();
73
+
74
+ const int threads_per_block = 256;
75
+ const int n_blocks = (n_input + threads_per_block - 1) / threads_per_block;
76
+
77
+ // Launch kernel with shared memory for bincount
78
+ const size_t shared_mem_size = n_bins * sizeof(int);
79
+
80
+ AT_DISPATCH_INTEGRAL_TYPES(
81
+ input.scalar_type(),
82
+ "bincount_cumsum_cuda",
83
+ ([&] {
84
+ bincount_cumsum_kernel<scalar_t>
85
+ <<<n_blocks, threads_per_block, shared_mem_size>>>(
86
+ input.data_ptr<scalar_t>(),
87
+ bins_out.data_ptr<int32_t>(),
88
+ n_input,
89
+ n_bins);
90
+ }));
91
+
92
+ cudaError_t err = cudaGetLastError();
93
+ TORCH_CHECK(err == cudaSuccess,
94
+ "CUDA kernel failed: ",
95
+ cudaGetErrorString(err));
96
+
97
+ // No return needed - output is modified in-place
98
+ }
csrc/gather.cu ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // csrc/gather.cu
2
+
3
+ #include <cuda_runtime.h>
4
+ #include <torch/torch.h>
5
+
6
+ template <typename scalar_t>
7
+ __global__ void gather_kernel(
8
+ const scalar_t *__restrict__ x, // [T,H]
9
+ const int *__restrict__ idx, // [S]
10
+ const int *__restrict__ bins, // [E] cumulative
11
+ scalar_t *__restrict__ out, // [E,C,H]
12
+ int T,
13
+ int H,
14
+ int E,
15
+ int C,
16
+ int top_k) {
17
+ int e = blockIdx.x; // expert
18
+ int i = blockIdx.y; // row within capacity
19
+ if (e >= E || i >= C)
20
+ return;
21
+
22
+ const int end = bins[e];
23
+ const int start = (e == 0) ? 0 : bins[e - 1];
24
+ const int n = end - start;
25
+
26
+ bool valid = (i < n);
27
+ int tok = 0;
28
+ if (valid) {
29
+ int flat = idx[start + i];
30
+ tok = flat / top_k;
31
+ if (tok < 0 || tok >= T)
32
+ valid = false; // guard
33
+ }
34
+
35
+ const scalar_t *src = valid ? (x + (size_t)tok * H) : nullptr;
36
+ scalar_t *dst = out + ((size_t)e * C + i) * H;
37
+
38
+ int t = threadIdx.x;
39
+
40
+ // Try vectorized 16B moves if H is multiple of 4 and pointers are aligned
41
+ // (only for float type)
42
+ if constexpr (std::is_same<scalar_t, float>::value) {
43
+ if ((H % 4) == 0 && ((reinterpret_cast<uintptr_t>(dst) & 0xF) == 0) &&
44
+ (!valid || (reinterpret_cast<uintptr_t>(src) & 0xF) == 0)) {
45
+ const int HV = H / 4;
46
+ using F4 = float4;
47
+ const F4 *src4 = reinterpret_cast<const F4 *>(src);
48
+ F4 *dst4 = reinterpret_cast<F4 *>(dst);
49
+
50
+ for (int j = t; j < HV; j += blockDim.x) {
51
+ F4 v;
52
+ if (valid)
53
+ v = src4[j];
54
+ else
55
+ v = make_float4(0.f, 0.f, 0.f, 0.f);
56
+ dst4[j] = v;
57
+ }
58
+ return;
59
+ }
60
+ }
61
+
62
+ // Fallback to scalar copy
63
+ for (int j = t; j < H; j += blockDim.x) {
64
+ dst[j] = valid ? src[j] : scalar_t(0);
65
+ }
66
+ }
67
+
68
+ void gather_cuda(
69
+ torch::Tensor const &x, // [T, H]
70
+ torch::Tensor const &indices, // [S]
71
+ torch::Tensor const &bins, // [E] cumulative
72
+ torch::Tensor &output, // [E, C, H] pre-allocated output buffer
73
+ int64_t E, // number of experts
74
+ int64_t C, // expert capacity
75
+ int64_t top_k // top-k value
76
+ ) {
77
+ // Get dimensions
78
+ int64_t T = x.size(0);
79
+ int64_t H = x.size(1);
80
+
81
+ // Validate output tensor dimensions
82
+ TORCH_CHECK(output.size(0) == E && output.size(1) == C && output.size(2) == H,
83
+ "Output tensor must have shape [E, C, H]");
84
+
85
+ // Launch kernel with 2D grid (E, C)
86
+ dim3 grid(E, C);
87
+ int threads = 256;
88
+
89
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf,
90
+ at::kBFloat16,
91
+ x.scalar_type(),
92
+ "gather_cuda",
93
+ ([&] {
94
+ using scalar_t_ =
95
+ scalar_t; // avoid shadowing surprises
96
+ gather_kernel<scalar_t_><<<grid, threads>>>(
97
+ x.data_ptr<scalar_t_>(),
98
+ indices.data_ptr<int>(),
99
+ bins.data_ptr<int>(),
100
+ output.data_ptr<scalar_t_>(),
101
+ (int)T,
102
+ (int)H,
103
+ (int)E,
104
+ (int)C,
105
+ (int)top_k);
106
+ }));
107
+
108
+ // No return needed - output is modified in-place
109
+ }
csrc/index_select.cu ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // csrc/index_select.cu
2
+
3
+ #include <c10/cuda/CUDAStream.h>
4
+ #include <cuda_runtime.h>
5
+ #include <torch/torch.h>
6
+
7
+ template <typename scalar_t>
8
+ __global__ void index_select_kernel(
9
+ const scalar_t *__restrict__ in,
10
+ const int32_t *__restrict__ idx,
11
+ scalar_t *__restrict__ out,
12
+ int64_t N) {
13
+ int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
14
+ if (i < N)
15
+ out[i] = in[(int64_t)idx[i]];
16
+ }
17
+
18
+ torch::Tensor index_select_out_cuda(
19
+ torch::Tensor out, // [N], same dtype/device as in
20
+ torch::Tensor in, // [M], contiguous
21
+ torch::Tensor idx_int32) // [N], int32, contiguous
22
+ {
23
+ TORCH_CHECK(in.is_cuda() && idx_int32.is_cuda() && out.is_cuda(),
24
+ "cuda only");
25
+ TORCH_CHECK(idx_int32.dtype() == torch::kInt32, "idx must be int32");
26
+ TORCH_CHECK(
27
+ in.is_contiguous() && idx_int32.is_contiguous() && out.is_contiguous(),
28
+ "contiguous required");
29
+
30
+ int64_t N = idx_int32.numel();
31
+ int threads = 256;
32
+ int blocks = (int)((N + threads - 1) / threads);
33
+
34
+ AT_DISPATCH_FLOATING_TYPES_AND2(
35
+ torch::kBFloat16,
36
+ torch::kHalf,
37
+ in.scalar_type(),
38
+ "index_select_int32",
39
+ [&] {
40
+ const scalar_t *pin = in.data_ptr<scalar_t>();
41
+ const int32_t *pidx = idx_int32.data_ptr<int32_t>();
42
+ scalar_t *pout = out.data_ptr<scalar_t>();
43
+ index_select_kernel<scalar_t>
44
+ <<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(pin,
45
+ pidx,
46
+ pout,
47
+ N);
48
+ });
49
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
50
+ return out;
51
+ }
csrc/moe.cpp ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // csrc/moe.cpp
2
+
3
+ #include <ATen/cuda/CUDAContext.h>
4
+ #include <c10/cuda/CUDAStream.h>
5
+ #include <torch/torch.h>
6
+
7
+ // Forward declarations for existing functions
8
+ void sort_cuda(torch::Tensor x,
9
+ int64_t end_bit,
10
+ torch::Tensor x_out,
11
+ torch::Tensor iota_out);
12
+
13
+ void bincount_cumsum_cuda(torch::Tensor input,
14
+ torch::Tensor &output,
15
+ int64_t minlength);
16
+
17
+ torch::Tensor index_select_out_cuda(torch::Tensor out,
18
+ torch::Tensor in,
19
+ torch::Tensor idx_int32);
20
+
21
+ void gather_cuda(torch::Tensor const &x,
22
+ torch::Tensor const &indices,
23
+ torch::Tensor const &bins,
24
+ torch::Tensor &output,
25
+ int64_t E,
26
+ int64_t C,
27
+ int64_t top_k);
28
+
29
+ void scatter_cuda(torch::Tensor const &src,
30
+ torch::Tensor const &indices,
31
+ torch::Tensor const &bins,
32
+ torch::Tensor const &weights,
33
+ torch::Tensor &y,
34
+ int64_t T,
35
+ int64_t E,
36
+ int64_t C,
37
+ int64_t top_k);
38
+
39
+ torch::Tensor batch_mm(torch::Tensor x,
40
+ torch::Tensor weights,
41
+ torch::Tensor batch_sizes,
42
+ torch::Tensor output,
43
+ bool trans_b = false);
44
+
45
+ torch::Tensor experts_cuda(
46
+ torch::Tensor hidden_states, // [B*S, H] - flattened hidden states
47
+ torch::Tensor router_indices, // [B*S, K] - expert indices per token
48
+ torch::Tensor routing_weights, // [B*S, E] or [B*S, K] - routing weights
49
+ torch::Tensor gate_up_proj, // [E, H, 2*H] - gate/up projection weights
50
+ torch::Tensor gate_up_proj_bias, // [E, 2*H] - gate/up projection bias
51
+ torch::Tensor down_proj, // [E, H, H] - down projection weights
52
+ torch::Tensor down_proj_bias, // [E, H] - down projection bias
53
+ int64_t expert_capacity, // C - capacity per expert
54
+ int64_t num_experts, // E - number of experts
55
+ int64_t top_k // K - top-k routing
56
+ ) {
57
+ // Input validation
58
+ TORCH_CHECK(hidden_states.is_cuda(), "hidden_states must be on CUDA");
59
+ TORCH_CHECK(router_indices.is_cuda(), "router_indices must be on CUDA");
60
+ TORCH_CHECK(routing_weights.is_cuda(), "routing_weights must be on CUDA");
61
+ TORCH_CHECK(gate_up_proj.is_cuda(), "gate_up_proj must be on CUDA");
62
+ TORCH_CHECK(gate_up_proj_bias.is_cuda(), "gate_up_proj_bias must be on CUDA");
63
+ TORCH_CHECK(down_proj.is_cuda(), "down_proj must be on CUDA");
64
+ TORCH_CHECK(down_proj_bias.is_cuda(), "down_proj_bias must be on CUDA");
65
+
66
+ TORCH_CHECK(hidden_states.ndimension() == 2,
67
+ "hidden_states must be 2D [T, H]");
68
+ TORCH_CHECK(router_indices.ndimension() == 2,
69
+ "router_indices must be 2D [T, K]");
70
+ TORCH_CHECK(routing_weights.ndimension() == 2,
71
+ "routing_weights must be 2D [T, K]");
72
+ TORCH_CHECK(gate_up_proj.ndimension() == 3,
73
+ "gate_up_proj must be 3D [E, H, 2*H]");
74
+ TORCH_CHECK(gate_up_proj_bias.ndimension() == 2,
75
+ "gate_up_proj_bias must be 2D [E, 2*H]");
76
+ TORCH_CHECK(down_proj.ndimension() == 3, "down_proj must be 3D [E, H, H]");
77
+ TORCH_CHECK(down_proj_bias.ndimension() == 2,
78
+ "down_proj_bias must be 2D [E, H]");
79
+
80
+ const int64_t T = hidden_states.size(0); // Total tokens
81
+ const int64_t H = hidden_states.size(1); // Hidden size
82
+ const int64_t E = num_experts;
83
+ const int64_t C = expert_capacity;
84
+ const int64_t K = top_k;
85
+
86
+ TORCH_CHECK(router_indices.size(0) == T && router_indices.size(1) == K);
87
+ TORCH_CHECK(routing_weights.size(0) == T && (routing_weights.size(1) == K ||
88
+ routing_weights.size(1) == E),
89
+ "routing_weights must be [T, K] or [T, E]");
90
+ TORCH_CHECK(gate_up_proj.size(0) == E && gate_up_proj.size(1) == H &&
91
+ gate_up_proj.size(2) == 2 * H);
92
+ TORCH_CHECK(gate_up_proj_bias.size(0) == E &&
93
+ gate_up_proj_bias.size(1) == 2 * H);
94
+ TORCH_CHECK(down_proj.size(0) == E && down_proj.size(1) == H &&
95
+ down_proj.size(2) == H);
96
+ TORCH_CHECK(down_proj_bias.size(0) == E && down_proj_bias.size(1) == H);
97
+
98
+ // Ensure simple contiguity where helpful
99
+ hidden_states = hidden_states.contiguous();
100
+ router_indices = router_indices.contiguous();
101
+ routing_weights = routing_weights.contiguous();
102
+
103
+ // ALLOCATE
104
+
105
+ auto device_opts = torch::TensorOptions()
106
+ .dtype(torch::kInt32)
107
+ .device(hidden_states.device());
108
+ auto int64_opts = torch::TensorOptions()
109
+ .dtype(torch::kInt64)
110
+ .device(hidden_states.device());
111
+ auto float_opts = torch::TensorOptions()
112
+ .dtype(hidden_states.dtype())
113
+ .device(hidden_states.device());
114
+
115
+ // Buffers for sorting
116
+ torch::Tensor flat_indices =
117
+ router_indices.flatten().to(torch::kInt32, /*non_blocking=*/true);
118
+ torch::Tensor sorted_values = torch::empty_like(flat_indices);
119
+ torch::Tensor sorted_indices = torch::empty_like(flat_indices);
120
+
121
+ // Buffer for bins - use int32 for smaller footprint
122
+ torch::Tensor bins =
123
+ torch::empty({E + 1},
124
+ device_opts); // Pre-allocate for bincount_cumsum result
125
+
126
+ // Buffer for gathered tokens
127
+ torch::Tensor x = torch::empty({E, C, H}, float_opts);
128
+
129
+ // Buffer for expert token counts
130
+ torch::Tensor expert_tokens = torch::empty({E}, device_opts);
131
+
132
+ // Buffers for intermediate results
133
+ torch::Tensor gate_up = torch::empty({E, C, 2 * H}, float_opts);
134
+
135
+ // Final output buffer
136
+ torch::Tensor output = torch::zeros_like(hidden_states);
137
+
138
+ // COMPUTE
139
+
140
+ // Sort tokens by expert
141
+ sort_cuda(flat_indices, 32, sorted_values, sorted_indices);
142
+
143
+ // Compute bins using bincount_cumsum
144
+ bincount_cumsum_cuda(sorted_values, bins, E);
145
+
146
+ // Gather tokens by expert
147
+ // [T, H] -> [E, C, H]
148
+ gather_cuda(hidden_states, sorted_indices, bins, x, E, C, K);
149
+
150
+ if (E > 1) {
151
+ expert_tokens.slice(0, 0, E - 1) =
152
+ bins.slice(0, 1, E) - bins.slice(0, 0, E - 1);
153
+ expert_tokens[E - 1] =
154
+ (int32_t)(flat_indices.size(0) - bins[E - 1].item<int32_t>());
155
+ } else {
156
+ expert_tokens[0] = (int32_t)flat_indices.size(0);
157
+ }
158
+ // Clamp to expert capacity
159
+ expert_tokens = torch::clamp(expert_tokens, 0, (int32_t)C);
160
+
161
+ batch_mm(x, gate_up_proj, expert_tokens, gate_up, true);
162
+
163
+ // add the gate bias to the output in-place
164
+ gate_up.add_(gate_up_proj_bias.unsqueeze(1));
165
+
166
+ // Compute GLU in-place, reusing gate_up buffer for output
167
+ auto gate = gate_up.index({torch::indexing::Ellipsis,
168
+ torch::indexing::Slice(torch::indexing::None,
169
+ torch::indexing::None,
170
+ 2)});
171
+ auto up =
172
+ gate_up.index({torch::indexing::Ellipsis,
173
+ torch::indexing::Slice(1, torch::indexing::None, 2)});
174
+
175
+ const float limit = 7.0f;
176
+ gate = gate.clamp(/*min=*/c10::nullopt, /*max=*/limit);
177
+ up = up.clamp(/*min=*/-limit, /*max=*/limit);
178
+
179
+ gate.mul_(torch::sigmoid(gate * 1.702f));
180
+ up.add_(1).mul_(gate);
181
+
182
+ // Down projection uses GLU result directly
183
+ gate_up.resize_(0);
184
+ batch_mm(up, down_proj, expert_tokens, gate_up, true);
185
+
186
+ // add the down_bias in-place
187
+ gate_up.add_(down_proj_bias.unsqueeze(1));
188
+
189
+ // Stage allocations right before use
190
+ torch::Tensor selected_weights = torch::empty({T * K}, float_opts);
191
+ torch::Tensor weights_sorted = torch::empty({T * K}, float_opts);
192
+
193
+ torch::Tensor selected_weights_2d =
194
+ selected_weights.view({T, K}); // named lvalue view
195
+ torch::Tensor flat_dense = routing_weights.view({T, E});
196
+ torch::Tensor flat_router = router_indices.view({T, K});
197
+
198
+ // gather_out(out&, self, dim, index, sparse_grad=false)
199
+ at::gather_out(selected_weights_2d,
200
+ flat_dense,
201
+ /*dim=*/1,
202
+ flat_router,
203
+ /*sparse_grad=*/false);
204
+
205
+ // Use int32 index select to avoid dtype conversion
206
+ index_select_out_cuda(weights_sorted, // [T*K], float_opts
207
+ selected_weights.view({T * K}), // const&, ok as rvalue
208
+ sorted_indices // int32 indices, no conversion needed
209
+ );
210
+
211
+ // Scatter back to original positions with weights applied
212
+ scatter_cuda(gate_up.view({E, C, H}),
213
+ sorted_indices,
214
+ bins,
215
+ weights_sorted,
216
+ output,
217
+ T,
218
+ E,
219
+ C,
220
+ K);
221
+
222
+ return output;
223
+ }
csrc/scatter.cu ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // csrc/scatter.cu
2
+
3
+ #include <cstdint>
4
+ #include <cuda_runtime.h>
5
+ #include <torch/torch.h>
6
+ #include <type_traits>
7
+
8
+ // Minimal atomic add shim:
9
+ // - native CUDA atomics for float/double
10
+ // - 16-bit CAS fallback for Half/BFloat16 (works on all SMs)
11
+
12
+ // CAS-based 16-bit atomic add (for c10::Half / c10::BFloat16)
13
+ template <typename T>
14
+ __device__ inline void atomicAdd16(
15
+ T *addr,
16
+ T val) {
17
+ // Find containing 32-bit word and whether we're the high or low 16 bits
18
+ std::uintptr_t uaddr = reinterpret_cast<std::uintptr_t>(addr);
19
+ unsigned int *base =
20
+ reinterpret_cast<unsigned int *>(uaddr & ~std::uintptr_t(0x3));
21
+ const bool hi_half = (uaddr & 0x2) != 0;
22
+
23
+ unsigned int old32 = *base, assumed;
24
+ do {
25
+ assumed = old32;
26
+
27
+ // Extract current 16-bit payload
28
+ unsigned short cur16 = hi_half ? (assumed >> 16) : (assumed & 0xFFFFu);
29
+
30
+ // Reinterpret those 16 bits as T, then promote to float
31
+ T cur;
32
+ *reinterpret_cast<unsigned short *>(&cur) = cur16;
33
+ float f = static_cast<float>(cur) + static_cast<float>(val);
34
+
35
+ // Convert back to T (rounds appropriately), grab its 16-bit payload
36
+ T res = static_cast<T>(f);
37
+ unsigned short res16 = *reinterpret_cast<unsigned short *>(&res);
38
+
39
+ // Merge back into the correct half and attempt CAS
40
+ unsigned int new32 =
41
+ hi_half ? ((assumed & 0x0000FFFFu) |
42
+ (static_cast<unsigned int>(res16) << 16))
43
+ : ((assumed & 0xFFFF0000u) | static_cast<unsigned int>(res16));
44
+
45
+ old32 = atomicCAS(base, assumed, new32);
46
+ } while (old32 != assumed);
47
+ }
48
+
49
+ // Unified atomicAdd for all scalar_t
50
+ template <typename T>
51
+ __device__ inline void atomicAddT(
52
+ T *addr,
53
+ T val) {
54
+ if constexpr (std::is_same<T, float>::value) {
55
+ atomicAdd(addr, val);
56
+ } else if constexpr (std::is_same<T, double>::value) {
57
+ atomicAdd(addr, val);
58
+ } else {
59
+ // c10::Half or c10::BFloat16
60
+ atomicAdd16(addr, val);
61
+ }
62
+ }
63
+
64
+ // Kernel: y[tok, :] += src[e, i, :] for valid (e,i)
65
+ // where tok = indices[bins[e-1] + i] / top_k
66
+ template <typename scalar_t>
67
+ __global__ void scatter_kernel(
68
+ const scalar_t *__restrict__ src, // [E, C, H]
69
+ const int *__restrict__ idx, // [S]
70
+ const int *__restrict__ bins, // [E] cumulative
71
+ const scalar_t *__restrict__ weights, // [S] routing weights (can be null)
72
+ scalar_t *__restrict__ y, // [T, H] (accumulated)
73
+ int T,
74
+ int H,
75
+ int E,
76
+ int C,
77
+ int top_k) {
78
+ int e = blockIdx.x;
79
+ int i = blockIdx.y;
80
+ if (e >= E || i >= C)
81
+ return;
82
+
83
+ const int end = bins[e];
84
+ const int start = (e == 0) ? 0 : bins[e - 1];
85
+ const int n = end - start;
86
+
87
+ bool valid = (i < n);
88
+ int tok = 0;
89
+ if (valid) {
90
+ int flat = idx[start + i];
91
+ tok = flat / top_k;
92
+ if (tok < 0 || tok >= T)
93
+ valid = false; // guard
94
+ }
95
+ if (!valid)
96
+ return;
97
+
98
+ const scalar_t *src_row = src + ((size_t)e * C + i) * H;
99
+ scalar_t *y_row = y + (size_t)tok * H;
100
+
101
+ // Get the weight/scale factor for this token if provided
102
+ scalar_t scale = (weights != nullptr) ? weights[start + i] : scalar_t(1.0);
103
+
104
+ int t = threadIdx.x;
105
+ for (int h = t; h < H; h += blockDim.x) {
106
+ atomicAddT(&y_row[h], src_row[h] * scale);
107
+ }
108
+ }
109
+
110
+ void scatter_cuda(
111
+ const torch::Tensor &src, // [E, C, H]
112
+ const torch::Tensor &indices, // [S] (int32)
113
+ const torch::Tensor &bins, // [E] cumulative (int32)
114
+ const torch::Tensor &weights, // [S] routing weights (optional)
115
+ torch::Tensor &y, // [T, H] (accumulate into)
116
+ int64_t T, // tokens
117
+ int64_t E, // experts
118
+ int64_t C, // capacity
119
+ int64_t top_k // router top-k
120
+ ) {
121
+ const int64_t H = src.size(2);
122
+
123
+ // Grid over experts x capacity; threads over H
124
+ dim3 grid(E, C);
125
+ int threads = 256;
126
+
127
+ // Include Half + BFloat16 in dispatch
128
+ AT_DISPATCH_FLOATING_TYPES_AND2(
129
+ at::kHalf,
130
+ at::kBFloat16,
131
+ src.scalar_type(),
132
+ "scatter_cuda",
133
+ ([&] {
134
+ using scalar_t_ = scalar_t;
135
+ scatter_kernel<scalar_t_><<<grid, threads>>>(
136
+ src.data_ptr<scalar_t_>(),
137
+ indices.data_ptr<int>(),
138
+ bins.data_ptr<int>(),
139
+ weights.defined() ? weights.data_ptr<scalar_t_>() : nullptr,
140
+ y.data_ptr<scalar_t_>(),
141
+ static_cast<int>(T),
142
+ static_cast<int>(H),
143
+ static_cast<int>(E),
144
+ static_cast<int>(C),
145
+ static_cast<int>(top_k));
146
+ }));
147
+ }
csrc/sort.cu ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // csrc/sort.cu
2
+ // originally from
3
+ // https://github.com/databricks/megablocks/blob/main/csrc/sort.h
4
+
5
+ #include <c10/cuda/CUDAStream.h>
6
+ #include <cstdint>
7
+ #include <cub/cub.cuh>
8
+ #include <torch/torch.h>
9
+
10
+ #define CUDA_CALL(code) \
11
+ do { \
12
+ cudaError_t status = (code); \
13
+ std::string err = cudaGetErrorString(status); \
14
+ TORCH_CHECK(status == cudaSuccess, err); \
15
+ } while (0)
16
+
17
+ template <typename T>
18
+ void cub_radix_sort(
19
+ torch::Tensor x,
20
+ int64_t end_bit,
21
+ torch::Tensor x_out,
22
+ torch::Tensor iota_out) {
23
+ // Get iota for values in sort.
24
+ auto iota_options =
25
+ torch::TensorOptions().dtype(x.scalar_type()).device(x.device());
26
+ torch::Tensor iota = torch::arange(0, x.numel(), iota_options);
27
+
28
+ // Get temporary buffer size.
29
+ size_t scratchpad_bytes = 0;
30
+ CUDA_CALL(cub::DeviceRadixSort::SortPairs(
31
+ /*d_temp_storage*/ nullptr,
32
+ /*temp_storage_bytes*/ scratchpad_bytes,
33
+ /*d_keys_in*/ x.data_ptr<T>(),
34
+ /*d_keys_out*/ x_out.data_ptr<T>(),
35
+ /*d_values_in*/ iota.data_ptr<T>(),
36
+ /*d_values_out*/ iota_out.data_ptr<T>(),
37
+ /*num_items*/ x.numel(),
38
+ /*begin_bit*/ 0,
39
+ /*end_bit*/ end_bit,
40
+ /*stream*/ c10::cuda::getCurrentCUDAStream()));
41
+
42
+ // Allocate scratchpad.
43
+ auto options = torch::TensorOptions().dtype(torch::kInt8).device(x.device());
44
+ torch::Tensor scratchpad =
45
+ torch::empty(static_cast<long>(scratchpad_bytes), options);
46
+
47
+ // Run the kernel.
48
+ CUDA_CALL(cub::DeviceRadixSort::SortPairs(
49
+ /*d_temp_storage*/ scratchpad.data_ptr(),
50
+ /*temp_storage_bytes*/ scratchpad_bytes,
51
+ /*d_keys_in*/ x.data_ptr<T>(),
52
+ /*d_keys_out*/ x_out.data_ptr<T>(),
53
+ /*d_values_in*/ iota.data_ptr<T>(),
54
+ /*d_values_out*/ iota_out.data_ptr<T>(),
55
+ /*num_items*/ x.numel(),
56
+ /*begin_bit*/ 0,
57
+ /*end_bit*/ end_bit,
58
+ /*stream*/ c10::cuda::getCurrentCUDAStream()));
59
+ }
60
+
61
+ void sort_cuda(
62
+ torch::Tensor x,
63
+ int64_t end_bit,
64
+ torch::Tensor x_out,
65
+ torch::Tensor iota_out) {
66
+ TORCH_CHECK(x.is_cuda());
67
+ TORCH_CHECK(x.ndimension() == 1);
68
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
69
+ x.scalar_type() == torch::kInt32 ||
70
+ x.scalar_type() == torch::kInt64);
71
+ TORCH_CHECK(x_out.is_cuda());
72
+ TORCH_CHECK(x_out.ndimension() == 1);
73
+ TORCH_CHECK(x_out.scalar_type() == x.scalar_type());
74
+ TORCH_CHECK(iota_out.is_cuda());
75
+ TORCH_CHECK(iota_out.ndimension() == 1);
76
+ TORCH_CHECK(iota_out.scalar_type() == x.scalar_type());
77
+
78
+ // Exit early if there is no work to do.
79
+ if (x_out.numel() == 0)
80
+ return;
81
+
82
+ switch (x.scalar_type()) {
83
+ case torch::kInt16:
84
+ return cub_radix_sort<short>(x, end_bit, x_out, iota_out);
85
+ case torch::kInt32:
86
+ return cub_radix_sort<int>(x, end_bit, x_out, iota_out);
87
+ default:
88
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
89
+ return cub_radix_sort<long>(x, end_bit, x_out, iota_out);
90
+ }
91
+ }
92
+
93
+ #undef CUDA_CALL
flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1754038838,
77
+ "narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1756320464,
102
+ "narHash": "sha256-x9LI4h87/Z9UgTQjgeG0fRcdeXl91xIqBlTauGKZM70=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "b4accba4496b28faef19a0487fbcf9686b14e2ef",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1752785354,
117
+ "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
118
+ "owner": "nixos",
119
+ "repo": "nixpkgs",
120
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "nixos",
125
+ "repo": "nixpkgs",
126
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }
flake.nix ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for yamoe kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+
17
+ pythonCheckInputs = pkgs: with pkgs; [
18
+ tqdm
19
+ py-cpuinfo
20
+ importlib-metadata
21
+ torchmetrics
22
+ ];
23
+ };
24
+ }
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(
7
+ TORCH_EXTENSION_NAME,
8
+ ops) {
9
+ ops.def("gather("
10
+ "Tensor x, "
11
+ "Tensor indices, "
12
+ "Tensor bins, "
13
+ "Tensor! output, "
14
+ "int E, "
15
+ "int C, "
16
+ "int top_k) -> ()");
17
+ ops.impl("gather", torch::kCUDA, &gather_cuda);
18
+
19
+ ops.def("scatter("
20
+ "Tensor src, "
21
+ "Tensor indices, "
22
+ "Tensor bins, "
23
+ "Tensor weights, "
24
+ "Tensor! y, "
25
+ "int T, "
26
+ "int E, "
27
+ "int C, "
28
+ "int top_k) -> ()");
29
+ ops.impl("scatter", torch::kCUDA, &scatter_cuda);
30
+
31
+ ops.def("sort("
32
+ "Tensor x, "
33
+ "int end_bit, "
34
+ "Tensor! x_out, "
35
+ "Tensor! iota_out) -> ()");
36
+ ops.impl("sort", torch::kCUDA, &sort_cuda);
37
+
38
+ ops.def("bincount_cumsum("
39
+ "Tensor input, "
40
+ "Tensor! output, "
41
+ "int minlength) -> ()");
42
+ ops.impl("bincount_cumsum", torch::kCUDA, &bincount_cumsum_cuda);
43
+
44
+ ops.def("index_select_out("
45
+ "Tensor! out, "
46
+ "Tensor input, "
47
+ "Tensor idx_int32) -> Tensor");
48
+ ops.impl("index_select_out", torch::kCUDA, &index_select_out_cuda);
49
+
50
+ ops.def("batch_mm("
51
+ "Tensor x, "
52
+ "Tensor weights, "
53
+ "Tensor batch_sizes, "
54
+ "Tensor! output, "
55
+ "bool trans_b=False) -> Tensor");
56
+ ops.impl("batch_mm", torch::kCUDA, &batch_mm);
57
+
58
+ ops.def("experts("
59
+ "Tensor hidden_states, "
60
+ "Tensor router_indices, "
61
+ "Tensor routing_weights, "
62
+ "Tensor gate_up_proj, "
63
+ "Tensor gate_up_proj_bias, "
64
+ "Tensor down_proj, "
65
+ "Tensor down_proj_bias, "
66
+ "int expert_capacity, "
67
+ "int num_experts, "
68
+ "int top_k) -> Tensor");
69
+ ops.impl("experts", torch::kCUDA, &experts_cuda);
70
+ }
71
+
72
+ REGISTER_EXTENSION(
73
+ TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void gather_cuda(torch::Tensor const &x,
6
+ torch::Tensor const &indices,
7
+ torch::Tensor const &bins,
8
+ torch::Tensor &output,
9
+ int64_t E,
10
+ int64_t C,
11
+ int64_t top_k);
12
+
13
+ void scatter_cuda(torch::Tensor const &src,
14
+ torch::Tensor const &indices,
15
+ torch::Tensor const &bins,
16
+ torch::Tensor const &weights,
17
+ torch::Tensor &y,
18
+ int64_t T,
19
+ int64_t E,
20
+ int64_t C,
21
+ int64_t top_k);
22
+
23
+ void sort_cuda(torch::Tensor x,
24
+ int64_t end_bit,
25
+ torch::Tensor x_out,
26
+ torch::Tensor iota_out);
27
+
28
+ void bincount_cumsum_cuda(torch::Tensor input,
29
+ torch::Tensor &output,
30
+ int64_t minlength);
31
+
32
+ torch::Tensor index_select_out_cuda(torch::Tensor out,
33
+ torch::Tensor in,
34
+ torch::Tensor idx_int32);
35
+
36
+ torch::Tensor
37
+ batch_mm(torch::Tensor x, // [E, C, H] - expert tokens
38
+ torch::Tensor weights, // [E, H, H_out] - expert weight matrices
39
+ torch::Tensor batch_sizes, // [E] - actual tokens per expert (<=C)
40
+ torch::Tensor output, // [E, C, H_out] - output buffer
41
+ bool trans_b = false // transpose weights if needed
42
+ );
43
+
44
+ torch::Tensor experts_cuda(
45
+ torch::Tensor hidden_states, // [T, H] - flattened hidden states
46
+ torch::Tensor router_indices, // [T, K] - expert indices per token
47
+ torch::Tensor routing_weights, // [T, E] or [T, K] - routing weights
48
+ torch::Tensor gate_up_proj, // [E, H, 2*H] - gate/up projection weights
49
+ torch::Tensor gate_up_proj_bias, // [E, 2*H] - gate/up projection bias
50
+ torch::Tensor down_proj, // [E, H, H] - down projection weights
51
+ torch::Tensor down_proj_bias, // [E, H] - down projection bias
52
+ int64_t expert_capacity, // C - capacity per expert
53
+ int64_t num_experts, // E - number of experts
54
+ int64_t top_k // K - top-k routing
55
+ );
torch-ext/yamoe/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._ops import ops
2
+ from . import reference
3
+
4
+ gather = ops.gather
5
+ scatter = ops.scatter
6
+ sort = ops.sort
7
+ bincount_cumsum = ops.bincount_cumsum
8
+ batch_mm = ops.batch_mm
9
+ experts = ops.experts
10
+
11
+ __all__ = [
12
+ "shuffle",
13
+ "gather",
14
+ "scatter",
15
+ "sort",
16
+ "bincount_cumsum",
17
+ "batch_mm",
18
+ "experts",
19
+ # Export the reference implementation
20
+ "reference",
21
+ ]
torch-ext/yamoe/reference.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class GptOssExperts(nn.Module):
5
+ def __init__(self, config):
6
+ super().__init__()
7
+ self.intermediate_size = config.intermediate_size
8
+ self.num_experts = config.num_local_experts
9
+ self.hidden_size = config.hidden_size
10
+ self.expert_dim = self.intermediate_size
11
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
12
+ self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
13
+ self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
14
+ self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
15
+ self.alpha = 1.702
16
+ self.limit = 7.0
17
+
18
+ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
19
+ """
20
+ When training is is more efficient to just loop over the experts and compute the output for each expert
21
+ as otherwise the memory would explode.
22
+
23
+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
24
+
25
+ Args:
26
+ hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
27
+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
28
+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
29
+ Returns:
30
+ torch.Tensor
31
+ """
32
+
33
+ # import ipdb; ipdb.set_trace()
34
+
35
+ batch_size = hidden_states.shape[0]
36
+ hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
37
+ num_experts = routing_weights.shape[1]
38
+ if self.training:
39
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
40
+ with torch.no_grad():
41
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
42
+ expert_mask = expert_mask.permute(2, 1, 0)
43
+ # we sum on the top_k and on the sequence lenght to get which experts
44
+ # are hit this time around
45
+ expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
46
+ for expert_idx in expert_hitted[:]:
47
+ with torch.no_grad():
48
+ _, token_idx = torch.where(expert_mask[expert_idx[0]])
49
+ current_state = hidden_states[token_idx]
50
+ gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
51
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
52
+ gate = gate.clamp(min=None, max=self.limit)
53
+ up = up.clamp(min=-self.limit, max=self.limit)
54
+ glu = gate * torch.sigmoid(gate * self.alpha)
55
+ gated_output = (up + 1) * glu
56
+ out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
57
+ weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
58
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
59
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
60
+ else:
61
+ hidden_states = hidden_states.repeat(num_experts, 1)
62
+ hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
63
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
64
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
65
+ gate = gate.clamp(min=None, max=self.limit)
66
+ up = up.clamp(min=-self.limit, max=self.limit)
67
+ glu = gate * torch.sigmoid(gate * self.alpha)
68
+ next_states = torch.bmm(((up + 1) * glu), self.down_proj)
69
+ next_states = next_states + self.down_proj_bias[..., None, :]
70
+ next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
71
+ next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
72
+ next_states = next_states.sum(dim=0)
73
+ return next_states