|
|
#include <torch/library.h> |
|
|
|
|
|
#include "registration.h" |
|
|
#include "torch_binding.h" |
|
|
|
|
|
TORCH_LIBRARY_EXPAND( |
|
|
TORCH_EXTENSION_NAME, |
|
|
ops) { |
|
|
ops.def("gather(" |
|
|
"Tensor x, " |
|
|
"Tensor indices, " |
|
|
"Tensor bins, " |
|
|
"Tensor! output, " |
|
|
"int E, " |
|
|
"int C, " |
|
|
"int top_k) -> ()"); |
|
|
ops.impl("gather", torch::kCUDA, &gather_cuda); |
|
|
|
|
|
ops.def("scatter(" |
|
|
"Tensor src, " |
|
|
"Tensor indices, " |
|
|
"Tensor bins, " |
|
|
"Tensor weights, " |
|
|
"Tensor! y, " |
|
|
"int T, " |
|
|
"int E, " |
|
|
"int C, " |
|
|
"int top_k) -> ()"); |
|
|
ops.impl("scatter", torch::kCUDA, &scatter_cuda); |
|
|
|
|
|
ops.def("sort(" |
|
|
"Tensor x, " |
|
|
"int end_bit, " |
|
|
"Tensor! x_out, " |
|
|
"Tensor! iota_out) -> ()"); |
|
|
ops.impl("sort", torch::kCUDA, &sort_cuda); |
|
|
|
|
|
ops.def("bincount_cumsum(" |
|
|
"Tensor input, " |
|
|
"Tensor! output, " |
|
|
"int minlength) -> ()"); |
|
|
ops.impl("bincount_cumsum", torch::kCUDA, &bincount_cumsum_cuda); |
|
|
|
|
|
ops.def("index_select_out(" |
|
|
"Tensor! out, " |
|
|
"Tensor input, " |
|
|
"Tensor idx_int32) -> Tensor"); |
|
|
ops.impl("index_select_out", torch::kCUDA, &index_select_out_cuda); |
|
|
|
|
|
ops.def("batch_mm(" |
|
|
"Tensor x, " |
|
|
"Tensor weights, " |
|
|
"Tensor batch_sizes, " |
|
|
"Tensor! output, " |
|
|
"bool trans_b=False) -> Tensor"); |
|
|
ops.impl("batch_mm", torch::kCUDA, &batch_mm); |
|
|
|
|
|
ops.def("experts(" |
|
|
"Tensor hidden_states, " |
|
|
"Tensor router_indices, " |
|
|
"Tensor routing_weights, " |
|
|
"Tensor gate_up_proj, " |
|
|
"Tensor gate_up_proj_bias, " |
|
|
"Tensor down_proj, " |
|
|
"Tensor down_proj_bias, " |
|
|
"int expert_capacity, " |
|
|
"int num_experts, " |
|
|
"int top_k) -> Tensor"); |
|
|
ops.impl("experts", torch::kCUDA, &experts_cuda); |
|
|
} |
|
|
|
|
|
REGISTER_EXTENSION( |
|
|
TORCH_EXTENSION_NAME) |