Skip to content

Commit 8581686

Browse files
committed
Address review
- plumb `use_aten` through the driver - use bitmask-style values for the mode enum Signed-off-by: Alex Zinenko <git@ozinenko.com>
1 parent 53f5924 commit 8581686

File tree

2 files changed

+36
-26
lines changed

2 files changed

+36
-26
lines changed

iree/turbine/kernel/boo/op_exports/layer_norm.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
class Mode(ModeBase, IntEnum):
1818
"""Mode selector for layer normalization, with each gradient being its own mode."""
1919

20-
FORWARD = 0
21-
INPUT_BACKWARD = 1
22-
WEIGHT_BACKWARD = 2
23-
BIAS_BACKWARD = 3
24-
FULL_BACKWARD = 4
20+
FORWARD = 1
21+
INPUT_BACKWARD = 2
22+
WEIGHT_BACKWARD = 4
23+
BIAS_BACKWARD = 8
24+
FULL_BACKWARD = INPUT_BACKWARD | WEIGHT_BACKWARD | BIAS_BACKWARD
2525

2626

2727
class LayerNormSignature(OpSignature):
@@ -46,6 +46,7 @@ def __init__(
4646
dtype=torch.bfloat16,
4747
mode: str | Mode = Mode.FORWARD,
4848
forwarded_args_dtype: torch.dtype | None = None,
49+
use_aten: bool = True,
4950
):
5051
if (
5152
len(normalized_shape) > len(input_shape)
@@ -63,6 +64,7 @@ def __init__(
6364
self.dtype = dtype
6465
self.mode = Mode.parse(mode)
6566
self.forwarded_args_dtype = forwarded_args_dtype or dtype
67+
self.use_aten = use_aten
6668

6769
@property
6870
def output_shape(self) -> list[int]:
@@ -124,6 +126,7 @@ def func_name(self) -> str:
124126
"x".join(str(i) for i in self.input_shape),
125127
"w" if self.elementwise_affine is not None else "",
126128
"b" if self.bias is not None else "",
129+
"aten" if self.use_aten else "",
127130
]
128131
return "_".join(name_items)
129132

@@ -183,6 +186,7 @@ def as_init_kwargs(self) -> dict[str, Any]:
183186
"dtype": self.dtype,
184187
"mode": self.Mode,
185188
"forwarded_args_dtype": self.forwarded_args_dtype,
189+
"use_aten": self.use_aten,
186190
}
187191

