Skip to content

Commit 1e4a294

Browse files
authored
[boo] layer norm support (#979)
Add support for layer normalization kernels in boo. This patch adds the signature, parser and implementation for forward/backward kernels + autograd support. Smoke tests are provided for correctness and correct caching of kernel IR. Additional testing can be performed via the dedicated numerics script, similar to how convolutions are handled. Not yet supported: - Non-contiguous tensors. - Op replacement in PyTorch graph. These will be added separately. --------- Signed-off-by: Alex Zinenko <git@ozinenko.com>
1 parent c5a4044 commit 1e4a294

File tree

14 files changed

+613
-33
lines changed

14 files changed

+613
-33
lines changed

iree/turbine/kernel/boo/conv_exports/conv.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
)
1414

1515
from enum import IntEnum
16+
from functools import cached_property
1617
import math
1718

1819
import torch
1920

2021
from .utils import Permutation
21-
from ..exports.signature import OpSignature
22+
from ..exports.signature import OpSignature, ModeBase
2223
from ....ops.conv_fwd import conv_2d_nhwc_fhwc, generic_conv
2324
from ....ops.insert_slice import insert_slice
2425

@@ -32,7 +33,7 @@
3233
]
3334

3435

35-
class Mode(IntEnum):
36+
class Mode(ModeBase, IntEnum):
3637
FORWARD = 0
3738
INPUT_BACKWARD = 1
3839
WEIGHT_BACKWARD = 2
@@ -42,20 +43,6 @@ class Mode(IntEnum):
4243
BWD = INPUT_BACKWARD
4344
WRW = WEIGHT_BACKWARD
4445

45-
@staticmethod
46-
def parse(spec: Union[str, None, "Mode"]) -> "Mode":
47-
if spec is None:
48-
return Mode.FORWARD
49-
if isinstance(spec, Mode):
50-
return spec
51-
spec = spec.upper().replace("-", "_")
52-
if spec not in Mode.__members__:
53-
raise ValueError(
54-
f"For mode= argument, expected one of: "
55-
f"{', '.join(Mode.__members__.keys())}"
56-
)
57-
return Mode[spec]
58-
5946
def __str__(self):
6047
return self.name
6148

@@ -325,7 +312,8 @@ def get(shape):
325312
return (get(self.output_shape), get(self.kernel_shape))
326313
raise ValueError(f"Unknown mode: {self.mode}")
327314

