|
|
|
|
|
|
|
|
#include <torch/torch.h> |
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor batch_mm( |
|
|
torch::Tensor x, |
|
|
torch::Tensor weights, |
|
|
torch::Tensor batch_sizes, |
|
|
torch::Tensor output, |
|
|
bool trans_b) { |
|
|
|
|
|
TORCH_CHECK(x.is_cuda(), "x must be on CUDA"); |
|
|
TORCH_CHECK(weights.is_cuda(), "weights must be on CUDA"); |
|
|
TORCH_CHECK(batch_sizes.is_cuda(), "batch_sizes must be on CUDA"); |
|
|
|
|
|
TORCH_CHECK(x.ndimension() == 3, "x must be 3D tensor"); |
|
|
TORCH_CHECK(weights.ndimension() == 3, |
|
|
"weights must be 3D tensor"); |
|
|
TORCH_CHECK(batch_sizes.ndimension() == 1, |
|
|
"batch_sizes must be 1D tensor"); |
|
|
|
|
|
TORCH_CHECK(x.size(0) == weights.size(0) && x.size(0) == batch_sizes.size(0)); |
|
|
TORCH_CHECK(x.size(2) == weights.size(1)); |
|
|
|
|
|
|
|
|
|
|
|
torch::bmm_out(output, x, weights); |
|
|
return output; |
|
|
} |
|
|
|