Kernels
wyldecat commited on
Commit
7e4334d
·
1 Parent(s): 66b3c5e

refactor(activation): change fused_add_rms_norm and fused_add_rms_norm_backward to out-place operations

Browse files
activation/fused_add_rms_norm.cu CHANGED
@@ -295,20 +295,19 @@ __global__ std::enable_if_t<(width == 0)> fused_add_rms_norm_backward_kernel(
295
  weight.data_ptr<scalar_t>(), eps, d); \
296
  });
297
 
298
- void fused_add_rms_norm(torch::Tensor &out, // [..., d]
299
- torch::Tensor &add_out, // [..., d]
300
- const torch::Tensor &input, // [..., d]
301
- const torch::Tensor &residual, // [..., d]
302
- const torch::Tensor &weight, // [d]
303
- double eps) {
 
 
 
304
  AssertTensorShapeEqual(input, residual, "input", "residual");
305
- AssertTensorShapeEqual(input, out, "input", "out");
306
- AssertTensorShapeEqual(input, add_out, "input", "result");
307
  AssertTensorNotNull(weight, "weight");
308
  // TODO shape check
309
 
310
- AssertTensorContiguous(out, "out");
311
- AssertTensorContiguous(add_out, "add_out");
312
  AssertTensorContiguous(input, "input");
313
  AssertTensorContiguous(residual, "residual");
314
  AssertTensorContiguous(weight, "weight");
@@ -326,6 +325,8 @@ void fused_add_rms_norm(torch::Tensor &out, // [..., d]
326
  } else {
327
  LAUNCH_FUSED_ADD_RMS_NORM(0);
328
  }
 
 
329
  }
330
 
331
  #define LAUNCH_FUSED_ADD_RMS_NORM_BWD(width) \
@@ -340,22 +341,24 @@ void fused_add_rms_norm(torch::Tensor &out, // [..., d]
340
  weight.data_ptr<scalar_t>(), eps, d); \
341
  });
342
 
343
- void fused_add_rms_norm_backward(
344
- torch::Tensor &input_grad, // [..., d]
345
- torch::Tensor &weight_grad, // [d]
346
- const torch::Tensor &output_grad, // [..., d]
347
- const torch::Tensor &add_output_grad, // [..., d]
348
- const torch::Tensor &input, // [..., d]
349
- const torch::Tensor &weight, // [d]
350
- double eps) {
351
- AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
 
 
 
 
352
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
353
  AssertTensorShapeEqual(input, output_grad, "input", "add_output_grad");
354
  AssertTensorNotNull(weight, "weight");
355
 
356
  constexpr bool ALLOW_NULL = true;
357
- AssertTensorContiguous(input_grad, "input_grad", ALLOW_NULL);
358
- AssertTensorContiguous(weight_grad, "weight_grad", ALLOW_NULL);
359
  AssertTensorContiguous(output_grad, "output_grad");
360
  AssertTensorContiguous(add_output_grad, "add_output_grad");
361
  AssertTensorContiguous(input, "input");
@@ -386,4 +389,6 @@ void fused_add_rms_norm_backward(
386
  at::sum_out(acc, temp_weight_grad, {0});
387
  weight_grad.copy_(acc);
388
  }
 
 
389
  }
 
295
  weight.data_ptr<scalar_t>(), eps, d); \
296
  });
297
 
298
+ std::tuple<torch::Tensor, torch::Tensor>
299
+ fused_add_rms_norm(const torch::Tensor &input, // [..., d]
300
+ const torch::Tensor &residual, // [..., d]
301
+ const torch::Tensor &weight, // [d]
302
+ double eps) {
303
+
304
+ torch::Tensor out = torch::empty_like(input);
305
+ torch::Tensor add_out = torch::empty_like(input);
306
+
307
  AssertTensorShapeEqual(input, residual, "input", "residual");
 
 
308
  AssertTensorNotNull(weight, "weight");
309
  // TODO shape check
310
 
 
 
311
  AssertTensorContiguous(input, "input");
312
  AssertTensorContiguous(residual, "residual");
313
  AssertTensorContiguous(weight, "weight");
 
325
  } else {
326
  LAUNCH_FUSED_ADD_RMS_NORM(0);
327
  }
328
+
329
+ return {out, add_out};
330
  }
331
 
332
  #define LAUNCH_FUSED_ADD_RMS_NORM_BWD(width) \
 
341
  weight.data_ptr<scalar_t>(), eps, d); \
342
  });
343
 
