Add fusion (#3)
Browse files* vectorize and optimized block reduce
* add benchmark test (w/o readme update)
* implemented fused_mul_poly_norm
Signed-off-by: taehyun <[email protected]>
* add_rms_norm added
* deleted backward pass on fused add rms norm, split test and benchmarks
Signed-off-by: taehyun <[email protected]>
* refactored benchmarks
* add readme
* fix readme
* add build
* fix readme
* fix readme2
* add mi250 results
* highlight used our kernel for baseline in fused performance
* applied yapf
---------
Signed-off-by: taehyun <[email protected]>
Co-authored-by: taehyun <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +178 -7
- activation/block_reduce.h +0 -21
- activation/fused_add_rms_norm.cu +157 -0
- activation/fused_mul_poly_norm.cu +642 -0
- activation/poly_norm.cu +88 -61
- activation/rms_norm.cu +243 -51
- benchmarks/README.md +35 -0
- benchmarks/cases/__init__.py +1 -0
- benchmarks/cases/add_rms.py +55 -0
- benchmarks/cases/mul_poly.py +53 -0
- benchmarks/cases/poly.py +58 -0
- benchmarks/cases/rms.py +35 -0
- benchmarks/common/__init__.py +1 -0
- benchmarks/common/bench_framework.py +220 -0
- benchmarks/common/diff_engine.py +85 -0
- benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png +0 -0
- benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png +0 -0
- benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png +0 -0
- benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png +0 -0
- benchmarks/plots/h100/poly/plot_poly-bwd-perf.png +0 -0
- benchmarks/plots/h100/poly/plot_poly-fwd-perf.png +0 -0
- benchmarks/plots/h100/rms/plot_rms-bwd-perf.png +0 -0
- benchmarks/plots/h100/rms/plot_rms-fwd-perf.png +0 -0
- benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png +0 -0
- benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png +0 -0
- benchmarks/plots/mi250/mul_poly/plot_mul_poly-bwd-perf.png +0 -0
- benchmarks/plots/mi250/mul_poly/plot_mul_poly-fwd-perf.png +0 -0
- benchmarks/plots/mi250/poly/plot_poly-bwd-perf.png +0 -0
- benchmarks/plots/mi250/poly/plot_poly-fwd-perf.png +0 -0
- benchmarks/plots/mi250/rms/plot_rms-bwd-perf.png +0 -0
- benchmarks/plots/mi250/rms/plot_rms-fwd-perf.png +0 -0
- benchmarks/run_cases.py +143 -0
- build.toml +4 -2
- build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py +24 -2
- tests/perf.png → build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_20250907180255.abi3.so +2 -2
- build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py +48 -2
- build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py +37 -0
- build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py +47 -0
- build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py +24 -2
- build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so +3 -0
- build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py +3 -3
- build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py +48 -2
- build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py +37 -0
- build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py +47 -0
- build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py +24 -2
- build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so +3 -0
- build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py +48 -2
- build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py +37 -0
README.md
CHANGED
|
@@ -11,6 +11,37 @@ Activation is a python package that contains custom CUDA-based activation kernel
|
|
| 11 |
- Currently implemented
|
| 12 |
- [PolyNorm](https://arxiv.org/html/2411.03884v1)
|
| 13 |
- [RMSNorm](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
## Usage
|
| 16 |
|
|
@@ -28,18 +59,158 @@ print(poly_norm(x))
|
|
| 28 |
```
|
| 29 |
|
| 30 |
## Performance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
### PolyNorm
|
| 33 |
|
| 34 |
-
|
| 35 |
-
- You can reproduce the results with:
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
|
| 44 |
## Pre-commit Hooks
|
| 45 |
|
|
|
|
| 11 |
- Currently implemented
|
| 12 |
- [PolyNorm](https://arxiv.org/html/2411.03884v1)
|
| 13 |
- [RMSNorm](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html)
|
| 14 |
+
- **FusedAddRMSNorm**
|
| 15 |
+
|
| 16 |
+
A fused operator that combines **residual addition** (`x + residual`) with **RMSNorm** in a single kernel.
|
| 17 |
+
- Instead of:
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
y = x + residual
|
| 21 |
+
out = rms_norm(y, weight, eps)
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
- Fused as:
|
| 25 |
+
|
| 26 |
+
```python
|
| 27 |
+
out = fused_add_rms_norm(x, residual, weight, eps)
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
- **FusedMulPolyNorm**
|
| 31 |
+
|
| 32 |
+
A fused operator that combines **PolyNorm** with an **element-wise multiplication** by a Tensor.
|
| 33 |
+
- Instead of:
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
y = poly_norm(x, weight, bias, eps)
|
| 37 |
+
out = y * a
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
- Fused as:
|
| 41 |
+
|
| 42 |
+
```python
|
| 43 |
+
out = fused_mul_poly_norm(x, a, weight, bias, eps)
|
| 44 |
+
```
|
| 45 |
|
| 46 |
## Usage
|
| 47 |
|
|
|
|
| 59 |
```
|
| 60 |
|
| 61 |
## Performance
|
| 62 |
+
- Test cases are from the Motif LLM
|
| 63 |
+
- The results can be reproduced using the provided benchmarking tools.
|
| 64 |
+
- For details on how to use the benchmarking tools, please refer to the [benchmarks README](./benchmarks/README.md).
|
| 65 |
+
- The benchmark results may show fluctuations, especially in the backward pass and when the dimension size is small.
|
| 66 |
+
|
| 67 |
+
### RMSNorm
|
| 68 |
+
|
| 69 |
+
#### H100 Results
|
| 70 |
+
|
| 71 |
+
<details>
|
| 72 |
+
<summary>Forward Performance</summary>
|
| 73 |
+
|
| 74 |
+

|
| 75 |
+
|
| 76 |
+
</details>
|
| 77 |
+
|
| 78 |
+
<details>
|
| 79 |
+
<summary>Backward Performance</summary>
|
| 80 |
+
|
| 81 |
+

|
| 82 |
+
|
| 83 |
+
</details>
|
| 84 |
+
|
| 85 |
+
#### MI250 Results
|
| 86 |
+
|
| 87 |
+
<details>
|
| 88 |
+
<summary>Forward Performance</summary>
|
| 89 |
+
|
| 90 |
+

|
| 91 |
+
|
| 92 |
+
</details>
|
| 93 |
+
|
| 94 |
+
<details>
|
| 95 |
+
<summary>Backward Performance</summary>
|
| 96 |
+
|
| 97 |
+

|
| 98 |
+
|
| 99 |
+
</details>
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
### FusedAddRMSNorm
|
| 104 |
+
|
| 105 |
+
> [!NOTE]
|
| 106 |
+
> For fusion case performance, the **non-fused baseline** was implemented with our **custom kernels**.
|
| 107 |
+
|
| 108 |
+
#### H100 Results
|
| 109 |
+
|
| 110 |
+
<details>
|
| 111 |
+
<summary>Forward Performance</summary>
|
| 112 |
+
|
| 113 |
+

|
| 114 |
+
|
| 115 |
+
</details>
|
| 116 |
+
|
| 117 |
+
<details>
|
| 118 |
+
<summary>Backward Performance</summary>
|
| 119 |
+
|
| 120 |
+

|
| 121 |
+
|
| 122 |
+
</details>
|
| 123 |
+
|
| 124 |
+
#### MI250 Results
|
| 125 |
+
|
| 126 |
+
<details>
|
| 127 |
+
<summary>Forward Performance</summary>
|
| 128 |
+
|
| 129 |
+

|
| 130 |
+
|
| 131 |
+
</details>
|
| 132 |
+
|
| 133 |
+
<details>
|
| 134 |
+
<summary>Backward Performance</summary>
|
| 135 |
+
|
| 136 |
+

|
| 137 |
+
|
| 138 |
+
</details>
|
| 139 |
+
|
| 140 |
+
---
|
| 141 |
|
| 142 |
### PolyNorm
|
| 143 |
|
| 144 |
+
#### H100 Results
|
|
|
|
| 145 |
|
| 146 |
+
<details>
|
| 147 |
+
<summary>Forward Performance</summary>
|
| 148 |
+
|
| 149 |
+

|
| 150 |
+
|
| 151 |
+
</details>
|
| 152 |
+
|
| 153 |
+
<details>
|
| 154 |
+
<summary>Backward Performance</summary>
|
| 155 |
+
|
| 156 |
+

|
| 157 |
+
|
| 158 |
+
</details>
|
| 159 |
+
|
| 160 |
+
#### MI250 Results
|
| 161 |
+
|
| 162 |
+
<details>
|
| 163 |
+
<summary>Forward Performance</summary>
|
| 164 |
+
|
| 165 |
+

|
| 166 |
+
|
| 167 |
+
</details>
|
| 168 |
+
|
| 169 |
+
<details>
|
| 170 |
+
<summary>Backward Performance</summary>
|
| 171 |
+
|
| 172 |
+

|
| 173 |
+
|
| 174 |
+
</details>
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
### FusedMulPolyNorm
|
| 179 |
+
|
| 180 |
+
> [!NOTE]
|
| 181 |
+
> For fusion case performance, the **non-fused baseline** was implemented with our **custom kernels**.
|
| 182 |
+
|
| 183 |
+
#### H100 Results
|
| 184 |
+
|
| 185 |
+
<details>
|
| 186 |
+
<summary>Forward Performance</summary>
|
| 187 |
+
|
| 188 |
+

|
| 189 |
+
|
| 190 |
+
</details>
|
| 191 |
+
|
| 192 |
+
<details>
|
| 193 |
+
<summary>Backward Performance</summary>
|
| 194 |
+
|
| 195 |
+

|
| 196 |
+
|
| 197 |
+
</details>
|
| 198 |
+
|
| 199 |
+
#### MI250 Results
|
| 200 |
+
|
| 201 |
+
<details>
|
| 202 |
+
<summary>Forward Performance</summary>
|
| 203 |
+
|
| 204 |
+

|
| 205 |
+
|
| 206 |
+
</details>
|
| 207 |
+
|
| 208 |
+
<details>
|
| 209 |
+
<summary>Backward Performance</summary>
|
| 210 |
+
|
| 211 |
+

|
| 212 |
|
| 213 |
+
</details>
|
| 214 |
|
| 215 |
## Pre-commit Hooks
|
| 216 |
|
activation/block_reduce.h
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
namespace motif {
|
| 2 |
-
|
| 3 |
-
template <typename acc_t, int BLOCK_SIZE>
|
| 4 |
-
__device__ acc_t _block_reduce_sum(acc_t *shared, const float val,
|
| 5 |
-
const int d) {
|
| 6 |
-
// TODO: Optimize with warp-level primitives
|
| 7 |
-
__syncthreads();
|
| 8 |
-
|
| 9 |
-
shared[threadIdx.x] = threadIdx.x < d ? val : 0.0f;
|
| 10 |
-
__syncthreads();
|
| 11 |
-
for (int stride = BLOCK_SIZE / 2; stride > 0; stride /= 2) {
|
| 12 |
-
if (threadIdx.x < stride) {
|
| 13 |
-
shared[threadIdx.x] += shared[threadIdx.x + stride];
|
| 14 |
-
}
|
| 15 |
-
__syncthreads();
|
| 16 |
-
}
|
| 17 |
-
|
| 18 |
-
return shared[0];
|
| 19 |
-
}
|
| 20 |
-
|
| 21 |
-
} // namespace motif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
activation/fused_add_rms_norm.cu
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/Functions.h>
|
| 2 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 3 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 4 |
+
#include <torch/all.h>
|
| 5 |
+
|
| 6 |
+
#include <cmath>
|
| 7 |
+
|
| 8 |
+
#include "assert_utils.h"
|
| 9 |
+
#include "atomic_utils.h"
|
| 10 |
+
#include "cuda_compat.h"
|
| 11 |
+
#include "dispatch_utils.h"
|
| 12 |
+
|
| 13 |
+
namespace motif {
|
| 14 |
+
|
| 15 |
+
template <typename type, int N> struct alignas(sizeof(type) * N) type_vec_t {
|
| 16 |
+
type data[N];
|
| 17 |
+
};
|
| 18 |
+
|
| 19 |
+
template <typename scalar_t, typename acc_t, int width>
|
| 20 |
+
__global__ std::enable_if_t<(width > 0)>
|
| 21 |
+
fused_add_rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
| 22 |
+
scalar_t *__restrict__ add_out, // [..., d]
|
| 23 |
+
const scalar_t *__restrict__ input, // [..., d]
|
| 24 |
+
const scalar_t *__restrict__ residual, // [..., d]
|
| 25 |
+
const scalar_t *__restrict__ weight, // [d]
|
| 26 |
+
const float eps, const int d) {
|
| 27 |
+
using vec_t = type_vec_t<scalar_t, width>;
|
| 28 |
+
|
| 29 |
+
const int vec_d = d / width;
|
| 30 |
+
const int64_t vec_offset = blockIdx.x * vec_d;
|
| 31 |
+
const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
|
| 32 |
+
const vec_t *__restrict__ residual_vec =
|
| 33 |
+
reinterpret_cast<const vec_t *>(residual);
|
| 34 |
+
vec_t *__restrict__ add_out_vec = reinterpret_cast<vec_t *>(add_out);
|
| 35 |
+
acc_t sum_square = 0.0f;
|
| 36 |
+
|
| 37 |
+
for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
|
| 38 |
+
vec_t x_vec = input_vec[vec_offset + idx];
|
| 39 |
+
vec_t res_vec = residual_vec[vec_offset + idx];
|
| 40 |
+
vec_t add_vec;
|
| 41 |
+
|
| 42 |
+
#pragma unroll
|
| 43 |
+
for (int i = 0; i < width; ++i) {
|
| 44 |
+
acc_t x = x_vec.data[i] + res_vec.data[i];
|
| 45 |
+
sum_square += x * x;
|
| 46 |
+
add_vec.data[i] = x;
|
| 47 |
+
}
|
| 48 |
+
add_out_vec[vec_offset + idx] = add_vec;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
using BlockReduce = cub::BlockReduce<float, 1024>;
|
| 52 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 53 |
+
|
| 54 |
+
sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
|
| 55 |
+
|
| 56 |
+
__shared__ acc_t s_scale;
|
| 57 |
+
|
| 58 |
+
if (threadIdx.x == 0) {
|
| 59 |
+
s_scale = rsqrtf(sum_square / d + eps);
|
| 60 |
+
}
|
| 61 |
+
__syncthreads();
|
| 62 |
+
|
| 63 |
+
const vec_t *__restrict__ weight_vec =
|
| 64 |
+
reinterpret_cast<const vec_t *>(weight);
|
| 65 |
+
vec_t *__restrict__ output_vec = reinterpret_cast<vec_t *>(out);
|
| 66 |
+
|
| 67 |
+
for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
|
| 68 |
+
vec_t x_vec = add_out_vec[vec_offset + idx];
|
| 69 |
+
vec_t w_vec = weight_vec[idx];
|
| 70 |
+
vec_t y_vec;
|
| 71 |
+
|
| 72 |
+
#pragma unroll
|
| 73 |
+
for (int i = 0; i < width; ++i) {
|
| 74 |
+
acc_t x = x_vec.data[i];
|
| 75 |
+
acc_t w = w_vec.data[i];
|
| 76 |
+
|
| 77 |
+
y_vec.data[i] = w * x * s_scale;
|
| 78 |
+
}
|
| 79 |
+
output_vec[vec_offset + idx] = y_vec;
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template <typename scalar_t, typename acc_t, int width>
|
| 84 |
+
__global__ std::enable_if_t<(width == 0)>
|
| 85 |
+
fused_add_rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
| 86 |
+
scalar_t *__restrict__ add_out, // [..., d]
|
| 87 |
+
const scalar_t *__restrict__ input, // [..., d]
|
| 88 |
+
const scalar_t *__restrict__ residual, // [..., d]
|
| 89 |
+
const scalar_t *__restrict__ weight, // [d]
|
| 90 |
+
const float eps, const int d) {
|
| 91 |
+
const int64_t token_idx = blockIdx.x;
|
| 92 |
+
const int64_t vec_idx = threadIdx.x;
|
| 93 |
+
acc_t sum_square = 0.0f;
|
| 94 |
+
|
| 95 |
+
for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
|
| 96 |
+
acc_t x = input[token_idx * d + idx] + residual[token_idx * d + idx];
|
| 97 |
+
sum_square += x * x;
|
| 98 |
+
add_out[token_idx * d + idx] = x;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
using BlockReduce = cub::BlockReduce<float, 1024>;
|
| 102 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 103 |
+
|
| 104 |
+
sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
|
| 105 |
+
|
| 106 |
+
__shared__ acc_t s_scale;
|
| 107 |
+
|
| 108 |
+
if (vec_idx == 0) {
|
| 109 |
+
s_scale = rsqrtf(sum_square / d + eps);
|
| 110 |
+
}
|
| 111 |
+
__syncthreads();
|
| 112 |
+
|
| 113 |
+
for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
|
| 114 |
+
acc_t x = add_out[token_idx * d + idx];
|
| 115 |
+
acc_t w = weight[idx];
|
| 116 |
+
out[token_idx * d + idx] = w * x * s_scale;
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
} // namespace motif
|
| 121 |
+
|
| 122 |
+
#define LAUNCH_RMS_NORM(width) \
|
| 123 |
+
MOTIF_DISPATCH_FLOATING_TYPES( \
|
| 124 |
+
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
| 125 |
+
motif::fused_add_rms_norm_kernel<scalar_t, float, width> \
|
| 126 |
+
<<<grid, block, 0, stream>>>( \
|
| 127 |
+
out.data_ptr<scalar_t>(), add_out.data_ptr<scalar_t>(), \
|
| 128 |
+
input.data_ptr<scalar_t>(), residual.data_ptr<scalar_t>(), \
|
| 129 |
+
weight.data_ptr<scalar_t>(), eps, d); \
|
| 130 |
+
});
|
| 131 |
+
|
| 132 |
+
void fused_add_rms_norm(torch::Tensor &out, // [..., d]
|
| 133 |
+
torch::Tensor &add_out, // [..., d]
|
| 134 |
+
const torch::Tensor &input, // [..., d]
|
| 135 |
+
const torch::Tensor &residual, // [..., d]
|
| 136 |
+
const torch::Tensor &weight, // [d]
|
| 137 |
+
double eps) {
|
| 138 |
+
AssertTensorShapeEqual(input, residual, "input", "residual");
|
| 139 |
+
AssertTensorShapeEqual(input, out, "input", "out");
|
| 140 |
+
AssertTensorShapeEqual(input, add_out, "input", "result");
|
| 141 |
+
AssertTensorNotNull(weight, "weight");
|
| 142 |
+
// TODO shape check
|
| 143 |
+
|
| 144 |
+
int d = input.size(-1);
|
| 145 |
+
int64_t num_tokens = input.numel() / input.size(-1);
|
| 146 |
+
dim3 grid(num_tokens);
|
| 147 |
+
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
| 148 |
+
dim3 block(std::min(d, max_block_size));
|
| 149 |
+
|
| 150 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 151 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 152 |
+
if (d % 8 == 0) {
|
| 153 |
+
LAUNCH_RMS_NORM(8);
|
| 154 |
+
} else {
|
| 155 |
+
LAUNCH_RMS_NORM(0);
|
| 156 |
+
}
|
| 157 |
+
}
|
activation/fused_mul_poly_norm.cu
ADDED
|
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/Functions.h>
|
| 2 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 3 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 4 |
+
#include <torch/all.h>
|
| 5 |
+
|
| 6 |
+
#include <cmath>
|
| 7 |
+
|
| 8 |
+
#include "assert_utils.h"
|
| 9 |
+
#include "atomic_utils.h"
|
| 10 |
+
#include "cuda_compat.h"
|
| 11 |
+
#include "dispatch_utils.h"
|
| 12 |
+
|
| 13 |
+
namespace motif {
|
| 14 |
+
|
| 15 |
+
template <typename type, int N> struct alignas(sizeof(type) * N) type_vec_t {
|
| 16 |
+
type data[N];
|
| 17 |
+
};
|
| 18 |
+
|
| 19 |
+
struct SumOp {
|
| 20 |
+
__device__ float3 operator()(const float3 &a, const float3 &b) const {
|
| 21 |
+
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
| 22 |
+
}
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
struct SumOp4 {
|
| 26 |
+
__device__ float4 operator()(const float4 &a, const float4 &b) const {
|
| 27 |
+
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
| 28 |
+
}
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
template <typename scalar_t, typename acc_t, int width>
|
| 32 |
+
__global__ std::enable_if_t<(width > 0)>
|
| 33 |
+
fused_mul_poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
| 34 |
+
const scalar_t *__restrict__ input, // [..., d]
|
| 35 |
+
const scalar_t *__restrict__ mul, // [..., d]
|
| 36 |
+
const scalar_t *__restrict__ weight, // [3]
|
| 37 |
+
const scalar_t *__restrict__ bias, // [1]
|
| 38 |
+
const float eps, const int d) {
|
| 39 |
+
using vec_t = type_vec_t<scalar_t, width>;
|
| 40 |
+
|
| 41 |
+
const int vec_d = d / width;
|
| 42 |
+
const int64_t vec_offset = blockIdx.x * vec_d;
|
| 43 |
+
const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
|
| 44 |
+
|
| 45 |
+
acc_t sum2 = 0.0f;
|
| 46 |
+
acc_t sum4 = 0.0f;
|
| 47 |
+
acc_t sum6 = 0.0f;
|
| 48 |
+
|
| 49 |
+
for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
|
| 50 |
+
vec_t x_vec = input_vec[vec_offset + idx];
|
| 51 |
+
|
| 52 |
+
#pragma unroll
|
| 53 |
+
for (int i = 0; i < width; ++i) {
|
| 54 |
+
acc_t x1 = x_vec.data[i];
|
| 55 |
+
acc_t x2 = x1 * x1;
|
| 56 |
+
acc_t x4 = x2 * x2;
|
| 57 |
+
acc_t x6 = x4 * x2;
|
| 58 |
+
|
| 59 |
+
sum2 += x2;
|
| 60 |
+
sum4 += x4;
|
| 61 |
+
sum6 += x6;
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
| 66 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 67 |
+
|
| 68 |
+
float3 thread_sums = make_float3(sum2, sum4, sum6);
|
| 69 |
+
float3 block_sums =
|
| 70 |
+
BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
|
| 71 |
+
|
| 72 |
+
sum2 = block_sums.x;
|
| 73 |
+
sum4 = block_sums.y;
|
| 74 |
+
sum6 = block_sums.z;
|
| 75 |
+
|
| 76 |
+
__shared__ acc_t s_bias;
|
| 77 |
+
|
| 78 |
+
__shared__ acc_t s_w2_inv_std1;
|
| 79 |
+
__shared__ acc_t s_w1_inv_std2;
|
| 80 |
+
__shared__ acc_t s_w0_inv_std3;
|
| 81 |
+
|
| 82 |
+
if (threadIdx.x == 0) {
|
| 83 |
+
acc_t w0 = weight[0];
|
| 84 |
+
acc_t w1 = weight[1];
|
| 85 |
+
acc_t w2 = weight[2];
|
| 86 |
+
s_bias = bias[0];
|
| 87 |
+
|
| 88 |
+
s_w2_inv_std1 = rsqrtf(sum2 / d + eps) * w2;
|
| 89 |
+
s_w1_inv_std2 = rsqrtf(sum4 / d + eps) * w1;
|
| 90 |
+
s_w0_inv_std3 = rsqrtf(sum6 / d + eps) * w0;
|
| 91 |
+
}
|
| 92 |
+
__syncthreads();
|
| 93 |
+
|
| 94 |
+
acc_t w2_inv_std1 = s_w2_inv_std1;
|
| 95 |
+
acc_t w1_inv_std2 = s_w1_inv_std2;
|
| 96 |
+
acc_t w0_inv_std3 = s_w0_inv_std3;
|
| 97 |
+
acc_t bias_reg = s_bias;
|
| 98 |
+
|
| 99 |
+
vec_t *__restrict__ output_vec = reinterpret_cast<vec_t *>(out);
|
| 100 |
+
const vec_t *__restrict__ mul_vec = reinterpret_cast<const vec_t *>(mul);
|
| 101 |
+
|
| 102 |
+
for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
|
| 103 |
+
vec_t x_vec = input_vec[vec_offset + idx];
|
| 104 |
+
vec_t m_vec = mul_vec[vec_offset + idx];
|
| 105 |
+
vec_t y_vec;
|
| 106 |
+
|
| 107 |
+
#pragma unroll
|
| 108 |
+
for (int i = 0; i < width; ++i) {
|
| 109 |
+
acc_t x1 = x_vec.data[i];
|
| 110 |
+
scalar_t m = m_vec.data[i];
|
| 111 |
+
acc_t x2 = x1 * x1;
|
| 112 |
+
acc_t x3 = x2 * x1;
|
| 113 |
+
scalar_t poly_norm_result =
|
| 114 |
+
x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
|
| 115 |
+
y_vec.data[i] = poly_norm_result * m;
|
| 116 |
+
}
|
| 117 |
+
output_vec[vec_offset + idx] = y_vec;
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
template <typename scalar_t, typename acc_t, int width>
|
| 122 |
+
__global__ std::enable_if_t<(width == 0)>
|
| 123 |
+
fused_mul_poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
| 124 |
+
const scalar_t *__restrict__ input, // [..., d]
|
| 125 |
+
const scalar_t *__restrict__ mul, // [..., d]
|
| 126 |
+
const scalar_t *__restrict__ weight, // [3]
|
| 127 |
+
const scalar_t *__restrict__ bias, // [1]
|
| 128 |
+
const float eps, const int d) {
|
| 129 |
+
const int64_t token_idx = blockIdx.x;
|
| 130 |
+
|
| 131 |
+
acc_t sum2 = 0.0f;
|
| 132 |
+
acc_t sum4 = 0.0f;
|
| 133 |
+
acc_t sum6 = 0.0f;
|
| 134 |
+
|
| 135 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 136 |
+
acc_t x1 = input[token_idx * d + idx];
|
| 137 |
+
acc_t x2 = x1 * x1;
|
| 138 |
+
acc_t x4 = x2 * x2;
|
| 139 |
+
acc_t x6 = x4 * x2;
|
| 140 |
+
|
| 141 |
+
sum2 += x2;
|
| 142 |
+
sum4 += x4;
|
| 143 |
+
sum6 += x6;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
| 147 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 148 |
+
|
| 149 |
+
float3 thread_sums = make_float3(sum2, sum4, sum6);
|
| 150 |
+
float3 block_sums =
|
| 151 |
+
BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
|
| 152 |
+
|
| 153 |
+
sum2 = block_sums.x;
|
| 154 |
+
sum4 = block_sums.y;
|
| 155 |
+
sum6 = block_sums.z;
|
| 156 |
+
|
| 157 |
+
__shared__ acc_t s_bias;
|
| 158 |
+
|
| 159 |
+
__shared__ acc_t s_w2_inv_std1;
|
| 160 |
+
__shared__ acc_t s_w1_inv_std2;
|
| 161 |
+
__shared__ acc_t s_w0_inv_std3;
|
| 162 |
+
|
| 163 |
+
if (threadIdx.x == 0) {
|
| 164 |
+
acc_t w0 = weight[0];
|
| 165 |
+
acc_t w1 = weight[1];
|
| 166 |
+
acc_t w2 = weight[2];
|
| 167 |
+
s_bias = bias[0];
|
| 168 |
+
|
| 169 |
+
s_w2_inv_std1 = rsqrtf(sum2 / d + eps) * w2;
|
| 170 |
+
s_w1_inv_std2 = rsqrtf(sum4 / d + eps) * w1;
|
| 171 |
+
s_w0_inv_std3 = rsqrtf(sum6 / d + eps) * w0;
|
| 172 |
+
}
|
| 173 |
+
__syncthreads();
|
| 174 |
+
|
| 175 |
+
acc_t w2_inv_std1 = s_w2_inv_std1;
|
| 176 |
+
acc_t w1_inv_std2 = s_w1_inv_std2;
|
| 177 |
+
acc_t w0_inv_std3 = s_w0_inv_std3;
|
| 178 |
+
acc_t bias_reg = s_bias;
|
| 179 |
+
|
| 180 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 181 |
+
acc_t x1 = input[token_idx * d + idx];
|
| 182 |
+
scalar_t m = mul[token_idx * d + idx];
|
| 183 |
+
acc_t x2 = x1 * x1;
|
| 184 |
+
acc_t x3 = x2 * x1;
|
| 185 |
+
scalar_t poly_norm_result =
|
| 186 |
+
x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
|
| 187 |
+
out[token_idx * d + idx] = poly_norm_result * m;
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
template <typename scalar_t, typename acc_t, int width>
|
| 192 |
+
__global__ std::enable_if_t<(width > 0)> fused_mul_poly_norm_backward_kernel(
|
| 193 |
+
scalar_t *__restrict__ input_grad, // [..., d]
|
| 194 |
+
scalar_t *__restrict__ mul_grad, // [..., d]
|
| 195 |
+
acc_t *__restrict__ temp_weight_grad, // [..., 3]
|
| 196 |
+
acc_t *__restrict__ temp_bias_grad, // [..., 1]
|
| 197 |
+
const scalar_t *__restrict__ output_grad, // [..., d]
|
| 198 |
+
const scalar_t *__restrict__ input, // [..., d]
|
| 199 |
+
const scalar_t *__restrict__ mul, // [..., d]
|
| 200 |
+
const scalar_t *__restrict__ weight, // [3]
|
| 201 |
+
const scalar_t *__restrict__ bias, // [1]
|
| 202 |
+
const float eps, const int d) {
|
| 203 |
+
using vec_t = type_vec_t<scalar_t, width>;
|
| 204 |
+
|
| 205 |
+
const int vec_d = d / width;
|
| 206 |
+
const int64_t vec_offset = blockIdx.x * vec_d;
|
| 207 |
+
const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
|
| 208 |
+
const vec_t *__restrict__ mul_vec = reinterpret_cast<const vec_t *>(mul);
|
| 209 |
+
const vec_t *__restrict__ output_grad_vec =
|
| 210 |
+
reinterpret_cast<const vec_t *>(output_grad);
|
| 211 |
+
|
| 212 |
+
acc_t sum2 = 0.0f;
|
| 213 |
+
acc_t sum4 = 0.0f;
|
| 214 |
+
acc_t sum6 = 0.0f;
|
| 215 |
+
|
| 216 |
+
acc_t sum_dx1 = 0.0f;
|
| 217 |
+
acc_t sum_dx2 = 0.0f;
|
| 218 |
+
acc_t sum_dx3 = 0.0f;
|
| 219 |
+
|
| 220 |
+
for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
|
| 221 |
+
vec_t x_vec = input_vec[vec_offset + idx];
|
| 222 |
+
vec_t dy_fused_vec = output_grad_vec[vec_offset + idx];
|
| 223 |
+
vec_t m_vec = mul_vec[vec_offset + idx];
|
| 224 |
+
|
| 225 |
+
#pragma unroll
|
| 226 |
+
for (int i = 0; i < width; ++i) {
|
| 227 |
+
acc_t x1 = x_vec.data[i];
|
| 228 |
+
acc_t x2 = x1 * x1;
|
| 229 |
+
acc_t x3 = x2 * x1;
|
| 230 |
+
acc_t x4 = x2 * x2;
|
| 231 |
+
acc_t x6 = x3 * x3;
|
| 232 |
+
|
| 233 |
+
sum2 += x2;
|
| 234 |
+
sum4 += x4;
|
| 235 |
+
sum6 += x6;
|
| 236 |
+
|
| 237 |
+
acc_t dy = dy_fused_vec.data[i] * m_vec.data[i];
|
| 238 |
+
|
| 239 |
+
sum_dx1 += dy * x1;
|
| 240 |
+
sum_dx2 += dy * x2;
|
| 241 |
+
sum_dx3 += dy * x3;
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
| 246 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 247 |
+
|
| 248 |
+
float3 thread_sums = make_float3(sum2, sum4, sum6);
|
| 249 |
+
float3 block_sums =
|
| 250 |
+
BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
|
| 251 |
+
|
| 252 |
+
sum2 = block_sums.x;
|
| 253 |
+
sum4 = block_sums.y;
|
| 254 |
+
sum6 = block_sums.z;
|
| 255 |
+
|
| 256 |
+
float3 thread_dxs = make_float3(sum_dx1, sum_dx2, sum_dx3);
|
| 257 |
+
__syncthreads();
|
| 258 |
+
float3 block_sum_dxs =
|
| 259 |
+
BlockReduce(reduceStore).Reduce(thread_dxs, SumOp{}, blockDim.x);
|
| 260 |
+
|
| 261 |
+
sum_dx1 = block_sum_dxs.x;
|
| 262 |
+
sum_dx2 = block_sum_dxs.y;
|
| 263 |
+
sum_dx3 = block_sum_dxs.z;
|
| 264 |
+
|
| 265 |
+
__shared__ acc_t s_mean2;
|
| 266 |
+
__shared__ acc_t s_mean4;
|
| 267 |
+
__shared__ acc_t s_mean6;
|
| 268 |
+
__shared__ acc_t s_sdx1;
|
| 269 |
+
__shared__ acc_t s_sdx2;
|
| 270 |
+
__shared__ acc_t s_sdx3;
|
| 271 |
+
|
| 272 |
+
const acc_t inv_d = acc_t(1) / d;
|
| 273 |
+
|
| 274 |
+
if (threadIdx.x == 0) {
|
| 275 |
+
s_mean2 = sum2 * inv_d + eps;
|
| 276 |
+
s_mean4 = sum4 * inv_d + eps;
|
| 277 |
+
s_mean6 = sum6 * inv_d + eps;
|
| 278 |
+
|
| 279 |
+
s_sdx1 = sum_dx1 * inv_d;
|
| 280 |
+
s_sdx2 = sum_dx2 * inv_d;
|
| 281 |
+
s_sdx3 = sum_dx3 * inv_d;
|
| 282 |
+
}
|
| 283 |
+
__syncthreads();
|
| 284 |
+
|
| 285 |
+
acc_t w0 = weight[0];
|
| 286 |
+
acc_t w1 = weight[1];
|
| 287 |
+
acc_t w2 = weight[2];
|
| 288 |
+
acc_t bias_reg = bias[0];
|
| 289 |
+
|
| 290 |
+
acc_t mean2 = s_mean2;
|
| 291 |
+
acc_t mean4 = s_mean4;
|
| 292 |
+
acc_t mean6 = s_mean6;
|
| 293 |
+
acc_t sdx1 = s_sdx1;
|
| 294 |
+
acc_t sdx2 = s_sdx2;
|
| 295 |
+
acc_t sdx3 = s_sdx3;
|
| 296 |
+
|
| 297 |
+
acc_t inv_std1 = rsqrtf(mean2);
|
| 298 |
+
acc_t inv_std2 = rsqrtf(mean4);
|
| 299 |
+
acc_t inv_std3 = rsqrtf(mean6);
|
| 300 |
+
|
| 301 |
+
acc_t w2_inv_std1 = inv_std1 * w2;
|
| 302 |
+
acc_t w1_inv_std2 = inv_std2 * w1;
|
| 303 |
+
acc_t w0_inv_std3 = inv_std3 * w0;
|
| 304 |
+
|
| 305 |
+
// inv_std / mean == powf(mean, -1.5)
|
| 306 |
+
acc_t c1 = w2_inv_std1 / mean2;
|
| 307 |
+
acc_t c2 = acc_t(2) * w1_inv_std2 / mean4;
|
| 308 |
+
acc_t c3 = acc_t(3) * w0_inv_std3 / mean6;
|
| 309 |
+
|
| 310 |
+
acc_t sum_dy = 0;
|
| 311 |
+
acc_t sum_dw0 = 0;
|
| 312 |
+
acc_t sum_dw1 = 0;
|
| 313 |
+
acc_t sum_dw2 = 0;
|
| 314 |
+
|
| 315 |
+
vec_t *__restrict__ input_grad_vec = reinterpret_cast<vec_t *>(input_grad);
|
| 316 |
+
vec_t *__restrict__ mul_grad_vec = reinterpret_cast<vec_t *>(mul_grad);
|
| 317 |
+
|
| 318 |
+
for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
|
| 319 |
+
vec_t x_vec = input_vec[vec_offset + idx];
|
| 320 |
+
vec_t dy_fused_vec = output_grad_vec[vec_offset + idx];
|
| 321 |
+
vec_t m_vec = mul_vec[vec_offset + idx];
|
| 322 |
+
vec_t dx_vec;
|
| 323 |
+
vec_t dm_vec;
|
| 324 |
+
|
| 325 |
+
#pragma unroll
|
| 326 |
+
for (int i = 0; i < width; ++i) {
|
| 327 |
+
acc_t x1 = x_vec.data[i];
|
| 328 |
+
acc_t x2 = x1 * x1;
|
| 329 |
+
acc_t x3 = x2 * x1;
|
| 330 |
+
acc_t dy = dy_fused_vec.data[i] * m_vec.data[i];
|
| 331 |
+
|
| 332 |
+
// For register optimization, the order of the following logic matters.
|
| 333 |
+
// The input_grad related logic must be placed at the very end.
|
| 334 |
+
sum_dy += dy;
|
| 335 |
+
sum_dw0 += dy * (x3 * inv_std3);
|
| 336 |
+
sum_dw1 += dy * (x2 * inv_std2);
|
| 337 |
+
sum_dw2 += dy * (x1 * inv_std1);
|
| 338 |
+
|
| 339 |
+
if (mul_grad) {
|
| 340 |
+
scalar_t poly_norm_result =
|
| 341 |
+
x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
|
| 342 |
+
dm_vec.data[i] = poly_norm_result * dy_fused_vec.data[i];
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
if (input_grad) {
|
| 346 |
+
acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3);
|
| 347 |
+
acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2);
|
| 348 |
+
acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1);
|
| 349 |
+
dx_vec.data[i] = dx1 + dx2 + dx3;
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
if (input_grad) {
|
| 354 |
+
input_grad_vec[vec_offset + idx] = dx_vec;
|
| 355 |
+
}
|
| 356 |
+
if (mul_grad) {
|
| 357 |
+
mul_grad_vec[vec_offset + idx] = dm_vec;
|
| 358 |
+
}
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
using BlockReduce4 = cub::BlockReduce<float4, 1024>;
|
| 362 |
+
__shared__ typename BlockReduce4::TempStorage reduceStore4;
|
| 363 |
+
|
| 364 |
+
float4 thread_sum_ds = make_float4(sum_dy, sum_dw0, sum_dw1, sum_dw2);
|
| 365 |
+
float4 block_sum_ds =
|
| 366 |
+
BlockReduce4(reduceStore4).Reduce(thread_sum_ds, SumOp4{}, blockDim.x);
|
| 367 |
+
|
| 368 |
+
sum_dy = block_sum_ds.x;
|
| 369 |
+
sum_dw0 = block_sum_ds.y;
|
| 370 |
+
sum_dw1 = block_sum_ds.z;
|
| 371 |
+
sum_dw2 = block_sum_ds.w;
|
| 372 |
+
|
| 373 |
+
if (threadIdx.x == 0) {
|
| 374 |
+
temp_bias_grad[blockIdx.x] = sum_dy;
|
| 375 |
+
temp_weight_grad[blockIdx.x * 3 + 0] = sum_dw0;
|
| 376 |
+
temp_weight_grad[blockIdx.x * 3 + 1] = sum_dw1;
|
| 377 |
+
temp_weight_grad[blockIdx.x * 3 + 2] = sum_dw2;
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
template <typename scalar_t, typename acc_t, int width>
|
| 382 |
+
__global__ std::enable_if_t<(width == 0)> fused_mul_poly_norm_backward_kernel(
|
| 383 |
+
scalar_t *__restrict__ input_grad, // [..., d]
|
| 384 |
+
scalar_t *__restrict__ mul_grad, // [..., d]
|
| 385 |
+
acc_t *__restrict__ temp_weight_grad, // [..., 3]
|
| 386 |
+
acc_t *__restrict__ temp_bias_grad, // [..., 1]
|
| 387 |
+
const scalar_t *__restrict__ output_grad, // [..., d]
|
| 388 |
+
const scalar_t *__restrict__ input, // [..., d]
|
| 389 |
+
const scalar_t *__restrict__ mul, // [..., d]
|
| 390 |
+
const scalar_t *__restrict__ weight, // [3]
|
| 391 |
+
const scalar_t *__restrict__ bias, // [1]
|
| 392 |
+
const float eps, const int d) {
|
| 393 |
+
const int64_t token_idx = blockIdx.x;
|
| 394 |
+
|
| 395 |
+
acc_t sum2 = 0.0f;
|
| 396 |
+
acc_t sum4 = 0.0f;
|
| 397 |
+
acc_t sum6 = 0.0f;
|
| 398 |
+
|
| 399 |
+
acc_t sum_dx1 = 0.0f;
|
| 400 |
+
acc_t sum_dx2 = 0.0f;
|
| 401 |
+
acc_t sum_dx3 = 0.0f;
|
| 402 |
+
|
| 403 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 404 |
+
acc_t dy = output_grad[token_idx * d + idx] * mul[token_idx * d + idx];
|
| 405 |
+
|
| 406 |
+
acc_t x1 = input[token_idx * d + idx];
|
| 407 |
+
acc_t x2 = x1 * x1;
|
| 408 |
+
acc_t x3 = x2 * x1;
|
| 409 |
+
acc_t x4 = x2 * x2;
|
| 410 |
+
acc_t x6 = x3 * x3;
|
| 411 |
+
|
| 412 |
+
sum2 += x2;
|
| 413 |
+
sum4 += x4;
|
| 414 |
+
sum6 += x6;
|
| 415 |
+
|
| 416 |
+
sum_dx1 += dy * x1;
|
| 417 |
+
sum_dx2 += dy * x2;
|
| 418 |
+
sum_dx3 += dy * x3;
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
| 422 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 423 |
+
|
| 424 |
+
float3 thread_sums = make_float3(sum2, sum4, sum6);
|
| 425 |
+
float3 block_sums =
|
| 426 |
+
BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
|
| 427 |
+
|
| 428 |
+
sum2 = block_sums.x;
|
| 429 |
+
sum4 = block_sums.y;
|
| 430 |
+
sum6 = block_sums.z;
|
| 431 |
+
|
| 432 |
+
float3 thread_dxs = make_float3(sum_dx1, sum_dx2, sum_dx3);
|
| 433 |
+
__syncthreads();
|
| 434 |
+
float3 block_sum_dxs =
|
| 435 |
+
BlockReduce(reduceStore).Reduce(thread_dxs, SumOp{}, blockDim.x);
|
| 436 |
+
|
| 437 |
+
sum_dx1 = block_sum_dxs.x;
|
| 438 |
+
sum_dx2 = block_sum_dxs.y;
|
| 439 |
+
sum_dx3 = block_sum_dxs.z;
|
| 440 |
+
|
| 441 |
+
__shared__ acc_t s_mean2;
|
| 442 |
+
__shared__ acc_t s_mean4;
|
| 443 |
+
__shared__ acc_t s_mean6;
|
| 444 |
+
__shared__ acc_t s_sdx1;
|
| 445 |
+
__shared__ acc_t s_sdx2;
|
| 446 |
+
__shared__ acc_t s_sdx3;
|
| 447 |
+
|
| 448 |
+
const acc_t inv_d = acc_t(1) / d;
|
| 449 |
+
|
| 450 |
+
if (threadIdx.x == 0) {
|
| 451 |
+
s_mean2 = sum2 * inv_d + eps;
|
| 452 |
+
s_mean4 = sum4 * inv_d + eps;
|
| 453 |
+
s_mean6 = sum6 * inv_d + eps;
|
| 454 |
+
|
| 455 |
+
s_sdx1 = sum_dx1 * inv_d;
|
| 456 |
+
s_sdx2 = sum_dx2 * inv_d;
|
| 457 |
+
s_sdx3 = sum_dx3 * inv_d;
|
| 458 |
+
}
|
| 459 |
+
__syncthreads();
|
| 460 |
+
|
| 461 |
+
acc_t w0 = weight[0];
|
| 462 |
+
acc_t w1 = weight[1];
|
| 463 |
+
acc_t w2 = weight[2];
|
| 464 |
+
acc_t bias_reg = bias[0];
|
| 465 |
+
|
| 466 |
+
acc_t mean2 = s_mean2;
|
| 467 |
+
acc_t mean4 = s_mean4;
|
| 468 |
+
acc_t mean6 = s_mean6;
|
| 469 |
+
acc_t sdx1 = s_sdx1;
|
| 470 |
+
acc_t sdx2 = s_sdx2;
|
| 471 |
+
acc_t sdx3 = s_sdx3;
|
| 472 |
+
|
| 473 |
+
acc_t inv_std1 = rsqrtf(mean2);
|
| 474 |
+
acc_t inv_std2 = rsqrtf(mean4);
|
| 475 |
+
acc_t inv_std3 = rsqrtf(mean6);
|
| 476 |
+
|
| 477 |
+
acc_t w2_inv_std1 = inv_std1 * w2;
|
| 478 |
+
acc_t w1_inv_std2 = inv_std2 * w1;
|
| 479 |
+
acc_t w0_inv_std3 = inv_std3 * w0;
|
| 480 |
+
|
| 481 |
+
// inv_std / mean == powf(mean, -1.5)
|
| 482 |
+
acc_t c1 = w2_inv_std1 / mean2;
|
| 483 |
+
acc_t c2 = acc_t(2) * w1_inv_std2 / mean4;
|
| 484 |
+
acc_t c3 = acc_t(3) * w0_inv_std3 / mean6;
|
| 485 |
+
|
| 486 |
+
acc_t sum_dy = 0;
|
| 487 |
+
acc_t sum_dw0 = 0;
|
| 488 |
+
acc_t sum_dw1 = 0;
|
| 489 |
+
acc_t sum_dw2 = 0;
|
| 490 |
+
|
| 491 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 492 |
+
scalar_t dy_fused = output_grad[token_idx * d + idx];
|
| 493 |
+
acc_t dy = dy_fused * mul[token_idx * d + idx];
|
| 494 |
+
acc_t x1 = input[token_idx * d + idx];
|
| 495 |
+
acc_t x2 = x1 * x1;
|
| 496 |
+
acc_t x3 = x2 * x1;
|
| 497 |
+
|
| 498 |
+
if (input_grad) {
|
| 499 |
+
acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3);
|
| 500 |
+
acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2);
|
| 501 |
+
acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1);
|
| 502 |
+
input_grad[token_idx * d + idx] = dx1 + dx2 + dx3;
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
if (mul_grad) {
|
| 506 |
+
scalar_t poly_norm_result =
|
| 507 |
+
x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
|
| 508 |
+
mul_grad[token_idx * d + idx] = poly_norm_result * dy_fused;
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
sum_dy += dy;
|
| 512 |
+
sum_dw0 += dy * (x3 * inv_std3);
|
| 513 |
+
sum_dw1 += dy * (x2 * inv_std2);
|
| 514 |
+
sum_dw2 += dy * (x1 * inv_std1);
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
using BlockReduce4 = cub::BlockReduce<float4, 1024>;
|
| 518 |
+
__shared__ typename BlockReduce4::TempStorage reduceStore4;
|
| 519 |
+
|
| 520 |
+
float4 thread_sum_ds = make_float4(sum_dy, sum_dw0, sum_dw1, sum_dw2);
|
| 521 |
+
float4 block_sum_ds =
|
| 522 |
+
BlockReduce4(reduceStore4).Reduce(thread_sum_ds, SumOp4{}, blockDim.x);
|
| 523 |
+
|
| 524 |
+
sum_dy = block_sum_ds.x;
|
| 525 |
+
sum_dw0 = block_sum_ds.y;
|
| 526 |
+
sum_dw1 = block_sum_ds.z;
|
| 527 |
+
sum_dw2 = block_sum_ds.w;
|
| 528 |
+
|
| 529 |
+
if (threadIdx.x == 0) {
|
| 530 |
+
temp_bias_grad[token_idx] = sum_dy;
|
| 531 |
+
temp_weight_grad[token_idx * 3 + 0] = sum_dw0;
|
| 532 |
+
temp_weight_grad[token_idx * 3 + 1] = sum_dw1;
|
| 533 |
+
temp_weight_grad[token_idx * 3 + 2] = sum_dw2;
|
| 534 |
+
}
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
} // namespace motif
|
| 538 |
+
|
| 539 |
+
#define LAUNCH_FUSED_MUL_POLY_NORM(width) \
|
| 540 |
+
MOTIF_DISPATCH_FLOATING_TYPES( \
|
| 541 |
+
input.scalar_type(), "fused_mul_poly_norm_kernel", [&] { \
|
| 542 |
+
motif::fused_mul_poly_norm_kernel<scalar_t, float, width> \
|
| 543 |
+
<<<grid, block, 0, stream>>>( \
|
| 544 |
+
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
|
| 545 |
+
mul.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
|
| 546 |
+
bias.data_ptr<scalar_t>(), eps, d); \
|
| 547 |
+
});
|
| 548 |
+
|
| 549 |
+
void fused_mul_poly_norm(torch::Tensor &out, // [..., d]
|
| 550 |
+
const torch::Tensor &input, // [..., d]
|
| 551 |
+
const torch::Tensor &mul, // [..., d]
|
| 552 |
+
const torch::Tensor &weight, // [3]
|
| 553 |
+
const torch::Tensor &bias, // [1]
|
| 554 |
+
double eps) {
|
| 555 |
+
AssertTensorShapeEqual(input, out, "input", "out");
|
| 556 |
+
AssertTensorShapeEqual(input, mul, "input", "mul");
|
| 557 |
+
AssertTensorNotNull(weight, "weight");
|
| 558 |
+
AssertTensorNotNull(bias, "bias");
|
| 559 |
+
// TODO shape check
|
| 560 |
+
|
| 561 |
+
int d = input.size(-1);
|
| 562 |
+
int64_t num_tokens = input.numel() / d;
|
| 563 |
+
dim3 grid(num_tokens);
|
| 564 |
+
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
| 565 |
+
dim3 block(std::min(d, max_block_size));
|
| 566 |
+
|
| 567 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 568 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 569 |
+
if (d % 8 == 0) {
|
| 570 |
+
LAUNCH_FUSED_MUL_POLY_NORM(8);
|
| 571 |
+
} else {
|
| 572 |
+
LAUNCH_FUSED_MUL_POLY_NORM(0);
|
| 573 |
+
}
|
| 574 |
+
}
|
| 575 |
+
|
| 576 |
+
#define LAUNCH_POLY_NORM_BACKWARD(width) \
|
| 577 |
+
MOTIF_DISPATCH_FLOATING_TYPES( \
|
| 578 |
+
input.scalar_type(), "fused_mul_poly_norm_backward_kernel", [&] { \
|
| 579 |
+
motif::fused_mul_poly_norm_backward_kernel<scalar_t, float, width> \
|
| 580 |
+
<<<grid, block, 0, stream>>>( \
|
| 581 |
+
input_grad.data_ptr<scalar_t>(), \
|
| 582 |
+
mul_grad.data_ptr<scalar_t>(), \
|
| 583 |
+
temp_weight_grad.data_ptr<float>(), \
|
| 584 |
+
temp_bias_grad.data_ptr<float>(), \
|
| 585 |
+
output_grad.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
|
| 586 |
+
mul.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
|
| 587 |
+
bias.data_ptr<scalar_t>(), eps, d); \
|
| 588 |
+
});
|
| 589 |
+
|
| 590 |
+
void fused_mul_poly_norm_backward(torch::Tensor &input_grad, // [..., d]
|
| 591 |
+
torch::Tensor &mul_grad, // [..., d]
|
| 592 |
+
torch::Tensor &weight_grad, // [3]
|
| 593 |
+
torch::Tensor &bias_grad, // [1]
|
| 594 |
+
const torch::Tensor &output_grad, // [..., d]
|
| 595 |
+
const torch::Tensor &input, // [..., d]
|
| 596 |
+
const torch::Tensor &mul, // [..., d]
|
| 597 |
+
const torch::Tensor &weight, // [3]
|
| 598 |
+
const torch::Tensor &bias, // [1]
|
| 599 |
+
double eps) {
|
| 600 |
+
AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
|
| 601 |
+
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
| 602 |
+
AssertTensorShapeEqual(input, mul_grad, "input", "mul_grad");
|
| 603 |
+
AssertTensorShapeEqual(input, mul, "input", "mul");
|
| 604 |
+
AssertTensorNotNull(weight, "weight");
|
| 605 |
+
// TODO shape check
|
| 606 |
+
// weight_grad, bias_grad, mul_grad and input_grad can be nullable
|
| 607 |
+
|
| 608 |
+
int d = input.size(-1);
|
| 609 |
+
int64_t num_tokens = input.numel() / d;
|
| 610 |
+
dim3 grid(num_tokens);
|
| 611 |
+
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
| 612 |
+
dim3 block(std::min(d, max_block_size));
|
| 613 |
+
|
| 614 |
+
torch::Tensor temp_weight_grad =
|
| 615 |
+
torch::empty({num_tokens, 3}, input.options().dtype(torch::kFloat));
|
| 616 |
+
torch::Tensor temp_bias_grad =
|
| 617 |
+
torch::empty({num_tokens, 1}, output_grad.options().dtype(torch::kFloat));
|
| 618 |
+
|
| 619 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 620 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 621 |
+
|
| 622 |
+
if (d % 8 == 0 && input.element_size() == 2) {
|
| 623 |
+
LAUNCH_POLY_NORM_BACKWARD(8);
|
| 624 |
+
} else if (d % 4 == 0 && input.element_size() == 4) {
|
| 625 |
+
LAUNCH_POLY_NORM_BACKWARD(4);
|
| 626 |
+
} else {
|
| 627 |
+
LAUNCH_POLY_NORM_BACKWARD(0);
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
if (bias_grad.defined()) {
|
| 631 |
+
torch::Tensor acc = torch::empty_like(bias_grad, temp_bias_grad.options());
|
| 632 |
+
at::sum_out(acc, temp_bias_grad, {0});
|
| 633 |
+
bias_grad.copy_(acc);
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
if (weight_grad.defined()) {
|
| 637 |
+
torch::Tensor acc =
|
| 638 |
+
torch::empty_like(weight_grad, temp_weight_grad.options());
|
| 639 |
+
at::sum_out(acc, temp_weight_grad, {0});
|
| 640 |
+
weight_grad.copy_(acc);
|
| 641 |
+
}
|
| 642 |
+
}
|
activation/poly_norm.cu
CHANGED
|
@@ -7,7 +7,6 @@
|
|
| 7 |
|
| 8 |
#include "assert_utils.h"
|
| 9 |
#include "atomic_utils.h"
|
| 10 |
-
#include "block_reduce.h"
|
| 11 |
#include "cuda_compat.h"
|
| 12 |
#include "dispatch_utils.h"
|
| 13 |
|
|
@@ -17,6 +16,18 @@ template <typename type, int N> struct alignas(sizeof(type) * N) type_vec_t {
|
|
| 17 |
type data[N];
|
| 18 |
};
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
template <typename scalar_t, typename acc_t, int width>
|
| 21 |
__global__ std::enable_if_t<(width > 0)>
|
| 22 |
poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
|
@@ -39,7 +50,7 @@ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
|
| 39 |
|
| 40 |
#pragma unroll
|
| 41 |
for (int i = 0; i < width; ++i) {
|
| 42 |
-
acc_t x1 =
|
| 43 |
acc_t x2 = x1 * x1;
|
| 44 |
acc_t x4 = x2 * x2;
|
| 45 |
acc_t x6 = x4 * x2;
|
|
@@ -50,14 +61,16 @@ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
|
| 50 |
}
|
| 51 |
}
|
| 52 |
|
| 53 |
-
using BlockReduce = cub::BlockReduce<
|
| 54 |
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
|
| 62 |
__shared__ acc_t s_bias;
|
| 63 |
|
|
@@ -90,14 +103,12 @@ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
|
| 90 |
|
| 91 |
#pragma unroll
|
| 92 |
for (int i = 0; i < width; ++i) {
|
| 93 |
-
acc_t x1 =
|
| 94 |
acc_t x2 = x1 * x1;
|
| 95 |
acc_t x3 = x2 * x1;
|
| 96 |
|
| 97 |
-
|
| 98 |
x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
|
| 99 |
-
|
| 100 |
-
y_vec.data[i] = static_cast<scalar_t>(y);
|
| 101 |
}
|
| 102 |
output_vec[vec_offset + idx] = y_vec;
|
| 103 |
}
|
|
@@ -127,14 +138,16 @@ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
|
| 127 |
sum6 += x6;
|
| 128 |
}
|
| 129 |
|
| 130 |
-
using BlockReduce = cub::BlockReduce<
|
| 131 |
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
| 138 |
|
| 139 |
__shared__ acc_t s_bias;
|
| 140 |
|
|
@@ -199,7 +212,7 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
|
| 199 |
|
| 200 |
#pragma unroll
|
| 201 |
for (int i = 0; i < width; ++i) {
|
| 202 |
-
acc_t x1 =
|
| 203 |
acc_t x2 = x1 * x1;
|
| 204 |
acc_t x3 = x2 * x1;
|
| 205 |
acc_t x4 = x2 * x2;
|
|
@@ -209,7 +222,7 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
|
| 209 |
sum4 += x4;
|
| 210 |
sum6 += x6;
|
| 211 |
|
| 212 |
-
acc_t dy =
|
| 213 |
|
| 214 |
sum_dx1 += dy * x1;
|
| 215 |
sum_dx2 += dy * x2;
|
|
@@ -217,22 +230,25 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
|
| 217 |
}
|
| 218 |
}
|
| 219 |
|
| 220 |
-
using BlockReduce = cub::BlockReduce<
|
| 221 |
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
|
| 227 |
-
__syncthreads();
|
| 228 |
-
sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
__syncthreads();
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
| 236 |
|
| 237 |
__shared__ acc_t s_mean2;
|
| 238 |
__shared__ acc_t s_mean4;
|
|
@@ -288,16 +304,16 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
|
| 288 |
|
| 289 |
#pragma unroll
|
| 290 |
for (int i = 0; i < width; ++i) {
|
| 291 |
-
acc_t x1 =
|
| 292 |
acc_t x2 = x1 * x1;
|
| 293 |
acc_t x3 = x2 * x1;
|
| 294 |
-
acc_t dy =
|
| 295 |
|
| 296 |
if (input_grad) {
|
| 297 |
acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3);
|
| 298 |
acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2);
|
| 299 |
acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1);
|
| 300 |
-
dx_vec.data[i] =
|
| 301 |
}
|
| 302 |
|
| 303 |
sum_dy += dy;
|
|
@@ -311,13 +327,17 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
|
| 311 |
}
|
| 312 |
}
|
| 313 |
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
if (threadIdx.x == 0) {
|
| 323 |
temp_bias_grad[blockIdx.x] = sum_dy;
|
|
@@ -364,22 +384,25 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
|
| 364 |
sum_dx3 += dy * x3;
|
| 365 |
}
|
| 366 |
|
| 367 |
-
using BlockReduce = cub::BlockReduce<
|
| 368 |
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 369 |
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
|
| 374 |
-
__syncthreads();
|
| 375 |
-
sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
|
| 376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
__syncthreads();
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
|
|
|
| 383 |
|
| 384 |
__shared__ acc_t s_mean2;
|
| 385 |
__shared__ acc_t s_mean4;
|
|
@@ -445,13 +468,17 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
|
| 445 |
sum_dw2 += dy * (x1 * inv_std1);
|
| 446 |
}
|
| 447 |
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
|
| 456 |
if (threadIdx.x == 0) {
|
| 457 |
temp_bias_grad[token_idx] = sum_dy;
|
|
|
|
| 7 |
|
| 8 |
#include "assert_utils.h"
|
| 9 |
#include "atomic_utils.h"
|
|
|
|
| 10 |
#include "cuda_compat.h"
|
| 11 |
#include "dispatch_utils.h"
|
| 12 |
|
|
|
|
| 16 |
type data[N];
|
| 17 |
};
|
| 18 |
|
| 19 |
+
struct SumOp {
|
| 20 |
+
__device__ float3 operator()(const float3 &a, const float3 &b) const {
|
| 21 |
+
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
| 22 |
+
}
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
struct SumOp4 {
|
| 26 |
+
__device__ float4 operator()(const float4 &a, const float4 &b) const {
|
| 27 |
+
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
| 28 |
+
}
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
template <typename scalar_t, typename acc_t, int width>
|
| 32 |
__global__ std::enable_if_t<(width > 0)>
|
| 33 |
poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
|
|
|
| 50 |
|
| 51 |
#pragma unroll
|
| 52 |
for (int i = 0; i < width; ++i) {
|
| 53 |
+
acc_t x1 = x_vec.data[i];
|
| 54 |
acc_t x2 = x1 * x1;
|
| 55 |
acc_t x4 = x2 * x2;
|
| 56 |
acc_t x6 = x4 * x2;
|
|
|
|
| 61 |
}
|
| 62 |
}
|
| 63 |
|
| 64 |
+
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
| 65 |
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 66 |
|
| 67 |
+
float3 thread_sums = make_float3(sum2, sum4, sum6);
|
| 68 |
+
float3 block_sums =
|
| 69 |
+
BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
|
| 70 |
+
|
| 71 |
+
sum2 = block_sums.x;
|
| 72 |
+
sum4 = block_sums.y;
|
| 73 |
+
sum6 = block_sums.z;
|
| 74 |
|
| 75 |
__shared__ acc_t s_bias;
|
| 76 |
|
|
|
|
| 103 |
|
| 104 |
#pragma unroll
|
| 105 |
for (int i = 0; i < width; ++i) {
|
| 106 |
+
acc_t x1 = x_vec.data[i];
|
| 107 |
acc_t x2 = x1 * x1;
|
| 108 |
acc_t x3 = x2 * x1;
|
| 109 |
|
| 110 |
+
y_vec.data[i] =
|
| 111 |
x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
|
|
|
|
|
|
|
| 112 |
}
|
| 113 |
output_vec[vec_offset + idx] = y_vec;
|
| 114 |
}
|
|
|
|
| 138 |
sum6 += x6;
|
| 139 |
}
|
| 140 |
|
| 141 |
+
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
| 142 |
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 143 |
|
| 144 |
+
float3 thread_sums = make_float3(sum2, sum4, sum6);
|
| 145 |
+
float3 block_sums =
|
| 146 |
+
BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
|
| 147 |
+
|
| 148 |
+
sum2 = block_sums.x;
|
| 149 |
+
sum4 = block_sums.y;
|
| 150 |
+
sum6 = block_sums.z;
|
| 151 |
|
| 152 |
__shared__ acc_t s_bias;
|
| 153 |
|
|
|
|
| 212 |
|
| 213 |
#pragma unroll
|
| 214 |
for (int i = 0; i < width; ++i) {
|
| 215 |
+
acc_t x1 = x_vec.data[i];
|
| 216 |
acc_t x2 = x1 * x1;
|
| 217 |
acc_t x3 = x2 * x1;
|
| 218 |
acc_t x4 = x2 * x2;
|
|
|
|
| 222 |
sum4 += x4;
|
| 223 |
sum6 += x6;
|
| 224 |
|
| 225 |
+
acc_t dy = dy_vec.data[i];
|
| 226 |
|
| 227 |
sum_dx1 += dy * x1;
|
| 228 |
sum_dx2 += dy * x2;
|
|
|
|
| 230 |
}
|
| 231 |
}
|
| 232 |
|
| 233 |
+
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
| 234 |
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 235 |
|
| 236 |
+
float3 thread_sums = make_float3(sum2, sum4, sum6);
|
| 237 |
+
float3 block_sums =
|
| 238 |
+
BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
+
sum2 = block_sums.x;
|
| 241 |
+
sum4 = block_sums.y;
|
| 242 |
+
sum6 = block_sums.z;
|
| 243 |
+
|
| 244 |
+
float3 thread_dxs = make_float3(sum_dx1, sum_dx2, sum_dx3);
|
| 245 |
__syncthreads();
|
| 246 |
+
float3 block_sum_dxs =
|
| 247 |
+
BlockReduce(reduceStore).Reduce(thread_dxs, SumOp{}, blockDim.x);
|
| 248 |
+
|
| 249 |
+
sum_dx1 = block_sum_dxs.x;
|
| 250 |
+
sum_dx2 = block_sum_dxs.y;
|
| 251 |
+
sum_dx3 = block_sum_dxs.z;
|
| 252 |
|
| 253 |
__shared__ acc_t s_mean2;
|
| 254 |
__shared__ acc_t s_mean4;
|
|
|
|
| 304 |
|
| 305 |
#pragma unroll
|
| 306 |
for (int i = 0; i < width; ++i) {
|
| 307 |
+
acc_t x1 = x_vec.data[i];
|
| 308 |
acc_t x2 = x1 * x1;
|
| 309 |
acc_t x3 = x2 * x1;
|
| 310 |
+
acc_t dy = dy_vec.data[i];
|
| 311 |
|
| 312 |
if (input_grad) {
|
| 313 |
acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3);
|
| 314 |
acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2);
|
| 315 |
acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1);
|
| 316 |
+
dx_vec.data[i] = dx1 + dx2 + dx3;
|
| 317 |
}
|
| 318 |
|
| 319 |
sum_dy += dy;
|
|
|
|
| 327 |
}
|
| 328 |
}
|
| 329 |
|
| 330 |
+
using BlockReduce4 = cub::BlockReduce<float4, 1024>;
|
| 331 |
+
__shared__ typename BlockReduce4::TempStorage reduceStore4;
|
| 332 |
+
|
| 333 |
+
float4 thread_sum_ds = make_float4(sum_dy, sum_dw0, sum_dw1, sum_dw2);
|
| 334 |
+
float4 block_sum_ds =
|
| 335 |
+
BlockReduce4(reduceStore4).Reduce(thread_sum_ds, SumOp4{}, blockDim.x);
|
| 336 |
+
|
| 337 |
+
sum_dy = block_sum_ds.x;
|
| 338 |
+
sum_dw0 = block_sum_ds.y;
|
| 339 |
+
sum_dw1 = block_sum_ds.z;
|
| 340 |
+
sum_dw2 = block_sum_ds.w;
|
| 341 |
|
| 342 |
if (threadIdx.x == 0) {
|
| 343 |
temp_bias_grad[blockIdx.x] = sum_dy;
|
|
|
|
| 384 |
sum_dx3 += dy * x3;
|
| 385 |
}
|
| 386 |
|
| 387 |
+
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
| 388 |
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 389 |
|
| 390 |
+
float3 thread_sums = make_float3(sum2, sum4, sum6);
|
| 391 |
+
float3 block_sums =
|
| 392 |
+
BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
+
sum2 = block_sums.x;
|
| 395 |
+
sum4 = block_sums.y;
|
| 396 |
+
sum6 = block_sums.z;
|
| 397 |
+
|
| 398 |
+
float3 thread_dxs = make_float3(sum_dx1, sum_dx2, sum_dx3);
|
| 399 |
__syncthreads();
|
| 400 |
+
float3 block_sum_dxs =
|
| 401 |
+
BlockReduce(reduceStore).Reduce(thread_dxs, SumOp{}, blockDim.x);
|
| 402 |
+
|
| 403 |
+
sum_dx1 = block_sum_dxs.x;
|
| 404 |
+
sum_dx2 = block_sum_dxs.y;
|
| 405 |
+
sum_dx3 = block_sum_dxs.z;
|
| 406 |
|
| 407 |
__shared__ acc_t s_mean2;
|
| 408 |
__shared__ acc_t s_mean4;
|
|
|
|
| 468 |
sum_dw2 += dy * (x1 * inv_std1);
|
| 469 |
}
|
| 470 |
|
| 471 |
+
using BlockReduce4 = cub::BlockReduce<float4, 1024>;
|
| 472 |
+
__shared__ typename BlockReduce4::TempStorage reduceStore4;
|
| 473 |
+
|
| 474 |
+
float4 thread_sum_ds = make_float4(sum_dy, sum_dw0, sum_dw1, sum_dw2);
|
| 475 |
+
float4 block_sum_ds =
|
| 476 |
+
BlockReduce4(reduceStore4).Reduce(thread_sum_ds, SumOp4{}, blockDim.x);
|
| 477 |
+
|
| 478 |
+
sum_dy = block_sum_ds.x;
|
| 479 |
+
sum_dw0 = block_sum_ds.y;
|
| 480 |
+
sum_dw1 = block_sum_ds.z;
|
| 481 |
+
sum_dw2 = block_sum_ds.w;
|
| 482 |
|
| 483 |
if (threadIdx.x == 0) {
|
| 484 |
temp_bias_grad[token_idx] = sum_dy;
|
activation/rms_norm.cu
CHANGED
|
@@ -7,18 +7,76 @@
|
|
| 7 |
|
| 8 |
#include "assert_utils.h"
|
| 9 |
#include "atomic_utils.h"
|
| 10 |
-
#include "block_reduce.h"
|
| 11 |
#include "cuda_compat.h"
|
| 12 |
#include "dispatch_utils.h"
|
| 13 |
|
| 14 |
namespace motif {
|
| 15 |
|
| 16 |
-
template <typename
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
const scalar_t *__restrict__ weight, // [d]
|
| 20 |
-
const float eps, const int d) {
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
const int64_t token_idx = blockIdx.x;
|
| 23 |
const int64_t vec_idx = threadIdx.x;
|
| 24 |
acc_t sum_square = 0.0f;
|
|
@@ -28,20 +86,123 @@ __global__ void rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
|
| 28 |
sum_square += x * x;
|
| 29 |
}
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
acc_t variance =
|
| 34 |
-
_block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_square, d) / d;
|
| 35 |
-
acc_t scale = rsqrt(variance + eps);
|
| 36 |
for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
|
| 37 |
acc_t x = input[token_idx * d + idx];
|
| 38 |
acc_t w = weight[idx];
|
| 39 |
-
out[token_idx * d + idx] = w * x *
|
| 40 |
}
|
| 41 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
| 46 |
acc_t *__restrict__ temp_weight_grad, // [..., d]
|
| 47 |
const scalar_t *__restrict__ output_grad, // [..., d]
|
|
@@ -61,30 +222,55 @@ rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
|
| 61 |
sum_square += x * x;
|
| 62 |
}
|
| 63 |
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
d_sum =
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
acc_t
|
| 70 |
-
acc_t
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
|
| 74 |
acc_t x = input[token_idx * d + idx];
|
| 75 |
acc_t dy = output_grad[token_idx * d + idx];
|
| 76 |
acc_t w = weight[idx];
|
| 77 |
|
| 78 |
-
input_grad
|
| 79 |
-
|
| 80 |
-
if (temp_weight_grad) {
|
| 81 |
-
temp_weight_grad[token_idx * d + idx] = dy * x * scale;
|
| 82 |
}
|
|
|
|
| 83 |
}
|
| 84 |
}
|
| 85 |
|
| 86 |
} // namespace motif
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
void rms_norm(torch::Tensor &out, // [..., d]
|
| 89 |
const torch::Tensor &input, // [..., d]
|
| 90 |
const torch::Tensor &weight, // [d]
|
|
@@ -93,27 +279,36 @@ void rms_norm(torch::Tensor &out, // [..., d]
|
|
| 93 |
AssertTensorNotNull(weight, "weight");
|
| 94 |
// TODO shape check
|
| 95 |
|
| 96 |
-
constexpr int BLOCK_SIZE = 256;
|
| 97 |
-
|
| 98 |
int d = input.size(-1);
|
| 99 |
int64_t num_tokens = input.numel() / input.size(-1);
|
| 100 |
dim3 grid(num_tokens);
|
| 101 |
-
|
|
|
|
| 102 |
|
| 103 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 104 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
});
|
| 111 |
}
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
void rms_norm_backward(torch::Tensor &input_grad, // [..., d]
|
| 114 |
-
torch::Tensor &weight_grad, // [
|
| 115 |
-
const torch::Tensor &output_grad, // [d]
|
| 116 |
-
const torch::Tensor &input, // [d]
|
| 117 |
const torch::Tensor &weight, // [d]
|
| 118 |
double eps) {
|
| 119 |
AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
|
|
@@ -122,30 +317,27 @@ void rms_norm_backward(torch::Tensor &input_grad, // [..., d]
|
|
| 122 |
// TODO shape check
|
| 123 |
// weight_grad, input_grad can be nullable
|
| 124 |
|
| 125 |
-
constexpr int BLOCK_SIZE = 256;
|
| 126 |
-
|
| 127 |
int d = input.size(-1);
|
| 128 |
int64_t num_tokens = input.numel() / input.size(-1);
|
| 129 |
dim3 grid(num_tokens);
|
| 130 |
-
|
|
|
|
| 131 |
|
| 132 |
torch::Tensor temp_weight_grad =
|
| 133 |
torch::empty({num_tokens, d}, input.options().dtype(torch::kFloat));
|
| 134 |
|
| 135 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 136 |
-
|
| 137 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
input.data_ptr<scalar_t>(),
|
| 145 |
-
weight.data_ptr<scalar_t>(), eps, d);
|
| 146 |
-
});
|
| 147 |
|
| 148 |
if (weight_grad.defined()) {
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
| 150 |
}
|
| 151 |
}
|
|
|
|
| 7 |
|
| 8 |
#include "assert_utils.h"
|
| 9 |
#include "atomic_utils.h"
|
|
|
|
| 10 |
#include "cuda_compat.h"
|
| 11 |
#include "dispatch_utils.h"
|
| 12 |
|
| 13 |
namespace motif {
|
| 14 |
|
| 15 |
+
template <typename type, int N> struct alignas(sizeof(type) * N) type_vec_t {
|
| 16 |
+
type data[N];
|
| 17 |
+
};
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
template <typename scalar_t, typename acc_t, int width>
|
| 20 |
+
__global__ std::enable_if_t<(width > 0)>
|
| 21 |
+
rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
| 22 |
+
const scalar_t *__restrict__ input, // [..., d]
|
| 23 |
+
const scalar_t *__restrict__ weight, // [d]
|
| 24 |
+
const float eps, const int d) {
|
| 25 |
+
using vec_t = type_vec_t<scalar_t, width>;
|
| 26 |
+
|
| 27 |
+
const int vec_d = d / width;
|
| 28 |
+
const int64_t vec_offset = blockIdx.x * vec_d;
|
| 29 |
+
const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
|
| 30 |
+
acc_t sum_square = 0.0f;
|
| 31 |
+
|
| 32 |
+
for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
|
| 33 |
+
vec_t x_vec = input_vec[vec_offset + idx];
|
| 34 |
+
|
| 35 |
+
#pragma unroll
|
| 36 |
+
for (int i = 0; i < width; ++i) {
|
| 37 |
+
acc_t x = x_vec.data[i];
|
| 38 |
+
sum_square += x * x;
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
using BlockReduce = cub::BlockReduce<float, 1024>;
|
| 43 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 44 |
+
|
| 45 |
+
sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
|
| 46 |
+
|
| 47 |
+
__shared__ acc_t s_scale;
|
| 48 |
+
|
| 49 |
+
if (threadIdx.x == 0) {
|
| 50 |
+
s_scale = rsqrtf(sum_square / d + eps);
|
| 51 |
+
}
|
| 52 |
+
__syncthreads();
|
| 53 |
+
|
| 54 |
+
const vec_t *__restrict__ weight_vec =
|
| 55 |
+
reinterpret_cast<const vec_t *>(weight);
|
| 56 |
+
vec_t *__restrict__ output_vec = reinterpret_cast<vec_t *>(out);
|
| 57 |
+
|
| 58 |
+
for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
|
| 59 |
+
vec_t x_vec = input_vec[vec_offset + idx];
|
| 60 |
+
vec_t w_vec = weight_vec[idx];
|
| 61 |
+
vec_t y_vec;
|
| 62 |
+
|
| 63 |
+
#pragma unroll
|
| 64 |
+
for (int i = 0; i < width; ++i) {
|
| 65 |
+
acc_t x = x_vec.data[i];
|
| 66 |
+
acc_t w = w_vec.data[i];
|
| 67 |
+
|
| 68 |
+
y_vec.data[i] = w * x * s_scale;
|
| 69 |
+
}
|
| 70 |
+
output_vec[vec_offset + idx] = y_vec;
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
template <typename scalar_t, typename acc_t, int width>
|
| 75 |
+
__global__ std::enable_if_t<(width == 0)>
|
| 76 |
+
rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
| 77 |
+
const scalar_t *__restrict__ input, // [..., d]
|
| 78 |
+
const scalar_t *__restrict__ weight, // [d]
|
| 79 |
+
const float eps, const int d) {
|
| 80 |
const int64_t token_idx = blockIdx.x;
|
| 81 |
const int64_t vec_idx = threadIdx.x;
|
| 82 |
acc_t sum_square = 0.0f;
|
|
|
|
| 86 |
sum_square += x * x;
|
| 87 |
}
|
| 88 |
|
| 89 |
+
using BlockReduce = cub::BlockReduce<float, 1024>;
|
| 90 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 91 |
+
|
| 92 |
+
sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
|
| 93 |
+
|
| 94 |
+
__shared__ acc_t s_scale;
|
| 95 |
+
|
| 96 |
+
if (vec_idx == 0) {
|
| 97 |
+
s_scale = rsqrtf(sum_square / d + eps);
|
| 98 |
+
}
|
| 99 |
+
__syncthreads();
|
| 100 |
|
|
|
|
|
|
|
|
|
|
| 101 |
for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
|
| 102 |
acc_t x = input[token_idx * d + idx];
|
| 103 |
acc_t w = weight[idx];
|
| 104 |
+
out[token_idx * d + idx] = w * x * s_scale;
|
| 105 |
}
|
| 106 |
}
|
| 107 |
+
template <typename scalar_t, typename acc_t, int width>
|
| 108 |
+
__global__ std::enable_if_t<(width > 0)>
|
| 109 |
+
rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
| 110 |
+
acc_t *__restrict__ temp_weight_grad, // [..., d]
|
| 111 |
+
const scalar_t *__restrict__ output_grad, // [..., d]
|
| 112 |
+
const scalar_t *__restrict__ input, // [..., d]
|
| 113 |
+
const scalar_t *__restrict__ weight, // [d]
|
| 114 |
+
const float eps, const int d) {
|
| 115 |
+
using vec_t = type_vec_t<scalar_t, width>;
|
| 116 |
+
using dw_vec_t = type_vec_t<acc_t, width>;
|
| 117 |
+
|
| 118 |
+
const int64_t token_idx = blockIdx.x;
|
| 119 |
+
const int64_t vec_idx = threadIdx.x;
|
| 120 |
+
|
| 121 |
+
const int vec_d = d / width;
|
| 122 |
+
const int64_t vec_offset = token_idx * vec_d;
|
| 123 |
+
|
| 124 |
+
const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
|
| 125 |
+
const vec_t *__restrict__ output_grad_vec =
|
| 126 |
+
reinterpret_cast<const vec_t *>(output_grad);
|
| 127 |
+
const vec_t *__restrict__ weight_vec =
|
| 128 |
+
reinterpret_cast<const vec_t *>(weight);
|
| 129 |
+
|
| 130 |
+
acc_t d_sum = 0.0f;
|
| 131 |
+
acc_t sum_square = 0.0f;
|
| 132 |
+
|
| 133 |
+
for (int64_t vidx = vec_idx; vidx < vec_d; vidx += blockDim.x) {
|
| 134 |
+
vec_t x_vec = input_vec[vec_offset + vidx];
|
| 135 |
+
vec_t dy_vec = output_grad_vec[vec_offset + vidx];
|
| 136 |
+
vec_t w_vec = weight_vec[vidx];
|
| 137 |
|
| 138 |
+
#pragma unroll
|
| 139 |
+
for (int i = 0; i < width; ++i) {
|
| 140 |
+
acc_t x = x_vec.data[i];
|
| 141 |
+
acc_t dy = dy_vec.data[i];
|
| 142 |
+
acc_t w = w_vec.data[i];
|
| 143 |
+
d_sum += dy * x * w;
|
| 144 |
+
sum_square += x * x;
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
using BlockReduce = cub::BlockReduce<float2, 1024>;
|
| 149 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 150 |
+
struct SumOp {
|
| 151 |
+
__device__ float2 operator()(const float2 &a, const float2 &b) const {
|
| 152 |
+
return make_float2(a.x + b.x, a.y + b.y);
|
| 153 |
+
}
|
| 154 |
+
};
|
| 155 |
+
float2 thread_sums = make_float2(d_sum, sum_square);
|
| 156 |
+
float2 block_sums =
|
| 157 |
+
BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
|
| 158 |
+
|
| 159 |
+
d_sum = block_sums.x;
|
| 160 |
+
sum_square = block_sums.y;
|
| 161 |
+
|
| 162 |
+
__shared__ acc_t s_scale;
|
| 163 |
+
__shared__ acc_t s_dxx;
|
| 164 |
+
|
| 165 |
+
if (threadIdx.x == 0) {
|
| 166 |
+
acc_t scale = rsqrtf(sum_square / d + eps);
|
| 167 |
+
s_dxx = d_sum * scale * scale * scale / d;
|
| 168 |
+
s_scale = scale;
|
| 169 |
+
}
|
| 170 |
+
__syncthreads();
|
| 171 |
+
acc_t scale = s_scale;
|
| 172 |
+
acc_t dxx = s_dxx;
|
| 173 |
+
vec_t *__restrict__ input_grad_vec = reinterpret_cast<vec_t *>(input_grad);
|
| 174 |
+
dw_vec_t *__restrict__ temp_weight_grad_vec =
|
| 175 |
+
reinterpret_cast<dw_vec_t *>(temp_weight_grad);
|
| 176 |
+
|
| 177 |
+
for (int64_t vidx = vec_idx; vidx < vec_d; vidx += blockDim.x) {
|
| 178 |
+
vec_t x_vec = input_vec[vec_offset + vidx];
|
| 179 |
+
vec_t dy_vec = output_grad_vec[vec_offset + vidx];
|
| 180 |
+
vec_t w_vec = weight_vec[vidx];
|
| 181 |
+
|
| 182 |
+
vec_t in_grad_vec;
|
| 183 |
+
dw_vec_t tw_grad_vec;
|
| 184 |
+
|
| 185 |
+
#pragma unroll
|
| 186 |
+
for (int i = 0; i < width; ++i) {
|
| 187 |
+
acc_t x = x_vec.data[i];
|
| 188 |
+
acc_t dy = dy_vec.data[i];
|
| 189 |
+
acc_t w = w_vec.data[i];
|
| 190 |
+
|
| 191 |
+
if (input_grad) {
|
| 192 |
+
in_grad_vec.data[i] = scale * dy * w - dxx * x;
|
| 193 |
+
}
|
| 194 |
+
tw_grad_vec.data[i] = dy * x * scale;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
if (input_grad) {
|
| 198 |
+
input_grad_vec[vec_offset + vidx] = in_grad_vec;
|
| 199 |
+
}
|
| 200 |
+
temp_weight_grad_vec[vec_offset + vidx] = tw_grad_vec;
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
template <typename scalar_t, typename acc_t, int width>
|
| 205 |
+
__global__ std::enable_if_t<(width == 0)>
|
| 206 |
rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
| 207 |
acc_t *__restrict__ temp_weight_grad, // [..., d]
|
| 208 |
const scalar_t *__restrict__ output_grad, // [..., d]
|
|
|
|
| 222 |
sum_square += x * x;
|
| 223 |
}
|
| 224 |
|
| 225 |
+
using BlockReduce = cub::BlockReduce<float2, 1024>;
|
| 226 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 227 |
+
struct SumOp {
|
| 228 |
+
__device__ float2 operator()(const float2 &a, const float2 &b) const {
|
| 229 |
+
return make_float2(a.x + b.x, a.y + b.y);
|
| 230 |
+
}
|
| 231 |
+
};
|
| 232 |
+
float2 thread_sums = make_float2(d_sum, sum_square);
|
| 233 |
+
float2 block_sums =
|
| 234 |
+
BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
|
| 235 |
|
| 236 |
+
d_sum = block_sums.x;
|
| 237 |
+
sum_square = block_sums.y;
|
| 238 |
+
|
| 239 |
+
__shared__ acc_t s_scale;
|
| 240 |
+
__shared__ acc_t s_dxx;
|
| 241 |
+
|
| 242 |
+
if (threadIdx.x == 0) {
|
| 243 |
+
acc_t scale = rsqrtf(sum_square / d + eps);
|
| 244 |
+
s_dxx = d_sum * scale * scale * scale / d;
|
| 245 |
+
s_scale = scale;
|
| 246 |
+
}
|
| 247 |
+
__syncthreads();
|
| 248 |
+
|
| 249 |
+
acc_t scale = s_scale;
|
| 250 |
+
acc_t dxx = s_dxx;
|
| 251 |
|
| 252 |
for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
|
| 253 |
acc_t x = input[token_idx * d + idx];
|
| 254 |
acc_t dy = output_grad[token_idx * d + idx];
|
| 255 |
acc_t w = weight[idx];
|
| 256 |
|
| 257 |
+
if (input_grad) {
|
| 258 |
+
input_grad[token_idx * d + idx] = scale * dy * w - dxx * x;
|
|
|
|
|
|
|
| 259 |
}
|
| 260 |
+
temp_weight_grad[token_idx * d + idx] = dy * x * scale;
|
| 261 |
}
|
| 262 |
}
|
| 263 |
|
| 264 |
} // namespace motif
|
| 265 |
|
| 266 |
+
#define LAUNCH_RMS_NORM(width) \
|
| 267 |
+
MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { \
|
| 268 |
+
motif::rms_norm_kernel<scalar_t, float, width> \
|
| 269 |
+
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
| 270 |
+
input.data_ptr<scalar_t>(), \
|
| 271 |
+
weight.data_ptr<scalar_t>(), eps, d); \
|
| 272 |
+
});
|
| 273 |
+
|
| 274 |
void rms_norm(torch::Tensor &out, // [..., d]
|
| 275 |
const torch::Tensor &input, // [..., d]
|
| 276 |
const torch::Tensor &weight, // [d]
|
|
|
|
| 279 |
AssertTensorNotNull(weight, "weight");
|
| 280 |
// TODO shape check
|
| 281 |
|
|
|
|
|
|
|
| 282 |
int d = input.size(-1);
|
| 283 |
int64_t num_tokens = input.numel() / input.size(-1);
|
| 284 |
dim3 grid(num_tokens);
|
| 285 |
+
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
| 286 |
+
dim3 block(std::min(d, max_block_size));
|
| 287 |
|
| 288 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 289 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 290 |
+
if (d % 8 == 0) {
|
| 291 |
+
LAUNCH_RMS_NORM(8);
|
| 292 |
+
} else {
|
| 293 |
+
LAUNCH_RMS_NORM(0);
|
| 294 |
+
}
|
|
|
|
| 295 |
}
|
| 296 |
|
| 297 |
+
#define LAUNCH_RMS_NORM_BWD(width) \
|
| 298 |
+
MOTIF_DISPATCH_FLOATING_TYPES( \
|
| 299 |
+
input.scalar_type(), "rms_norm_backward_kernel", [&] { \
|
| 300 |
+
motif::rms_norm_backward_kernel<scalar_t, float, width> \
|
| 301 |
+
<<<grid, block, 0, stream>>>(input_grad.data_ptr<scalar_t>(), \
|
| 302 |
+
temp_weight_grad.data_ptr<float>(), \
|
| 303 |
+
output_grad.data_ptr<scalar_t>(), \
|
| 304 |
+
input.data_ptr<scalar_t>(), \
|
| 305 |
+
weight.data_ptr<scalar_t>(), eps, d); \
|
| 306 |
+
});
|
| 307 |
+
|
| 308 |
void rms_norm_backward(torch::Tensor &input_grad, // [..., d]
|
| 309 |
+
torch::Tensor &weight_grad, // [d]
|
| 310 |
+
const torch::Tensor &output_grad, // [..., d]
|
| 311 |
+
const torch::Tensor &input, // [..., d]
|
| 312 |
const torch::Tensor &weight, // [d]
|
| 313 |
double eps) {
|
| 314 |
AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
|
|
|
|
| 317 |
// TODO shape check
|
| 318 |
// weight_grad, input_grad can be nullable
|
| 319 |
|
|
|
|
|
|
|
| 320 |
int d = input.size(-1);
|
| 321 |
int64_t num_tokens = input.numel() / input.size(-1);
|
| 322 |
dim3 grid(num_tokens);
|
| 323 |
+
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
| 324 |
+
dim3 block(std::min(d, max_block_size));
|
| 325 |
|
| 326 |
torch::Tensor temp_weight_grad =
|
| 327 |
torch::empty({num_tokens, d}, input.options().dtype(torch::kFloat));
|
| 328 |
|
|
|
|
|
|
|
| 329 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 330 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 331 |
+
if (d % 8 == 0) {
|
| 332 |
+
LAUNCH_RMS_NORM_BWD(8);
|
| 333 |
+
} else {
|
| 334 |
+
LAUNCH_RMS_NORM_BWD(0);
|
| 335 |
+
}
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
if (weight_grad.defined()) {
|
| 338 |
+
torch::Tensor acc =
|
| 339 |
+
torch::empty_like(weight_grad, temp_weight_grad.options());
|
| 340 |
+
at::sum_out(acc, temp_weight_grad, {0});
|
| 341 |
+
weight_grad.copy_(acc);
|
| 342 |
}
|
| 343 |
}
|
benchmarks/README.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Benchmark Runner
|
| 2 |
+
|
| 3 |
+
This script benchmarks **forward/backward performance** of several operations (`rms`, `add_rms`, `poly`, `mul_poly`).
|
| 4 |
+
Results can be saved as **CSV files** or **plots**.
|
| 5 |
+
|
| 6 |
+
> **Note**<br>
|
| 7 |
+
> To run the benchmarks, you must select the appropriate Torch version along with the corresponding CUDA/ROCm build from within the `build` directory.
|
| 8 |
+
>
|
| 9 |
+
> **Example:**
|
| 10 |
+
>
|
| 11 |
+
> ```bash
|
| 12 |
+
> export PYTHONPATH=$PYTHONPATH:<YOUR_PATH>/activation/build/torch27-cxx11-cu128-x86_64-linux
|
| 13 |
+
> ```
|
| 14 |
+
|
| 15 |
+
## Usage
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
python main.py --case <CASE> [--plot] [--save-path <DIR>]
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
- `--case` (required): one of `rms`, `add_rms`, `poly`, `mul_poly`
|
| 22 |
+
- `--plot`: save plots instead of CSVs
|
| 23 |
+
- `--save-path`: output directory (default: `./configs/`)
|
| 24 |
+
|
| 25 |
+
## Examples
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
python main.py --case add_rms --save-path ./results/
|
| 29 |
+
python main.py --case poly --plot --save-path ./plots/
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## Output
|
| 33 |
+
|
| 34 |
+
- CSV: `<case>-fwd-perf.csv`, `<case>-bwd-perf.csv`
|
| 35 |
+
- Plots: `plot_<case>-fwd-perf.png`, `plot_<case>-bwd-perf.png`
|
benchmarks/cases/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
benchmarks/cases/add_rms.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from common.diff_engine import DiffCase
|
| 3 |
+
|
| 4 |
+
import activation
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class FusedAddRMSNorm(torch.nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, d, eps=1e-6, dtype: torch.dtype = torch.float32):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.weight = torch.nn.Parameter(torch.ones(d, dtype=dtype))
|
| 12 |
+
self.eps = eps
|
| 13 |
+
|
| 14 |
+
def forward(self, x, residual):
|
| 15 |
+
return activation.rms_norm((x + residual), self.weight, self.eps)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AddRMS(DiffCase):
|
| 19 |
+
|
| 20 |
+
def build_inputs(self, bs, sl, hidden, dtype, eps):
|
| 21 |
+
return {
|
| 22 |
+
"x":
|
| 23 |
+
torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
|
| 24 |
+
"residual":
|
| 25 |
+
torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
|
| 26 |
+
"weight":
|
| 27 |
+
torch.ones(hidden, dtype=dtype),
|
| 28 |
+
"dim":
|
| 29 |
+
hidden,
|
| 30 |
+
"eps":
|
| 31 |
+
eps,
|
| 32 |
+
"dtype":
|
| 33 |
+
dtype,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def make_naive(self, I):
|
| 37 |
+
m = FusedAddRMSNorm(I["dim"], I["eps"], dtype=I["dtype"])
|
| 38 |
+
m.weight = torch.nn.Parameter(I["weight"].detach().clone())
|
| 39 |
+
return m
|
| 40 |
+
|
| 41 |
+
def make_cuda(self, I):
|
| 42 |
+
m = activation.layers.FusedAddRMSNorm(I["dim"],
|
| 43 |
+
I["eps"],
|
| 44 |
+
dtype=I["dtype"])
|
| 45 |
+
m.weight = torch.nn.Parameter(I["weight"].detach().clone())
|
| 46 |
+
return m
|
| 47 |
+
|
| 48 |
+
def forward(self, obj, I):
|
| 49 |
+
return obj(I["x"], I["residual"])
|
| 50 |
+
|
| 51 |
+
def grad_inputs(self, I):
|
| 52 |
+
return [I["x"], I["residual"]]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
CASE = AddRMS()
|
benchmarks/cases/mul_poly.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from common.diff_engine import DiffCase
|
| 3 |
+
|
| 4 |
+
import activation
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class FusedMulPolyNorm(torch.nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
|
| 12 |
+
self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
|
| 13 |
+
self.eps = eps
|
| 14 |
+
|
| 15 |
+
def forward(self, x, mul):
|
| 16 |
+
output = activation.poly_norm(x, self.weight, self.bias, self.eps)
|
| 17 |
+
return output * mul
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MulPoly(DiffCase):
|
| 21 |
+
|
| 22 |
+
def build_inputs(self, bs, sl, hidden, dtype, eps):
|
| 23 |
+
return {
|
| 24 |
+
"x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
|
| 25 |
+
"mul": torch.randn(bs, sl, hidden, dtype=dtype,
|
| 26 |
+
requires_grad=True),
|
| 27 |
+
"weight": torch.ones(3, dtype=dtype),
|
| 28 |
+
"bias": torch.ones(1, dtype=dtype),
|
| 29 |
+
"dim": hidden,
|
| 30 |
+
"eps": eps,
|
| 31 |
+
"dtype": dtype,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def make_naive(self, I):
|
| 35 |
+
m = FusedMulPolyNorm(I["eps"], dtype=I["dtype"])
|
| 36 |
+
m.weight = torch.nn.Parameter(I["weight"].detach().clone())
|
| 37 |
+
m.bias = torch.nn.Parameter(I["bias"].detach().clone())
|
| 38 |
+
return m
|
| 39 |
+
|
| 40 |
+
def make_cuda(self, I):
|
| 41 |
+
m = activation.layers.FusedMulPolyNorm(I["eps"], dtype=I["dtype"])
|
| 42 |
+
m.weight = torch.nn.Parameter(I["weight"].detach().clone())
|
| 43 |
+
m.bias = torch.nn.Parameter(I["bias"].detach().clone())
|
| 44 |
+
return m
|
| 45 |
+
|
| 46 |
+
def forward(self, obj, I):
|
| 47 |
+
return obj(I["x"], I["mul"])
|
| 48 |
+
|
| 49 |
+
def grad_inputs(self, I):
|
| 50 |
+
return [I["x"], I["mul"]]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
CASE = MulPoly()
|
benchmarks/cases/poly.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from common.diff_engine import DiffCase
|
| 3 |
+
|
| 4 |
+
import activation
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PolyNorm(torch.nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
|
| 12 |
+
self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
|
| 13 |
+
self.eps = eps
|
| 14 |
+
|
| 15 |
+
def _norm(self, x):
|
| 16 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
orig_dtype = x.dtype
|
| 20 |
+
x_float = x.to(torch.float32)
|
| 21 |
+
output = (self.weight[0] * self._norm(x_float**3) +
|
| 22 |
+
self.weight[1] * self._norm(x_float**2) +
|
| 23 |
+
self.weight[2] * self._norm(x_float) + self.bias)
|
| 24 |
+
return output.to(orig_dtype)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Poly(DiffCase):
|
| 28 |
+
|
| 29 |
+
def build_inputs(self, bs, sl, hidden, dtype, eps):
|
| 30 |
+
return {
|
| 31 |
+
"x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
|
| 32 |
+
"weight": torch.ones(3, dtype=dtype),
|
| 33 |
+
"bias": torch.ones(1, dtype=dtype),
|
| 34 |
+
"dim": hidden,
|
| 35 |
+
"eps": eps,
|
| 36 |
+
"dtype": dtype,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
def make_naive(self, I):
|
| 40 |
+
m = PolyNorm(I["eps"], dtype=I["dtype"])
|
| 41 |
+
m.weight = torch.nn.Parameter(I["weight"].detach().clone())
|
| 42 |
+
m.bias = torch.nn.Parameter(I["bias"].detach().clone())
|
| 43 |
+
return m
|
| 44 |
+
|
| 45 |
+
def make_cuda(self, I):
|
| 46 |
+
m = activation.layers.PolyNorm(I["eps"], dtype=I["dtype"])
|
| 47 |
+
m.weight = torch.nn.Parameter(I["weight"].detach().clone())
|
| 48 |
+
m.bias = torch.nn.Parameter(I["bias"].detach().clone())
|
| 49 |
+
return m
|
| 50 |
+
|
| 51 |
+
def forward(self, obj, I):
|
| 52 |
+
return obj(I["x"])
|
| 53 |
+
|
| 54 |
+
def grad_inputs(self, I):
|
| 55 |
+
return [I["x"]]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
CASE = Poly()
|
benchmarks/cases/rms.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from common.diff_engine import DiffCase
|
| 3 |
+
|
| 4 |
+
import activation
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RMS(DiffCase):
|
| 8 |
+
|
| 9 |
+
def build_inputs(self, bs, sl, hidden, dtype, eps):
|
| 10 |
+
return {
|
| 11 |
+
"x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
|
| 12 |
+
"weight": torch.ones(hidden, dtype=dtype),
|
| 13 |
+
"dim": hidden,
|
| 14 |
+
"eps": eps,
|
| 15 |
+
"dtype": dtype,
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
def make_naive(self, I):
|
| 19 |
+
m = torch.nn.RMSNorm(I["dim"], I["eps"], dtype=I["dtype"])
|
| 20 |
+
m.weight = torch.nn.Parameter(I["weight"].detach().clone())
|
| 21 |
+
return m
|
| 22 |
+
|
| 23 |
+
def make_cuda(self, I):
|
| 24 |
+
m = activation.layers.RMSNorm(I["dim"], I["eps"], dtype=I["dtype"])
|
| 25 |
+
m.weight = torch.nn.Parameter(I["weight"].detach().clone())
|
| 26 |
+
return m
|
| 27 |
+
|
| 28 |
+
def forward(self, obj, I):
|
| 29 |
+
return obj(I["x"])
|
| 30 |
+
|
| 31 |
+
def grad_inputs(self, I):
|
| 32 |
+
return [I["x"]]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
CASE = RMS()
|
benchmarks/common/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
benchmarks/common/bench_framework.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import math
|
| 3 |
+
import re
|
| 4 |
+
from typing import Any, Dict, Sequence
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
|
| 9 |
+
from .diff_engine import DiffCase
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def make_fwd_key(batch_size, seq_len, dim):
|
| 13 |
+
return f"forward : ({batch_size}, {seq_len}, {dim})"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_bwd_key(batch_size, seq_len, dim):
|
| 17 |
+
return f"backward : ({batch_size}, {seq_len}, {dim})"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def parse_config_string(config_str):
|
| 21 |
+
match = re.match(r"(\w+)\s*:\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)",
|
| 22 |
+
config_str)
|
| 23 |
+
if not match:
|
| 24 |
+
raise ValueError(f"Invalid config string: {config_str}")
|
| 25 |
+
_, bs, sl, d = match.groups()
|
| 26 |
+
return int(bs), int(sl), int(d)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def make_fwd_benchmark_for_case(
|
| 30 |
+
*,
|
| 31 |
+
case: DiffCase,
|
| 32 |
+
configs: Sequence[tuple[int, int, int]],
|
| 33 |
+
plot_name: str,
|
| 34 |
+
ylabel: str = "us",
|
| 35 |
+
line_vals=("naive", "cuda", "speedup"),
|
| 36 |
+
line_names: Dict[str, str] | None = None,
|
| 37 |
+
dtype=torch.bfloat16,
|
| 38 |
+
eps: float = 1e-6,
|
| 39 |
+
time_unit_scale: float = 1000,
|
| 40 |
+
):
|
| 41 |
+
timings_ms = collections.defaultdict(dict)
|
| 42 |
+
line_vals = list(line_vals)
|
| 43 |
+
line_names = line_names or {v: v.title() for v in line_vals}
|
| 44 |
+
x_vals = [list(_) for _ in configs]
|
| 45 |
+
|
| 46 |
+
@triton.testing.perf_report(
|
| 47 |
+
triton.testing.Benchmark(x_names=["dim", "batch_size", "seq_len"],
|
| 48 |
+
x_vals=x_vals,
|
| 49 |
+
line_arg="provider",
|
| 50 |
+
line_vals=line_vals,
|
| 51 |
+
line_names=[line_names[v] for v in line_vals],
|
| 52 |
+
ylabel=ylabel,
|
| 53 |
+
plot_name=plot_name,
|
| 54 |
+
args={}))
|
| 55 |
+
def bench(dim, batch_size, seq_len, provider):
|
| 56 |
+
key = make_fwd_key(dim, batch_size, seq_len)
|
| 57 |
+
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 58 |
+
if provider == "speedup":
|
| 59 |
+
return timings_ms["naive"][key] / timings_ms["cuda"][key]
|
| 60 |
+
obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
|
| 61 |
+
run = lambda: case.forward(obj, I)
|
| 62 |
+
ms = triton.testing.do_bench(run)
|
| 63 |
+
timings_ms[provider][key] = ms
|
| 64 |
+
return time_unit_scale * ms
|
| 65 |
+
|
| 66 |
+
return bench
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def make_fwd_benchmark_plot_for_case(
|
| 70 |
+
*,
|
| 71 |
+
case: DiffCase,
|
| 72 |
+
configs: Sequence[tuple[int, int, int]],
|
| 73 |
+
plot_name: str,
|
| 74 |
+
ylabel: str = "Relative Speedup",
|
| 75 |
+
line_vals=("naive", "cuda"),
|
| 76 |
+
line_names: Dict[str, str] | None = None,
|
| 77 |
+
dtype=torch.bfloat16,
|
| 78 |
+
eps: float = 1e-6,
|
| 79 |
+
):
|
| 80 |
+
timings_ms = collections.defaultdict(dict)
|
| 81 |
+
spdup_ratio = list()
|
| 82 |
+
line_vals = list(line_vals)
|
| 83 |
+
line_names = line_names or {v: v.title() for v in line_vals}
|
| 84 |
+
x_vals = [make_fwd_key(*_) for _ in configs]
|
| 85 |
+
x_vals.append("Geometric Mean")
|
| 86 |
+
|
| 87 |
+
@triton.testing.perf_report(
|
| 88 |
+
triton.testing.Benchmark(x_names=["config"],
|
| 89 |
+
x_vals=x_vals,
|
| 90 |
+
line_arg="provider",
|
| 91 |
+
line_vals=line_vals,
|
| 92 |
+
line_names=[line_names[v] for v in line_vals],
|
| 93 |
+
ylabel=ylabel,
|
| 94 |
+
plot_name=plot_name,
|
| 95 |
+
args={}))
|
| 96 |
+
def bench(config, provider):
|
| 97 |
+
if config == "Geometric Mean":
|
| 98 |
+
if provider == "cuda":
|
| 99 |
+
return round(math.prod(spdup_ratio)**(1 / len(spdup_ratio)), 2)
|
| 100 |
+
else:
|
| 101 |
+
return 1.00
|
| 102 |
+
batch_size, seq_len, dim = parse_config_string(config)
|
| 103 |
+
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 104 |
+
obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
|
| 105 |
+
run = lambda: case.forward(obj, I)
|
| 106 |
+
ms = triton.testing.do_bench(run)
|
| 107 |
+
timings_ms[provider][config] = ms
|
| 108 |
+
if provider == "cuda":
|
| 109 |
+
ratio = timings_ms["naive"][config] / timings_ms["cuda"][config]
|
| 110 |
+
spdup_ratio.append(ratio)
|
| 111 |
+
return round(ratio, 2)
|
| 112 |
+
else:
|
| 113 |
+
return 1.00
|
| 114 |
+
|
| 115 |
+
return bench
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def make_bwd_benchmark_for_case(
|
| 119 |
+
*,
|
| 120 |
+
case: DiffCase,
|
| 121 |
+
configs: Sequence[tuple[int, int, int]],
|
| 122 |
+
plot_name: str,
|
| 123 |
+
ylabel: str = "us",
|
| 124 |
+
line_vals=("naive", "cuda", "speedup"),
|
| 125 |
+
line_names: Dict[str, str] | None = None,
|
| 126 |
+
dtype=torch.bfloat16,
|
| 127 |
+
eps: float = 1e-6,
|
| 128 |
+
time_unit_scale: float = 1000,
|
| 129 |
+
):
|
| 130 |
+
timings_ms = collections.defaultdict(dict)
|
| 131 |
+
line_vals = list(line_vals)
|
| 132 |
+
line_names = line_names or {v: v.title() for v in line_vals}
|
| 133 |
+
x_vals = [list(_) for _ in configs]
|
| 134 |
+
|
| 135 |
+
@triton.testing.perf_report(
|
| 136 |
+
triton.testing.Benchmark(x_names=["dim", "batch_size", "seq_len"],
|
| 137 |
+
x_vals=x_vals,
|
| 138 |
+
line_arg="provider",
|
| 139 |
+
line_vals=line_vals,
|
| 140 |
+
line_names=[line_names[v] for v in line_vals],
|
| 141 |
+
ylabel=ylabel,
|
| 142 |
+
plot_name=plot_name,
|
| 143 |
+
args={}))
|
| 144 |
+
def bench(dim, batch_size, seq_len, provider):
|
| 145 |
+
key = make_bwd_key(dim, batch_size, seq_len)
|
| 146 |
+
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 147 |
+
if provider == "speedup":
|
| 148 |
+
return timings_ms["naive"][key] / timings_ms["cuda"][key]
|
| 149 |
+
obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
|
| 150 |
+
y = case.forward(obj, I)
|
| 151 |
+
gin = list(case.grad_inputs(I)) + list(obj.parameters())
|
| 152 |
+
g = torch.randn_like(y)
|
| 153 |
+
run = lambda: torch.autograd.grad(y,
|
| 154 |
+
gin,
|
| 155 |
+
g,
|
| 156 |
+
retain_graph=True,
|
| 157 |
+
create_graph=False,
|
| 158 |
+
allow_unused=False)
|
| 159 |
+
ms = triton.testing.do_bench(run)
|
| 160 |
+
timings_ms[provider][key] = ms
|
| 161 |
+
return time_unit_scale * ms
|
| 162 |
+
|
| 163 |
+
return bench
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def make_bwd_benchmark_plot_for_case(
|
| 167 |
+
*,
|
| 168 |
+
case: DiffCase,
|
| 169 |
+
configs: Sequence[tuple[int, int, int]],
|
| 170 |
+
plot_name: str,
|
| 171 |
+
ylabel: str = "Relative Speedup",
|
| 172 |
+
line_vals=("naive", "cuda"),
|
| 173 |
+
line_names: Dict[str, str] | None = None,
|
| 174 |
+
dtype=torch.bfloat16,
|
| 175 |
+
eps: float = 1e-6,
|
| 176 |
+
):
|
| 177 |
+
timings_ms = collections.defaultdict(dict)
|
| 178 |
+
spdup_ratio = list()
|
| 179 |
+
line_vals = list(line_vals)
|
| 180 |
+
line_names = line_names or {v: v.title() for v in line_vals}
|
| 181 |
+
x_vals = [make_bwd_key(*_) for _ in configs]
|
| 182 |
+
x_vals.append("Geometric Mean")
|
| 183 |
+
|
| 184 |
+
@triton.testing.perf_report(
|
| 185 |
+
triton.testing.Benchmark(x_names=["config"],
|
| 186 |
+
x_vals=x_vals,
|
| 187 |
+
line_arg="provider",
|
| 188 |
+
line_vals=line_vals,
|
| 189 |
+
line_names=[line_names[v] for v in line_vals],
|
| 190 |
+
ylabel=ylabel,
|
| 191 |
+
plot_name=plot_name,
|
| 192 |
+
args={}))
|
| 193 |
+
def bench(config, provider):
|
| 194 |
+
if config == "Geometric Mean":
|
| 195 |
+
if provider == "cuda":
|
| 196 |
+
return round(math.prod(spdup_ratio)**(1 / len(spdup_ratio)), 2)
|
| 197 |
+
else:
|
| 198 |
+
return 1.00
|
| 199 |
+
batch_size, seq_len, dim = parse_config_string(config)
|
| 200 |
+
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 201 |
+
obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
|
| 202 |
+
y = case.forward(obj, I)
|
| 203 |
+
gin = list(case.grad_inputs(I)) + list(obj.parameters())
|
| 204 |
+
g = torch.randn_like(y)
|
| 205 |
+
run = lambda: torch.autograd.grad(y,
|
| 206 |
+
gin,
|
| 207 |
+
g,
|
| 208 |
+
retain_graph=True,
|
| 209 |
+
create_graph=False,
|
| 210 |
+
allow_unused=False)
|
| 211 |
+
ms = triton.testing.do_bench(run)
|
| 212 |
+
timings_ms[provider][config] = ms
|
| 213 |
+
if provider == "cuda":
|
| 214 |
+
ratio = timings_ms["naive"][config] / timings_ms["cuda"][config]
|
| 215 |
+
spdup_ratio.append(ratio)
|
| 216 |
+
return round(ratio, 2)
|
| 217 |
+
else:
|
| 218 |
+
return 1.00
|
| 219 |
+
|
| 220 |
+
return bench
|
benchmarks/common/diff_engine.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Any, Dict, Sequence
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DiffCase(ABC):
|
| 8 |
+
|
| 9 |
+
@abstractmethod
|
| 10 |
+
def build_inputs(self, hidden: int, bs: int, sl: int, dtype: torch.dtype,
|
| 11 |
+
eps: float) -> Dict[str, Any]:
|
| 12 |
+
...
|
| 13 |
+
|
| 14 |
+
@abstractmethod
|
| 15 |
+
def make_naive(self, I: Dict[str, Any]) -> Any:
|
| 16 |
+
...
|
| 17 |
+
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def make_cuda(self, I: Dict[str, Any]) -> Any:
|
| 20 |
+
...
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def forward(self, obj: Any, I: Dict[str, Any]) -> torch.Tensor:
|
| 24 |
+
...
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def grad_inputs(self, I: Dict[str, Any]) -> Sequence[torch.Tensor]:
|
| 28 |
+
...
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _clone_payload(d, device):
|
| 32 |
+
out = {}
|
| 33 |
+
for k, v in d.items():
|
| 34 |
+
if isinstance(v, torch.Tensor):
|
| 35 |
+
t = v.detach().clone().to(device)
|
| 36 |
+
t.requires_grad_(v.requires_grad)
|
| 37 |
+
out[k] = t
|
| 38 |
+
else:
|
| 39 |
+
out[k] = v
|
| 40 |
+
return out
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _unit_grad_like(y):
|
| 44 |
+
g = torch.randn_like(y)
|
| 45 |
+
n = g.norm()
|
| 46 |
+
return g if n == 0 else g / n
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def calculate_diff(
|
| 50 |
+
case: DiffCase,
|
| 51 |
+
*,
|
| 52 |
+
batch_size: int,
|
| 53 |
+
seq_len: int,
|
| 54 |
+
hidden_size: int,
|
| 55 |
+
dtype=torch.bfloat16,
|
| 56 |
+
eps: float = 1e-6,
|
| 57 |
+
atol: float = 1e-2,
|
| 58 |
+
rtol: float = 1e-2,
|
| 59 |
+
device="cuda",
|
| 60 |
+
) -> None:
|
| 61 |
+
base = case.build_inputs(hidden_size, batch_size, seq_len, dtype, eps)
|
| 62 |
+
I_n = _clone_payload(base, device)
|
| 63 |
+
I_c = _clone_payload(base, device)
|
| 64 |
+
obj_n = case.make_naive(I_n)
|
| 65 |
+
obj_c = case.make_cuda(I_c)
|
| 66 |
+
y_n = case.forward(obj_n, I_n)
|
| 67 |
+
y_c = case.forward(obj_c, I_c)
|
| 68 |
+
torch.testing.assert_close(y_n, y_c, atol=atol, rtol=rtol)
|
| 69 |
+
gin_n = list(case.grad_inputs(I_n)) + list(obj_n.parameters())
|
| 70 |
+
gin_c = list(case.grad_inputs(I_c)) + list(obj_c.parameters())
|
| 71 |
+
g = _unit_grad_like(y_n).to(device)
|
| 72 |
+
ng = torch.autograd.grad(y_n,
|
| 73 |
+
gin_n,
|
| 74 |
+
g,
|
| 75 |
+
retain_graph=False,
|
| 76 |
+
create_graph=False,
|
| 77 |
+
allow_unused=False)
|
| 78 |
+
cg = torch.autograd.grad(y_c,
|
| 79 |
+
gin_c,
|
| 80 |
+
g,
|
| 81 |
+
retain_graph=False,
|
| 82 |
+
create_graph=False,
|
| 83 |
+
allow_unused=False)
|
| 84 |
+
torch.testing.assert_close(ng, cg, atol=atol, rtol=rtol)
|
| 85 |
+
print("✅ forward + backward match")
|
benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png
ADDED
|
benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png
ADDED
|
benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png
ADDED
|
benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png
ADDED
|
benchmarks/plots/h100/poly/plot_poly-bwd-perf.png
ADDED
|
benchmarks/plots/h100/poly/plot_poly-fwd-perf.png
ADDED
|
benchmarks/plots/h100/rms/plot_rms-bwd-perf.png
ADDED
|
benchmarks/plots/h100/rms/plot_rms-fwd-perf.png
ADDED
|
benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png
ADDED
|
benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png
ADDED
|
benchmarks/plots/mi250/mul_poly/plot_mul_poly-bwd-perf.png
ADDED
|
benchmarks/plots/mi250/mul_poly/plot_mul_poly-fwd-perf.png
ADDED
|
benchmarks/plots/mi250/poly/plot_poly-bwd-perf.png
ADDED
|
benchmarks/plots/mi250/poly/plot_poly-fwd-perf.png
ADDED
|
benchmarks/plots/mi250/rms/plot_rms-bwd-perf.png
ADDED
|
benchmarks/plots/mi250/rms/plot_rms-fwd-perf.png
ADDED
|
benchmarks/run_cases.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import importlib
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from common.bench_framework import (make_bwd_benchmark_for_case,
|
| 9 |
+
make_bwd_benchmark_plot_for_case,
|
| 10 |
+
make_fwd_benchmark_for_case,
|
| 11 |
+
make_fwd_benchmark_plot_for_case)
|
| 12 |
+
from common.diff_engine import DiffCase, calculate_diff
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def make_title_tag():
|
| 16 |
+
if torch.cuda.is_available():
|
| 17 |
+
dev_name = torch.cuda.get_device_name(0)
|
| 18 |
+
else:
|
| 19 |
+
dev_name = "CPU"
|
| 20 |
+
|
| 21 |
+
torch_ver = torch.__version__
|
| 22 |
+
|
| 23 |
+
return f"[{dev_name} | torch {torch_ver}]"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def plot_result(r_path):
|
| 27 |
+
import matplotlib.pyplot as plt
|
| 28 |
+
import pandas as pd
|
| 29 |
+
df = pd.read_csv(r_path + ".csv")
|
| 30 |
+
plt.figure(figsize=(12, 6))
|
| 31 |
+
ax = df.plot(x="config", y=["Naive", "Cuda"], kind="bar", ax=plt.gca())
|
| 32 |
+
ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(),
|
| 33 |
+
fontsize=14,
|
| 34 |
+
fontweight="bold")
|
| 35 |
+
ax.set_ylabel("Relative Speedup", fontsize=14)
|
| 36 |
+
ax.set_xlabel("")
|
| 37 |
+
plt.xticks(rotation=45, fontsize=12, ha="right", rotation_mode="anchor")
|
| 38 |
+
for container in ax.containers:
|
| 39 |
+
labels = [f"x{v.get_height():.2f}" for v in container]
|
| 40 |
+
ax.bar_label(container, labels=labels, label_type="edge", fontsize=10)
|
| 41 |
+
plt.tight_layout()
|
| 42 |
+
plt.savefig(r_path + ".png", bbox_inches="tight")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main():
|
| 46 |
+
ap = argparse.ArgumentParser()
|
| 47 |
+
ap.add_argument("--case",
|
| 48 |
+
choices=["rms", "add_rms", "poly", "mul_poly"],
|
| 49 |
+
required=True)
|
| 50 |
+
ap.add_argument("--plot", action="store_true")
|
| 51 |
+
ap.add_argument(
|
| 52 |
+
"--save-path",
|
| 53 |
+
type=str,
|
| 54 |
+
default="./configs/",
|
| 55 |
+
help="Path to save benchmark results",
|
| 56 |
+
)
|
| 57 |
+
args = ap.parse_args()
|
| 58 |
+
|
| 59 |
+
torch.set_default_device("cuda")
|
| 60 |
+
mod = importlib.import_module(f"cases.{args.case}")
|
| 61 |
+
case: DiffCase = mod.CASE
|
| 62 |
+
|
| 63 |
+
calculate_diff(
|
| 64 |
+
case,
|
| 65 |
+
batch_size=2,
|
| 66 |
+
seq_len=128,
|
| 67 |
+
hidden_size=4096,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
save_dir = os.path.join(args.save_path, args.case)
|
| 71 |
+
if args.plot:
|
| 72 |
+
batch_size_range = [1]
|
| 73 |
+
seq_length_range = [4096, 8192, 16384]
|
| 74 |
+
dim = [8192, 16384] if "poly" in args.case else [2048, 4096]
|
| 75 |
+
configs = list(
|
| 76 |
+
itertools.product(batch_size_range, seq_length_range, dim))
|
| 77 |
+
plot_name = f"plot_{args.case}-fwd-perf"
|
| 78 |
+
bench = make_fwd_benchmark_plot_for_case(
|
| 79 |
+
case=case,
|
| 80 |
+
configs=configs,
|
| 81 |
+
plot_name=plot_name,
|
| 82 |
+
line_names={
|
| 83 |
+
"naive": "Naive",
|
| 84 |
+
"cuda": "Cuda",
|
| 85 |
+
},
|
| 86 |
+
)
|
| 87 |
+
bench.run(print_data=True, save_path=save_dir)
|
| 88 |
+
plot_result(os.path.join(save_dir, plot_name))
|
| 89 |
+
|
| 90 |
+
plot_name = f"plot_{args.case}-bwd-perf"
|
| 91 |
+
bench = make_bwd_benchmark_plot_for_case(
|
| 92 |
+
case=case,
|
| 93 |
+
configs=configs,
|
| 94 |
+
plot_name=plot_name,
|
| 95 |
+
line_names={
|
| 96 |
+
"naive": "Naive",
|
| 97 |
+
"cuda": "Cuda",
|
| 98 |
+
},
|
| 99 |
+
)
|
| 100 |
+
bench.run(print_data=True, save_path=save_dir)
|
| 101 |
+
plot_result(os.path.join(save_dir, plot_name))
|
| 102 |
+
for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob(
|
| 103 |
+
os.path.join(save_dir, "*.csv")):
|
| 104 |
+
os.remove(f)
|
| 105 |
+
else:
|
| 106 |
+
batch_size_range = [2**i for i in range(0, 4, 1)]
|
| 107 |
+
seq_length_range = [2**i for i in range(10, 14, 1)]
|
| 108 |
+
dim = [8192, 16384] if "poly" in args.case else [2048, 4096]
|
| 109 |
+
configs = list(
|
| 110 |
+
itertools.product(dim, batch_size_range, seq_length_range))
|
| 111 |
+
|
| 112 |
+
bench = make_fwd_benchmark_for_case(
|
| 113 |
+
case=case,
|
| 114 |
+
configs=configs,
|
| 115 |
+
plot_name=f"{args.case}-fwd-perf",
|
| 116 |
+
line_names={
|
| 117 |
+
"naive": "Naive",
|
| 118 |
+
"cuda": "Cuda",
|
| 119 |
+
"speedup": "SpeedUp"
|
| 120 |
+
},
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
bench.run(print_data=True, save_path=save_dir)
|
| 124 |
+
|
| 125 |
+
bench = make_bwd_benchmark_for_case(
|
| 126 |
+
case=case,
|
| 127 |
+
configs=configs,
|
| 128 |
+
plot_name=f"{args.case}-bwd-perf",
|
| 129 |
+
line_names={
|
| 130 |
+
"naive": "Naive",
|
| 131 |
+
"cuda": "Cuda",
|
| 132 |
+
"speedup": "SpeedUp"
|
| 133 |
+
},
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
bench.run(print_data=True, save_path=save_dir)
|
| 137 |
+
for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob(
|
| 138 |
+
os.path.join(save_dir, "*.png")):
|
| 139 |
+
os.remove(f)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
main()
|
build.toml
CHANGED
|
@@ -13,9 +13,10 @@ backend = "rocm"
|
|
| 13 |
rocm-archs = [ "gfx90a", "gfx942" ]
|
| 14 |
src = [
|
| 15 |
"activation/poly_norm.cu",
|
|
|
|
| 16 |
"activation/rms_norm.cu",
|
|
|
|
| 17 |
"activation/cuda_compat.h",
|
| 18 |
-
"activation/block_reduce.h",
|
| 19 |
"activation/dispatch_utils.h",
|
| 20 |
"activation/assert_utils.h",
|
| 21 |
"activation/atomic_utils.h",
|
|
@@ -26,9 +27,10 @@ depends = [ "torch" ]
|
|
| 26 |
backend = "cuda"
|
| 27 |
src = [
|
| 28 |
"activation/poly_norm.cu",
|
|
|
|
| 29 |
"activation/rms_norm.cu",
|
|
|
|
| 30 |
"activation/cuda_compat.h",
|
| 31 |
-
"activation/block_reduce.h",
|
| 32 |
"activation/dispatch_utils.h",
|
| 33 |
"activation/assert_utils.h",
|
| 34 |
"activation/atomic_utils.h",
|
|
|
|
| 13 |
rocm-archs = [ "gfx90a", "gfx942" ]
|
| 14 |
src = [
|
| 15 |
"activation/poly_norm.cu",
|
| 16 |
+
"activation/fused_mul_poly_norm.cu",
|
| 17 |
"activation/rms_norm.cu",
|
| 18 |
+
"activation/fused_add_rms_norm.cu",
|
| 19 |
"activation/cuda_compat.h",
|
|
|
|
| 20 |
"activation/dispatch_utils.h",
|
| 21 |
"activation/assert_utils.h",
|
| 22 |
"activation/atomic_utils.h",
|
|
|
|
| 27 |
backend = "cuda"
|
| 28 |
src = [
|
| 29 |
"activation/poly_norm.cu",
|
| 30 |
+
"activation/fused_mul_poly_norm.cu",
|
| 31 |
"activation/rms_norm.cu",
|
| 32 |
+
"activation/fused_add_rms_norm.cu",
|
| 33 |
"activation/cuda_compat.h",
|
|
|
|
| 34 |
"activation/dispatch_utils.h",
|
| 35 |
"activation/assert_utils.h",
|
| 36 |
"activation/atomic_utils.h",
|
build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py
CHANGED
|
@@ -2,8 +2,8 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers
|
| 4 |
from ._ops import ops
|
| 5 |
-
from .poly_norm import PolyNormFunction
|
| 6 |
-
from .rms_norm import RMSNormFunction
|
| 7 |
|
| 8 |
|
| 9 |
def poly_norm(
|
|
@@ -15,6 +15,16 @@ def poly_norm(
|
|
| 15 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def rms_norm(
|
| 19 |
x: torch.Tensor,
|
| 20 |
weight: torch.Tensor,
|
|
@@ -23,8 +33,20 @@ def rms_norm(
|
|
| 23 |
return RMSNormFunction.apply(x, weight, eps)
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
__all__ = [
|
| 27 |
"poly_norm",
|
|
|
|
|
|
|
|
|
|
| 28 |
"layers",
|
| 29 |
"ops",
|
| 30 |
]
|
|
|
|
| 2 |
|
| 3 |
from . import layers
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
+
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
| 8 |
|
| 9 |
def poly_norm(
|
|
|
|
| 15 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
| 16 |
|
| 17 |
|
| 18 |
+
def fused_mul_poly_norm(
|
| 19 |
+
x: torch.Tensor,
|
| 20 |
+
mul: torch.Tensor,
|
| 21 |
+
weight: torch.Tensor,
|
| 22 |
+
bias: torch.Tensor,
|
| 23 |
+
eps: float = 1e-6,
|
| 24 |
+
) -> None:
|
| 25 |
+
return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
def rms_norm(
|
| 29 |
x: torch.Tensor,
|
| 30 |
weight: torch.Tensor,
|
|
|
|
| 33 |
return RMSNormFunction.apply(x, weight, eps)
|
| 34 |
|
| 35 |
|
| 36 |
+
def fused_add_rms_norm(
|
| 37 |
+
x: torch.Tensor,
|
| 38 |
+
residual: torch.Tensor,
|
| 39 |
+
weight: torch.Tensor,
|
| 40 |
+
eps: float = 1e-6,
|
| 41 |
+
) -> None:
|
| 42 |
+
return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
+
"fused_mul_poly_norm",
|
| 48 |
+
"rms_norm",
|
| 49 |
+
"fused_add_rms_norm",
|
| 50 |
"layers",
|
| 51 |
"ops",
|
| 52 |
]
|
tests/perf.png → build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_20250907180255.abi3.so
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d21a85bf21aa74f1281541e658acfd4f4326d902efe3578b059eccf054443284
|
| 3 |
+
size 8089696
|
build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_20250907180255
|
| 3 |
+
ops = torch.ops._activation_20250907180255
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_20250907180255::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py
CHANGED
|
@@ -2,8 +2,8 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from torch.nn import init
|
| 4 |
|
| 5 |
-
from .poly_norm import PolyNormFunction
|
| 6 |
-
from .rms_norm import RMSNormFunction
|
| 7 |
|
| 8 |
|
| 9 |
class PolyNorm(nn.Module):
|
|
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
|
|
| 28 |
init.zeros_(self.bias)
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
class RMSNorm(nn.Module):
|
| 32 |
|
| 33 |
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
|
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
|
|
| 46 |
Resets parameters based on their initialization used in __init__.
|
| 47 |
"""
|
| 48 |
init.ones_(self.weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from torch.nn import init
|
| 4 |
|
| 5 |
+
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
+
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
| 8 |
|
| 9 |
class PolyNorm(nn.Module):
|
|
|
|
| 28 |
init.zeros_(self.bias)
|
| 29 |
|
| 30 |
|
| 31 |
+
class FusedMulPolyNorm(nn.Module):
|
| 32 |
+
|
| 33 |
+
def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
|
| 36 |
+
self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
|
| 37 |
+
self.eps = eps
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self,
|
| 41 |
+
x: torch.Tensor,
|
| 42 |
+
mul: torch.Tensor,
|
| 43 |
+
):
|
| 44 |
+
return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
|
| 45 |
+
self.eps)
|
| 46 |
+
|
| 47 |
+
def reset_parameters(self) -> None:
|
| 48 |
+
"""
|
| 49 |
+
Resets parameters based on their initialization used in __init__.
|
| 50 |
+
"""
|
| 51 |
+
init.ones_(self.weight)
|
| 52 |
+
init.zeros_(self.bias)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
class RMSNorm(nn.Module):
|
| 56 |
|
| 57 |
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
|
|
|
| 70 |
Resets parameters based on their initialization used in __init__.
|
| 71 |
"""
|
| 72 |
init.ones_(self.weight)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FusedAddRMSNorm(nn.Module):
|
| 76 |
+
|
| 77 |
+
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
|
| 80 |
+
self.eps = eps
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
x: torch.Tensor,
|
| 85 |
+
residual: torch.Tensor,
|
| 86 |
+
):
|
| 87 |
+
return FusedAddRMSNormFunction.apply(x, residual, self.weight,
|
| 88 |
+
self.eps)[0]
|
| 89 |
+
|
| 90 |
+
def reset_parameters(self) -> None:
|
| 91 |
+
"""
|
| 92 |
+
Resets parameters based on their initialization used in __init__.
|
| 93 |
+
"""
|
| 94 |
+
init.ones_(self.weight)
|
build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py
CHANGED
|
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
|
|
| 37 |
input, weight, eps)
|
| 38 |
|
| 39 |
return input_grad, weight_grad, bias_grad, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
input, weight, eps)
|
| 38 |
|
| 39 |
return input_grad, weight_grad, bias_grad, None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class FusedMulPolyNormFunction(torch.autograd.Function):
|
| 43 |
+
# Note that forward, setup_context, and backward are @staticmethods
|
| 44 |
+
@staticmethod
|
| 45 |
+
def forward(input, mul, weight, bias, eps):
|
| 46 |
+
output = torch.empty_like(input)
|
| 47 |
+
ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
|
| 48 |
+
return output
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
# inputs is a Tuple of all of the inputs passed to forward.
|
| 52 |
+
# output is the output of the forward().
|
| 53 |
+
def setup_context(ctx, inputs, output):
|
| 54 |
+
input, mul, weight, bias, eps = inputs
|
| 55 |
+
ctx.save_for_backward(input, mul, weight, bias)
|
| 56 |
+
ctx.eps = eps
|
| 57 |
+
|
| 58 |
+
# This function has only a single output, so it gets only one gradient
|
| 59 |
+
@staticmethod
|
| 60 |
+
def backward(ctx, output_grad):
|
| 61 |
+
input, mul, weight, bias = ctx.saved_tensors
|
| 62 |
+
eps = ctx.eps
|
| 63 |
+
|
| 64 |
+
input_grad = torch.empty_like(
|
| 65 |
+
input) if ctx.needs_input_grad[0] else None
|
| 66 |
+
mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
|
| 67 |
+
weight_grad = torch.empty_like(
|
| 68 |
+
weight) if ctx.needs_input_grad[2] else None
|
| 69 |
+
bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
|
| 70 |
+
if ctx.needs_input_grad[3] else None)
|
| 71 |
+
|
| 72 |
+
ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
|
| 73 |
+
bias_grad, output_grad, input, mul,
|
| 74 |
+
weight, bias, eps)
|
| 75 |
+
|
| 76 |
+
return input_grad, mul_grad, weight_grad, bias_grad, None
|
build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py
CHANGED
|
@@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function):
|
|
| 35 |
weight, eps)
|
| 36 |
|
| 37 |
return input_grad, weight_grad, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
weight, eps)
|
| 36 |
|
| 37 |
return input_grad, weight_grad, None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Inherit from Function
|
| 41 |
+
class FusedAddRMSNormFunction(torch.autograd.Function):
|
| 42 |
+
# Note that forward, setup_context, and backward are @staticmethods
|
| 43 |
+
@staticmethod
|
| 44 |
+
def forward(input, residual, weight, eps):
|
| 45 |
+
output = torch.empty_like(input)
|
| 46 |
+
add_output = torch.empty_like(input)
|
| 47 |
+
ops.fused_add_rms_norm(output, add_output, input, residual, weight,
|
| 48 |
+
eps)
|
| 49 |
+
return output, add_output
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
# inputs is a Tuple of all of the inputs passed to forward.
|
| 53 |
+
# output is the output of the forward().
|
| 54 |
+
def setup_context(ctx, inputs, outputs):
|
| 55 |
+
_, _, weight, eps = inputs
|
| 56 |
+
_, add_output = outputs
|
| 57 |
+
ctx.mark_non_differentiable(add_output)
|
| 58 |
+
ctx.set_materialize_grads(False)
|
| 59 |
+
ctx.save_for_backward(weight, add_output)
|
| 60 |
+
ctx.eps = eps
|
| 61 |
+
|
| 62 |
+
# This function only needs one gradient
|
| 63 |
+
@staticmethod
|
| 64 |
+
def backward(ctx, output_grad, _):
|
| 65 |
+
weight, add_output = ctx.saved_tensors
|
| 66 |
+
eps = ctx.eps
|
| 67 |
+
|
| 68 |
+
if output_grad is None:
|
| 69 |
+
output_grad = torch.zeros_like(add_output)
|
| 70 |
+
|
| 71 |
+
need_in = ctx.needs_input_grad[0]
|
| 72 |
+
need_res = ctx.needs_input_grad[1]
|
| 73 |
+
|
| 74 |
+
grad = torch.empty_like(output_grad) if need_in or need_res else None
|
| 75 |
+
|
| 76 |
+
weight_grad = torch.empty_like(
|
| 77 |
+
weight) if ctx.needs_input_grad[2] else None
|
| 78 |
+
|
| 79 |
+
ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
|
| 80 |
+
weight, eps)
|
| 81 |
+
input_grad = grad if need_in else None
|
| 82 |
+
residual_grad = grad if need_res else None
|
| 83 |
+
|
| 84 |
+
return input_grad, residual_grad, weight_grad, None
|
build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py
CHANGED
|
@@ -2,8 +2,8 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers
|
| 4 |
from ._ops import ops
|
| 5 |
-
from .poly_norm import PolyNormFunction
|
| 6 |
-
from .rms_norm import RMSNormFunction
|
| 7 |
|
| 8 |
|
| 9 |
def poly_norm(
|
|
@@ -15,6 +15,16 @@ def poly_norm(
|
|
| 15 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def rms_norm(
|
| 19 |
x: torch.Tensor,
|
| 20 |
weight: torch.Tensor,
|
|
@@ -23,8 +33,20 @@ def rms_norm(
|
|
| 23 |
return RMSNormFunction.apply(x, weight, eps)
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
__all__ = [
|
| 27 |
"poly_norm",
|
|
|
|
|
|
|
|
|
|
| 28 |
"layers",
|
| 29 |
"ops",
|
| 30 |
]
|
|
|
|
| 2 |
|
| 3 |
from . import layers
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
+
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
| 8 |
|
| 9 |
def poly_norm(
|
|
|
|
| 15 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
| 16 |
|
| 17 |
|
| 18 |
+
def fused_mul_poly_norm(
|
| 19 |
+
x: torch.Tensor,
|
| 20 |
+
mul: torch.Tensor,
|
| 21 |
+
weight: torch.Tensor,
|
| 22 |
+
bias: torch.Tensor,
|
| 23 |
+
eps: float = 1e-6,
|
| 24 |
+
) -> None:
|
| 25 |
+
return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
def rms_norm(
|
| 29 |
x: torch.Tensor,
|
| 30 |
weight: torch.Tensor,
|
|
|
|
| 33 |
return RMSNormFunction.apply(x, weight, eps)
|
| 34 |
|
| 35 |
|
| 36 |
+
def fused_add_rms_norm(
|
| 37 |
+
x: torch.Tensor,
|
| 38 |
+
residual: torch.Tensor,
|
| 39 |
+
weight: torch.Tensor,
|
| 40 |
+
eps: float = 1e-6,
|
| 41 |
+
) -> None:
|
| 42 |
+
return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
+
"fused_mul_poly_norm",
|
| 48 |
+
"rms_norm",
|
| 49 |
+
"fused_add_rms_norm",
|
| 50 |
"layers",
|
| 51 |
"ops",
|
| 52 |
]
|
build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:74d4955271509451b946495da75f69a0f978e7258b8303fe3c077e585c0d3e6a
|
| 3 |
+
size 8272456
|
build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_20250907180255
|
| 3 |
+
ops = torch.ops._activation_20250907180255
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_20250907180255::{op_name}"
|
build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py
CHANGED
|
@@ -2,8 +2,8 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from torch.nn import init
|
| 4 |
|
| 5 |
-
from .poly_norm import PolyNormFunction
|
| 6 |
-
from .rms_norm import RMSNormFunction
|
| 7 |
|
| 8 |
|
| 9 |
class PolyNorm(nn.Module):
|
|
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
|
|
| 28 |
init.zeros_(self.bias)
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
class RMSNorm(nn.Module):
|
| 32 |
|
| 33 |
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
|
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
|
|
| 46 |
Resets parameters based on their initialization used in __init__.
|
| 47 |
"""
|
| 48 |
init.ones_(self.weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from torch.nn import init
|
| 4 |
|
| 5 |
+
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
+
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
| 8 |
|
| 9 |
class PolyNorm(nn.Module):
|
|
|
|
| 28 |
init.zeros_(self.bias)
|
| 29 |
|
| 30 |
|
| 31 |
+
class FusedMulPolyNorm(nn.Module):
|
| 32 |
+
|
| 33 |
+
def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
|
| 36 |
+
self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
|
| 37 |
+
self.eps = eps
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self,
|
| 41 |
+
x: torch.Tensor,
|
| 42 |
+
mul: torch.Tensor,
|
| 43 |
+
):
|
| 44 |
+
return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
|
| 45 |
+
self.eps)
|
| 46 |
+
|
| 47 |
+
def reset_parameters(self) -> None:
|
| 48 |
+
"""
|
| 49 |
+
Resets parameters based on their initialization used in __init__.
|
| 50 |
+
"""
|
| 51 |
+
init.ones_(self.weight)
|
| 52 |
+
init.zeros_(self.bias)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
class RMSNorm(nn.Module):
|
| 56 |
|
| 57 |
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
|
|
|
| 70 |
Resets parameters based on their initialization used in __init__.
|
| 71 |
"""
|
| 72 |
init.ones_(self.weight)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FusedAddRMSNorm(nn.Module):
|
| 76 |
+
|
| 77 |
+
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
|
| 80 |
+
self.eps = eps
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
x: torch.Tensor,
|
| 85 |
+
residual: torch.Tensor,
|
| 86 |
+
):
|
| 87 |
+
return FusedAddRMSNormFunction.apply(x, residual, self.weight,
|
| 88 |
+
self.eps)[0]
|
| 89 |
+
|
| 90 |
+
def reset_parameters(self) -> None:
|
| 91 |
+
"""
|
| 92 |
+
Resets parameters based on their initialization used in __init__.
|
| 93 |
+
"""
|
| 94 |
+
init.ones_(self.weight)
|
build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py
CHANGED
|
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
|
|
| 37 |
input, weight, eps)
|
| 38 |
|
| 39 |
return input_grad, weight_grad, bias_grad, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
input, weight, eps)
|
| 38 |
|
| 39 |
return input_grad, weight_grad, bias_grad, None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class FusedMulPolyNormFunction(torch.autograd.Function):
|
| 43 |
+
# Note that forward, setup_context, and backward are @staticmethods
|
| 44 |
+
@staticmethod
|
| 45 |
+
def forward(input, mul, weight, bias, eps):
|
| 46 |
+
output = torch.empty_like(input)
|
| 47 |
+
ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
|
| 48 |
+
return output
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
# inputs is a Tuple of all of the inputs passed to forward.
|
| 52 |
+
# output is the output of the forward().
|
| 53 |
+
def setup_context(ctx, inputs, output):
|
| 54 |
+
input, mul, weight, bias, eps = inputs
|
| 55 |
+
ctx.save_for_backward(input, mul, weight, bias)
|
| 56 |
+
ctx.eps = eps
|
| 57 |
+
|
| 58 |
+
# This function has only a single output, so it gets only one gradient
|
| 59 |
+
@staticmethod
|
| 60 |
+
def backward(ctx, output_grad):
|
| 61 |
+
input, mul, weight, bias = ctx.saved_tensors
|
| 62 |
+
eps = ctx.eps
|
| 63 |
+
|
| 64 |
+
input_grad = torch.empty_like(
|
| 65 |
+
input) if ctx.needs_input_grad[0] else None
|
| 66 |
+
mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
|
| 67 |
+
weight_grad = torch.empty_like(
|
| 68 |
+
weight) if ctx.needs_input_grad[2] else None
|
| 69 |
+
bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
|
| 70 |
+
if ctx.needs_input_grad[3] else None)
|
| 71 |
+
|
| 72 |
+
ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
|
| 73 |
+
bias_grad, output_grad, input, mul,
|
| 74 |
+
weight, bias, eps)
|
| 75 |
+
|
| 76 |
+
return input_grad, mul_grad, weight_grad, bias_grad, None
|
build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py
CHANGED
|
@@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function):
|
|
| 35 |
weight, eps)
|
| 36 |
|
| 37 |
return input_grad, weight_grad, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
weight, eps)
|
| 36 |
|
| 37 |
return input_grad, weight_grad, None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Inherit from Function
|
| 41 |
+
class FusedAddRMSNormFunction(torch.autograd.Function):
|
| 42 |
+
# Note that forward, setup_context, and backward are @staticmethods
|
| 43 |
+
@staticmethod
|
| 44 |
+
def forward(input, residual, weight, eps):
|
| 45 |
+
output = torch.empty_like(input)
|
| 46 |
+
add_output = torch.empty_like(input)
|
| 47 |
+
ops.fused_add_rms_norm(output, add_output, input, residual, weight,
|
| 48 |
+
eps)
|
| 49 |
+
return output, add_output
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
# inputs is a Tuple of all of the inputs passed to forward.
|
| 53 |
+
# output is the output of the forward().
|
| 54 |
+
def setup_context(ctx, inputs, outputs):
|
| 55 |
+
_, _, weight, eps = inputs
|
| 56 |
+
_, add_output = outputs
|
| 57 |
+
ctx.mark_non_differentiable(add_output)
|
| 58 |
+
ctx.set_materialize_grads(False)
|
| 59 |
+
ctx.save_for_backward(weight, add_output)
|
| 60 |
+
ctx.eps = eps
|
| 61 |
+
|
| 62 |
+
# This function only needs one gradient
|
| 63 |
+
@staticmethod
|
| 64 |
+
def backward(ctx, output_grad, _):
|
| 65 |
+
weight, add_output = ctx.saved_tensors
|
| 66 |
+
eps = ctx.eps
|
| 67 |
+
|
| 68 |
+
if output_grad is None:
|
| 69 |
+
output_grad = torch.zeros_like(add_output)
|
| 70 |
+
|
| 71 |
+
need_in = ctx.needs_input_grad[0]
|
| 72 |
+
need_res = ctx.needs_input_grad[1]
|
| 73 |
+
|
| 74 |
+
grad = torch.empty_like(output_grad) if need_in or need_res else None
|
| 75 |
+
|
| 76 |
+
weight_grad = torch.empty_like(
|
| 77 |
+
weight) if ctx.needs_input_grad[2] else None
|
| 78 |
+
|
| 79 |
+
ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
|
| 80 |
+
weight, eps)
|
| 81 |
+
input_grad = grad if need_in else None
|
| 82 |
+
residual_grad = grad if need_res else None
|
| 83 |
+
|
| 84 |
+
return input_grad, residual_grad, weight_grad, None
|
build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py
CHANGED
|
@@ -2,8 +2,8 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers
|
| 4 |
from ._ops import ops
|
| 5 |
-
from .poly_norm import PolyNormFunction
|
| 6 |
-
from .rms_norm import RMSNormFunction
|
| 7 |
|
| 8 |
|
| 9 |
def poly_norm(
|
|
@@ -15,6 +15,16 @@ def poly_norm(
|
|
| 15 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def rms_norm(
|
| 19 |
x: torch.Tensor,
|
| 20 |
weight: torch.Tensor,
|
|
@@ -23,8 +33,20 @@ def rms_norm(
|
|
| 23 |
return RMSNormFunction.apply(x, weight, eps)
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
__all__ = [
|
| 27 |
"poly_norm",
|
|
|
|
|
|
|
|
|
|
| 28 |
"layers",
|
| 29 |
"ops",
|
| 30 |
]
|
|
|
|
| 2 |
|
| 3 |
from . import layers
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
+
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
| 8 |
|
| 9 |
def poly_norm(
|
|
|
|
| 15 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
| 16 |
|
| 17 |
|
| 18 |
+
def fused_mul_poly_norm(
|
| 19 |
+
x: torch.Tensor,
|
| 20 |
+
mul: torch.Tensor,
|
| 21 |
+
weight: torch.Tensor,
|
| 22 |
+
bias: torch.Tensor,
|
| 23 |
+
eps: float = 1e-6,
|
| 24 |
+
) -> None:
|
| 25 |
+
return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
def rms_norm(
|
| 29 |
x: torch.Tensor,
|
| 30 |
weight: torch.Tensor,
|
|
|
|
| 33 |
return RMSNormFunction.apply(x, weight, eps)
|
| 34 |
|
| 35 |
|
| 36 |
+
def fused_add_rms_norm(
|
| 37 |
+
x: torch.Tensor,
|
| 38 |
+
residual: torch.Tensor,
|
| 39 |
+
weight: torch.Tensor,
|
| 40 |
+
eps: float = 1e-6,
|
| 41 |
+
) -> None:
|
| 42 |
+
return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
+
"fused_mul_poly_norm",
|
| 48 |
+
"rms_norm",
|
| 49 |
+
"fused_add_rms_norm",
|
| 50 |
"layers",
|
| 51 |
"ops",
|
| 52 |
]
|
build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0bf0d2ab5ff5520704e0b0c959b61d0043d360cfd4335950e69677873a87e436
|
| 3 |
+
size 12792112
|
build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_20250907180255
|
| 3 |
+
ops = torch.ops._activation_20250907180255
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_20250907180255::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py
CHANGED
|
@@ -2,8 +2,8 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from torch.nn import init
|
| 4 |
|
| 5 |
-
from .poly_norm import PolyNormFunction
|
| 6 |
-
from .rms_norm import RMSNormFunction
|
| 7 |
|
| 8 |
|
| 9 |
class PolyNorm(nn.Module):
|
|
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
|
|
| 28 |
init.zeros_(self.bias)
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
class RMSNorm(nn.Module):
|
| 32 |
|
| 33 |
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
|
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
|
|
| 46 |
Resets parameters based on their initialization used in __init__.
|
| 47 |
"""
|
| 48 |
init.ones_(self.weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from torch.nn import init
|
| 4 |
|
| 5 |
+
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
+
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
| 8 |
|
| 9 |
class PolyNorm(nn.Module):
|
|
|
|
| 28 |
init.zeros_(self.bias)
|
| 29 |
|
| 30 |
|
| 31 |
+
class FusedMulPolyNorm(nn.Module):
|
| 32 |
+
|
| 33 |
+
def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
|
| 36 |
+
self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
|
| 37 |
+
self.eps = eps
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self,
|
| 41 |
+
x: torch.Tensor,
|
| 42 |
+
mul: torch.Tensor,
|
| 43 |
+
):
|
| 44 |
+
return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
|
| 45 |
+
self.eps)
|
| 46 |
+
|
| 47 |
+
def reset_parameters(self) -> None:
|
| 48 |
+
"""
|
| 49 |
+
Resets parameters based on their initialization used in __init__.
|
| 50 |
+
"""
|
| 51 |
+
init.ones_(self.weight)
|
| 52 |
+
init.zeros_(self.bias)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
class RMSNorm(nn.Module):
|
| 56 |
|
| 57 |
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
|
|
|
| 70 |
Resets parameters based on their initialization used in __init__.
|
| 71 |
"""
|
| 72 |
init.ones_(self.weight)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FusedAddRMSNorm(nn.Module):
|
| 76 |
+
|
| 77 |
+
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
|
| 80 |
+
self.eps = eps
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
x: torch.Tensor,
|
| 85 |
+
residual: torch.Tensor,
|
| 86 |
+
):
|
| 87 |
+
return FusedAddRMSNormFunction.apply(x, residual, self.weight,
|
| 88 |
+
self.eps)[0]
|
| 89 |
+
|
| 90 |
+
def reset_parameters(self) -> None:
|
| 91 |
+
"""
|
| 92 |
+
Resets parameters based on their initialization used in __init__.
|
| 93 |
+
"""
|
| 94 |
+
init.ones_(self.weight)
|
build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py
CHANGED
|
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
|
|
| 37 |
input, weight, eps)
|
| 38 |
|
| 39 |
return input_grad, weight_grad, bias_grad, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
input, weight, eps)
|
| 38 |
|
| 39 |
return input_grad, weight_grad, bias_grad, None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class FusedMulPolyNormFunction(torch.autograd.Function):
|
| 43 |
+
# Note that forward, setup_context, and backward are @staticmethods
|
| 44 |
+
@staticmethod
|
| 45 |
+
def forward(input, mul, weight, bias, eps):
|
| 46 |
+
output = torch.empty_like(input)
|
| 47 |
+
ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
|
| 48 |
+
return output
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
# inputs is a Tuple of all of the inputs passed to forward.
|
| 52 |
+
# output is the output of the forward().
|
| 53 |
+
def setup_context(ctx, inputs, output):
|
| 54 |
+
input, mul, weight, bias, eps = inputs
|
| 55 |
+
ctx.save_for_backward(input, mul, weight, bias)
|
| 56 |
+
ctx.eps = eps
|
| 57 |
+
|
| 58 |
+
# This function has only a single output, so it gets only one gradient
|
| 59 |
+
@staticmethod
|
| 60 |
+
def backward(ctx, output_grad):
|
| 61 |
+
input, mul, weight, bias = ctx.saved_tensors
|
| 62 |
+
eps = ctx.eps
|
| 63 |
+
|
| 64 |
+
input_grad = torch.empty_like(
|
| 65 |
+
input) if ctx.needs_input_grad[0] else None
|
| 66 |
+
mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
|
| 67 |
+
weight_grad = torch.empty_like(
|
| 68 |
+
weight) if ctx.needs_input_grad[2] else None
|
| 69 |
+
bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
|
| 70 |
+
if ctx.needs_input_grad[3] else None)
|
| 71 |
+
|
| 72 |
+
ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
|
| 73 |
+
bias_grad, output_grad, input, mul,
|
| 74 |
+
weight, bias, eps)
|
| 75 |
+
|
| 76 |
+
return input_grad, mul_grad, weight_grad, bias_grad, None
|