Skip to content

Commit 0b5b671

Browse files
committed
[boo] add a combined backward layer norm
Add an op export for combined computaiton of all gradients in layer norm. This may be more efficient than executing them one by one in some cases and requires separate testing. Signed-off-by: Alex Zinenko <git@ozinenko.com>
1 parent b3065b6 commit 0b5b671

File tree

2 files changed

+158
-0
lines changed

2 files changed

+158
-0
lines changed

iree/turbine/kernel/boo/op_exports/layer_norm.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class Mode(ModeBase, IntEnum):
2121
INPUT_BACKWARD = 1
2222
WEIGHT_BACKWARD = 2
2323
BIAS_BACKWARD = 3
24+
FULL_BACKWARD = 4
2425

2526

2627
class LayerNormSignature(OpSignature):
@@ -160,13 +161,16 @@ def arrange_backward_launch_args(
160161
input = forward_args[0]
161162
# TODO: is this possible at this level?
162163
weight = forward_args[1] if len(forward_args) > 1 else None
164+
bias = forward_args[2] if len(forward_args) > 2 else None
163165
_, mean, rstd = forward_results
164166
if self.mode == Mode.INPUT_BACKWARD:
165167
return (input, weight, mean, rstd)
166168
if self.mode == Mode.WEIGHT_BACKWARD:
167169
return (input, mean, rstd)
168170
if self.mode == Mode.BIAS_BACKWARD:
169171
return ()
172+
if self.mode == Mode.FULL_BACKWARD:
173+
return (input, weight, bias, mean, rstd)
170174
assert False, "Unsupported mode."
171175

172176
def as_init_kwargs(self) -> dict[str, Any]:
@@ -202,6 +206,9 @@ def get_nn_module(self, use_custom: bool) -> torch.nn.Module:
202206
return LayerNormBackwardWeight(self)
203207
if self.mode == Mode.BIAS_BACKWARD:
204208
return LayerNormBackwardBias(self)
209+
if self.mode == Mode.FULL_BACKWARD:
210+
return LayerNormBackwardFull(self)
211+
assert False, f"Unknown mode: {self.mode.name}."
205212

206213
def get_sample_args(
207214
self,
@@ -247,6 +254,15 @@ def get(shape: Sequence[int]) -> torch.Tensor:
247254
if self.mode == Mode.BIAS_BACKWARD:
248255
# (dLdy,)
249256
return (get(self.output_shape),)
257+
if self.mode == Mode.FULL_BACKWARD:
258+
return (
259+
get(self.output_shape),
260+
get(self.input_shape),
261+
get(self.normalized_shape) if self.elementwise_affine else None,
262+
get(self.normalized_shape) if self.bias else None,
263+
get(self.aggregate_shape).to(dtype=self.forwarded_args_dtype),
264+
get(self.aggregate_shape).to(dtype=self.forwarded_args_dtype),
265+
)
250266
raise ValueError(f"Unknown mode: {self.mode}")
251267

252268

@@ -374,6 +390,69 @@ def forward(self, grad_output: torch.Tensor) -> torch.Tensor:
374390
return grad_output.sum(dim=self.keep_dim)
375391

376392

393+
class LayerNormBackwardFull(torch.nn.Module):
394+
"""Module computing, as its forward computation, the gradients of the input,
395+
weights, and bias of the layer normalization given the gradient of its
396+
output."""
397+
398+
def __init__(self, signature: LayerNormSignature, *, use_aten=True):
399+
super().__init__()
400+
self.use_aten = use_aten
401+
self.normalized_shape = signature.normalized_shape
402+
self.need_bias = signature.bias
403+
self.need_weight = signature.elementwise_affine
404+
self.normalized_dim = list(
405+
range(len(signature.input_shape))[-len(self.normalized_shape) :]
406+
)
407+
self.keep_dim = list(
408+
range(len(signature.input_shape))[: -len(signature.normalized_shape)]
409+
)
410+
411+
def forward(
412+
self,
413+
grad_output: torch.Tensor,
414+
input: torch.Tensor,
415+
weight: torch.Tensor | None,
416+
bias: torch.Tensor | None,
417+
mean: torch.Tensor,
418+
rstd: torch.Tensor,
419+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
420+
assert self.need_weight != (
421+
weight is None
422+
), "Weight must be provided if its gradient is requested."
423+
assert self.need_bias != (
424+
bias is None
425+
), "Bias must be provided if its gradient is requested."
426+
if self.use_aten:
427+
return torch.ops.aten.native_layer_norm_backward(
428+
grad_output,
429+
input,
430+
self.normalized_shape,
431+
mean,
432+
rstd,
433+
weight,
434+
bias,
435+
(True, self.need_weight, self.need_bias),
436+
)
437+
438+
# Recompute norm instead of saving it. Judging by the signature, this is the same
439+
# decision as ATen.
440+
norm = (input - mean) * rstd
441+
dnorm = grad_output * weight if weight is not None else grad_output
442+
dx = (
443+
dnorm
444+
- dnorm.mean(dim=self.normalized_dim, keepdim=True)
445+
- norm * (dnorm * norm).mean(dim=self.normalized_dim, keepdim=True)
446+
) * rstd
447+
dw = None
448+
if self.need_weight:
449+
dw = (grad_output * norm).sum(self.keep_dim)
450+
db = None
451+
if self.need_bias:
452+
db = grad_output.sum(dim=self.keep_dim)
453+
return dx, dw, db
454+
455+
377456
def _parse_shape(shape: str) -> list[int]:
378457
for symbol in shape:
379458
assert symbol in "0123456789x", "Unsupported shape syntax."
@@ -419,6 +498,8 @@ def get_signature(args: argparse.Namespace) -> LayerNormSignature:
419498
mode = Mode.WEIGHT_BACKWARD
420499
case 4:
421500
mode = Mode.BIAS_BACKWARD
501+
case 5:
502+
mode = Mode.FULL_BACKWARD
422503
case _:
423504
raise ValueError(f"Unsupported mode {args.forw}.")
424505

tests/kernel/boo/op_exports/layer_norm_backward_impl_test.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@
1313
)
1414

1515

16+
def _marked_xfail(*args):
17+
return pytest.param(
18+
*args,
19+
marks=pytest.mark.xfail(
20+
condition=not torch.cuda.is_available(),
21+
reason="Cannot run on GPU with no GPU.",
22+
),
23+
)
24+
25+
1626
# Note that elementwise_affine and bias flags are grouped together to avoid an
1727
# invalid combination.
1828
@pytest.mark.parametrize("dtype", [torch.float32])
@@ -96,3 +106,70 @@ def test_layer_norm_impl(
96106
print(f"Expected for gradient #{i}: ", args[i].grad)
97107
print(f"Actual for gradient #{i}: ", grads[i])
98108
raise RuntimeError(f"Tensor matches: {results}")
109+
110+
111+
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
112+
@pytest.mark.parametrize("device", ["cpu", _marked_xfail("cuda")])
113+
@pytest.mark.parametrize("input_shape", [(10, 12, 14, 16), (11, 13, 15)])
114+
@pytest.mark.parametrize(
115+
"elementwise_affine_bias", [(False, False), (True, False), (True, True)]
116+
)
117+
def test_layer_norm_combined_impl(
118+
input_shape: tuple[int, ...],
119+
device: str,
120+
dtype: torch.dtype,
121+
elementwise_affine_bias: tuple[bool, bool],
122+
):
123+
# Account for ATen weirdness on GPU.
124+
if device == "cuda" and dtype == torch.bfloat16:
125+
forwarded_args_dtype = torch.float32
126+
else:
127+
forwarded_args_dtype = dtype
128+
129+
elementwise_affine, bias = elementwise_affine_bias
130+
kwargs = {
131+
"input_shape": input_shape,
132+
"normalized_shape": input_shape[-1:],
133+
"elementwise_affine": elementwise_affine,
134+
"bias": bias,
135+
"dtype": dtype,
136+
"forwarded_args_dtype": forwarded_args_dtype,
137+
}
138+
fwd_sig = LayerNormSignature(**kwargs)
139+
args = fwd_sig.get_sample_args(seed=1)
140+
141+
args = tuple(
142+
arg.to(device=device).requires_grad_(True) if arg is not None else None
143+
for arg in args
144+
)
145+
fwd = fwd_sig.get_nn_module(use_custom=True).to(device=device)
146+
bwd_sig = LayerNormSignature(**kwargs, mode=Mode.FULL_BACKWARD)
147+
bwd = bwd_sig.get_nn_module(use_custom=True).to(device=device)
148+
149+
fwd_results = fwd(*args)
150+
151+
main_result = fwd_results[fwd_sig.main_result_index]
152+
main_result.retain_grad()
153+
# TODO: this is not a good loss function (#1021).
154+
loss = main_result.sum()
155+
loss.backward(retain_graph=True)
156+
157+
bwd_input_args = bwd_sig.arrange_backward_launch_args(args, fwd_results)
158+
grads = tuple(x for x in bwd(main_result.grad, *bwd_input_args) if x is not None)
159+
160+
rtol = 1e-4
161+
atol = 1e-4
162+
assert len(grads) == len(args)
163+
results = [
164+
torch.allclose(arg.grad, grad, rtol=rtol, atol=atol)
165+
for arg, grad in zip(args, grads)
166+
]
167+
if all(results):
168+
return
169+
170+
for i, r in enumerate(results):
171+
if r:
172+
continue
173+
print(f"Expected for gradient #{i}: ", args[i].grad)
174+
print(f"Actual for gradient #{i}: ", grads[i])
175+
raise RuntimeError(f"Tensor matches: {results}")

0 commit comments

Comments
 (0)