|
|
#pragma once |
|
|
|
|
|
#include <torch/torch.h> |
|
|
|
|
|
void gather_cuda(torch::Tensor const &x, |
|
|
torch::Tensor const &indices, |
|
|
torch::Tensor const &bins, |
|
|
torch::Tensor &output, |
|
|
int64_t E, |
|
|
int64_t C, |
|
|
int64_t top_k); |
|
|
|
|
|
void scatter_cuda(torch::Tensor const &src, |
|
|
torch::Tensor const &indices, |
|
|
torch::Tensor const &bins, |
|
|
torch::Tensor const &weights, |
|
|
torch::Tensor &y, |
|
|
int64_t T, |
|
|
int64_t E, |
|
|
int64_t C, |
|
|
int64_t top_k); |
|
|
|
|
|
void sort_cuda(torch::Tensor x, |
|
|
int64_t end_bit, |
|
|
torch::Tensor x_out, |
|
|
torch::Tensor iota_out); |
|
|
|
|
|
void bincount_cumsum_cuda(torch::Tensor input, |
|
|
torch::Tensor &output, |
|
|
int64_t minlength); |
|
|
|
|
|
torch::Tensor index_select_out_cuda(torch::Tensor out, |
|
|
torch::Tensor in, |
|
|
torch::Tensor idx_int32); |
|
|
|
|
|
torch::Tensor |
|
|
batch_mm(torch::Tensor x, |
|
|
torch::Tensor weights, |
|
|
torch::Tensor batch_sizes, |
|
|
torch::Tensor output, |
|
|
bool trans_b = false |
|
|
); |
|
|
|
|
|
torch::Tensor experts_cuda( |
|
|
torch::Tensor hidden_states, |
|
|
torch::Tensor router_indices, |
|
|
torch::Tensor routing_weights, |
|
|
torch::Tensor gate_up_proj, |
|
|
torch::Tensor gate_up_proj_bias, |
|
|
torch::Tensor down_proj, |
|
|
torch::Tensor down_proj_bias, |
|
|
int64_t expert_capacity, |
|
|
int64_t num_experts, |
|
|
int64_t top_k |
|
|
); |
|
|
|
|
|
std::vector<torch::Tensor> experts_backward_cuda( |
|
|
const torch::Tensor &grad_out, |
|
|
const torch::Tensor &hidden_states, |
|
|
const torch::Tensor &router_indices, |
|
|
const torch::Tensor &routing_weights, |
|
|
const torch::Tensor |
|
|
&gate_up_proj, |
|
|
const torch::Tensor |
|
|
&gate_up_proj_bias, |
|
|
const torch::Tensor &down_proj, |
|
|
const torch::Tensor &down_proj_bias, |
|
|
int64_t expert_capacity, |
|
|
int64_t num_experts, |
|
|
int64_t top_k |
|
|
); |
|
|
|