344
+ std::tuple<torch::Tensor, torch::Tensor>
345
+ fused_add_rms_norm_backward(const torch::Tensor &output_grad, // [..., d]
346
+ const torch::Tensor &add_output_grad, // [..., d]
347
+ const torch::Tensor &input, // [..., d]
348
+ const torch::Tensor &weight, // [d]
349
+ double eps, bool need_input_grad) {
350
+
351
+ torch::Tensor input_grad;
352
+ if (need_input_grad) {
353
+ input_grad = torch::empty_like(input);
354
+ }
355
+ torch::Tensor weight_grad = torch::empty_like(weight);
356
+
357
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
358
  AssertTensorShapeEqual(input, output_grad, "input", "add_output_grad");
359
  AssertTensorNotNull(weight, "weight");
360
 
361
  constexpr bool ALLOW_NULL = true;
 
 
362
  AssertTensorContiguous(output_grad, "output_grad");
363
  AssertTensorContiguous(add_output_grad, "add_output_grad");
364
  AssertTensorContiguous(input, "input");
 
389
  at::sum_out(acc, temp_weight_grad, {0});
390
  weight_grad.copy_(acc);
391
  }
392
+
393
+ return {input_grad, weight_grad};
394
  }
tests/test_fused_add_rms_norm.py CHANGED
@@ -81,7 +81,7 @@ def test_fused_add_rms_norm(
81
 
82
  out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
83
  add_out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
84
- opcheck(op, (out, add_out, x, residual, weight, eps))
85
 
86
  out = fn(x, residual, weight, eps)
87
  mod_out, mod_a_out = layer(x, residual)
 
81
 
82
  out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
83
  add_out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
84
+ opcheck(op, (x, residual, weight, eps))
85
 
86
  out = fn(x, residual, weight, eps)
87
  mod_out, mod_a_out = layer(x, residual)
torch-ext/activation/rms_norm.py CHANGED
@@ -38,10 +38,8 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
38
  # Note that forward, setup_context, and backward are @staticmethods
39
  @staticmethod
40
  def forward(input, residual, weight, eps):
41
- output = torch.empty_like(input)
42
- add_output = torch.empty_like(input)
43
- ops.fused_add_rms_norm(output, add_output, input, residual, weight,
44
- eps)
45
  return output, add_output
46
 
47
  @staticmethod
@@ -61,20 +59,42 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
61
  need_in = ctx.needs_input_grad[0]
62
  need_res = ctx.needs_input_grad[1]
63
 
64
- grad = torch.empty_like(output_grad) if need_in or need_res else None
65
-
66
- weight_grad = torch.empty_like(
67
- weight) if ctx.needs_input_grad[2] else None
68
-
69
- ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad,
70
- add_output_grad, add_output, weight,
71
- eps)
72
  input_grad = grad if need_in else None
73
  residual_grad = grad if need_res else None
74
 
75
  return input_grad, residual_grad, weight_grad, None
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if version.parse(torch.__version__) >= version.parse("2.8"):
79
  from .rms_norm_meta import register_rms_norm_meta
80
  register_rms_norm_meta()
 
38
  # Note that forward, setup_context, and backward are @staticmethods
39
  @staticmethod
40
  def forward(input, residual, weight, eps):
41
+ output, add_output = ops.fused_add_rms_norm(input, residual, weight,
42
+ eps)
 
 
43
  return output, add_output
44
 
45
  @staticmethod
 
59
  need_in = ctx.needs_input_grad[0]
60
  need_res = ctx.needs_input_grad[1]
61
 
62
+ grad, weight_grad = ops.fused_add_rms_norm_backward(
63
+ output_grad,
64
+ add_output_grad,
65
+ add_output,
66
+ weight,
67
+ eps,
68
+ need_input_grad=need_in or need_res)
 
69
  input_grad = grad if need_in else None
70
  residual_grad = grad if need_res else None
71
 
72
  return input_grad, residual_grad, weight_grad, None
73
 
74
 
75
+ @torch.library.register_fake(ops.rms_norm.default)
76
+ def rms_norm_abstract(x, weight, eps):
77
+ return torch.empty_like(x)
78
+
79
+
80
+ @torch.library.register_fake(ops.rms_norm_backward.default)
81
+ def rms_norm_backward_abstract(output_grad, x, weight, eps):
82
+ return torch.empty_like(x), torch.empty_like(weight)
83
+
84
+
85
+ @torch.library.register_fake(ops.fused_add_rms_norm.default)
86
+ def fused_add_rms_norm_abstract(x, residual, weight, eps):
87
+ return torch.empty_like(x), torch.empty_like(x)
88
+
89
+
90
+ @torch.library.register_fake(ops.fused_add_rms_norm_backward.default)
91
+ def fused_add_rms_norm_backward_abstract(output_grad, add_output_grad,
92
+ add_output, weight, eps,
93
+ need_input_grad: bool):
94
+ return torch.empty_like(x) if need_input_grad else None, torch.empty_like(
95
+ weight)
96
+
97
+
98
  if version.parse(torch.__version__) >= version.parse("2.8"):
