File size: 1,951 Bytes
281d8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"

TORCH_LIBRARY_EXPAND(
    TORCH_EXTENSION_NAME,
    ops) {
  ops.def("gather("
          "Tensor x, "
          "Tensor indices, "
          "Tensor bins, "
          "Tensor! output, "
          "int E, "
          "int C, "
          "int top_k) -> ()");
  ops.impl("gather", torch::kCUDA, &gather_cuda);

  ops.def("scatter("
          "Tensor src, "
          "Tensor indices, "
          "Tensor bins, "
          "Tensor weights, "
          "Tensor! y, "
          "int T, "
          "int E, "
          "int C, "
          "int top_k) -> ()");
  ops.impl("scatter", torch::kCUDA, &scatter_cuda);

  ops.def("sort("
          "Tensor x, "
          "int end_bit, "
          "Tensor! x_out, "
          "Tensor! iota_out) -> ()");
  ops.impl("sort", torch::kCUDA, &sort_cuda);

  ops.def("bincount_cumsum("
          "Tensor input, "
          "Tensor! output, "
          "int minlength) -> ()");
  ops.impl("bincount_cumsum", torch::kCUDA, &bincount_cumsum_cuda);

  ops.def("index_select_out("
          "Tensor! out, "
          "Tensor input, "
          "Tensor idx_int32) -> Tensor");
  ops.impl("index_select_out", torch::kCUDA, &index_select_out_cuda);

  ops.def("batch_mm("
          "Tensor x, "
          "Tensor weights, "
          "Tensor batch_sizes, "
          "Tensor! output, "
          "bool trans_b=False) -> Tensor");
  ops.impl("batch_mm", torch::kCUDA, &batch_mm);

  ops.def("experts("
          "Tensor hidden_states, "
          "Tensor router_indices, "
          "Tensor routing_weights, "
          "Tensor gate_up_proj, "
          "Tensor gate_up_proj_bias, "
          "Tensor down_proj, "
          "Tensor down_proj_bias, "
          "int expert_capacity, "
          "int num_experts, "
          "int top_k) -> Tensor");
  ops.impl("experts", torch::kCUDA, &experts_cuda);
}

REGISTER_EXTENSION(
    TORCH_EXTENSION_NAME)