File size: 1,086 Bytes
281d8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// csrc/batch_mm.cu

#include <torch/torch.h>

// Simply use a standard bmm for now but this can be adapted for
// faster batched expert matrix multiply if needed
torch::Tensor batch_mm(
    torch::Tensor x,
    torch::Tensor weights,
    torch::Tensor batch_sizes,
    torch::Tensor output,
    bool trans_b) {
  // Validate inputs
  TORCH_CHECK(x.is_cuda(), "x must be on CUDA");
  TORCH_CHECK(weights.is_cuda(), "weights must be on CUDA");
  TORCH_CHECK(batch_sizes.is_cuda(), "batch_sizes must be on CUDA");

  TORCH_CHECK(x.ndimension() == 3, "x must be 3D tensor"); // [E, C, H]
  TORCH_CHECK(weights.ndimension() == 3,
              "weights must be 3D tensor"); // [E, H, H_out]
  TORCH_CHECK(batch_sizes.ndimension() == 1,
              "batch_sizes must be 1D tensor"); // [E]

  TORCH_CHECK(x.size(0) == weights.size(0) && x.size(0) == batch_sizes.size(0));
  TORCH_CHECK(x.size(2) == weights.size(1)); // H dimension match

  // For now, just fall back to bmm to test the binding
  // torch::bmm(x, weights, output);
  torch::bmm_out(output, x, weights);
  return output;
}