188192
def get_output_size(self) -> int:
@@ -395,9 +399,9 @@ class LayerNormBackwardFull(torch.nn.Module):
395399
weights, and bias of the layer normalization given the gradient of its
396400
output."""
397401

398-
def __init__(self, signature: LayerNormSignature, *, use_aten=True):
402+
def __init__(self, signature: LayerNormSignature):
399403
super().__init__()
400-
self.use_aten = use_aten
404+
self.use_aten = signature.use_aten
401405
self.normalized_shape = signature.normalized_shape
402406
self.need_bias = signature.bias
403407
self.need_weight = signature.elementwise_affine
@@ -438,12 +442,14 @@ def forward(
438442
# Recompute norm instead of saving it. Judging by the signature, this is the same
439443
# decision as ATen.
440444
norm = (input - mean) * rstd
445+
# norm = norm.to(dtype=input.dtype)
441446
dnorm = grad_output * weight if weight is not None else grad_output
442447
dx = (
443448
dnorm
444449
- dnorm.mean(dim=self.normalized_dim, keepdim=True)
445450
- norm * (dnorm * norm).mean(dim=self.normalized_dim, keepdim=True)
446451
) * rstd
452+
# dx = dx.to(dtype=input.dtype)
447453
dw = None
448454
if self.need_weight:
449455
dw = (grad_output * norm).sum(self.keep_dim)
@@ -489,19 +495,10 @@ def get_signature(args: argparse.Namespace) -> LayerNormSignature:
489495
), "Can only normalize one trailing dimension for now (MIOpen limitation)."
490496
normalized_shape = shape[args.normalized_dim :]
491497

492-
match args.forw:
493-
case 1:
494-
mode = Mode.FORWARD
495-
case 2:
496-
mode = Mode.INPUT_BACKWARD
497-
case 3:
498-
mode = Mode.WEIGHT_BACKWARD
499-
case 4:
500-
mode = Mode.BIAS_BACKWARD
501-
case 5:
502-
mode = Mode.FULL_BACKWARD
503-
case _:
504-
raise ValueError(f"Unsupported mode {args.forw}.")
498+
try:
499+
mode = Mode(args.forw)
500+
except Exception as e:
501+
raise ValueError(f"Unsupported mode {args.forw}.") from e
505502

506503
return LayerNormSignature(
507504
input_shape=shape,
@@ -511,6 +508,7 @@ def get_signature(args: argparse.Namespace) -> LayerNormSignature:
511508
bias=True,
512509
dtype=_DTypeCommandDispatcher.get_dtype(args.command),
513510
mode=mode,
511+
use_aten=args.use_aten,
514512
)
515513

516514
def get_miopen_parser() -> argparse.ArgumentParser:
@@ -519,7 +517,11 @@ def get_miopen_parser() -> argparse.ArgumentParser:
519517
"command", default="layernorm", choices=_DTypeCommandDispatcher.choices()
520518
)
521519
parser.add_argument(
522-
"--forw", "-F", type=int, default=1, help="Run only forward LayerNorm"
520+
"--forw",
521+
"-F",
522+
type=int,
523+
default=1,
524+
help="Kind of kernel to run, not compatible with MIOpen (1 forward, 2 backward input, 4 backward weight, 8 backward bias, 14 full backward)",
523525
)
524526
parser.add_argument(
525527
"--input",
@@ -539,6 +541,12 @@ def get_miopen_parser() -> argparse.ArgumentParser:
539541
parser.add_argument(
540542
"--normalized_dim", "-o", type=int, default=3, help="Normalized dim"
541543
)
544+
parser.add_argument(
545+
"--use-aten",
546+
type=bool,
547+
default=True,
548+
help="Use core ATen op instead of a manual implementation",
549+
)
542550
return parser
543551

544552
@classmethod

tests/kernel/boo/op_exports/layer_norm_backward_impl_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,16 @@ def test_layer_norm_impl(
114114
@pytest.mark.parametrize(
115115
"elementwise_affine_bias", [(False, False), (True, False), (True, True)]
116116
)
117+
@pytest.mark.parametrize("use_aten", [True, False])
117118
def test_layer_norm_combined_impl(
118119
input_shape: tuple[int, ...],
119120
device: str,
120121
dtype: torch.dtype,
121122
elementwise_affine_bias: tuple[bool, bool],
123+
use_aten: bool,
122124
):
123125
# Account for ATen weirdness on GPU.
124-
if device == "cuda" and dtype == torch.bfloat16:
126+
if device == "cuda" and dtype == torch.bfloat16 and use_aten:
125127
forwarded_args_dtype = torch.float32
126128
else:
127129
forwarded_args_dtype = dtype
@@ -134,6 +136,7 @@ def test_layer_norm_combined_impl(
134136
"bias": bias,
135137
"dtype": dtype,
136138
"forwarded_args_dtype": forwarded_args_dtype,
139+
"use_aten": use_aten,
137140
}
138141
fwd_sig = LayerNormSignature(**kwargs)
139142
args = fwd_sig.get_sample_args(seed=1)
@@ -150,15 +153,14 @@ def test_layer_norm_combined_impl(
150153

151154
main_result = fwd_results[fwd_sig.main_result_index]
152155
main_result.retain_grad()
153-
# TODO: this is not a good loss function (#1021).
154-
loss = main_result.sum()
156+
loss = main_result.mean() / main_result.numel()
155157
loss.backward(retain_graph=True)
156158

157159
bwd_input_args = bwd_sig.arrange_backward_launch_args(args, fwd_results)
158160
grads = tuple(x for x in bwd(main_result.grad, *bwd_input_args) if x is not None)
159161

160-
rtol = 1e-4
161-
atol = 1e-4
162+
rtol = 1e-6
163+
atol = 1e-6
162164
assert len(grads) == len(args)
163165
results = [
164166
torch.allclose(arg.grad, grad, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)