99
  from .rms_norm_meta import register_rms_norm_meta
100
  register_rms_norm_meta()
torch-ext/activation/rms_norm_meta.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  from torch.distributed.tensor._dtensor_spec import DTensorSpec
3
  from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
@@ -17,16 +19,6 @@ def register_rms_norm_meta():
17
  pass
18
 
19
 
20
- @torch.library.register_fake(ops.rms_norm.default)
21
- def rms_norm_abstract(x, weight, eps):
22
- return torch.empty_like(x)
23
-
24
-
25
- @torch.library.register_fake(ops.rms_norm_backward.default)
26
- def rms_norm_backward_abstract(output_grad, x, weight, eps):
27
- return torch.empty_like(x), torch.empty_like(weight)
28
-
29
-
30
  def _replicate_dims_start_at(placements: Sequence[Placement],
31
  start_dim: int = 0) -> tuple[Placement, ...]:
32
  new_placements: list[Placement] = []
 
1
+ from collections.abc import Sequence
2
+
3
  import torch
4
  from torch.distributed.tensor._dtensor_spec import DTensorSpec
5
  from torch.distributed.tensor._op_schema import (OpSchema, OpSpec, OpStrategy,
 
19
  pass
20
 
21
 
 
 
 
 
 
 
 
 
 
 
22
  def _replicate_dims_start_at(placements: Sequence[Placement],
23
  start_dim: int = 0) -> tuple[Placement, ...]:
24
  new_placements: list[Placement] = []
torch-ext/torch_binding.cpp CHANGED
@@ -38,17 +38,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
38
  &fused_mul_poly_norm_backward);
39
 
40
  // fused_add_rms_norm
41
- ops.def(
42
- "fused_add_rms_norm(Tensor! out, Tensor! add_out, Tensor input, Tensor "
43
- "residual, Tensor "
44
- "weight, float eps) -> ()");
45
  ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
46
 
47
  ops.def(
48
- "fused_add_rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, "
49
- "Tensor "
50
- "output_grad, Tensor add_output_grad, Tensor input, Tensor weight, float "
51
- "eps) -> ()");
52
  ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
53
  &fused_add_rms_norm_backward);
54
  }
 
38
  &fused_mul_poly_norm_backward);
39
 
40
  // fused_add_rms_norm
41
+ ops.def("fused_add_rms_norm(Tensor input, Tensor residual, Tensor "
42
+ "weight, float eps) -> (Tensor, Tensor)");
 
 
43
  ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
44
 
45
  ops.def(
46
+ "fused_add_rms_norm_backward(Tensor output_grad, Tensor add_output_grad,"
47
+ "Tensor input, Tensor weight, float eps, bool need_input_grad) -> "
48
+ "(Tensor, Tensor)");
 
49
  ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
50
  &fused_add_rms_norm_backward);
51
  }
torch-ext/torch_binding.h CHANGED
@@ -27,13 +27,11 @@ void fused_mul_poly_norm_backward(
27
  const torch::Tensor &mul, const torch::Tensor &weight,
28
  const torch::Tensor &bias, double eps);
29
 
30
- void fused_add_rms_norm(torch::Tensor &out, torch::Tensor &add_out,
31
- const torch::Tensor &input,
32
- const torch::Tensor &residual,
33
- const torch::Tensor &weight, double eps);
34
- void fused_add_rms_norm_backward(torch::Tensor &input_grad,
35
- torch::Tensor &weight_grad,
36
- const torch::Tensor &output_grad,
37
- const torch::Tensor &add_output_grad,
38
- const torch::Tensor &input,
39
- const torch::Tensor &weight, double eps);
 
27
  const torch::Tensor &mul, const torch::Tensor &weight,
28
  const torch::Tensor &bias, double eps);
29
 
30
+ std::tuple<torch::Tensor, torch::Tensor>
31
+ fused_add_rms_norm(const torch::Tensor &input, const torch::Tensor &residual,
32
+ const torch::Tensor &weight, double eps);
33
+
34
+ std::tuple<torch::Tensor, torch::Tensor> fused_add_rms_norm_backward(
35
+ const torch::Tensor &output_grad, const torch::Tensor &add_output_grad,
36
+ const torch::Tensor &input, const torch::Tensor &weight, double eps,
37
+ bool need_input_grad);