328-
def get_func_name(self):
315+
@cached_property
316+
def func_name(self) -> str:
329317
name_items = [
330318
"conv",
331319
f"{self.num_spatial_dims}d",

iree/turbine/kernel/boo/driver/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ populator.run()
3131
sample_signature = populator.signatures[0]
3232

3333
# One can also check the cache for this signature
34-
cache_status = populator.get_cache_status(sample_signature.get_func_name())
34+
cache_status = populator.get_cache_status(sample_signature.func_name)
3535
print(cache_status)
3636

3737
# You can also get the list of failed signatures via:

iree/turbine/kernel/boo/driver/driver.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@ def main():
2828
parser = argparse.ArgumentParser(
2929
usage="%(prog)s [-h] [... MIOpenDriver command ...] [--commands-file COMMANDS_FILE]",
3030
description="""
31-
Run a convolution with the IREE runtime. Command line arguments mirror the
31+
Run a kernel with the IREE runtime. Command line arguments mirror the
3232
arguments to MIOpenDriver.
3333
34+
Currently supports convolution and layernorm.
35+
3436
If COMMANDS_FILE is specified, driver commands are read from the file. Each
3537
line is treated as a separate invocation of the driver, and any additional
3638
command-line arguments are appended to the arguments from the file.
@@ -115,10 +117,16 @@ def run(cli_args: Sequence[str], gpu_id: int):
115117
from iree.turbine.kernel.boo.exports.parser import OpCLIParser
116118

117119
def dispatch(cli_args: Sequence[str]) -> type[OpCLIParser]:
118-
if any(map(lambda x: "conv" in x, cli_args)):
120+
if any("conv" in x for x in cli_args):
119121
from iree.turbine.kernel.boo.conv_exports.miopen_parser import ConvParser
120122

121123
return ConvParser
124+
if any("layernorm" in x for x in cli_args):
125+
from iree.turbine.kernel.boo.layer_norm_exports.miopen_parser import (
126+
LayerNormParser,
127+
)
128+
129+
return LayerNormParser
122130
raise ValueError("unsupported operation kind in " + shlex.join(cli_args))
123131

124132
from iree.turbine.kernel.boo.driver.launch import get_launchable
@@ -165,24 +173,28 @@ def dispatch(cli_args: Sequence[str]) -> type[OpCLIParser]:
165173
mem_bytes_threshold = 96 * (10**9)
166174
iter_thresh = int(mem_bytes_threshold // res_mem_bytes)
167175

168-
result = None
176+
results: tuple[torch.Tensor, ...] | torch.Tensor | None = None
169177
for iter in range(iter_per_device + 1):
170178
for device_idx, launch_args in enumerate(per_device_data):
171179
if iter == iter_per_device and device_idx >= rem_iter:
172180
break
173-
result = launchable(*launch_args)
181+
results = launchable(*launch_args)
174182
if (iter + 1) % iter_thresh == 0:
175183
print(f"Synchronizing all devices on iter {iter} and collecting garbage.")
176184
for i in range(num_devices):
177185
torch.cuda.synchronize(torch.device(f"cuda:{i}"))
178186
gc.collect()
179187

180188
torch.cuda.synchronize()
181-
print(
182-
f">>> result shape: {result.shape}; dtype: {result.dtype}; device type: {result.device.type}"
183-
)
189+
results = results or ()
190+
if isinstance(results, torch.Tensor):
191+
results = (results,)
192+
for i, result in enumerate(results):
193+
print(
194+
f">>> result #{i} shape: {result.shape}; dtype: {result.dtype}; device type: {result.device.type}"
195+
)
184196

185-
return sig.get_func_name()
197+
return sig.func_name
186198

187199

188200
TRACY_PORT = str(random.randint(40_000, 50_000))

iree/turbine/kernel/boo/driver/launch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def get_module_asm(
1414
signature: OpSignature, func_name: str | None = None, use_custom: bool = True
1515
) -> str:
16-
func_name = func_name or signature.get_func_name()
16+
func_name = func_name or signature.func_name
1717
module_factory = lambda: signature.get_nn_module(use_custom=use_custom)
1818
arg_factory = lambda: signature.get_sample_args(splat_value=0)
1919
return generic_get_module_asm(
@@ -24,7 +24,7 @@ def get_module_asm(
2424
def get_launchable(
2525
signature: OpSignature, *, use_custom=True, cache_only=False
2626
) -> Launchable:
27-
func_name = signature.get_func_name()
27+
func_name = signature.func_name
2828
module_factory = lambda: signature.get_nn_module(use_custom=use_custom)
2929
arg_factory = lambda: signature.get_sample_args(splat_value=0)
3030
return generic_get_launchable(

iree/turbine/kernel/boo/driver/preload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def get_failures(self) -> dict[str, str]:
180180

181181

182182
def mlir_import(sig: OpSignature) -> tuple[str, bool]:
183-
func_name = sig.get_func_name()
183+
func_name = sig.func_name
184184
success = False
185185
try:
186186
get_module_asm(sig, func_name)

iree/turbine/kernel/boo/exports/signature.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@ def get_sample_args(
4242
"""Generates sample arguments as PyTorch tensors for the operation."""
4343
...
4444

45+
@property
4546
@abstractmethod
46-
def get_func_name(self) -> str:
47-
"""Generates an MLIR function name to use for the operation, unique across operations."""
47+
def func_name(self) -> str:
48+
"""MLIR function name to use for the operation, unique across operations."""
4849
...
4950

5051
@abstractmethod

0 commit comments

Comments
 (0)