Skip to content

Commit b3065b6

Browse files
committed
[boo] add explicit casts for layer norm intermediates
Torch choses to store intermediate values passes from the forward to the backward pass using different dtypes based on the device: on CPU, the same dtype as the input is used even if the computation uses double bitwidth, on GPU, the double-width dtype is used instead. Guard against this by allowing the user to explicitly specify the dtype to use for intermediates in the layer norm signature. Note that exporting from torch will still introduce truncations, but these are expected to cancel out during compilation. Signed-off-by: Alex Zinenko <git@ozinenko.com>
1 parent dc34ab9 commit b3065b6

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

iree/turbine/kernel/boo/op_exports/layer_norm.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
bias: bool = True,
4545
dtype=torch.bfloat16,
4646
mode: str | Mode = Mode.FORWARD,
47+
forwarded_args_dtype: torch.dtype | None = None,
4748
):
4849
if (
4950
len(normalized_shape) > len(input_shape)
@@ -60,6 +61,7 @@ def __init__(
6061
self.bias = bias
6162
self.dtype = dtype
6263
self.mode = Mode.parse(mode)
64+
self.forwarded_args_dtype = forwarded_args_dtype or dtype
6365

6466
@property
6567
def output_shape(self) -> list[int]:
@@ -116,6 +118,7 @@ def func_name(self) -> str:
116118
"layer_norm",
117119
f"{len(self.normalized_shape)}d",
118120
str(self.dtype).removeprefix("torch."),
121+
str(self.forwarded_args_dtype).removeprefix("torch."),
119122
self.mode.name.lower(),
120123
"x".join(str(i) for i in self.input_shape),
121124
"w" if self.elementwise_affine is not None else "",
@@ -175,6 +178,7 @@ def as_init_kwargs(self) -> dict[str, Any]:
175178
"bias": self.bias,
176179
"dtype": self.dtype,
177180
"mode": self.Mode,
181+
"forwarded_args_dtype": self.forwarded_args_dtype,
178182
}
179183

180184
def get_output_size(self) -> int:
@@ -229,16 +233,16 @@ def get(shape: Sequence[int]) -> torch.Tensor:
229233
get(self.output_shape),
230234
get(self.input_shape),
231235
get(self.normalized_shape),
232-
get(self.aggregate_shape),
233-
get(self.aggregate_shape),
236+
get(self.aggregate_shape).to(dtype=self.forwarded_args_dtype),
237+
get(self.aggregate_shape).to(dtype=self.forwarded_args_dtype),
234238
)
235239
if self.mode == Mode.WEIGHT_BACKWARD:
236240
# (dLdy, input, mean, rstd)
237241
return (
238242
get(self.output_shape),
239243
get(self.input_shape),
240-
get(self.aggregate_shape),
241-
get(self.aggregate_shape),
244+
get(self.aggregate_shape).to(dtype=self.forwarded_args_dtype),
245+
get(self.aggregate_shape).to(dtype=self.forwarded_args_dtype),
242246
)
243247
if self.mode == Mode.BIAS_BACKWARD:
244248
# (dLdy,)
@@ -253,6 +257,7 @@ def __init__(self, signature: LayerNormSignature):
253257
super().__init__()
254258
self.normalized_shape = signature.normalized_shape
255259
self.eps = signature.eps
260+
self.forwarded_args_dtype = signature.forwarded_args_dtype
256261

257262
def forward(
258263
self,
@@ -272,9 +277,14 @@ def forward(
272277
# torch.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
273278
#
274279
# wrapper hides. We want those too so we can save them for backward.
275-
return torch.ops.aten.native_layer_norm(
280+
output, mean, rstd = torch.ops.aten.native_layer_norm(
276281
input, self.normalized_shape, weight, bias, self.eps
277282
)
283+
return (
284+
output,
285+
mean.to(dtype=self.forwarded_args_dtype),
286+
rstd.to(dtype=self.forwarded_args_dtype),
287+
)
278288

279289

280290
class LayerNormBackwardInput(torch.nn.Module):
@@ -285,6 +295,7 @@ def __init__(self, signature: LayerNormSignature):
285295
super().__init__()
286296
self.normalized_shape = signature.normalized_shape
287297
self.eps = signature.eps
298+
self.dtype = signature.dtype
288299

289300
self.normalized_dim = list(
290301
range(len(signature.input_shape))[-len(self.normalized_shape) :]
@@ -303,6 +314,8 @@ def forward(
303314
) -> torch.Tensor:
304315
# Recompute norm instead of saving it. Judging by the signature, this is the same
305316
# decision as ATen.
317+
mean = mean.to(dtype=self.dtype)
318+
rstd = rstd.to(dtype=self.dtype)
306319
norm = (input - mean) * rstd
307320
dnorm = grad_output * weight if weight is not None else grad_output
308321
dx = (
@@ -321,6 +334,8 @@ def __init__(self, signature: LayerNormSignature):
321334
super().__init__()
322335
self.normalized_shape = signature.normalized_shape
323336
self.eps = signature.eps
337+
self.dtype = signature.dtype
338+
324339
self.normalized_dim = list(
325340
range(len(signature.input_shape))[-len(self.normalized_shape) :]
326341
)
@@ -337,6 +352,8 @@ def forward(
337352
):
338353
# Recompute norm instead of saving it. Judging by the signature, this is the same
339354
# decision as ATen.
355+
mean = mean.to(dtype=self.dtype)
356+
rstd = rstd.to(dtype=self.dtype)
340357
norm = (input - mean) * rstd
341358
return (grad_output * norm).sum(self.keep_dim)
342359

tests/kernel/boo/op_exports/layer_norm_backward_impl_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
# Note that elementwise_affine and bias flags are grouped together to avoid an
1717
# invalid combination.
1818
@pytest.mark.parametrize("dtype", [torch.float32])
19+
@pytest.mark.parametrize("forwarded_dtype", [torch.float32, torch.float64])
1920
@pytest.mark.parametrize("input_shape", [(10, 12, 14, 16), (11, 13, 15)])
2021
@pytest.mark.parametrize(
2122
"elementwise_affine_bias", [(False, False), (True, False), (True, True)]
2223
)
2324
def test_layer_norm_impl(
2425
dtype: torch.dtype,
26+
forwarded_dtype: torch.dtype,
2527
input_shape: tuple[int, ...],
2628
elementwise_affine_bias: tuple[bool, bool],
2729
):
@@ -37,6 +39,7 @@ def test_layer_norm_impl(
3739
"elementwise_affine": elementwise_affine,
3840
"bias": bias,
3941
"dtype": dtype,
42+
"forwarded_args_dtype": forwarded_dtype,
4043
}
4144
fwd_sig = LayerNormSignature(**kwargs)
4245
args = fwd_sig.get_sample_args(seed=1)
@@ -53,6 +56,9 @@ def test_layer_norm_impl(
5356
bwd_bias = bwd_bias_sig.get_nn_module(use_custom=True).to(device="cpu")
5457

5558
fwd_results = fwd(*args)
59+
assert fwd_results[1].dtype == forwarded_dtype
60+
assert fwd_results[2].dtype == forwarded_dtype
61+
5662
main_result = fwd_results[fwd_sig.main_result_index]
5763
main_result.retain_grad()
5864
# TODO: this is not a good loss function (#1021).

0 commit comments

Comments
 (0)