File size: 1,951 Bytes
281d8ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
#include <torch/library.h>
#include "registration.h"
#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(
TORCH_EXTENSION_NAME,
ops) {
ops.def("gather("
"Tensor x, "
"Tensor indices, "
"Tensor bins, "
"Tensor! output, "
"int E, "
"int C, "
"int top_k) -> ()");
ops.impl("gather", torch::kCUDA, &gather_cuda);
ops.def("scatter("
"Tensor src, "
"Tensor indices, "
"Tensor bins, "
"Tensor weights, "
"Tensor! y, "
"int T, "
"int E, "
"int C, "
"int top_k) -> ()");
ops.impl("scatter", torch::kCUDA, &scatter_cuda);
ops.def("sort("
"Tensor x, "
"int end_bit, "
"Tensor! x_out, "
"Tensor! iota_out) -> ()");
ops.impl("sort", torch::kCUDA, &sort_cuda);
ops.def("bincount_cumsum("
"Tensor input, "
"Tensor! output, "
"int minlength) -> ()");
ops.impl("bincount_cumsum", torch::kCUDA, &bincount_cumsum_cuda);
ops.def("index_select_out("
"Tensor! out, "
"Tensor input, "
"Tensor idx_int32) -> Tensor");
ops.impl("index_select_out", torch::kCUDA, &index_select_out_cuda);
ops.def("batch_mm("
"Tensor x, "
"Tensor weights, "
"Tensor batch_sizes, "
"Tensor! output, "
"bool trans_b=False) -> Tensor");
ops.impl("batch_mm", torch::kCUDA, &batch_mm);
ops.def("experts("
"Tensor hidden_states, "
"Tensor router_indices, "
"Tensor routing_weights, "
"Tensor gate_up_proj, "
"Tensor gate_up_proj_bias, "
"Tensor down_proj, "
"Tensor down_proj_bias, "
"int expert_capacity, "
"int num_experts, "
"int top_k) -> Tensor");
ops.impl("experts", torch::kCUDA, &experts_cuda);
}
REGISTER_EXTENSION(
TORCH_EXTENSION_NAME) |