Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions iree/turbine/kernel/boo/driver/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ conv = get_launchable(sample_signature)

## Benchmarking

The `driver.py` script allows for running kernels from the command line. It uses the same interface as `MIOpenDriver`:
The `driver.py` script allows for running kernels from the command line. It uses
an interface similar to that of `MIOpenDriver`: some additional flags or flag
values are added to support scenarios not supported by the driver such as
non-default layouts.

```console
$ python driver.py convbfp16 -n 128 -c 128 -H 24 -W 48 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -m conv -g 1 -F 1 -t 1 --iter 100 --in_layout NHWC --out_layout NHWC --fil_layout NHWC
Expand All @@ -76,7 +79,19 @@ convbfp16 -n 128 -c 35 -H 48 -W 32 -k 35 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1
...
```

The `--time 1` (or `-t 1` for short) option to collect timing is implemented by launching the kernel, which is then profiled using `torch.profiler`. Only the actual IREE kernel dispatch time is reported. Note: you can output `min_time (us)` to a csv file with `--csv=results.csv`.
The `--time 1` (or `-t 1` for short) option to collect timing is implemented by
launching the kernel, which is then profiled using `torch.profiler`. Overall GPU
time is reported, including memory and other operations not necessarily included
in the kernel itself. Note: you can statistics to a csv file with
`--csv=results.csv`.

BOO operations can be compared against a set of reference backends by providing
one or more `--reference-backend` flags. Currently supported backends include:

- `torch`: Eager Pytorch.
- `inductor`: Pytorch Inductor (`torch.compile` default).
- `iree_boo_inductor`: BOO where applicable and Inductor otherwise.
- `iree_boo_legacy`: direct call of BOO kernel without `torch.compile`.

#### Misc requirements Q&A:

Expand Down
114 changes: 80 additions & 34 deletions iree/turbine/kernel/boo/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import shlex
import statistics
from functools import partial

import torch
from torch.autograd.profiler_util import FunctionEvent
Expand Down Expand Up @@ -56,7 +57,7 @@ def _get_main_driver_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--reference-backend",
type=str,
choices=["torch", "torch-compile"],
choices=[c for c in BACKEND_TO_FUNC_GENERATOR.keys() if c != "iree_boo"],
action="append",
default=[],
required=False,
Expand Down Expand Up @@ -112,11 +113,6 @@ def main():

# Check the reference backend
ref_backends = meta_args.reference_backend
# TODO: Add ability to benchmark against torch-compile (inductor).
if "torch-compile" in ref_backends:
raise NotImplementedError(
"Comparing against torch-compiled reference not yet implemented."
)

# Setup a csv output file with headers.
csv_stats = ALL_STATS
Expand Down Expand Up @@ -149,17 +145,10 @@ def main():
)
output_num_bytes = signature.get_output_size()

profile_context = (
profile(activities=[ProfilerActivity.CUDA])
if timing_args.time
else nullcontext()
)

