Kernels
wyldecat commited on
Commit
f19f8f4
·
1 Parent(s): 151bb5a

fix(rms_norm.py): add assertion for input gradients to handle unsupported cases in backward pass

Browse files
Files changed (1) hide show
  1. torch-ext/activation/rms_norm.py +3 -0
torch-ext/activation/rms_norm.py CHANGED
@@ -59,6 +59,9 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
59
  need_in = ctx.needs_input_grad[0]
60
  need_res = ctx.needs_input_grad[1]
61
 
 
 
 
62
  grad, weight_grad = ops.fused_add_rms_norm_backward(
63
  output_grad,
64
  add_output_grad,
 
59
  need_in = ctx.needs_input_grad[0]
60
  need_res = ctx.needs_input_grad[1]
61
 
62
+ # TODO(ai-system): kernels currently do not support no input gradients
63
+ assert need_in or need_res, "Not implemented for no input gradients yet"
64
+
65
  grad, weight_grad = ops.fused_add_rms_norm_backward(
66
  output_grad,
67
  add_output_grad,