drbh
commited on
Commit
·
281d8ba
0
Parent(s):
feat: yet another moe
Browse files- .clang-format +288 -0
- .gitattributes +2 -0
- .gitignore +15 -0
- .pre-commit-config.yaml +7 -0
- README.md +117 -0
- build.toml +36 -0
- csrc/batch_mm.cu +31 -0
- csrc/bincount_cumsum.cu +98 -0
- csrc/gather.cu +109 -0
- csrc/index_select.cu +51 -0
- csrc/moe.cpp +223 -0
- csrc/scatter.cu +147 -0
- csrc/sort.cu +93 -0
- flake.lock +168 -0
- flake.nix +24 -0
- torch-ext/torch_binding.cpp +73 -0
- torch-ext/torch_binding.h +55 -0
- torch-ext/yamoe/__init__.py +21 -0
- torch-ext/yamoe/reference.py +73 -0
.clang-format
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
Language: Cpp
|
| 3 |
+
AccessModifierOffset: -2
|
| 4 |
+
AlignAfterOpenBracket: Align
|
| 5 |
+
AlignArrayOfStructures: None
|
| 6 |
+
AlignConsecutiveAssignments:
|
| 7 |
+
Enabled: false
|
| 8 |
+
AcrossEmptyLines: false
|
| 9 |
+
AcrossComments: false
|
| 10 |
+
AlignCompound: false
|
| 11 |
+
AlignFunctionDeclarations: false
|
| 12 |
+
AlignFunctionPointers: false
|
| 13 |
+
PadOperators: true
|
| 14 |
+
AlignConsecutiveBitFields:
|
| 15 |
+
Enabled: false
|
| 16 |
+
AcrossEmptyLines: false
|
| 17 |
+
AcrossComments: false
|
| 18 |
+
AlignCompound: false
|
| 19 |
+
AlignFunctionDeclarations: false
|
| 20 |
+
AlignFunctionPointers: false
|
| 21 |
+
PadOperators: false
|
| 22 |
+
AlignConsecutiveDeclarations:
|
| 23 |
+
Enabled: false
|
| 24 |
+
AcrossEmptyLines: false
|
| 25 |
+
AcrossComments: false
|
| 26 |
+
AlignCompound: false
|
| 27 |
+
AlignFunctionDeclarations: true
|
| 28 |
+
AlignFunctionPointers: false
|
| 29 |
+
PadOperators: false
|
| 30 |
+
AlignConsecutiveMacros:
|
| 31 |
+
Enabled: false
|
| 32 |
+
AcrossEmptyLines: false
|
| 33 |
+
AcrossComments: false
|
| 34 |
+
AlignCompound: false
|
| 35 |
+
AlignFunctionDeclarations: false
|
| 36 |
+
AlignFunctionPointers: false
|
| 37 |
+
PadOperators: false
|
| 38 |
+
AlignConsecutiveShortCaseStatements:
|
| 39 |
+
Enabled: false
|
| 40 |
+
AcrossEmptyLines: false
|
| 41 |
+
AcrossComments: false
|
| 42 |
+
AlignCaseArrows: false
|
| 43 |
+
AlignCaseColons: false
|
| 44 |
+
AlignConsecutiveTableGenBreakingDAGArgColons:
|
| 45 |
+
Enabled: false
|
| 46 |
+
AcrossEmptyLines: false
|
| 47 |
+
AcrossComments: false
|
| 48 |
+
AlignCompound: false
|
| 49 |
+
AlignFunctionDeclarations: false
|
| 50 |
+
AlignFunctionPointers: false
|
| 51 |
+
PadOperators: false
|
| 52 |
+
AlignConsecutiveTableGenCondOperatorColons:
|
| 53 |
+
Enabled: false
|
| 54 |
+
AcrossEmptyLines: false
|
| 55 |
+
AcrossComments: false
|
| 56 |
+
AlignCompound: false
|
| 57 |
+
AlignFunctionDeclarations: false
|
| 58 |
+
AlignFunctionPointers: false
|
| 59 |
+
PadOperators: false
|
| 60 |
+
AlignConsecutiveTableGenDefinitionColons:
|
| 61 |
+
Enabled: false
|
| 62 |
+
AcrossEmptyLines: false
|
| 63 |
+
AcrossComments: false
|
| 64 |
+
AlignCompound: false
|
| 65 |
+
AlignFunctionDeclarations: false
|
| 66 |
+
AlignFunctionPointers: false
|
| 67 |
+
PadOperators: false
|
| 68 |
+
AlignEscapedNewlines: Right
|
| 69 |
+
AlignOperands: Align
|
| 70 |
+
AlignTrailingComments:
|
| 71 |
+
Kind: Always
|
| 72 |
+
OverEmptyLines: 0
|
| 73 |
+
AllowAllArgumentsOnNextLine: false
|
| 74 |
+
AllowAllParametersOfDeclarationOnNextLine: false
|
| 75 |
+
AllowBreakBeforeNoexceptSpecifier: Never
|
| 76 |
+
AllowShortBlocksOnASingleLine: Never
|
| 77 |
+
AllowShortCaseExpressionOnASingleLine: true
|
| 78 |
+
AllowShortCaseLabelsOnASingleLine: false
|
| 79 |
+
AllowShortCompoundRequirementOnASingleLine: true
|
| 80 |
+
AllowShortEnumsOnASingleLine: true
|
| 81 |
+
AllowShortFunctionsOnASingleLine: All
|
| 82 |
+
AllowShortIfStatementsOnASingleLine: Never
|
| 83 |
+
AllowShortLambdasOnASingleLine: All
|
| 84 |
+
AllowShortLoopsOnASingleLine: false
|
| 85 |
+
AllowShortNamespacesOnASingleLine: false
|
| 86 |
+
AlwaysBreakAfterDefinitionReturnType: None
|
| 87 |
+
AlwaysBreakBeforeMultilineStrings: false
|
| 88 |
+
AttributeMacros:
|
| 89 |
+
- __capability
|
| 90 |
+
BinPackArguments: false
|
| 91 |
+
BinPackParameters: false
|
| 92 |
+
BitFieldColonSpacing: Both
|
| 93 |
+
BraceWrapping:
|
| 94 |
+
AfterCaseLabel: false
|
| 95 |
+
AfterClass: false
|
| 96 |
+
AfterControlStatement: Never
|
| 97 |
+
AfterEnum: false
|
| 98 |
+
AfterExternBlock: false
|
| 99 |
+
AfterFunction: false
|
| 100 |
+
AfterNamespace: false
|
| 101 |
+
AfterObjCDeclaration: false
|
| 102 |
+
AfterStruct: false
|
| 103 |
+
AfterUnion: false
|
| 104 |
+
BeforeCatch: false
|
| 105 |
+
BeforeElse: false
|
| 106 |
+
BeforeLambdaBody: false
|
| 107 |
+
BeforeWhile: false
|
| 108 |
+
IndentBraces: false
|
| 109 |
+
SplitEmptyFunction: true
|
| 110 |
+
SplitEmptyRecord: true
|
| 111 |
+
SplitEmptyNamespace: true
|
| 112 |
+
BreakAdjacentStringLiterals: true
|
| 113 |
+
BreakAfterAttributes: Leave
|
| 114 |
+
BreakAfterJavaFieldAnnotations: false
|
| 115 |
+
BreakAfterReturnType: None
|
| 116 |
+
BreakArrays: true
|
| 117 |
+
BreakBeforeBinaryOperators: None
|
| 118 |
+
BreakBeforeConceptDeclarations: Always
|
| 119 |
+
BreakBeforeBraces: Attach
|
| 120 |
+
BreakBeforeInlineASMColon: OnlyMultiline
|
| 121 |
+
BreakBeforeTernaryOperators: true
|
| 122 |
+
BreakBinaryOperations: Never
|
| 123 |
+
BreakConstructorInitializers: AfterColon
|
| 124 |
+
BreakFunctionDefinitionParameters: true
|
| 125 |
+
BreakInheritanceList: BeforeColon
|
| 126 |
+
BreakStringLiterals: true
|
| 127 |
+
BreakTemplateDeclarations: MultiLine
|
| 128 |
+
ColumnLimit: 80
|
| 129 |
+
CommentPragmas: '^ IWYU pragma:'
|
| 130 |
+
CompactNamespaces: false
|
| 131 |
+
ConstructorInitializerIndentWidth: 4
|
| 132 |
+
ContinuationIndentWidth: 4
|
| 133 |
+
Cpp11BracedListStyle: true
|
| 134 |
+
DerivePointerAlignment: false
|
| 135 |
+
DisableFormat: false
|
| 136 |
+
EmptyLineAfterAccessModifier: Never
|
| 137 |
+
EmptyLineBeforeAccessModifier: LogicalBlock
|
| 138 |
+
ExperimentalAutoDetectBinPacking: false
|
| 139 |
+
FixNamespaceComments: true
|
| 140 |
+
ForEachMacros:
|
| 141 |
+
- foreach
|
| 142 |
+
- Q_FOREACH
|
| 143 |
+
- BOOST_FOREACH
|
| 144 |
+
IfMacros:
|
| 145 |
+
- KJ_IF_MAYBE
|
| 146 |
+
IncludeBlocks: Preserve
|
| 147 |
+
IncludeCategories:
|
| 148 |
+
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
|
| 149 |
+
Priority: 2
|
| 150 |
+
SortPriority: 0
|
| 151 |
+
CaseSensitive: false
|
| 152 |
+
- Regex: '^(<|"(gtest|gmock|isl|json)/)'
|
| 153 |
+
Priority: 3
|
| 154 |
+
SortPriority: 0
|
| 155 |
+
CaseSensitive: false
|
| 156 |
+
- Regex: '.*'
|
| 157 |
+
Priority: 1
|
| 158 |
+
SortPriority: 0
|
| 159 |
+
CaseSensitive: false
|
| 160 |
+
IncludeIsMainRegex: '(Test)?$'
|
| 161 |
+
IncludeIsMainSourceRegex: ''
|
| 162 |
+
IndentAccessModifiers: false
|
| 163 |
+
IndentCaseBlocks: false
|
| 164 |
+
IndentCaseLabels: false
|
| 165 |
+
IndentExportBlock: true
|
| 166 |
+
IndentExternBlock: AfterExternBlock
|
| 167 |
+
IndentGotoLabels: true
|
| 168 |
+
IndentPPDirectives: None
|
| 169 |
+
IndentRequiresClause: true
|
| 170 |
+
IndentWidth: 2
|
| 171 |
+
IndentWrappedFunctionNames: false
|
| 172 |
+
InsertBraces: false
|
| 173 |
+
InsertNewlineAtEOF: false
|
| 174 |
+
InsertTrailingCommas: None
|
| 175 |
+
IntegerLiteralSeparator:
|
| 176 |
+
Binary: 0
|
| 177 |
+
BinaryMinDigits: 0
|
| 178 |
+
Decimal: 0
|
| 179 |
+
DecimalMinDigits: 0
|
| 180 |
+
Hex: 0
|
| 181 |
+
HexMinDigits: 0
|
| 182 |
+
JavaScriptQuotes: Leave
|
| 183 |
+
JavaScriptWrapImports: true
|
| 184 |
+
KeepEmptyLines:
|
| 185 |
+
AtEndOfFile: false
|
| 186 |
+
AtStartOfBlock: true
|
| 187 |
+
AtStartOfFile: true
|
| 188 |
+
KeepFormFeed: false
|
| 189 |
+
LambdaBodyIndentation: Signature
|
| 190 |
+
LineEnding: DeriveLF
|
| 191 |
+
MacroBlockBegin: ''
|
| 192 |
+
MacroBlockEnd: ''
|
| 193 |
+
MainIncludeChar: Quote
|
| 194 |
+
MaxEmptyLinesToKeep: 1
|
| 195 |
+
NamespaceIndentation: None
|
| 196 |
+
ObjCBinPackProtocolList: Auto
|
| 197 |
+
ObjCBlockIndentWidth: 2
|
| 198 |
+
ObjCBreakBeforeNestedBlockParam: true
|
| 199 |
+
ObjCSpaceAfterProperty: false
|
| 200 |
+
ObjCSpaceBeforeProtocolList: true
|
| 201 |
+
PackConstructorInitializers: BinPack
|
| 202 |
+
PenaltyBreakAssignment: 2
|
| 203 |
+
PenaltyBreakBeforeFirstCallParameter: 0
|
| 204 |
+
PenaltyBreakBeforeMemberAccess: 150
|
| 205 |
+
PenaltyBreakComment: 300
|
| 206 |
+
PenaltyBreakFirstLessLess: 120
|
| 207 |
+
PenaltyBreakOpenParenthesis: 0
|
| 208 |
+
PenaltyBreakScopeResolution: 500
|
| 209 |
+
PenaltyBreakString: 1000
|
| 210 |
+
PenaltyBreakTemplateDeclaration: 10
|
| 211 |
+
PenaltyExcessCharacter: 1000000
|
| 212 |
+
PenaltyIndentedWhitespace: 0
|
| 213 |
+
PenaltyReturnTypeOnItsOwnLine: 60
|
| 214 |
+
PointerAlignment: Right
|
| 215 |
+
PPIndentWidth: -1
|
| 216 |
+
QualifierAlignment: Leave
|
| 217 |
+
ReferenceAlignment: Pointer
|
| 218 |
+
ReflowComments: Always
|
| 219 |
+
RemoveBracesLLVM: false
|
| 220 |
+
RemoveEmptyLinesInUnwrappedLines: false
|
| 221 |
+
RemoveParentheses: Leave
|
| 222 |
+
RemoveSemicolon: false
|
| 223 |
+
RequiresClausePosition: OwnLine
|
| 224 |
+
RequiresExpressionIndentation: OuterScope
|
| 225 |
+
SeparateDefinitionBlocks: Leave
|
| 226 |
+
ShortNamespaceLines: 1
|
| 227 |
+
SkipMacroDefinitionBody: false
|
| 228 |
+
SortIncludes: CaseSensitive
|
| 229 |
+
SortJavaStaticImport: Before
|
| 230 |
+
SortUsingDeclarations: LexicographicNumeric
|
| 231 |
+
SpaceAfterCStyleCast: false
|
| 232 |
+
SpaceAfterLogicalNot: false
|
| 233 |
+
SpaceAfterTemplateKeyword: true
|
| 234 |
+
SpaceAroundPointerQualifiers: Default
|
| 235 |
+
SpaceBeforeAssignmentOperators: true
|
| 236 |
+
SpaceBeforeCaseColon: false
|
| 237 |
+
SpaceBeforeCpp11BracedList: false
|
| 238 |
+
SpaceBeforeCtorInitializerColon: true
|
| 239 |
+
SpaceBeforeInheritanceColon: true
|
| 240 |
+
SpaceBeforeJsonColon: false
|
| 241 |
+
SpaceBeforeParens: ControlStatements
|
| 242 |
+
SpaceBeforeParensOptions:
|
| 243 |
+
AfterControlStatements: true
|
| 244 |
+
AfterForeachMacros: true
|
| 245 |
+
AfterFunctionDefinitionName: false
|
| 246 |
+
AfterFunctionDeclarationName: false
|
| 247 |
+
AfterIfMacros: true
|
| 248 |
+
AfterOverloadedOperator: false
|
| 249 |
+
AfterPlacementOperator: true
|
| 250 |
+
AfterRequiresInClause: false
|
| 251 |
+
AfterRequiresInExpression: false
|
| 252 |
+
BeforeNonEmptyParentheses: false
|
| 253 |
+
SpaceBeforeRangeBasedForLoopColon: true
|
| 254 |
+
SpaceBeforeSquareBrackets: false
|
| 255 |
+
SpaceInEmptyBlock: false
|
| 256 |
+
SpacesBeforeTrailingComments: 1
|
| 257 |
+
SpacesInAngles: Never
|
| 258 |
+
SpacesInContainerLiterals: true
|
| 259 |
+
SpacesInLineCommentPrefix:
|
| 260 |
+
Minimum: 1
|
| 261 |
+
Maximum: -1
|
| 262 |
+
SpacesInParens: Never
|
| 263 |
+
SpacesInParensOptions:
|
| 264 |
+
ExceptDoubleParentheses: false
|
| 265 |
+
InCStyleCasts: false
|
| 266 |
+
InConditionalStatements: false
|
| 267 |
+
InEmptyParentheses: false
|
| 268 |
+
Other: false
|
| 269 |
+
SpacesInSquareBrackets: false
|
| 270 |
+
Standard: Latest
|
| 271 |
+
StatementAttributeLikeMacros:
|
| 272 |
+
- Q_EMIT
|
| 273 |
+
StatementMacros:
|
| 274 |
+
- Q_UNUSED
|
| 275 |
+
- QT_REQUIRE_VERSION
|
| 276 |
+
TableGenBreakInsideDAGArg: DontBreak
|
| 277 |
+
TabWidth: 8
|
| 278 |
+
UseTab: Never
|
| 279 |
+
VerilogBreakBetweenInstancePorts: true
|
| 280 |
+
WhitespaceSensitiveMacros:
|
| 281 |
+
- BOOST_PP_STRINGIZE
|
| 282 |
+
- CF_SWIFT_NAME
|
| 283 |
+
- NS_SWIFT_NAME
|
| 284 |
+
- PP_STRINGIZE
|
| 285 |
+
- STRINGIZE
|
| 286 |
+
WrapNamespaceBodyWithEmptyLines: Leave
|
| 287 |
+
...
|
| 288 |
+
|
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.bak
|
| 2 |
+
.ruff_cache
|
| 3 |
+
.venv
|
| 4 |
+
cmake
|
| 5 |
+
result
|
| 6 |
+
scripts
|
| 7 |
+
__pycache__
|
| 8 |
+
CMakeLists.txt
|
| 9 |
+
setup.py
|
| 10 |
+
pyproject.toml
|
| 11 |
+
tests
|
| 12 |
+
torch-ext/registration.h
|
| 13 |
+
torch-ext/yamoe/_ops.py
|
| 14 |
+
csrc/batch_mm.cu
|
| 15 |
+
torch-ext/yamoe/*.abi3.so
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/mirrors-clang-format
|
| 3 |
+
rev: v20.1.8
|
| 4 |
+
hooks:
|
| 5 |
+
- id: clang-format
|
| 6 |
+
files: ^(csrc/|torch-ext/).*\.(?:c|cc|cpp|cxx|h|hh|hpp|hxx|cu|cuh)$
|
| 7 |
+
args: [-i]
|
README.md
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- kernel
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
oooo ooo .oooo. ooo. .oo. .oo. .ooooo. .ooooo.
|
| 10 |
+
`88. .8' `P )88b `888P"Y88bP"Y88b d88' `88b d88' `88b
|
| 11 |
+
`88..8' .oP"888 888 888 888 888 888 888ooo888
|
| 12 |
+
`888' d8( 888 888 888 888 888 888 888 .o
|
| 13 |
+
.8' `Y888""8o o888o o888o o888o `Y8bod8P' `Y8bod8P'
|
| 14 |
+
.o..P'
|
| 15 |
+
`Y8P'
|
| 16 |
+
|
| 17 |
+
Yet Another Mixture of Experts
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
`yamoe` is a no nonsense, straightforward implementation of Mixture of Experts (MoE) kernels, designed to be super easy to use and be very computationally efficient.
|
| 21 |
+
|
| 22 |
+
### Design goals
|
| 23 |
+
- simplicity: easy to read and understand the code
|
| 24 |
+
- efficiency: optimized for high throughput and low latency
|
| 25 |
+
- low memory usage: optimized to handle large batch sizes
|
| 26 |
+
- reproducibility: easy to reproduce results, no special new `sm` requirements
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
### How to use
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
# /// script
|
| 33 |
+
# requires-python = "==3.10"
|
| 34 |
+
# dependencies = ["torch==2.7.0", "triton", "numpy", "kernels"]
|
| 35 |
+
# [tool.uv.sources]
|
| 36 |
+
# kernels = { git = "https://github.com/huggingface/kernels.git" }
|
| 37 |
+
# ///
|
| 38 |
+
|
| 39 |
+
import time
|
| 40 |
+
import torch
|
| 41 |
+
from kernels import get_kernel
|
| 42 |
+
from pathlib import Path
|
| 43 |
+
from torch.nn import functional as F
|
| 44 |
+
|
| 45 |
+
yamoe = get_kernel("drbh/yamoe")
|
| 46 |
+
|
| 47 |
+
# Configuration
|
| 48 |
+
torch.manual_seed(0)
|
| 49 |
+
batch_size, seq_len, hidden_dim = 128, 2048, 2880
|
| 50 |
+
num_experts, top_k = 32, 4
|
| 51 |
+
|
| 52 |
+
# Create routing weights
|
| 53 |
+
logits = torch.randn(batch_size, seq_len, num_experts)
|
| 54 |
+
probs = F.softmax(logits, dim=-1)
|
| 55 |
+
weights, indices = torch.topk(probs, top_k, dim=-1)
|
| 56 |
+
|
| 57 |
+
batch_seq = batch_size * seq_len
|
| 58 |
+
routing_weights = torch.zeros(batch_seq, num_experts, dtype=weights.dtype)
|
| 59 |
+
flat_indices, flat_weights = indices.reshape(-1, top_k), weights.reshape(-1, top_k)
|
| 60 |
+
batch_indices = torch.arange(batch_seq).unsqueeze(1).expand(-1, top_k)
|
| 61 |
+
routing_weights[batch_indices, flat_indices] = flat_weights
|
| 62 |
+
|
| 63 |
+
# Create model tensors (scaled to prevent overflow)
|
| 64 |
+
hidden_states = torch.randn(batch_size, seq_len, hidden_dim).cuda().half() * 0.1
|
| 65 |
+
gate_up_proj = torch.randn(num_experts, hidden_dim, 2 * hidden_dim).cuda().half() * 0.02
|
| 66 |
+
gate_up_proj_bias = torch.zeros(num_experts, 2 * hidden_dim).cuda().half()
|
| 67 |
+
down_proj = torch.randn(num_experts, hidden_dim, hidden_dim).cuda().half() * 0.02
|
| 68 |
+
down_proj_bias = torch.zeros(num_experts, hidden_dim).cuda().half()
|
| 69 |
+
routing_weights = routing_weights.cuda().half()
|
| 70 |
+
router_indices = flat_indices.cuda()
|
| 71 |
+
|
| 72 |
+
# Warmup
|
| 73 |
+
for _ in range(5):
|
| 74 |
+
_ = yamoe.experts(
|
| 75 |
+
hidden_states.view(-1, hidden_dim),
|
| 76 |
+
router_indices,
|
| 77 |
+
routing_weights.view(-1, num_experts),
|
| 78 |
+
gate_up_proj,
|
| 79 |
+
gate_up_proj_bias,
|
| 80 |
+
down_proj,
|
| 81 |
+
down_proj_bias,
|
| 82 |
+
seq_len,
|
| 83 |
+
num_experts,
|
| 84 |
+
top_k,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Benchmark
|
| 88 |
+
torch.cuda.synchronize()
|
| 89 |
+
torch.cuda.reset_peak_memory_stats()
|
| 90 |
+
start = time.perf_counter()
|
| 91 |
+
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
output = yamoe.experts(
|
| 94 |
+
hidden_states.view(-1, hidden_dim),
|
| 95 |
+
router_indices,
|
| 96 |
+
routing_weights.view(-1, num_experts),
|
| 97 |
+
gate_up_proj,
|
| 98 |
+
gate_up_proj_bias,
|
| 99 |
+
down_proj,
|
| 100 |
+
down_proj_bias,
|
| 101 |
+
seq_len,
|
| 102 |
+
num_experts,
|
| 103 |
+
top_k,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
torch.cuda.synchronize()
|
| 107 |
+
elapsed_ms = (time.perf_counter() - start) * 1e3
|
| 108 |
+
peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
|
| 109 |
+
|
| 110 |
+
print(f"Output sum: {output.sum().item():.4f}")
|
| 111 |
+
print(f"Kernel time: {elapsed_ms:.3f} ms")
|
| 112 |
+
print(f"Peak GPU memory: {peak_mem_mb:.2f} MB")
|
| 113 |
+
# Output sum: 124.2500
|
| 114 |
+
# Kernel time: 85.722 ms
|
| 115 |
+
# Peak GPU memory: 8403.40 MB
|
| 116 |
+
|
| 117 |
+
```
|
build.toml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "yamoe"
|
| 3 |
+
universal = false
|
| 4 |
+
|
| 5 |
+
[torch]
|
| 6 |
+
src = [
|
| 7 |
+
"torch-ext/torch_binding.cpp",
|
| 8 |
+
"torch-ext/torch_binding.h"
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
[kernel.yamoe]
|
| 12 |
+
backend = "cuda"
|
| 13 |
+
cuda-capabilities = [
|
| 14 |
+
"7.0",
|
| 15 |
+
"7.2",
|
| 16 |
+
"7.5",
|
| 17 |
+
"8.0",
|
| 18 |
+
"8.6",
|
| 19 |
+
"8.7",
|
| 20 |
+
"8.9",
|
| 21 |
+
"9.0",
|
| 22 |
+
"10.0",
|
| 23 |
+
"10.1",
|
| 24 |
+
"11.8",
|
| 25 |
+
"12.0"
|
| 26 |
+
]
|
| 27 |
+
depends = ["torch", "cutlass_3_8"]
|
| 28 |
+
src = [
|
| 29 |
+
"csrc/index_select.cu",
|
| 30 |
+
"csrc/gather.cu",
|
| 31 |
+
"csrc/scatter.cu",
|
| 32 |
+
"csrc/sort.cu",
|
| 33 |
+
"csrc/bincount_cumsum.cu",
|
| 34 |
+
"csrc/batch_mm.cu",
|
| 35 |
+
"csrc/moe.cpp"
|
| 36 |
+
]
|
csrc/batch_mm.cu
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// csrc/batch_mm.cu
|
| 2 |
+
|
| 3 |
+
#include <torch/torch.h>
|
| 4 |
+
|
| 5 |
+
// Simply use a standard bmm for now but this can be adapted for
|
| 6 |
+
// faster batched expert matrix multiply if needed
|
| 7 |
+
torch::Tensor batch_mm(
|
| 8 |
+
torch::Tensor x,
|
| 9 |
+
torch::Tensor weights,
|
| 10 |
+
torch::Tensor batch_sizes,
|
| 11 |
+
torch::Tensor output,
|
| 12 |
+
bool trans_b) {
|
| 13 |
+
// Validate inputs
|
| 14 |
+
TORCH_CHECK(x.is_cuda(), "x must be on CUDA");
|
| 15 |
+
TORCH_CHECK(weights.is_cuda(), "weights must be on CUDA");
|
| 16 |
+
TORCH_CHECK(batch_sizes.is_cuda(), "batch_sizes must be on CUDA");
|
| 17 |
+
|
| 18 |
+
TORCH_CHECK(x.ndimension() == 3, "x must be 3D tensor"); // [E, C, H]
|
| 19 |
+
TORCH_CHECK(weights.ndimension() == 3,
|
| 20 |
+
"weights must be 3D tensor"); // [E, H, H_out]
|
| 21 |
+
TORCH_CHECK(batch_sizes.ndimension() == 1,
|
| 22 |
+
"batch_sizes must be 1D tensor"); // [E]
|
| 23 |
+
|
| 24 |
+
TORCH_CHECK(x.size(0) == weights.size(0) && x.size(0) == batch_sizes.size(0));
|
| 25 |
+
TORCH_CHECK(x.size(2) == weights.size(1)); // H dimension match
|
| 26 |
+
|
| 27 |
+
// For now, just fall back to bmm to test the binding
|
| 28 |
+
// torch::bmm(x, weights, output);
|
| 29 |
+
torch::bmm_out(output, x, weights);
|
| 30 |
+
return output;
|
| 31 |
+
}
|
csrc/bincount_cumsum.cu
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// csrc/bincount_cumsum.cu
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
#include <cuda_runtime.h>
|
| 5 |
+
#include <torch/torch.h>
|
| 6 |
+
|
| 7 |
+
template <typename scalar_t>
|
| 8 |
+
__global__ void bincount_cumsum_kernel(
|
| 9 |
+
const scalar_t *__restrict__ input,
|
| 10 |
+
int32_t *__restrict__ bins_out,
|
| 11 |
+
const int n_input,
|
| 12 |
+
const int n_bins) {
|
| 13 |
+
// Shared memory for local bincount
|
| 14 |
+
extern __shared__ int shared_counts[];
|
| 15 |
+
|
| 16 |
+
int tid = threadIdx.x;
|
| 17 |
+
int bid = blockIdx.x;
|
| 18 |
+
int threads_per_block = blockDim.x;
|
| 19 |
+
|
| 20 |
+
// Initialize shared memory
|
| 21 |
+
for (int i = tid; i < n_bins; i += threads_per_block) {
|
| 22 |
+
shared_counts[i] = 0;
|
| 23 |
+
}
|
| 24 |
+
__syncthreads();
|
| 25 |
+
|
| 26 |
+
// Each block processes a chunk of input
|
| 27 |
+
int start = bid * threads_per_block;
|
| 28 |
+
int end = min(start + threads_per_block, n_input);
|
| 29 |
+
|
| 30 |
+
// Bincount phase - each thread processes its elements
|
| 31 |
+
for (int i = start + tid; i < end; i += threads_per_block) {
|
| 32 |
+
if (i < n_input) {
|
| 33 |
+
int bin = static_cast<int>(input[i]);
|
| 34 |
+
if (bin >= 0 && bin < n_bins) {
|
| 35 |
+
atomicAdd(&shared_counts[bin], 1);
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
__syncthreads();
|
| 40 |
+
|
| 41 |
+
// Write block results to global memory
|
| 42 |
+
for (int i = tid; i < n_bins; i += threads_per_block) {
|
| 43 |
+
atomicAdd(&bins_out[i], shared_counts[i]);
|
| 44 |
+
}
|
| 45 |
+
__syncthreads();
|
| 46 |
+
|
| 47 |
+
// Only first block does the cumsum
|
| 48 |
+
if (bid == 0) {
|
| 49 |
+
// Simple cumsum on first block
|
| 50 |
+
if (tid == 0) {
|
| 51 |
+
for (int i = 1; i < n_bins; i++) {
|
| 52 |
+
bins_out[i] += bins_out[i - 1];
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
void bincount_cumsum_cuda(
|
| 59 |
+
torch::Tensor input,
|
| 60 |
+
torch::Tensor &bins_out,
|
| 61 |
+
int64_t minlength) {
|
| 62 |
+
TORCH_CHECK(input.is_cuda(), "Input must be CUDA tensor");
|
| 63 |
+
TORCH_CHECK(input.dtype() == torch::kInt32, "Input must be int32");
|
| 64 |
+
TORCH_CHECK(bins_out.is_cuda(), "Output must be CUDA tensor");
|
| 65 |
+
|
| 66 |
+
const auto n_input = input.numel();
|
| 67 |
+
const auto n_bins = static_cast<int>(minlength);
|
| 68 |
+
|
| 69 |
+
// Validate output tensor dimensions and clear it
|
| 70 |
+
TORCH_CHECK(bins_out.numel() >= n_bins,
|
| 71 |
+
"Output tensor must have at least minlength elements");
|
| 72 |
+
bins_out.zero_();
|
| 73 |
+
|
| 74 |
+
const int threads_per_block = 256;
|
| 75 |
+
const int n_blocks = (n_input + threads_per_block - 1) / threads_per_block;
|
| 76 |
+
|
| 77 |
+
// Launch kernel with shared memory for bincount
|
| 78 |
+
const size_t shared_mem_size = n_bins * sizeof(int);
|
| 79 |
+
|
| 80 |
+
AT_DISPATCH_INTEGRAL_TYPES(
|
| 81 |
+
input.scalar_type(),
|
| 82 |
+
"bincount_cumsum_cuda",
|
| 83 |
+
([&] {
|
| 84 |
+
bincount_cumsum_kernel<scalar_t>
|
| 85 |
+
<<<n_blocks, threads_per_block, shared_mem_size>>>(
|
| 86 |
+
input.data_ptr<scalar_t>(),
|
| 87 |
+
bins_out.data_ptr<int32_t>(),
|
| 88 |
+
n_input,
|
| 89 |
+
n_bins);
|
| 90 |
+
}));
|
| 91 |
+
|
| 92 |
+
cudaError_t err = cudaGetLastError();
|
| 93 |
+
TORCH_CHECK(err == cudaSuccess,
|
| 94 |
+
"CUDA kernel failed: ",
|
| 95 |
+
cudaGetErrorString(err));
|
| 96 |
+
|
| 97 |
+
// No return needed - output is modified in-place
|
| 98 |
+
}
|
csrc/gather.cu
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// csrc/gather.cu
|
| 2 |
+
|
| 3 |
+
#include <cuda_runtime.h>
|
| 4 |
+
#include <torch/torch.h>
|
| 5 |
+
|
| 6 |
+
template <typename scalar_t>
|
| 7 |
+
__global__ void gather_kernel(
|
| 8 |
+
const scalar_t *__restrict__ x, // [T,H]
|
| 9 |
+
const int *__restrict__ idx, // [S]
|
| 10 |
+
const int *__restrict__ bins, // [E] cumulative
|
| 11 |
+
scalar_t *__restrict__ out, // [E,C,H]
|
| 12 |
+
int T,
|
| 13 |
+
int H,
|
| 14 |
+
int E,
|
| 15 |
+
int C,
|
| 16 |
+
int top_k) {
|
| 17 |
+
int e = blockIdx.x; // expert
|
| 18 |
+
int i = blockIdx.y; // row within capacity
|
| 19 |
+
if (e >= E || i >= C)
|
| 20 |
+
return;
|
| 21 |
+
|
| 22 |
+
const int end = bins[e];
|
| 23 |
+
const int start = (e == 0) ? 0 : bins[e - 1];
|
| 24 |
+
const int n = end - start;
|
| 25 |
+
|
| 26 |
+
bool valid = (i < n);
|
| 27 |
+
int tok = 0;
|
| 28 |
+
if (valid) {
|
| 29 |
+
int flat = idx[start + i];
|
| 30 |
+
tok = flat / top_k;
|
| 31 |
+
if (tok < 0 || tok >= T)
|
| 32 |
+
valid = false; // guard
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
const scalar_t *src = valid ? (x + (size_t)tok * H) : nullptr;
|
| 36 |
+
scalar_t *dst = out + ((size_t)e * C + i) * H;
|
| 37 |
+
|
| 38 |
+
int t = threadIdx.x;
|
| 39 |
+
|
| 40 |
+
// Try vectorized 16B moves if H is multiple of 4 and pointers are aligned
|
| 41 |
+
// (only for float type)
|
| 42 |
+
if constexpr (std::is_same<scalar_t, float>::value) {
|
| 43 |
+
if ((H % 4) == 0 && ((reinterpret_cast<uintptr_t>(dst) & 0xF) == 0) &&
|
| 44 |
+
(!valid || (reinterpret_cast<uintptr_t>(src) & 0xF) == 0)) {
|
| 45 |
+
const int HV = H / 4;
|
| 46 |
+
using F4 = float4;
|
| 47 |
+
const F4 *src4 = reinterpret_cast<const F4 *>(src);
|
| 48 |
+
F4 *dst4 = reinterpret_cast<F4 *>(dst);
|
| 49 |
+
|
| 50 |
+
for (int j = t; j < HV; j += blockDim.x) {
|
| 51 |
+
F4 v;
|
| 52 |
+
if (valid)
|
| 53 |
+
v = src4[j];
|
| 54 |
+
else
|
| 55 |
+
v = make_float4(0.f, 0.f, 0.f, 0.f);
|
| 56 |
+
dst4[j] = v;
|
| 57 |
+
}
|
| 58 |
+
return;
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
// Fallback to scalar copy
|
| 63 |
+
for (int j = t; j < H; j += blockDim.x) {
|
| 64 |
+
dst[j] = valid ? src[j] : scalar_t(0);
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
void gather_cuda(
|
| 69 |
+
torch::Tensor const &x, // [T, H]
|
| 70 |
+
torch::Tensor const &indices, // [S]
|
| 71 |
+
torch::Tensor const &bins, // [E] cumulative
|
| 72 |
+
torch::Tensor &output, // [E, C, H] pre-allocated output buffer
|
| 73 |
+
int64_t E, // number of experts
|
| 74 |
+
int64_t C, // expert capacity
|
| 75 |
+
int64_t top_k // top-k value
|
| 76 |
+
) {
|
| 77 |
+
// Get dimensions
|
| 78 |
+
int64_t T = x.size(0);
|
| 79 |
+
int64_t H = x.size(1);
|
| 80 |
+
|
| 81 |
+
// Validate output tensor dimensions
|
| 82 |
+
TORCH_CHECK(output.size(0) == E && output.size(1) == C && output.size(2) == H,
|
| 83 |
+
"Output tensor must have shape [E, C, H]");
|
| 84 |
+
|
| 85 |
+
// Launch kernel with 2D grid (E, C)
|
| 86 |
+
dim3 grid(E, C);
|
| 87 |
+
int threads = 256;
|
| 88 |
+
|
| 89 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf,
|
| 90 |
+
at::kBFloat16,
|
| 91 |
+
x.scalar_type(),
|
| 92 |
+
"gather_cuda",
|
| 93 |
+
([&] {
|
| 94 |
+
using scalar_t_ =
|
| 95 |
+
scalar_t; // avoid shadowing surprises
|
| 96 |
+
gather_kernel<scalar_t_><<<grid, threads>>>(
|
| 97 |
+
x.data_ptr<scalar_t_>(),
|
| 98 |
+
indices.data_ptr<int>(),
|
| 99 |
+
bins.data_ptr<int>(),
|
| 100 |
+
output.data_ptr<scalar_t_>(),
|
| 101 |
+
(int)T,
|
| 102 |
+
(int)H,
|
| 103 |
+
(int)E,
|
| 104 |
+
(int)C,
|
| 105 |
+
(int)top_k);
|
| 106 |
+
}));
|
| 107 |
+
|
| 108 |
+
// No return needed - output is modified in-place
|
| 109 |
+
}
|
csrc/index_select.cu
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// csrc/index_select.cu
|
| 2 |
+
|
| 3 |
+
#include <c10/cuda/CUDAStream.h>
|
| 4 |
+
#include <cuda_runtime.h>
|
| 5 |
+
#include <torch/torch.h>
|
| 6 |
+
|
| 7 |
+
template <typename scalar_t>
|
| 8 |
+
__global__ void index_select_kernel(
|
| 9 |
+
const scalar_t *__restrict__ in,
|
| 10 |
+
const int32_t *__restrict__ idx,
|
| 11 |
+
scalar_t *__restrict__ out,
|
| 12 |
+
int64_t N) {
|
| 13 |
+
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
|
| 14 |
+
if (i < N)
|
| 15 |
+
out[i] = in[(int64_t)idx[i]];
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
torch::Tensor index_select_out_cuda(
|
| 19 |
+
torch::Tensor out, // [N], same dtype/device as in
|
| 20 |
+
torch::Tensor in, // [M], contiguous
|
| 21 |
+
torch::Tensor idx_int32) // [N], int32, contiguous
|
| 22 |
+
{
|
| 23 |
+
TORCH_CHECK(in.is_cuda() && idx_int32.is_cuda() && out.is_cuda(),
|
| 24 |
+
"cuda only");
|
| 25 |
+
TORCH_CHECK(idx_int32.dtype() == torch::kInt32, "idx must be int32");
|
| 26 |
+
TORCH_CHECK(
|
| 27 |
+
in.is_contiguous() && idx_int32.is_contiguous() && out.is_contiguous(),
|
| 28 |
+
"contiguous required");
|
| 29 |
+
|
| 30 |
+
int64_t N = idx_int32.numel();
|
| 31 |
+
int threads = 256;
|
| 32 |
+
int blocks = (int)((N + threads - 1) / threads);
|
| 33 |
+
|
| 34 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(
|
| 35 |
+
torch::kBFloat16,
|
| 36 |
+
torch::kHalf,
|
| 37 |
+
in.scalar_type(),
|
| 38 |
+
"index_select_int32",
|
| 39 |
+
[&] {
|
| 40 |
+
const scalar_t *pin = in.data_ptr<scalar_t>();
|
| 41 |
+
const int32_t *pidx = idx_int32.data_ptr<int32_t>();
|
| 42 |
+
scalar_t *pout = out.data_ptr<scalar_t>();
|
| 43 |
+
index_select_kernel<scalar_t>
|
| 44 |
+
<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(pin,
|
| 45 |
+
pidx,
|
| 46 |
+
pout,
|
| 47 |
+
N);
|
| 48 |
+
});
|
| 49 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 50 |
+
return out;
|
| 51 |
+
}
|
csrc/moe.cpp
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// csrc/moe.cpp
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 4 |
+
#include <c10/cuda/CUDAStream.h>
|
| 5 |
+
#include <torch/torch.h>
|
| 6 |
+
|
| 7 |
+
// Forward declarations for existing functions
|
| 8 |
+
void sort_cuda(torch::Tensor x,
|
| 9 |
+
int64_t end_bit,
|
| 10 |
+
torch::Tensor x_out,
|
| 11 |
+
torch::Tensor iota_out);
|
| 12 |
+
|
| 13 |
+
void bincount_cumsum_cuda(torch::Tensor input,
|
| 14 |
+
torch::Tensor &output,
|
| 15 |
+
int64_t minlength);
|
| 16 |
+
|
| 17 |
+
torch::Tensor index_select_out_cuda(torch::Tensor out,
|
| 18 |
+
torch::Tensor in,
|
| 19 |
+
torch::Tensor idx_int32);
|
| 20 |
+
|
| 21 |
+
void gather_cuda(torch::Tensor const &x,
|
| 22 |
+
torch::Tensor const &indices,
|
| 23 |
+
torch::Tensor const &bins,
|
| 24 |
+
torch::Tensor &output,
|
| 25 |
+
int64_t E,
|
| 26 |
+
int64_t C,
|
| 27 |
+
int64_t top_k);
|
| 28 |
+
|
| 29 |
+
void scatter_cuda(torch::Tensor const &src,
|
| 30 |
+
torch::Tensor const &indices,
|
| 31 |
+
torch::Tensor const &bins,
|
| 32 |
+
torch::Tensor const &weights,
|
| 33 |
+
torch::Tensor &y,
|
| 34 |
+
int64_t T,
|
| 35 |
+
int64_t E,
|
| 36 |
+
int64_t C,
|
| 37 |
+
int64_t top_k);
|
| 38 |
+
|
| 39 |
+
torch::Tensor batch_mm(torch::Tensor x,
|
| 40 |
+
torch::Tensor weights,
|
| 41 |
+
torch::Tensor batch_sizes,
|
| 42 |
+
torch::Tensor output,
|
| 43 |
+
bool trans_b = false);
|
| 44 |
+
|
| 45 |
+
torch::Tensor experts_cuda(
|
| 46 |
+
torch::Tensor hidden_states, // [B*S, H] - flattened hidden states
|
| 47 |
+
torch::Tensor router_indices, // [B*S, K] - expert indices per token
|
| 48 |
+
torch::Tensor routing_weights, // [B*S, E] or [B*S, K] - routing weights
|
| 49 |
+
torch::Tensor gate_up_proj, // [E, H, 2*H] - gate/up projection weights
|
| 50 |
+
torch::Tensor gate_up_proj_bias, // [E, 2*H] - gate/up projection bias
|
| 51 |
+
torch::Tensor down_proj, // [E, H, H] - down projection weights
|
| 52 |
+
torch::Tensor down_proj_bias, // [E, H] - down projection bias
|
| 53 |
+
int64_t expert_capacity, // C - capacity per expert
|
| 54 |
+
int64_t num_experts, // E - number of experts
|
| 55 |
+
int64_t top_k // K - top-k routing
|
| 56 |
+
) {
|
| 57 |
+
// Input validation
|
| 58 |
+
TORCH_CHECK(hidden_states.is_cuda(), "hidden_states must be on CUDA");
|
| 59 |
+
TORCH_CHECK(router_indices.is_cuda(), "router_indices must be on CUDA");
|
| 60 |
+
TORCH_CHECK(routing_weights.is_cuda(), "routing_weights must be on CUDA");
|
| 61 |
+
TORCH_CHECK(gate_up_proj.is_cuda(), "gate_up_proj must be on CUDA");
|
| 62 |
+
TORCH_CHECK(gate_up_proj_bias.is_cuda(), "gate_up_proj_bias must be on CUDA");
|
| 63 |
+
TORCH_CHECK(down_proj.is_cuda(), "down_proj must be on CUDA");
|
| 64 |
+
TORCH_CHECK(down_proj_bias.is_cuda(), "down_proj_bias must be on CUDA");
|
| 65 |
+
|
| 66 |
+
TORCH_CHECK(hidden_states.ndimension() == 2,
|
| 67 |
+
"hidden_states must be 2D [T, H]");
|
| 68 |
+
TORCH_CHECK(router_indices.ndimension() == 2,
|
| 69 |
+
"router_indices must be 2D [T, K]");
|
| 70 |
+
TORCH_CHECK(routing_weights.ndimension() == 2,
|
| 71 |
+
"routing_weights must be 2D [T, K]");
|
| 72 |
+
TORCH_CHECK(gate_up_proj.ndimension() == 3,
|
| 73 |
+
"gate_up_proj must be 3D [E, H, 2*H]");
|
| 74 |
+
TORCH_CHECK(gate_up_proj_bias.ndimension() == 2,
|
| 75 |
+
"gate_up_proj_bias must be 2D [E, 2*H]");
|
| 76 |
+
TORCH_CHECK(down_proj.ndimension() == 3, "down_proj must be 3D [E, H, H]");
|
| 77 |
+
TORCH_CHECK(down_proj_bias.ndimension() == 2,
|
| 78 |
+
"down_proj_bias must be 2D [E, H]");
|
| 79 |
+
|
| 80 |
+
const int64_t T = hidden_states.size(0); // Total tokens
|
| 81 |
+
const int64_t H = hidden_states.size(1); // Hidden size
|
| 82 |
+
const int64_t E = num_experts;
|
| 83 |
+
const int64_t C = expert_capacity;
|
| 84 |
+
const int64_t K = top_k;
|
| 85 |
+
|
| 86 |
+
TORCH_CHECK(router_indices.size(0) == T && router_indices.size(1) == K);
|
| 87 |
+
TORCH_CHECK(routing_weights.size(0) == T && (routing_weights.size(1) == K ||
|
| 88 |
+
routing_weights.size(1) == E),
|
| 89 |
+
"routing_weights must be [T, K] or [T, E]");
|
| 90 |
+
TORCH_CHECK(gate_up_proj.size(0) == E && gate_up_proj.size(1) == H &&
|
| 91 |
+
gate_up_proj.size(2) == 2 * H);
|
| 92 |
+
TORCH_CHECK(gate_up_proj_bias.size(0) == E &&
|
| 93 |
+
gate_up_proj_bias.size(1) == 2 * H);
|
| 94 |
+
TORCH_CHECK(down_proj.size(0) == E && down_proj.size(1) == H &&
|
| 95 |
+
down_proj.size(2) == H);
|
| 96 |
+
TORCH_CHECK(down_proj_bias.size(0) == E && down_proj_bias.size(1) == H);
|
| 97 |
+
|
| 98 |
+
// Ensure simple contiguity where helpful
|
| 99 |
+
hidden_states = hidden_states.contiguous();
|
| 100 |
+
router_indices = router_indices.contiguous();
|
| 101 |
+
routing_weights = routing_weights.contiguous();
|
| 102 |
+
|
| 103 |
+
// ALLOCATE
|
| 104 |
+
|
| 105 |
+
auto device_opts = torch::TensorOptions()
|
| 106 |
+
.dtype(torch::kInt32)
|
| 107 |
+
.device(hidden_states.device());
|
| 108 |
+
auto int64_opts = torch::TensorOptions()
|
| 109 |
+
.dtype(torch::kInt64)
|
| 110 |
+
.device(hidden_states.device());
|
| 111 |
+
auto float_opts = torch::TensorOptions()
|
| 112 |
+
.dtype(hidden_states.dtype())
|
| 113 |
+
.device(hidden_states.device());
|
| 114 |
+
|
| 115 |
+
// Buffers for sorting
|
| 116 |
+
torch::Tensor flat_indices =
|
| 117 |
+
router_indices.flatten().to(torch::kInt32, /*non_blocking=*/true);
|
| 118 |
+
torch::Tensor sorted_values = torch::empty_like(flat_indices);
|
| 119 |
+
torch::Tensor sorted_indices = torch::empty_like(flat_indices);
|
| 120 |
+
|
| 121 |
+
// Buffer for bins - use int32 for smaller footprint
|
| 122 |
+
torch::Tensor bins =
|
| 123 |
+
torch::empty({E + 1},
|
| 124 |
+
device_opts); // Pre-allocate for bincount_cumsum result
|
| 125 |
+
|
| 126 |
+
// Buffer for gathered tokens
|
| 127 |
+
torch::Tensor x = torch::empty({E, C, H}, float_opts);
|
| 128 |
+
|
| 129 |
+
// Buffer for expert token counts
|
| 130 |
+
torch::Tensor expert_tokens = torch::empty({E}, device_opts);
|
| 131 |
+
|
| 132 |
+
// Buffers for intermediate results
|
| 133 |
+
torch::Tensor gate_up = torch::empty({E, C, 2 * H}, float_opts);
|
| 134 |
+
|
| 135 |
+
// Final output buffer
|
| 136 |
+
torch::Tensor output = torch::zeros_like(hidden_states);
|
| 137 |
+
|
| 138 |
+
// COMPUTE
|
| 139 |
+
|
| 140 |
+
// Sort tokens by expert
|
| 141 |
+
sort_cuda(flat_indices, 32, sorted_values, sorted_indices);
|
| 142 |
+
|
| 143 |
+
// Compute bins using bincount_cumsum
|
| 144 |
+
bincount_cumsum_cuda(sorted_values, bins, E);
|
| 145 |
+
|
| 146 |
+
// Gather tokens by expert
|
| 147 |
+
// [T, H] -> [E, C, H]
|
| 148 |
+
gather_cuda(hidden_states, sorted_indices, bins, x, E, C, K);
|
| 149 |
+
|
| 150 |
+
if (E > 1) {
|
| 151 |
+
expert_tokens.slice(0, 0, E - 1) =
|
| 152 |
+
bins.slice(0, 1, E) - bins.slice(0, 0, E - 1);
|
| 153 |
+
expert_tokens[E - 1] =
|
| 154 |
+
(int32_t)(flat_indices.size(0) - bins[E - 1].item<int32_t>());
|
| 155 |
+
} else {
|
| 156 |
+
expert_tokens[0] = (int32_t)flat_indices.size(0);
|
| 157 |
+
}
|
| 158 |
+
// Clamp to expert capacity
|
| 159 |
+
expert_tokens = torch::clamp(expert_tokens, 0, (int32_t)C);
|
| 160 |
+
|
| 161 |
+
batch_mm(x, gate_up_proj, expert_tokens, gate_up, true);
|
| 162 |
+
|
| 163 |
+
// add the gate bias to the output in-place
|
| 164 |
+
gate_up.add_(gate_up_proj_bias.unsqueeze(1));
|
| 165 |
+
|
| 166 |
+
// Compute GLU in-place, reusing gate_up buffer for output
|
| 167 |
+
auto gate = gate_up.index({torch::indexing::Ellipsis,
|
| 168 |
+
torch::indexing::Slice(torch::indexing::None,
|
| 169 |
+
torch::indexing::None,
|
| 170 |
+
2)});
|
| 171 |
+
auto up =
|
| 172 |
+
gate_up.index({torch::indexing::Ellipsis,
|
| 173 |
+
torch::indexing::Slice(1, torch::indexing::None, 2)});
|
| 174 |
+
|
| 175 |
+
const float limit = 7.0f;
|
| 176 |
+
gate = gate.clamp(/*min=*/c10::nullopt, /*max=*/limit);
|
| 177 |
+
up = up.clamp(/*min=*/-limit, /*max=*/limit);
|
| 178 |
+
|
| 179 |
+
gate.mul_(torch::sigmoid(gate * 1.702f));
|
| 180 |
+
up.add_(1).mul_(gate);
|
| 181 |
+
|
| 182 |
+
// Down projection uses GLU result directly
|
| 183 |
+
gate_up.resize_(0);
|
| 184 |
+
batch_mm(up, down_proj, expert_tokens, gate_up, true);
|
| 185 |
+
|
| 186 |
+
// add the down_bias in-place
|
| 187 |
+
gate_up.add_(down_proj_bias.unsqueeze(1));
|
| 188 |
+
|
| 189 |
+
// Stage allocations right before use
|
| 190 |
+
torch::Tensor selected_weights = torch::empty({T * K}, float_opts);
|
| 191 |
+
torch::Tensor weights_sorted = torch::empty({T * K}, float_opts);
|
| 192 |
+
|
| 193 |
+
torch::Tensor selected_weights_2d =
|
| 194 |
+
selected_weights.view({T, K}); // named lvalue view
|
| 195 |
+
torch::Tensor flat_dense = routing_weights.view({T, E});
|
| 196 |
+
torch::Tensor flat_router = router_indices.view({T, K});
|
| 197 |
+
|
| 198 |
+
// gather_out(out&, self, dim, index, sparse_grad=false)
|
| 199 |
+
at::gather_out(selected_weights_2d,
|
| 200 |
+
flat_dense,
|
| 201 |
+
/*dim=*/1,
|
| 202 |
+
flat_router,
|
| 203 |
+
/*sparse_grad=*/false);
|
| 204 |
+
|
| 205 |
+
// Use int32 index select to avoid dtype conversion
|
| 206 |
+
index_select_out_cuda(weights_sorted, // [T*K], float_opts
|
| 207 |
+
selected_weights.view({T * K}), // const&, ok as rvalue
|
| 208 |
+
sorted_indices // int32 indices, no conversion needed
|
| 209 |
+
);
|
| 210 |
+
|
| 211 |
+
// Scatter back to original positions with weights applied
|
| 212 |
+
scatter_cuda(gate_up.view({E, C, H}),
|
| 213 |
+
sorted_indices,
|
| 214 |
+
bins,
|
| 215 |
+
weights_sorted,
|
| 216 |
+
output,
|
| 217 |
+
T,
|
| 218 |
+
E,
|
| 219 |
+
C,
|
| 220 |
+
K);
|
| 221 |
+
|
| 222 |
+
return output;
|
| 223 |
+
}
|
csrc/scatter.cu
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// csrc/scatter.cu
|
| 2 |
+
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
#include <cuda_runtime.h>
|
| 5 |
+
#include <torch/torch.h>
|
| 6 |
+
#include <type_traits>
|
| 7 |
+
|
| 8 |
+
// Minimal atomic add shim:
|
| 9 |
+
// - native CUDA atomics for float/double
|
| 10 |
+
// - 16-bit CAS fallback for Half/BFloat16 (works on all SMs)
|
| 11 |
+
|
| 12 |
+
// CAS-based 16-bit atomic add (for c10::Half / c10::BFloat16)
|
| 13 |
+
template <typename T>
|
| 14 |
+
__device__ inline void atomicAdd16(
|
| 15 |
+
T *addr,
|
| 16 |
+
T val) {
|
| 17 |
+
// Find containing 32-bit word and whether we're the high or low 16 bits
|
| 18 |
+
std::uintptr_t uaddr = reinterpret_cast<std::uintptr_t>(addr);
|
| 19 |
+
unsigned int *base =
|
| 20 |
+
reinterpret_cast<unsigned int *>(uaddr & ~std::uintptr_t(0x3));
|
| 21 |
+
const bool hi_half = (uaddr & 0x2) != 0;
|
| 22 |
+
|
| 23 |
+
unsigned int old32 = *base, assumed;
|
| 24 |
+
do {
|
| 25 |
+
assumed = old32;
|
| 26 |
+
|
| 27 |
+
// Extract current 16-bit payload
|
| 28 |
+
unsigned short cur16 = hi_half ? (assumed >> 16) : (assumed & 0xFFFFu);
|
| 29 |
+
|
| 30 |
+
// Reinterpret those 16 bits as T, then promote to float
|
| 31 |
+
T cur;
|
| 32 |
+
*reinterpret_cast<unsigned short *>(&cur) = cur16;
|
| 33 |
+
float f = static_cast<float>(cur) + static_cast<float>(val);
|
| 34 |
+
|
| 35 |
+
// Convert back to T (rounds appropriately), grab its 16-bit payload
|
| 36 |
+
T res = static_cast<T>(f);
|
| 37 |
+
unsigned short res16 = *reinterpret_cast<unsigned short *>(&res);
|
| 38 |
+
|
| 39 |
+
// Merge back into the correct half and attempt CAS
|
| 40 |
+
unsigned int new32 =
|
| 41 |
+
hi_half ? ((assumed & 0x0000FFFFu) |
|
| 42 |
+
(static_cast<unsigned int>(res16) << 16))
|
| 43 |
+
: ((assumed & 0xFFFF0000u) | static_cast<unsigned int>(res16));
|
| 44 |
+
|
| 45 |
+
old32 = atomicCAS(base, assumed, new32);
|
| 46 |
+
} while (old32 != assumed);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// Unified atomicAdd for all scalar_t
|
| 50 |
+
template <typename T>
|
| 51 |
+
__device__ inline void atomicAddT(
|
| 52 |
+
T *addr,
|
| 53 |
+
T val) {
|
| 54 |
+
if constexpr (std::is_same<T, float>::value) {
|
| 55 |
+
atomicAdd(addr, val);
|
| 56 |
+
} else if constexpr (std::is_same<T, double>::value) {
|
| 57 |
+
atomicAdd(addr, val);
|
| 58 |
+
} else {
|
| 59 |
+
// c10::Half or c10::BFloat16
|
| 60 |
+
atomicAdd16(addr, val);
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
// Kernel: y[tok, :] += src[e, i, :] for valid (e,i)
|
| 65 |
+
// where tok = indices[bins[e-1] + i] / top_k
|
| 66 |
+
template <typename scalar_t>
|
| 67 |
+
__global__ void scatter_kernel(
|
| 68 |
+
const scalar_t *__restrict__ src, // [E, C, H]
|
| 69 |
+
const int *__restrict__ idx, // [S]
|
| 70 |
+
const int *__restrict__ bins, // [E] cumulative
|
| 71 |
+
const scalar_t *__restrict__ weights, // [S] routing weights (can be null)
|
| 72 |
+
scalar_t *__restrict__ y, // [T, H] (accumulated)
|
| 73 |
+
int T,
|
| 74 |
+
int H,
|
| 75 |
+
int E,
|
| 76 |
+
int C,
|
| 77 |
+
int top_k) {
|
| 78 |
+
int e = blockIdx.x;
|
| 79 |
+
int i = blockIdx.y;
|
| 80 |
+
if (e >= E || i >= C)
|
| 81 |
+
return;
|
| 82 |
+
|
| 83 |
+
const int end = bins[e];
|
| 84 |
+
const int start = (e == 0) ? 0 : bins[e - 1];
|
| 85 |
+
const int n = end - start;
|
| 86 |
+
|
| 87 |
+
bool valid = (i < n);
|
| 88 |
+
int tok = 0;
|
| 89 |
+
if (valid) {
|
| 90 |
+
int flat = idx[start + i];
|
| 91 |
+
tok = flat / top_k;
|
| 92 |
+
if (tok < 0 || tok >= T)
|
| 93 |
+
valid = false; // guard
|
| 94 |
+
}
|
| 95 |
+
if (!valid)
|
| 96 |
+
return;
|
| 97 |
+
|
| 98 |
+
const scalar_t *src_row = src + ((size_t)e * C + i) * H;
|
| 99 |
+
scalar_t *y_row = y + (size_t)tok * H;
|
| 100 |
+
|
| 101 |
+
// Get the weight/scale factor for this token if provided
|
| 102 |
+
scalar_t scale = (weights != nullptr) ? weights[start + i] : scalar_t(1.0);
|
| 103 |
+
|
| 104 |
+
int t = threadIdx.x;
|
| 105 |
+
for (int h = t; h < H; h += blockDim.x) {
|
| 106 |
+
atomicAddT(&y_row[h], src_row[h] * scale);
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
void scatter_cuda(
|
| 111 |
+
const torch::Tensor &src, // [E, C, H]
|
| 112 |
+
const torch::Tensor &indices, // [S] (int32)
|
| 113 |
+
const torch::Tensor &bins, // [E] cumulative (int32)
|
| 114 |
+
const torch::Tensor &weights, // [S] routing weights (optional)
|
| 115 |
+
torch::Tensor &y, // [T, H] (accumulate into)
|
| 116 |
+
int64_t T, // tokens
|
| 117 |
+
int64_t E, // experts
|
| 118 |
+
int64_t C, // capacity
|
| 119 |
+
int64_t top_k // router top-k
|
| 120 |
+
) {
|
| 121 |
+
const int64_t H = src.size(2);
|
| 122 |
+
|
| 123 |
+
// Grid over experts x capacity; threads over H
|
| 124 |
+
dim3 grid(E, C);
|
| 125 |
+
int threads = 256;
|
| 126 |
+
|
| 127 |
+
// Include Half + BFloat16 in dispatch
|
| 128 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(
|
| 129 |
+
at::kHalf,
|
| 130 |
+
at::kBFloat16,
|
| 131 |
+
src.scalar_type(),
|
| 132 |
+
"scatter_cuda",
|
| 133 |
+
([&] {
|
| 134 |
+
using scalar_t_ = scalar_t;
|
| 135 |
+
scatter_kernel<scalar_t_><<<grid, threads>>>(
|
| 136 |
+
src.data_ptr<scalar_t_>(),
|
| 137 |
+
indices.data_ptr<int>(),
|
| 138 |
+
bins.data_ptr<int>(),
|
| 139 |
+
weights.defined() ? weights.data_ptr<scalar_t_>() : nullptr,
|
| 140 |
+
y.data_ptr<scalar_t_>(),
|
| 141 |
+
static_cast<int>(T),
|
| 142 |
+
static_cast<int>(H),
|
| 143 |
+
static_cast<int>(E),
|
| 144 |
+
static_cast<int>(C),
|
| 145 |
+
static_cast<int>(top_k));
|
| 146 |
+
}));
|
| 147 |
+
}
|
csrc/sort.cu
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// csrc/sort.cu
|
| 2 |
+
// originally from
|
| 3 |
+
// https://github.com/databricks/megablocks/blob/main/csrc/sort.h
|
| 4 |
+
|
| 5 |
+
#include <c10/cuda/CUDAStream.h>
|
| 6 |
+
#include <cstdint>
|
| 7 |
+
#include <cub/cub.cuh>
|
| 8 |
+
#include <torch/torch.h>
|
| 9 |
+
|
| 10 |
+
#define CUDA_CALL(code) \
|
| 11 |
+
do { \
|
| 12 |
+
cudaError_t status = (code); \
|
| 13 |
+
std::string err = cudaGetErrorString(status); \
|
| 14 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
| 15 |
+
} while (0)
|
| 16 |
+
|
| 17 |
+
template <typename T>
|
| 18 |
+
void cub_radix_sort(
|
| 19 |
+
torch::Tensor x,
|
| 20 |
+
int64_t end_bit,
|
| 21 |
+
torch::Tensor x_out,
|
| 22 |
+
torch::Tensor iota_out) {
|
| 23 |
+
// Get iota for values in sort.
|
| 24 |
+
auto iota_options =
|
| 25 |
+
torch::TensorOptions().dtype(x.scalar_type()).device(x.device());
|
| 26 |
+
torch::Tensor iota = torch::arange(0, x.numel(), iota_options);
|
| 27 |
+
|
| 28 |
+
// Get temporary buffer size.
|
| 29 |
+
size_t scratchpad_bytes = 0;
|
| 30 |
+
CUDA_CALL(cub::DeviceRadixSort::SortPairs(
|
| 31 |
+
/*d_temp_storage*/ nullptr,
|
| 32 |
+
/*temp_storage_bytes*/ scratchpad_bytes,
|
| 33 |
+
/*d_keys_in*/ x.data_ptr<T>(),
|
| 34 |
+
/*d_keys_out*/ x_out.data_ptr<T>(),
|
| 35 |
+
/*d_values_in*/ iota.data_ptr<T>(),
|
| 36 |
+
/*d_values_out*/ iota_out.data_ptr<T>(),
|
| 37 |
+
/*num_items*/ x.numel(),
|
| 38 |
+
/*begin_bit*/ 0,
|
| 39 |
+
/*end_bit*/ end_bit,
|
| 40 |
+
/*stream*/ c10::cuda::getCurrentCUDAStream()));
|
| 41 |
+
|
| 42 |
+
// Allocate scratchpad.
|
| 43 |
+
auto options = torch::TensorOptions().dtype(torch::kInt8).device(x.device());
|
| 44 |
+
torch::Tensor scratchpad =
|
| 45 |
+
torch::empty(static_cast<long>(scratchpad_bytes), options);
|
| 46 |
+
|
| 47 |
+
// Run the kernel.
|
| 48 |
+
CUDA_CALL(cub::DeviceRadixSort::SortPairs(
|
| 49 |
+
/*d_temp_storage*/ scratchpad.data_ptr(),
|
| 50 |
+
/*temp_storage_bytes*/ scratchpad_bytes,
|
| 51 |
+
/*d_keys_in*/ x.data_ptr<T>(),
|
| 52 |
+
/*d_keys_out*/ x_out.data_ptr<T>(),
|
| 53 |
+
/*d_values_in*/ iota.data_ptr<T>(),
|
| 54 |
+
/*d_values_out*/ iota_out.data_ptr<T>(),
|
| 55 |
+
/*num_items*/ x.numel(),
|
| 56 |
+
/*begin_bit*/ 0,
|
| 57 |
+
/*end_bit*/ end_bit,
|
| 58 |
+
/*stream*/ c10::cuda::getCurrentCUDAStream()));
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
void sort_cuda(
|
| 62 |
+
torch::Tensor x,
|
| 63 |
+
int64_t end_bit,
|
| 64 |
+
torch::Tensor x_out,
|
| 65 |
+
torch::Tensor iota_out) {
|
| 66 |
+
TORCH_CHECK(x.is_cuda());
|
| 67 |
+
TORCH_CHECK(x.ndimension() == 1);
|
| 68 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
| 69 |
+
x.scalar_type() == torch::kInt32 ||
|
| 70 |
+
x.scalar_type() == torch::kInt64);
|
| 71 |
+
TORCH_CHECK(x_out.is_cuda());
|
| 72 |
+
TORCH_CHECK(x_out.ndimension() == 1);
|
| 73 |
+
TORCH_CHECK(x_out.scalar_type() == x.scalar_type());
|
| 74 |
+
TORCH_CHECK(iota_out.is_cuda());
|
| 75 |
+
TORCH_CHECK(iota_out.ndimension() == 1);
|
| 76 |
+
TORCH_CHECK(iota_out.scalar_type() == x.scalar_type());
|
| 77 |
+
|
| 78 |
+
// Exit early if there is no work to do.
|
| 79 |
+
if (x_out.numel() == 0)
|
| 80 |
+
return;
|
| 81 |
+
|
| 82 |
+
switch (x.scalar_type()) {
|
| 83 |
+
case torch::kInt16:
|
| 84 |
+
return cub_radix_sort<short>(x, end_bit, x_out, iota_out);
|
| 85 |
+
case torch::kInt32:
|
| 86 |
+
return cub_radix_sort<int>(x, end_bit, x_out, iota_out);
|
| 87 |
+
default:
|
| 88 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
| 89 |
+
return cub_radix_sort<long>(x, end_bit, x_out, iota_out);
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
#undef CUDA_CALL
|
flake.lock
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nodes": {
|
| 3 |
+
"flake-compat": {
|
| 4 |
+
"locked": {
|
| 5 |
+
"lastModified": 1747046372,
|
| 6 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
| 7 |
+
"owner": "edolstra",
|
| 8 |
+
"repo": "flake-compat",
|
| 9 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
| 10 |
+
"type": "github"
|
| 11 |
+
},
|
| 12 |
+
"original": {
|
| 13 |
+
"owner": "edolstra",
|
| 14 |
+
"repo": "flake-compat",
|
| 15 |
+
"type": "github"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"flake-compat_2": {
|
| 19 |
+
"locked": {
|
| 20 |
+
"lastModified": 1733328505,
|
| 21 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
| 22 |
+
"owner": "edolstra",
|
| 23 |
+
"repo": "flake-compat",
|
| 24 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
| 25 |
+
"type": "github"
|
| 26 |
+
},
|
| 27 |
+
"original": {
|
| 28 |
+
"owner": "edolstra",
|
| 29 |
+
"repo": "flake-compat",
|
| 30 |
+
"type": "github"
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"flake-utils": {
|
| 34 |
+
"inputs": {
|
| 35 |
+
"systems": "systems"
|
| 36 |
+
},
|
| 37 |
+
"locked": {
|
| 38 |
+
"lastModified": 1731533236,
|
| 39 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 40 |
+
"owner": "numtide",
|
| 41 |
+
"repo": "flake-utils",
|
| 42 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 43 |
+
"type": "github"
|
| 44 |
+
},
|
| 45 |
+
"original": {
|
| 46 |
+
"owner": "numtide",
|
| 47 |
+
"repo": "flake-utils",
|
| 48 |
+
"type": "github"
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"flake-utils_2": {
|
| 52 |
+
"inputs": {
|
| 53 |
+
"systems": "systems_2"
|
| 54 |
+
},
|
| 55 |
+
"locked": {
|
| 56 |
+
"lastModified": 1731533236,
|
| 57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 58 |
+
"owner": "numtide",
|
| 59 |
+
"repo": "flake-utils",
|
| 60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 61 |
+
"type": "github"
|
| 62 |
+
},
|
| 63 |
+
"original": {
|
| 64 |
+
"owner": "numtide",
|
| 65 |
+
"repo": "flake-utils",
|
| 66 |
+
"type": "github"
|
| 67 |
+
}
|
| 68 |
+
},
|
| 69 |
+
"hf-nix": {
|
| 70 |
+
"inputs": {
|
| 71 |
+
"flake-compat": "flake-compat_2",
|
| 72 |
+
"flake-utils": "flake-utils_2",
|
| 73 |
+
"nixpkgs": "nixpkgs"
|
| 74 |
+
},
|
| 75 |
+
"locked": {
|
| 76 |
+
"lastModified": 1754038838,
|
| 77 |
+
"narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
|
| 78 |
+
"owner": "huggingface",
|
| 79 |
+
"repo": "hf-nix",
|
| 80 |
+
"rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
|
| 81 |
+
"type": "github"
|
| 82 |
+
},
|
| 83 |
+
"original": {
|
| 84 |
+
"owner": "huggingface",
|
| 85 |
+
"repo": "hf-nix",
|
| 86 |
+
"type": "github"
|
| 87 |
+
}
|
| 88 |
+
},
|
| 89 |
+
"kernel-builder": {
|
| 90 |
+
"inputs": {
|
| 91 |
+
"flake-compat": "flake-compat",
|
| 92 |
+
"flake-utils": "flake-utils",
|
| 93 |
+
"hf-nix": "hf-nix",
|
| 94 |
+
"nixpkgs": [
|
| 95 |
+
"kernel-builder",
|
| 96 |
+
"hf-nix",
|
| 97 |
+
"nixpkgs"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
"locked": {
|
| 101 |
+
"lastModified": 1756320464,
|
| 102 |
+
"narHash": "sha256-x9LI4h87/Z9UgTQjgeG0fRcdeXl91xIqBlTauGKZM70=",
|
| 103 |
+
"owner": "huggingface",
|
| 104 |
+
"repo": "kernel-builder",
|
| 105 |
+
"rev": "b4accba4496b28faef19a0487fbcf9686b14e2ef",
|
| 106 |
+
"type": "github"
|
| 107 |
+
},
|
| 108 |
+
"original": {
|
| 109 |
+
"owner": "huggingface",
|
| 110 |
+
"repo": "kernel-builder",
|
| 111 |
+
"type": "github"
|
| 112 |
+
}
|
| 113 |
+
},
|
| 114 |
+
"nixpkgs": {
|
| 115 |
+
"locked": {
|
| 116 |
+
"lastModified": 1752785354,
|
| 117 |
+
"narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
|
| 118 |
+
"owner": "nixos",
|
| 119 |
+
"repo": "nixpkgs",
|
| 120 |
+
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
| 121 |
+
"type": "github"
|
| 122 |
+
},
|
| 123 |
+
"original": {
|
| 124 |
+
"owner": "nixos",
|
| 125 |
+
"repo": "nixpkgs",
|
| 126 |
+
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
| 127 |
+
"type": "github"
|
| 128 |
+
}
|
| 129 |
+
},
|
| 130 |
+
"root": {
|
| 131 |
+
"inputs": {
|
| 132 |
+
"kernel-builder": "kernel-builder"
|
| 133 |
+
}
|
| 134 |
+
},
|
| 135 |
+
"systems": {
|
| 136 |
+
"locked": {
|
| 137 |
+
"lastModified": 1681028828,
|
| 138 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 139 |
+
"owner": "nix-systems",
|
| 140 |
+
"repo": "default",
|
| 141 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 142 |
+
"type": "github"
|
| 143 |
+
},
|
| 144 |
+
"original": {
|
| 145 |
+
"owner": "nix-systems",
|
| 146 |
+
"repo": "default",
|
| 147 |
+
"type": "github"
|
| 148 |
+
}
|
| 149 |
+
},
|
| 150 |
+
"systems_2": {
|
| 151 |
+
"locked": {
|
| 152 |
+
"lastModified": 1681028828,
|
| 153 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 154 |
+
"owner": "nix-systems",
|
| 155 |
+
"repo": "default",
|
| 156 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 157 |
+
"type": "github"
|
| 158 |
+
},
|
| 159 |
+
"original": {
|
| 160 |
+
"owner": "nix-systems",
|
| 161 |
+
"repo": "default",
|
| 162 |
+
"type": "github"
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
},
|
| 166 |
+
"root": "root",
|
| 167 |
+
"version": 7
|
| 168 |
+
}
|
flake.nix
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
description = "Flake for yamoe kernels";
|
| 3 |
+
|
| 4 |
+
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
outputs =
|
| 9 |
+
{
|
| 10 |
+
self,
|
| 11 |
+
kernel-builder,
|
| 12 |
+
}:
|
| 13 |
+
kernel-builder.lib.genFlakeOutputs {
|
| 14 |
+
path = ./.;
|
| 15 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
| 16 |
+
|
| 17 |
+
pythonCheckInputs = pkgs: with pkgs; [
|
| 18 |
+
tqdm
|
| 19 |
+
py-cpuinfo
|
| 20 |
+
importlib-metadata
|
| 21 |
+
torchmetrics
|
| 22 |
+
];
|
| 23 |
+
};
|
| 24 |
+
}
|
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/library.h>
|
| 2 |
+
|
| 3 |
+
#include "registration.h"
|
| 4 |
+
#include "torch_binding.h"
|
| 5 |
+
|
| 6 |
+
TORCH_LIBRARY_EXPAND(
|
| 7 |
+
TORCH_EXTENSION_NAME,
|
| 8 |
+
ops) {
|
| 9 |
+
ops.def("gather("
|
| 10 |
+
"Tensor x, "
|
| 11 |
+
"Tensor indices, "
|
| 12 |
+
"Tensor bins, "
|
| 13 |
+
"Tensor! output, "
|
| 14 |
+
"int E, "
|
| 15 |
+
"int C, "
|
| 16 |
+
"int top_k) -> ()");
|
| 17 |
+
ops.impl("gather", torch::kCUDA, &gather_cuda);
|
| 18 |
+
|
| 19 |
+
ops.def("scatter("
|
| 20 |
+
"Tensor src, "
|
| 21 |
+
"Tensor indices, "
|
| 22 |
+
"Tensor bins, "
|
| 23 |
+
"Tensor weights, "
|
| 24 |
+
"Tensor! y, "
|
| 25 |
+
"int T, "
|
| 26 |
+
"int E, "
|
| 27 |
+
"int C, "
|
| 28 |
+
"int top_k) -> ()");
|
| 29 |
+
ops.impl("scatter", torch::kCUDA, &scatter_cuda);
|
| 30 |
+
|
| 31 |
+
ops.def("sort("
|
| 32 |
+
"Tensor x, "
|
| 33 |
+
"int end_bit, "
|
| 34 |
+
"Tensor! x_out, "
|
| 35 |
+
"Tensor! iota_out) -> ()");
|
| 36 |
+
ops.impl("sort", torch::kCUDA, &sort_cuda);
|
| 37 |
+
|
| 38 |
+
ops.def("bincount_cumsum("
|
| 39 |
+
"Tensor input, "
|
| 40 |
+
"Tensor! output, "
|
| 41 |
+
"int minlength) -> ()");
|
| 42 |
+
ops.impl("bincount_cumsum", torch::kCUDA, &bincount_cumsum_cuda);
|
| 43 |
+
|
| 44 |
+
ops.def("index_select_out("
|
| 45 |
+
"Tensor! out, "
|
| 46 |
+
"Tensor input, "
|
| 47 |
+
"Tensor idx_int32) -> Tensor");
|
| 48 |
+
ops.impl("index_select_out", torch::kCUDA, &index_select_out_cuda);
|
| 49 |
+
|
| 50 |
+
ops.def("batch_mm("
|
| 51 |
+
"Tensor x, "
|
| 52 |
+
"Tensor weights, "
|
| 53 |
+
"Tensor batch_sizes, "
|
| 54 |
+
"Tensor! output, "
|
| 55 |
+
"bool trans_b=False) -> Tensor");
|
| 56 |
+
ops.impl("batch_mm", torch::kCUDA, &batch_mm);
|
| 57 |
+
|
| 58 |
+
ops.def("experts("
|
| 59 |
+
"Tensor hidden_states, "
|
| 60 |
+
"Tensor router_indices, "
|
| 61 |
+
"Tensor routing_weights, "
|
| 62 |
+
"Tensor gate_up_proj, "
|
| 63 |
+
"Tensor gate_up_proj_bias, "
|
| 64 |
+
"Tensor down_proj, "
|
| 65 |
+
"Tensor down_proj_bias, "
|
| 66 |
+
"int expert_capacity, "
|
| 67 |
+
"int num_experts, "
|
| 68 |
+
"int top_k) -> Tensor");
|
| 69 |
+
ops.impl("experts", torch::kCUDA, &experts_cuda);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
REGISTER_EXTENSION(
|
| 73 |
+
TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/torch.h>
|
| 4 |
+
|
| 5 |
+
void gather_cuda(torch::Tensor const &x,
|
| 6 |
+
torch::Tensor const &indices,
|
| 7 |
+
torch::Tensor const &bins,
|
| 8 |
+
torch::Tensor &output,
|
| 9 |
+
int64_t E,
|
| 10 |
+
int64_t C,
|
| 11 |
+
int64_t top_k);
|
| 12 |
+
|
| 13 |
+
void scatter_cuda(torch::Tensor const &src,
|
| 14 |
+
torch::Tensor const &indices,
|
| 15 |
+
torch::Tensor const &bins,
|
| 16 |
+
torch::Tensor const &weights,
|
| 17 |
+
torch::Tensor &y,
|
| 18 |
+
int64_t T,
|
| 19 |
+
int64_t E,
|
| 20 |
+
int64_t C,
|
| 21 |
+
int64_t top_k);
|
| 22 |
+
|
| 23 |
+
void sort_cuda(torch::Tensor x,
|
| 24 |
+
int64_t end_bit,
|
| 25 |
+
torch::Tensor x_out,
|
| 26 |
+
torch::Tensor iota_out);
|
| 27 |
+
|
| 28 |
+
void bincount_cumsum_cuda(torch::Tensor input,
|
| 29 |
+
torch::Tensor &output,
|
| 30 |
+
int64_t minlength);
|
| 31 |
+
|
| 32 |
+
torch::Tensor index_select_out_cuda(torch::Tensor out,
|
| 33 |
+
torch::Tensor in,
|
| 34 |
+
torch::Tensor idx_int32);
|
| 35 |
+
|
| 36 |
+
torch::Tensor
|
| 37 |
+
batch_mm(torch::Tensor x, // [E, C, H] - expert tokens
|
| 38 |
+
torch::Tensor weights, // [E, H, H_out] - expert weight matrices
|
| 39 |
+
torch::Tensor batch_sizes, // [E] - actual tokens per expert (<=C)
|
| 40 |
+
torch::Tensor output, // [E, C, H_out] - output buffer
|
| 41 |
+
bool trans_b = false // transpose weights if needed
|
| 42 |
+
);
|
| 43 |
+
|
| 44 |
+
torch::Tensor experts_cuda(
|
| 45 |
+
torch::Tensor hidden_states, // [T, H] - flattened hidden states
|
| 46 |
+
torch::Tensor router_indices, // [T, K] - expert indices per token
|
| 47 |
+
torch::Tensor routing_weights, // [T, E] or [T, K] - routing weights
|
| 48 |
+
torch::Tensor gate_up_proj, // [E, H, 2*H] - gate/up projection weights
|
| 49 |
+
torch::Tensor gate_up_proj_bias, // [E, 2*H] - gate/up projection bias
|
| 50 |
+
torch::Tensor down_proj, // [E, H, H] - down projection weights
|
| 51 |
+
torch::Tensor down_proj_bias, // [E, H] - down projection bias
|
| 52 |
+
int64_t expert_capacity, // C - capacity per expert
|
| 53 |
+
int64_t num_experts, // E - number of experts
|
| 54 |
+
int64_t top_k // K - top-k routing
|
| 55 |
+
);
|
torch-ext/yamoe/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._ops import ops
|
| 2 |
+
from . import reference
|
| 3 |
+
|
| 4 |
+
gather = ops.gather
|
| 5 |
+
scatter = ops.scatter
|
| 6 |
+
sort = ops.sort
|
| 7 |
+
bincount_cumsum = ops.bincount_cumsum
|
| 8 |
+
batch_mm = ops.batch_mm
|
| 9 |
+
experts = ops.experts
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"shuffle",
|
| 13 |
+
"gather",
|
| 14 |
+
"scatter",
|
| 15 |
+
"sort",
|
| 16 |
+
"bincount_cumsum",
|
| 17 |
+
"batch_mm",
|
| 18 |
+
"experts",
|
| 19 |
+
# Export the reference implementation
|
| 20 |
+
"reference",
|
| 21 |
+
]
|
torch-ext/yamoe/reference.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class GptOssExperts(nn.Module):
|
| 5 |
+
def __init__(self, config):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.intermediate_size = config.intermediate_size
|
| 8 |
+
self.num_experts = config.num_local_experts
|
| 9 |
+
self.hidden_size = config.hidden_size
|
| 10 |
+
self.expert_dim = self.intermediate_size
|
| 11 |
+
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
| 12 |
+
self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
|
| 13 |
+
self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
|
| 14 |
+
self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
|
| 15 |
+
self.alpha = 1.702
|
| 16 |
+
self.limit = 7.0
|
| 17 |
+
|
| 18 |
+
def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
|
| 19 |
+
"""
|
| 20 |
+
When training is is more efficient to just loop over the experts and compute the output for each expert
|
| 21 |
+
as otherwise the memory would explode.
|
| 22 |
+
|
| 23 |
+
For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
|
| 27 |
+
selected_experts (torch.Tensor): (batch_size * token_num, top_k)
|
| 28 |
+
routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
|
| 29 |
+
Returns:
|
| 30 |
+
torch.Tensor
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
# import ipdb; ipdb.set_trace()
|
| 34 |
+
|
| 35 |
+
batch_size = hidden_states.shape[0]
|
| 36 |
+
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
|
| 37 |
+
num_experts = routing_weights.shape[1]
|
| 38 |
+
if self.training:
|
| 39 |
+
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
|
| 42 |
+
expert_mask = expert_mask.permute(2, 1, 0)
|
| 43 |
+
# we sum on the top_k and on the sequence lenght to get which experts
|
| 44 |
+
# are hit this time around
|
| 45 |
+
expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
| 46 |
+
for expert_idx in expert_hitted[:]:
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
_, token_idx = torch.where(expert_mask[expert_idx[0]])
|
| 49 |
+
current_state = hidden_states[token_idx]
|
| 50 |
+
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
|
| 51 |
+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 52 |
+
gate = gate.clamp(min=None, max=self.limit)
|
| 53 |
+
up = up.clamp(min=-self.limit, max=self.limit)
|
| 54 |
+
glu = gate * torch.sigmoid(gate * self.alpha)
|
| 55 |
+
gated_output = (up + 1) * glu
|
| 56 |
+
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
|
| 57 |
+
weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
|
| 58 |
+
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
|
| 59 |
+
next_states = next_states.view(batch_size, -1, self.hidden_size)
|
| 60 |
+
else:
|
| 61 |
+
hidden_states = hidden_states.repeat(num_experts, 1)
|
| 62 |
+
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
|
| 63 |
+
gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
|
| 64 |
+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 65 |
+
gate = gate.clamp(min=None, max=self.limit)
|
| 66 |
+
up = up.clamp(min=-self.limit, max=self.limit)
|
| 67 |
+
glu = gate * torch.sigmoid(gate * self.alpha)
|
| 68 |
+
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
|
| 69 |
+
next_states = next_states + self.down_proj_bias[..., None, :]
|
| 70 |
+
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
|
| 71 |
+
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
|
| 72 |
+
next_states = next_states.sum(dim=0)
|
| 73 |
+
return next_states
|