refactor(activation): change fused_add_rms_norm and fused_add_rms_norm_backward to out-place operations
Browse files- activation/fused_add_rms_norm.cu +26 -21
- tests/test_fused_add_rms_norm.py +1 -1
- torch-ext/activation/rms_norm.py +32 -12
- torch-ext/activation/rms_norm_meta.py +2 -10
- torch-ext/torch_binding.cpp +5 -8
- torch-ext/torch_binding.h +8 -10
activation/fused_add_rms_norm.cu
CHANGED
|
@@ -295,20 +295,19 @@ __global__ std::enable_if_t<(width == 0)> fused_add_rms_norm_backward_kernel(
|
|
| 295 |
weight.data_ptr<scalar_t>(), eps, d); \
|
| 296 |
});
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
| 304 |
AssertTensorShapeEqual(input, residual, "input", "residual");
|
| 305 |
-
AssertTensorShapeEqual(input, out, "input", "out");
|
| 306 |
-
AssertTensorShapeEqual(input, add_out, "input", "result");
|
| 307 |
AssertTensorNotNull(weight, "weight");
|
| 308 |
// TODO shape check
|
| 309 |
|
| 310 |
-
AssertTensorContiguous(out, "out");
|
| 311 |
-
AssertTensorContiguous(add_out, "add_out");
|
| 312 |
AssertTensorContiguous(input, "input");
|
| 313 |
AssertTensorContiguous(residual, "residual");
|
| 314 |
AssertTensorContiguous(weight, "weight");
|
|
@@ -326,6 +325,8 @@ void fused_add_rms_norm(torch::Tensor &out, // [..., d]
|
|
| 326 |
} else {
|
| 327 |
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
| 328 |
}
|
|
|
|
|
|
|
| 329 |
}
|
| 330 |
|
| 331 |
#define LAUNCH_FUSED_ADD_RMS_NORM_BWD(width) \
|
|
@@ -340,22 +341,24 @@ void fused_add_rms_norm(torch::Tensor &out, // [..., d]
|
|
| 340 |
weight.data_ptr<scalar_t>(), eps, d); \
|
| 341 |
});
|
| 342 |
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
| 353 |
AssertTensorShapeEqual(input, output_grad, "input", "add_output_grad");
|
| 354 |
AssertTensorNotNull(weight, "weight");
|
| 355 |
|
| 356 |
constexpr bool ALLOW_NULL = true;
|
| 357 |
-
AssertTensorContiguous(input_grad, "input_grad", ALLOW_NULL);
|
| 358 |
-
AssertTensorContiguous(weight_grad, "weight_grad", ALLOW_NULL);
|
| 359 |
AssertTensorContiguous(output_grad, "output_grad");
|
| 360 |
AssertTensorContiguous(add_output_grad, "add_output_grad");
|
| 361 |
AssertTensorContiguous(input, "input");
|
|
@@ -386,4 +389,6 @@ void fused_add_rms_norm_backward(
|
|
| 386 |
at::sum_out(acc, temp_weight_grad, {0});
|
| 387 |
weight_grad.copy_(acc);
|
| 388 |
}
|
|
|
|
|
|
|
| 389 |
}
|
|
|
|
| 295 |
weight.data_ptr<scalar_t>(), eps, d); \
|
| 296 |
});
|
| 297 |
|
| 298 |
+
std::tuple<torch::Tensor, torch::Tensor>
|
| 299 |
+
fused_add_rms_norm(const torch::Tensor &input, // [..., d]
|
| 300 |
+
const torch::Tensor &residual, // [..., d]
|
| 301 |
+
const torch::Tensor &weight, // [d]
|
| 302 |
+
double eps) {
|
| 303 |
+
|
| 304 |
+
torch::Tensor out = torch::empty_like(input);
|
| 305 |
+
torch::Tensor add_out = torch::empty_like(input);
|
| 306 |
+
|
| 307 |
AssertTensorShapeEqual(input, residual, "input", "residual");
|
|
|
|
|
|
|
| 308 |
AssertTensorNotNull(weight, "weight");
|
| 309 |
// TODO shape check
|
| 310 |
|
|
|
|
|
|
|
| 311 |
AssertTensorContiguous(input, "input");
|
| 312 |
AssertTensorContiguous(residual, "residual");
|
| 313 |
AssertTensorContiguous(weight, "weight");
|
|
|
|
| 325 |
} else {
|
| 326 |
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
| 327 |
}
|
| 328 |
+
|
| 329 |
+
return {out, add_out};
|
| 330 |
}
|
| 331 |
|
| 332 |
#define LAUNCH_FUSED_ADD_RMS_NORM_BWD(width) \
|
|
|
|
| 341 |
weight.data_ptr<scalar_t>(), eps, d); \
|
| 342 |
});
|
| 343 |
|
| 344 |
+
std::tuple<torch::Tensor, torch::Tensor>
|
| 345 |
+
fused_add_rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
|
| 346 |
+
const torch::Tensor &add_output_grad, // [..., d]
|
| 347 |
+
const torch::Tensor &input, // [..., d]
|
| 348 |
+
const torch::Tensor &weight, // [d]
|
| 349 |
+
double eps, bool need_input_grad) {
|
| 350 |
+
|
| 351 |
+
torch::Tensor input_grad;
|
| 352 |
+
if (need_input_grad) {
|
| 353 |
+
input_grad = torch::empty_like(input);
|
| 354 |
+
}
|
| 355 |
+
torch::Tensor weight_grad = torch::empty_like(weight);
|
| 356 |
+
|
| 357 |
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
| 358 |
AssertTensorShapeEqual(input, output_grad, "input", "add_output_grad");
|
| 359 |
AssertTensorNotNull(weight, "weight");
|
| 360 |
|
| 361 |
constexpr bool ALLOW_NULL = true;
|
|
|
|
|
|
|
| 362 |
AssertTensorContiguous(output_grad, "output_grad");
|
| 363 |
AssertTensorContiguous(add_output_grad, "add_output_grad");
|
| 364 |
AssertTensorContiguous(input, "input");
|
|
|
|
| 389 |
at::sum_out(acc, temp_weight_grad, {0});
|
| 390 |
weight_grad.copy_(acc);
|
| 391 |
}
|
| 392 |
+
|
| 393 |
+
return {input_grad, weight_grad};
|
| 394 |
}
|
tests/test_fused_add_rms_norm.py
CHANGED
|
@@ -81,7 +81,7 @@ def test_fused_add_rms_norm(
|
|
| 81 |
|
| 82 |
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
| 83 |
add_out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
| 84 |
-
opcheck(op, (
|
| 85 |
|
| 86 |
out = fn(x, residual, weight, eps)
|
| 87 |
mod_out, mod_a_out = layer(x, residual)
|
|
|
|
| 81 |
|
| 82 |
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
| 83 |
add_out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
| 84 |
+
opcheck(op, (x, residual, weight, eps))
|
| 85 |
|
| 86 |
out = fn(x, residual, weight, eps)
|
| 87 |
mod_out, mod_a_out = layer(x, residual)
|
torch-ext/activation/rms_norm.py
CHANGED
|
@@ -38,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
|
|
| 38 |
# Note that forward, setup_context, and backward are @staticmethods
|
| 39 |
@staticmethod
|
| 40 |
def forward(input, residual, weight, eps):
|
| 41 |
-
output =
|
| 42 |
-
|
| 43 |
-
ops.fused_add_rms_norm(output, add_output, input, residual, weight,
|
| 44 |
-
eps)
|
| 45 |
return output, add_output
|
| 46 |
|
| 47 |
@staticmethod
|
|
@@ -61,20 +59,42 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
|
|
| 61 |
need_in = ctx.needs_input_grad[0]
|
| 62 |
need_res = ctx.needs_input_grad[1]
|
| 63 |
|
| 64 |
-
grad =
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
eps)
|
| 72 |
input_grad = grad if need_in else None
|
| 73 |
residual_grad = grad if need_res else None
|
| 74 |
|
| 75 |
return input_grad, residual_grad, weight_grad, None
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if version.parse(torch.__version__) >= version.parse("2.8"):
|
| 79 |
from .rms_norm_meta import register_rms_norm_meta
|
| 80 |
register_rms_norm_meta()
|
|
|
|
| 38 |
# Note that forward, setup_context, and backward are @staticmethods
|
| 39 |
@staticmethod
|
| 40 |
def forward(input, residual, weight, eps):
|
| 41 |
+
output, add_output = ops.fused_add_rms_norm(input, residual, weight,
|
| 42 |
+
eps)
|
|
|
|
|
|
|
| 43 |
return output, add_output
|
| 44 |
|
| 45 |
@staticmethod
|
|
|
|
| 59 |
need_in = ctx.needs_input_grad[0]
|
| 60 |
need_res = ctx.needs_input_grad[1]
|
| 61 |
|
| 62 |
+
grad, weight_grad = ops.fused_add_rms_norm_backward(
|
| 63 |
+
output_grad,
|
| 64 |
+
add_output_grad,
|
| 65 |
+
add_output,
|
| 66 |
+
weight,
|
| 67 |
+
eps,
|
| 68 |
+
need_input_grad=need_in or need_res)
|
|
|
|
| 69 |
input_grad = grad if need_in else None
|
| 70 |
residual_grad = grad if need_res else None
|
| 71 |
|
| 72 |
return input_grad, residual_grad, weight_grad, None
|
| 73 |
|
| 74 |
|
| 75 |
+
@torch.library.register_fake(ops.rms_norm.default)
|
| 76 |
+
def rms_norm_abstract(x, weight, eps):
|
| 77 |
+
return torch.empty_like(x)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@torch.library.register_fake(ops.rms_norm_backward.default)
|
| 81 |
+
def rms_norm_backward_abstract(output_grad, x, weight, eps):
|
| 82 |
+
return torch.empty_like(x), torch.empty_like(weight)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@torch.library.register_fake(ops.fused_add_rms_norm.default)
|
| 86 |
+
def fused_add_rms_norm_abstract(x, residual, weight, eps):
|
| 87 |
+
return torch.empty_like(x), torch.empty_like(x)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@torch.library.register_fake(ops.fused_add_rms_norm_backward.default)
|
| 91 |
+
def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad,
|
| 92 |
+
add_output, weight, eps,
|
| 93 |
+
need_input_grad: bool):
|
| 94 |
+
return torch.empty_like(x) if need_input_grad else None, torch.empty_like(
|
| 95 |
+
weight)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
if version.parse(torch.__version__) >= version.parse("2.8"):
|
| 99 |
from .rms_norm_meta import register_rms_norm_meta
|
| 100 |
register_rms_norm_meta()
|
torch-ext/activation/rms_norm_meta.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
| 3 |
from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
|
|
@@ -17,16 +19,6 @@ def register_rms_norm_meta():
|
|
| 17 |
pass
|
| 18 |
|
| 19 |
|
| 20 |
-
@torch.library.register_fake(ops.rms_norm.default)
|
| 21 |
-
def rms_norm_abstract(x, weight, eps):
|
| 22 |
-
return torch.empty_like(x)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
@torch.library.register_fake(ops.rms_norm_backward.default)
|
| 26 |
-
def rms_norm_backward_abstract(output_grad, x, weight, eps):
|
| 27 |
-
return torch.empty_like(x), torch.empty_like(weight)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
def _replicate_dims_start_at(placements: Sequence[Placement],
|
| 31 |
start_dim: int = 0) -> tuple[Placement, ...]:
|
| 32 |
new_placements: list[Placement] = []
|
|
|
|
| 1 |
+
from collections.abc import Sequence
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
| 5 |
from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
|
|
|
|
| 19 |
pass
|
| 20 |
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
def _replicate_dims_start_at(placements: Sequence[Placement],
|
| 23 |
start_dim: int = 0) -> tuple[Placement, ...]:
|
| 24 |
new_placements: list[Placement] = []
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -38,17 +38,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 38 |
&fused_mul_poly_norm_backward);
|
| 39 |
|
| 40 |
// fused_add_rms_norm
|
| 41 |
-
ops.def(
|
| 42 |
-
|
| 43 |
-
"residual, Tensor "
|
| 44 |
-
"weight, float eps) -> ()");
|
| 45 |
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
| 46 |
|
| 47 |
ops.def(
|
| 48 |
-
"fused_add_rms_norm_backward(Tensor
|
| 49 |
-
"Tensor "
|
| 50 |
-
"
|
| 51 |
-
"eps) -> ()");
|
| 52 |
ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
|
| 53 |
&fused_add_rms_norm_backward);
|
| 54 |
}
|
|
|
|
| 38 |
&fused_mul_poly_norm_backward);
|
| 39 |
|
| 40 |
// fused_add_rms_norm
|
| 41 |
+
ops.def("fused_add_rms_norm(Tensor input, Tensor residual, Tensor "
|
| 42 |
+
"weight, float eps) -> (Tensor, Tensor)");
|
|
|
|
|
|
|
| 43 |
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
| 44 |
|
| 45 |
ops.def(
|
| 46 |
+
"fused_add_rms_norm_backward(Tensor output_grad, Tensor add_output_grad,"
|
| 47 |
+
"Tensor input, Tensor weight, float eps, bool need_input_grad) -> "
|
| 48 |
+
"(Tensor, Tensor)");
|
|
|
|
| 49 |
ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
|
| 50 |
&fused_add_rms_norm_backward);
|
| 51 |
}
|
torch-ext/torch_binding.h
CHANGED
|
@@ -27,13 +27,11 @@ void fused_mul_poly_norm_backward(
|
|
| 27 |
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 28 |
const torch::Tensor &bias, double eps);
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
const torch::Tensor &input,
|
| 39 |
-
const torch::Tensor &weight, double eps);
|
|
|
|
| 27 |
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 28 |
const torch::Tensor &bias, double eps);
|
| 29 |
|
| 30 |
+
std::tuple<torch::Tensor, torch::Tensor>
|
| 31 |
+
fused_add_rms_norm(const torch::Tensor &input, const torch::Tensor &residual,
|
| 32 |
+
const torch::Tensor &weight, double eps);
|
| 33 |
+
|
| 34 |
+
std::tuple<torch::Tensor, torch::Tensor> fused_add_rms_norm_backward(
|
| 35 |
+
const torch::Tensor &output_grad, const torch::Tensor &add_output_grad,
|
| 36 |
+
const torch::Tensor &input, const torch::Tensor &weight, double eps,
|
| 37 |
+
bool need_input_grad);
|
|
|
|
|
|