fix(rms_norm.py): add assertion for input gradients to handle unsupported cases in backward pass
Browse files
torch-ext/activation/rms_norm.py
CHANGED
|
@@ -59,6 +59,9 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
|
|
| 59 |
need_in = ctx.needs_input_grad[0]
|
| 60 |
need_res = ctx.needs_input_grad[1]
|
| 61 |
|
|
|
|
|
|
|
|
|
|
| 62 |
grad, weight_grad = ops.fused_add_rms_norm_backward(
|
| 63 |
output_grad,
|
| 64 |
add_output_grad,
|
|
|
|
| 59 |
need_in = ctx.needs_input_grad[0]
|
| 60 |
need_res = ctx.needs_input_grad[1]
|
| 61 |
|
| 62 |
+
# TODO(ai-system): kernels currently do not support no input gradients
|
| 63 |
+
assert need_in or need_res, "Not implemented for no input gradients yet"
|
| 64 |
+
|
| 65 |
grad, weight_grad = ops.fused_add_rms_norm_backward(
|
| 66 |
output_grad,
|
| 67 |
add_output_grad,
|