for backend in backends:
_func = BACKEND_TO_FUNC_GENERATOR[backend](signature)
try:
with profile_context as prof:
run(_func, timing_args.iter, output_num_bytes, sample_inputs)
prof = run(_func, timing_args, output_num_bytes, sample_inputs)
except Exception as exc:
print(f">>> ERROR: {exc}")
csv_file.write("N.A.," * len(csv_stats))
Expand Down Expand Up @@ -217,33 +206,80 @@ def get_aggregate_stats(

def run(
func: Callable,
iter: int,
timing_args: argparse.ArgumentParser,
output_num_bytes: int,
per_device_args: Sequence[tuple[torch.Tensor, ...]],
) -> None:
"""Distributes `iter`-many applications of `func` to `per_device_args`."""
) -> profile | None:
"""Distributes `iter`-many applications of `func` to `per_device_args`. If
timing is requested, returns a torch profiler object that can be inspected
to recover time-related information."""
num_devices = len(per_device_args)
iter_per_device = iter // num_devices
rem_iter = iter % num_devices
# This is a rough threshold: Mi300x 192 GB memory divided by 2.
mem_bytes_threshold = 96 * (10**9)
iter_thresh = int(mem_bytes_threshold // output_num_bytes)
assert (
iter_thresh > 1 or not timing_args.time
), "Cannot reliably profile if cleanup is needed after every step."

# Cleanup is performed after every iter_thresh steps, including the
# initialization step and the cleanup steps themselves:
# num_cleanups = (iter // num_devices + num_cleanups + 1) // iter_thresh
# Solving which leads to the form below.
num_cleanups = (timing_args.iter // num_devices + 1) // (iter_thresh - 1)

# The total number of iterations includes cleanups and the initial
# initialization operation that are not profiled, so the profiler records
# the expected number of iterations. When not profiling, just run as many
# times as requested.
total_num_iters = (
timing_args.iter + (num_cleanups + 1) * num_devices
if timing_args.time
else timing_args.iter
)

def needs_cleanup(step: int) -> bool:
per_device_step = step // num_devices
return (per_device_step + 1) % iter_thresh == 0

def schedule_fn(step: int) -> torch.profiler.ProfilerAction:
"""Scheduling function for the profiler. Ensures it doesn't capture the
first iteration and the cleanup iterations where additional overhead may
happen."""
# Skip fist run on each device.
if step < num_devices:
return torch.profiler.ProfilerAction.NONE

# Skip the step on which cleanup happens.
if needs_cleanup(step):
return torch.profiler.ProfilerAction.NONE

# Save the results at the last iteration.
if step == total_num_iters:
return torch.profiler.ProfilerAction.RECORD_AND_SAVE
return torch.profiler.ProfilerAction.RECORD

profile_context = (
profile(activities=[ProfilerActivity.CUDA], schedule=schedule_fn)
if timing_args.time
else nullcontext()
)

results: tuple[torch.Tensor, ...] | torch.Tensor | None = None
for iter in range(iter_per_device + 1):
for device_idx, launch_args in enumerate(per_device_args):
if iter == iter_per_device and device_idx >= rem_iter:
break
with profile_context as prof:
for iter in range(total_num_iters):
device_idx = iter % num_devices
launch_args = per_device_args[device_idx]
results = func(*launch_args)
if (iter + 1) % iter_thresh == 0:
print(
f">>>\tSynchronizing all devices on iter {iter} and collecting garbage."
)
for i in range(num_devices):
torch.cuda.synchronize(torch.device(f"cuda:{i}"))
gc.collect()

torch.cuda.synchronize()
if needs_cleanup(iter):
print(
f">>>\tSynchronizing all devices on iter {iter} and collecting garbage."
)
for i in range(num_devices):
torch.cuda.synchronize(torch.device(f"cuda:{i}"))
gc.collect()
prof.step()

torch.cuda.synchronize()
if results is None:
results = ()
if isinstance(results, torch.Tensor):
Expand All @@ -252,12 +288,22 @@ def run(
print(
f">>>\tresult #{i} shape: {result.shape}; dtype: {result.dtype}; device type: {result.device.type}"
)
return
return prof if timing_args.time else None


def get_torch_compiled_module(signature: OpSignature, backend: str) -> Callable:
mod = signature.get_nn_module(use_custom=False)
return torch.compile(mod, dynamic=False, backend=backend)


BACKEND_TO_FUNC_GENERATOR: dict[str, Callable[[OpSignature], Callable]] = {
"iree_boo": get_launchable,
"iree_boo_legacy": get_launchable,
"torch": (lambda signature: signature.get_nn_module(use_custom=False)),
"inductor": partial(get_torch_compiled_module, backend="inductor"),
"iree_boo": partial(get_torch_compiled_module, backend="iree_boo"),
"iree_boo_inductor": partial(
get_torch_compiled_module, backend="iree_boo_inductor"
),
}


Expand Down
132 changes: 35 additions & 97 deletions iree/turbine/kernel/boo/op_exports/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,20 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
from typing import (
Any,
Sequence,
TypeVar,
)
from collections.abc import Collection
from typing import Any
from enum import IntEnum
from functools import lru_cache
import math
import warnings

import torch

from .utils import Permutation, permute_layout
from ..exports.signature import OpSignature, ModeBase
from ..exports.parser import OpCLIParser
from ....ops.conv_fwd import conv_2d_nhwc_fhwc, generic_conv
from ....ops.insert_slice import insert_slice
from typing import Sequence

__all__ = [
"Mode",
Expand All @@ -34,93 +31,6 @@
"get_conv_func_name",
]

_T = TypeVar("_T")


class Permutation:
"""Composable and invertible lists which represent the second argument of `torch.permute`."""

def __init__(self, ordering: Sequence[int]):
assert list(sorted(ordering)) == list(
range(len(ordering))
), "ordering must be rearragement of [0,1,2,...,n-1]"
self._items = tuple(ordering)

@property
def size(self) -> int:
return len(self._items)

@property
def items(self) -> tuple[int, ...]:
return self._items

def __getitem__(self, n: int) -> int:
return self.items[n]

def __repr__(self) -> str:
return f"Permutation of {self.size} : {self.items}"

def __eq__(self, other: object) -> bool:
if not isinstance(other, Permutation):
return False
return self.items == other.items

def __len__(self) -> int:
return self.size

def __mul__(self, other: "Permutation") -> "Permutation":
"""mimics composition `torch.permute(torch.permute(a, p1), p0) = torch.permute(a, p0*p1)"""
assert self.size == other.size, "permutations must be the same size"
return Permutation([other.items[element] for element in self.items])

def __call__(self, other: torch.Tensor | Collection[_T]) -> torch.Tensor | list[_T]:
"""apply the permutation to a tensor or iterable (e.g., a shape)"""
if isinstance(other, torch.Tensor):
assert (
len(other.shape) == self.size
), f"permutation must match the rank of the tensor being permuted, got permutation size {self.size} for tensor of shape {other.shape}"
return torch.permute(other, self.items)
if isinstance(other, Collection):
assert len(other) == self.size
return [other[item] for item in self.items]
raise TypeError(f"Unexpected argument type: {type(other)}.")

def __truediv__(self, other: "Permutation") -> "Permutation":
return self * other.inv()

def inv(self) -> "Permutation":
"""inverts the permutation x*inv(x) = inv(x)*x = Permutation.identity(x.size)"""
inverse = list(range(self.size))
for i in range(self.size):
index = self.items[i]
inverse[index] = i
return Permutation(inverse)

@staticmethod
def identity(size: int) -> "Permutation":
"""creates an identity permutation"""
assert size > 0, "size must be positive integer"
return Permutation(list(range(size)))

@staticmethod
def get(src: Collection[_T], target: Collection[_T]) -> "Permutation":
"""Gets a permutation p such that `torch.permute(a, p) = b` where `a.shape = src` and `b.shape = target`"""
n = len(src)
assert n > 0 and n == len(
target
), "source and target iterables must share the same positive length"
d = {t: i for i, t in enumerate(target)}
inverse = []
try:
for item in src:
value = d.pop(item)
inverse.append(value)
except KeyError as e:
raise ValueError(
f"src and target should be permutations of a common set of unique items, got {src=}, {target=}"
)
return Permutation(inverse).inv()


class Mode(ModeBase, IntEnum):
FORWARD = 0
Expand Down Expand Up @@ -441,19 +351,47 @@ def get(shape):
return torch.ones(shape, dtype=self.dtype, device=device) * splat_value
return torch.randn(shape, generator=gen, dtype=self.dtype, device=device)

def get_permuted(
shape: Sequence[int], order: Sequence[int] | Permutation
) -> torch.Tensor:
tensor = get(shape)
items = order if not isinstance(order, Permutation) else order.items
if items == sorted(items):
return tensor
return permute_layout(tensor, order)

# Argument tensors will have to be permuted accordingly to the layouts
# specified in the signature.
input_permutation = Permutation.get("NCHW", self.input_layout)
kernel_permutation = Permutation.get("NCHW", self.kernel_layout)
output_permutation = Permutation.get("NCHW", self.output_layout)

if self.mode == Mode.FORWARD:
# (x, w, b) or (x, w)
return (
(get(self.input_shape), get(self.kernel_shape), get(out_channels))
(
get_permuted(self.input_shape, input_permutation),
get_permuted(self.kernel_shape, kernel_permutation),
get_permuted(out_channels, output_permutation),
)
if self.bias
else (get(self.input_shape), get(self.kernel_shape))
else (
get_permuted(self.input_shape, input_permutation),
get_permuted(self.kernel_shape, kernel_permutation),
)
)
if self.mode == Mode.WEIGHT_BACKWARD:
# (dLdy, x)
return (get(self.output_shape), get(self.input_shape))
return (
get_permuted(self.output_shape, output_permutation),
get_permuted(self.input_shape, input_permutation),
)
if self.mode == Mode.INPUT_BACKWARD:
# (dLdy, w)
return (get(self.output_shape), get(self.kernel_shape))
return (
get_permuted(self.output_shape, output_permutation),
get_permuted(self.kernel_shape, kernel_permutation),
)
raise ValueError(f"Unknown mode: {self.mode}")

@property
Expand Down
Loading
Loading