Skip to content

Commit 9ab898a

Browse files
committed
[boo] layernorm ops
1 parent 1e4a294 commit 9ab898a

File tree

5 files changed

+409
-2
lines changed

5 files changed

+409
-2
lines changed

iree/turbine/kernel/boo/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
from .conv import *
8+
from .layer_norm import *
89
from .layout_customizable_conv import *
910
from .library import *
1011
from .utils import *

iree/turbine/kernel/boo/ops/conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def pytorch_convolution_backward(ctx, grad_output):
271271
"""Fallback implementation for backward."""
272272
x, w = ctx.saved_tensors
273273

274-
mask = tuple((ctx.needs_input_grad[i] for i in range(3)))
274+
mask = tuple(ctx.needs_input_grad[0:3])
275275

276276
# return to NCHW if necessary
277277
input_grad, weight_grad, bias_grad = torch.ops.aten.convolution_backward(
@@ -331,7 +331,7 @@ def backward(ctx, grad_output):
331331

332332
x, w = ctx.saved_tensors
333333

334-
mask = tuple((ctx.needs_input_grad[i] for i in range(3)))
334+
mask = tuple(ctx.needs_input_grad[0:3])
335335

336336
input_grad, weight_grad, bias_grad = torch.ops.boo.convolution_backward(
337337
x,
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright 2025 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import torch
8+
from .library import define_schema, register_impl, register_meta
9+
from ..layer_norm_exports.layer_norm import LayerNormSignature, Mode
10+
from ..driver.launch import get_launchable
11+
from ..runtime import LaunchableRuntimeCache
12+
from .utils import *
13+
from typing import Sequence
14+
15+
__all__ = [
16+
"boo_layer_norm",
17+
]
18+
19+
# TODO(azinenko): can this be automated, pytorch doc says these can be inferred from type information?
20+
define_schema(
21+
"layer_norm",
22+
"(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float? eps) -> (Tensor, Tensor, Tensor)",
23+
)
24+
# "(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-5) -> Tensor")
25+
26+
# TODO(azinenko,zjgarvey): this should eventually be generalized with non-boo registration.
27+
28+
29+
@register_impl("layer_norm")
30+
def _boo_layer_norm_impl(
31+
input: torch.Tensor,
32+
normalized_shape: Sequence[int],
33+
weight: torch.Tensor,
34+
bias: torch.Tensor,
35+
eps: float,
36+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
37+
signature = LayerNormSignature.get(input, normalized_shape, weight, bias, eps=eps)
38+
39+
# TODO: support non-contiguous memory formats via permutations
40+
41+
func_name = signature.func_name
42+
args = tuple(
43+
filter(
44+
lambda x: x is not None,
45+
map(lambda x: x.data if x is not None else None, (input, weight, bias)),
46+
)
47+
)
48+
cache_hit = LaunchableRuntimeCache.get(func_name)
49+
if cache_hit:
50+
return cache_hit(*args)
51+
52+
layer_norm = get_launchable(signature)
53+
return layer_norm(*args)
54+
55+
56+
@register_meta("layer_norm")
57+
def _boo_layer_norm_meta(
58+
input: torch.Tensor,
59+
normalized_shape: Sequence[int],
60+
weight: torch.Tensor,
61+
bias: torch.Tensor,
62+
eps: float,
63+
) -> torch.Tensor:
64+
signature = LayerNormSignature.get(input, normalized_shape, weight, bias, eps=eps)
65+
66+
# TODO: support non-contiguous memory formats via permutations
67+
68+
return (
69+
torch.empty_like(input),
70+
torch.empty(
71+
signature.aggregate_shape, dtype=signature.dtype, device=input.device
72+
),
73+
torch.empty(
74+
signature.aggregate_shape, dtype=signature.dtype, device=input.device
75+
),
76+
)
77+
78+
79+
define_schema(
80+
"layer_norm_backward",
81+
"(Tensor grad_output, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor weight, Tensor bias, bool[3] mask) -> (Tensor?, Tensor?, Tensor?)",
82+
)
83+
84+
85+
@register_impl("layer_norm_backward")
86+
def _boo_layer_norm_backward_impl(
87+
grad_output: torch.Tensor,
88+
input: torch.Tensor,
89+
normalized_shape: int | Sequence[int] | torch.Size,
90+
mean: torch.Tensor,
91+
rstd: torch.Tensor,
92+
weight: torch.Tensor,
93+
bias: torch.Tensor,
94+
mask: tuple[bool, bool, bool],
95+
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
96+
97+
input_grad: torch.Tensor | None = None
98+
weight_grad: torch.Tensor | None = None
99+
bias_grad: torch.Tensor | None = None
100+
101+
# TODO(azinenko): it is unclear to me why convolution decided to implement
102+
# each derivative computation as a single kernel, but cargo-culting it here.
103+
104+
def data_tuple(*args: torch.Tensor):
105+
return tuple(a.data for a in args)
106+
107+
if mask[0]:
108+
signature = LayerNormSignature.get(
109+
input, normalized_shape, weight, bias, Mode.INPUT_BACKWARD
110+
)
111+
launchable = get_launchable(signature)
112+
input_grad = launchable(*data_tuple(grad_output, input, weight, mean, rstd))
113+
114+
if mask[1]:
115+
signature = LayerNormSignature.get(
116+
input, normalized_shape, weight, bias, Mode.WEIGHT_BACKWARD
117+
)
118+
launchable = get_launchable(signature)
119+
weight_grad = launchable(*data_tuple(grad_output, input, mean, rstd))
120+
121+
if mask[2]:
122+
signature = LayerNormSignature.get(
123+
input, normalized_shape, weight, bias, Mode.BIAS_BACKWARD
124+
)
125+
launchable = get_launchable(signature)
126+
bias_grad = launchable(*data_tuple(grad_output))
127+
128+
return input_grad, weight_grad, bias_grad
129+
130+
131+
@register_meta("layer_norm_backward")
132+
def _boo_layer_norm_backward_meta(
133+
grad_output: torch.Tensor,
134+
input: torch.Tensor,
135+
normalized_shape: int | Sequence[int] | torch.Size,
136+
mean: torch.Tensor,
137+
rstd: torch.Tensor,
138+
weight: torch.Tensor,
139+
bias: torch.Tensor,
140+
mask: tuple[bool, bool, bool],
141+
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
142+
input_grad: torch.Tensor | None = None
143+
weight_grad: torch.Tensor | None = None
144+
bias_grad: torch.Tensor | None = None
145+
146+
if mask[0]:
147+
input_grad = torch.empty_like(input)
148+
if mask[1]:
149+
weight_grad = torch.empty_like(weight)
150+
if mask[2]:
151+
bias_grad = torch.empty_like(bias)
152+
return input_grad, weight_grad, bias_grad
153+
154+
155+
def pytorch_layer_norm_backward(ctx, grad_output: torch.Tensor):
156+
"""ATen/PyTorch fallback implementation for backward."""
157+
158+
input, weight, bias, mean, rstd = ctx.saved_tensors
159+
mask = tuple(ctx.needs_input_grad[0:3])
160+
161+
input_grad, weight_grad, bias_grad = torch.ops.aten.native_layer_norm_backward(
162+
grad_output, input, ctx.normalized_shape, mean, rstd, weight, bias, mask
163+
)
164+
165+
return input_grad, None, weight_grad, bias_grad, None
166+
167+
168+
class _BooLayerNorm(torch.autograd.Function):
169+
@staticmethod
170+
def forward(
171+
ctx: torch.autograd.function.FunctionCtx,
172+
input: torch.Tensor,
173+
normalized_shape: int | Sequence[int] | torch.Size,
174+
weight: torch.Tensor,
175+
bias: torch.Tensor,
176+
eps: float,
177+
) -> torch.Tensor:
178+
result, mean, rstd = torch.ops.boo.layer_norm(
179+
input, normalized_shape, weight, bias, eps
180+
)
181+
ctx.save_for_backward(input, weight, bias, mean, rstd)
182+
ctx.normalized_shape = normalized_shape
183+
return result
184+
185+
@staticmethod
186+
def backward(
187+
ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
188+
) -> tuple[
189+
torch.Tensor | None, None, torch.Tensor | None, torch.Tensor | None, None
190+
]:
191+
if not is_boo_backward_enabled():
192+
return pytorch_layer_norm_backward(ctx, grad_output)
193+
194+
input, weight, bias, mean, rstd = ctx.saved_tensors
195+
# Note that the context contains grad flags for every forward argument
196+
# in order, including non-differentiable attributes like
197+
# `normalized_shape`. The indices below correspond to the positions of
198+
# input, weight and bias in the forward signature.
199+
mask = (
200+
ctx.needs_input_grad[0],
201+
ctx.needs_input_grad[2],
202+
ctx.needs_input_grad[3],
203+
)
204+
input_grad, weight_grad, bias_grad = torch.ops.boo.layer_norm_backward(
205+
grad_output, input, ctx.normalized_shape, mean, rstd, weight, bias, mask
206+
)
207+
208+
return input_grad, None, weight_grad, bias_grad, None
209+
210+
211+
def boo_layer_norm(
212+
input: torch.Tensor,
213+
normalized_shape: Sequence[int],
214+
weight: torch.Tensor | None = None,
215+
bias: torch.Tensor | None = None,
216+
eps: float = 1e-5,
217+
) -> torch.Tensor:
218+
use_autograd = torch._C.is_grad_enabled() and any(
219+
x is not None and x.requires_grad for x in (input, weight, bias)
220+
)
221+
if use_autograd:
222+
return _BooLayerNorm.apply(input, normalized_shape, weight, bias, eps)
223+
result, _, _ = torch.ops.boo.layer_norm(input, normalized_shape, weight, bias, eps)
224+
return result

tests/kernel/boo/ops/boo_conv_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def testBackwardCachePytorch(x_grad, w_grad):
4444
)
4545
y = boo_conv(x, w, shared_layout="NCHW")
4646

47+
# If none of the gradients are required, backward computation will raise
48+
# an error. Tell pytest that this is expected.
4749
context = (
4850
contextlib.nullcontext()
4951
if x_grad or w_grad

0 commit comments

Comments
 (0)