From d28ae8ba134a812a854d82ac079d2bc6f26c8eaf Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Fri, 5 Sep 2025 15:16:25 +0200 Subject: [PATCH 1/3] [boo] support input permutation in layer norm Add an option to support input permutations in the signature and generate sample arguments accordingly. Refactor the Permutation class from conv into a generic utils. Signed-off-by: Alex Zinenko --- iree/turbine/kernel/boo/op_exports/conv.py | 95 +-------------- .../kernel/boo/op_exports/layer_norm.py | 26 +++- iree/turbine/kernel/boo/op_exports/utils.py | 111 ++++++++++++++++++ tests/kernel/boo/op_exports/utils_test.py | 21 ++++ 4 files changed, 157 insertions(+), 96 deletions(-) create mode 100644 iree/turbine/kernel/boo/op_exports/utils.py create mode 100644 tests/kernel/boo/op_exports/utils_test.py diff --git a/iree/turbine/kernel/boo/op_exports/conv.py b/iree/turbine/kernel/boo/op_exports/conv.py index 6bef52df3..fede7930e 100644 --- a/iree/turbine/kernel/boo/op_exports/conv.py +++ b/iree/turbine/kernel/boo/op_exports/conv.py @@ -5,12 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import argparse -from typing import ( - Any, - Sequence, - TypeVar, -) -from collections.abc import Collection +from typing import Any from enum import IntEnum from functools import lru_cache import math @@ -18,6 +13,7 @@ import torch +from .utils import Permutation from ..exports.signature import OpSignature, ModeBase from ..exports.parser import OpCLIParser from ....ops.conv_fwd import conv_2d_nhwc_fhwc, generic_conv @@ -34,93 +30,6 @@ "get_conv_func_name", ] -_T = TypeVar("_T") - - -class Permutation: - """Composable and invertible lists which represent the second argument of `torch.permute`.""" - - def __init__(self, ordering: Sequence[int]): - assert list(sorted(ordering)) == list( - range(len(ordering)) - ), "ordering must be rearragement of [0,1,2,...,n-1]" - self._items = tuple(ordering) - - @property - def size(self) -> int: - return len(self._items) - - @property - def items(self) -> tuple[int, ...]: - return self._items - - def __getitem__(self, n: int) -> int: - return self.items[n] - - def __repr__(self) -> str: - return f"Permutation of {self.size} : {self.items}" - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Permutation): - return False - return self.items == other.items - - def __len__(self) -> int: - return self.size - - def __mul__(self, other: "Permutation") -> "Permutation": - """mimics composition `torch.permute(torch.permute(a, p1), p0) = torch.permute(a, p0*p1)""" - assert self.size == other.size, "permutations must be the same size" - return Permutation([other.items[element] for element in self.items]) - - def __call__(self, other: torch.Tensor | Collection[_T]) -> torch.Tensor | list[_T]: - """apply the permutation to a tensor or iterable (e.g., a shape)""" - if isinstance(other, torch.Tensor): - assert ( - len(other.shape) == self.size - ), f"permutation must match the rank of the tensor being permuted, got permutation size {self.size} for tensor of shape {other.shape}" - return torch.permute(other, self.items) - if isinstance(other, Collection): - assert len(other) == self.size - return [other[item] for item in self.items] - raise TypeError(f"Unexpected argument type: {type(other)}.") - - def __truediv__(self, other: "Permutation") -> "Permutation": - return self * other.inv() - - def inv(self) -> "Permutation": - """inverts the permutation x*inv(x) = inv(x)*x = Permutation.identity(x.size)""" - inverse = list(range(self.size)) - for i in range(self.size): - index = self.items[i] - inverse[index] = i - return Permutation(inverse) - - @staticmethod - def identity(size: int) -> "Permutation": - """creates an identity permutation""" - assert size > 0, "size must be positive integer" - return Permutation(list(range(size))) - - @staticmethod - def get(src: Collection[_T], target: Collection[_T]) -> "Permutation": - """Gets a permutation p such that `torch.permute(a, p) = b` where `a.shape = src` and `b.shape = target`""" - n = len(src) - assert n > 0 and n == len( - target - ), "source and target iterables must share the same positive length" - d = {t: i for i, t in enumerate(target)} - inverse = [] - try: - for item in src: - value = d.pop(item) - inverse.append(value) - except KeyError as e: - raise ValueError( - f"src and target should be permutations of a common set of unique items, got {src=}, {target=}" - ) - return Permutation(inverse).inv() - class Mode(ModeBase, IntEnum): FORWARD = 0 diff --git a/iree/turbine/kernel/boo/op_exports/layer_norm.py b/iree/turbine/kernel/boo/op_exports/layer_norm.py index a33541623..d2a37b22f 100644 --- a/iree/turbine/kernel/boo/op_exports/layer_norm.py +++ b/iree/turbine/kernel/boo/op_exports/layer_norm.py @@ -10,6 +10,7 @@ import torch import math from functools import cached_property +from .utils import Permutation, permute_layout from ..exports.signature import OpSignature, ModeBase from ..exports.parser import OpCLIParser @@ -33,6 +34,7 @@ class LayerNormSignature(OpSignature): bias: bool dtype: torch.dtype mode: Mode + input_permutation: list[int] def __init__( self, @@ -45,6 +47,7 @@ def __init__( dtype=torch.bfloat16, mode: str | Mode = Mode.FORWARD, forwarded_args_dtype: torch.dtype | None = None, + input_permutation: Sequence[int] | None = None, ): if ( len(normalized_shape) > len(input_shape) @@ -62,6 +65,9 @@ def __init__( self.dtype = dtype self.mode = Mode.parse(mode) self.forwarded_args_dtype = forwarded_args_dtype or dtype + self.input_permutation = input_permutation or list( + Permutation.identity(len(input_shape)).items + ) @property def output_shape(self) -> list[int]: @@ -123,6 +129,11 @@ def func_name(self) -> str: "x".join(str(i) for i in self.input_shape), "w" if self.elementwise_affine is not None else "", "b" if self.bias is not None else "", + ( + "perm_" + "".join(self.input_permutation) + if self.input_permutation != sorted(self.input_permutation) + else "" + ), ] return "_".join(name_items) @@ -179,6 +190,7 @@ def as_init_kwargs(self) -> dict[str, Any]: "dtype": self.dtype, "mode": self.Mode, "forwarded_args_dtype": self.forwarded_args_dtype, + "input_permutation": self.input_permutation, } def get_output_size(self) -> int: @@ -219,9 +231,15 @@ def get(shape: Sequence[int]) -> torch.Tensor: return torch.ones(shape, dtype=self.dtype, device=device) * splat_value return torch.randn(shape, generator=gen, dtype=self.dtype, device=device) + def get_permuted(shape: Sequence[int], order: Sequence[int]) -> torch.Tensor: + tensor = get(shape) + if order == sorted(order): + return tensor + return permute_layout(tensor, order) + if self.mode == Mode.FORWARD: # (x, w?, b?) - args = [get(self.input_shape)] + args = [get_permuted(self.input_shape, self.input_permutation)] if self.elementwise_affine: args.append(get(self.normalized_shape)) if self.bias: @@ -231,7 +249,7 @@ def get(shape: Sequence[int]) -> torch.Tensor: # (dLdy, input, weight, mean, rstd) return ( get(self.output_shape), - get(self.input_shape), + get_permuted(self.input_shape, self.input_permutation), get(self.normalized_shape), get(self.aggregate_shape).to(dtype=self.forwarded_args_dtype), get(self.aggregate_shape).to(dtype=self.forwarded_args_dtype), @@ -240,7 +258,7 @@ def get(shape: Sequence[int]) -> torch.Tensor: # (dLdy, input, mean, rstd) return ( get(self.output_shape), - get(self.input_shape), + get_permuted(self.input_shape, self.input_permutation), get(self.aggregate_shape).to(dtype=self.forwarded_args_dtype), get(self.aggregate_shape).to(dtype=self.forwarded_args_dtype), ) @@ -430,6 +448,7 @@ def get_signature(args: argparse.Namespace) -> LayerNormSignature: bias=True, dtype=_DTypeCommandDispatcher.get_dtype(args.command), mode=mode, + input_permutation=args.input_permutation, ) def get_miopen_parser() -> argparse.ArgumentParser: @@ -458,6 +477,7 @@ def get_miopen_parser() -> argparse.ArgumentParser: parser.add_argument( "--normalized_dim", "-o", type=int, default=3, help="Normalized dim" ) + parser.add_argument("--input-permutation", type=int, nargs="*", default=None) return parser @classmethod diff --git a/iree/turbine/kernel/boo/op_exports/utils.py b/iree/turbine/kernel/boo/op_exports/utils.py new file mode 100644 index 000000000..b57aef8cc --- /dev/null +++ b/iree/turbine/kernel/boo/op_exports/utils.py @@ -0,0 +1,111 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +from typing import Sequence, Collection, TypeVar + +_T = TypeVar("_T") + + +class Permutation: + """Composable and invertible lists which represent the second argument of `torch.permute`.""" + + def __init__(self, ordering: Sequence[int]): + assert list(sorted(ordering)) == list( + range(len(ordering)) + ), "ordering must be rearragement of [0,1,2,...,n-1]" + self._items = tuple(ordering) + + @property + def size(self) -> int: + return len(self._items) + + @property + def items(self) -> tuple[int, ...]: + return self._items + + def __getitem__(self, n: int) -> int: + return self.items[n] + + def __repr__(self) -> str: + return f"Permutation of {self.size} : {self.items}" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Permutation): + return False + return self.items == other.items + + def __len__(self) -> int: + return self.size + + def __mul__(self, other: "Permutation") -> "Permutation": + """mimics composition `torch.permute(torch.permute(a, p1), p0) = torch.permute(a, p0*p1)""" + assert self.size == other.size, "permutations must be the same size" + return Permutation([other.items[element] for element in self.items]) + + def __call__(self, other: torch.Tensor | Collection[_T]) -> torch.Tensor | list[_T]: + """apply the permutation to a tensor or iterable (e.g., a shape)""" + if isinstance(other, torch.Tensor): + assert ( + len(other.shape) == self.size + ), f"permutation must match the rank of the tensor being permuted, got permutation size {self.size} for tensor of shape {other.shape}" + return torch.permute(other, self.items) + if isinstance(other, Collection): + assert len(other) == self.size + return [other[item] for item in self.items] + raise TypeError(f"Unexpected argument type: {type(other)}.") + + def __truediv__(self, other: "Permutation") -> "Permutation": + return self * other.inv() + + def inv(self) -> "Permutation": + """inverts the permutation x*inv(x) = inv(x)*x = Permutation.identity(x.size)""" + inverse = list(range(self.size)) + for i in range(self.size): + index = self.items[i] + inverse[index] = i + return Permutation(inverse) + + @staticmethod + def identity(size: int) -> "Permutation": + """creates an identity permutation""" + assert size > 0, "size must be positive integer" + return Permutation(list(range(size))) + + @staticmethod + def get(src: Collection[_T], target: Collection[_T]) -> "Permutation": + """Gets a permutation p such that `torch.permute(a, p) = b` where `a.shape = src` and `b.shape = target`""" + n = len(src) + assert n > 0 and n == len( + target + ), "source and target iterables must share the same positive length" + d = {t: i for i, t in enumerate(target)} + inverse = [] + try: + for item in src: + value = d.pop(item) + inverse.append(value) + except KeyError as e: + raise ValueError( + f"src and target should be permutations of a common set of unique items, got {src=}, {target=}" + ) + return Permutation(inverse).inv() + + +def permute_layout( + tensor: torch.Tensor, permutation: Permutation | Sequence[int] +) -> torch.Tensor: + """Returns a new tensor that is the given permutation of the input tensor. + + The resulting tensor is stored in the contiguous format after permutation + and its shape/strides are adjusted to match the shape of the original + tensor. + """ + if not isinstance(permutation, Permutation): + permutation = Permutation(permutation) + permuted = permutation(tensor) + rematerialized = permuted.clone(memory_format=torch.contiguous_format) + return permutation.inv()(rematerialized) diff --git a/tests/kernel/boo/op_exports/utils_test.py b/tests/kernel/boo/op_exports/utils_test.py new file mode 100644 index 000000000..5b44927a4 --- /dev/null +++ b/tests/kernel/boo/op_exports/utils_test.py @@ -0,0 +1,21 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +from iree.turbine.kernel.boo.op_exports.utils import permute_layout + + +def test_permute_layout(): + assert permute_layout(torch.empty((2, 3, 4, 5)), [0, 2, 3, 1]).is_contiguous( + memory_format=torch.channels_last + ) + assert permute_layout(torch.empty((2, 3, 4, 5, 6)), [0, 2, 3, 4, 1]).is_contiguous( + memory_format=torch.channels_last_3d + ) + + permuted_1 = permute_layout(torch.empty((2, 3, 4)), [0, 2, 1]) + assert list(permuted_1.shape) == [2, 3, 4] + assert permuted_1.stride() == (12, 1, 3) From 575b8a766ffa3b102bc256e3b6ce32dc601b3d76 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 3 Sep 2025 12:26:01 +0000 Subject: [PATCH 2/3] [boo] use torch compile backend in driver Go through `torch.compile` in the boo driver instead of calling the kernel directly. This allows to collect more realistic execution time statistics that are close to what the final user will see after integration. The direct kernel invocation is kept as a reference option to evaluate the overhead. Tweak the profiler configuration to ignore initialization and cleanup steps so as to avoid skewing aggregate statistics. Signed-off-by: Alex Zinenko --- iree/turbine/kernel/boo/driver/README.md | 19 +++- iree/turbine/kernel/boo/driver/driver.py | 114 ++++++++++++++++------- 2 files changed, 97 insertions(+), 36 deletions(-) diff --git a/iree/turbine/kernel/boo/driver/README.md b/iree/turbine/kernel/boo/driver/README.md index 6b654becd..7fa710080 100644 --- a/iree/turbine/kernel/boo/driver/README.md +++ b/iree/turbine/kernel/boo/driver/README.md @@ -51,7 +51,10 @@ conv = get_launchable(sample_signature) ## Benchmarking -The `driver.py` script allows for running kernels from the command line. It uses the same interface as `MIOpenDriver`: +The `driver.py` script allows for running kernels from the command line. It uses +an interface similar to that of `MIOpenDriver`: some additional flags or flag +values are added to support scenarios not supported by the driver such as +non-default layouts. ```console $ python driver.py convbfp16 -n 128 -c 128 -H 24 -W 48 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -m conv -g 1 -F 1 -t 1 --iter 100 --in_layout NHWC --out_layout NHWC --fil_layout NHWC @@ -76,7 +79,19 @@ convbfp16 -n 128 -c 35 -H 48 -W 32 -k 35 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 ... ``` -The `--time 1` (or `-t 1` for short) option to collect timing is implemented by launching the kernel, which is then profiled using `torch.profiler`. Only the actual IREE kernel dispatch time is reported. Note: you can output `min_time (us)` to a csv file with `--csv=results.csv`. +The `--time 1` (or `-t 1` for short) option to collect timing is implemented by +launching the kernel, which is then profiled using `torch.profiler`. Overall GPU +time is reported, including memory and other operations not necessarily included +in the kernel itself. Note: you can statistics to a csv file with +`--csv=results.csv`. + +BOO operations can be compared against a set of reference backends by providing +one or more `--reference-backend` flags. Currently supported backends include: + +- `torch`: Eager Pytorch. +- `inductor`: Pytorch Inductor (`torch.compile` default). +- `iree_boo_inductor`: BOO where applicable and Inductor otherwise. +- `iree_boo_legacy`: direct call of BOO kernel without `torch.compile`. #### Misc requirements Q&A: diff --git a/iree/turbine/kernel/boo/driver/driver.py b/iree/turbine/kernel/boo/driver/driver.py index 2519047ba..84fdd311d 100644 --- a/iree/turbine/kernel/boo/driver/driver.py +++ b/iree/turbine/kernel/boo/driver/driver.py @@ -12,6 +12,7 @@ import os import shlex import statistics +from functools import partial import torch from torch.autograd.profiler_util import FunctionEvent @@ -56,7 +57,7 @@ def _get_main_driver_parser() -> argparse.ArgumentParser: parser.add_argument( "--reference-backend", type=str, - choices=["torch", "torch-compile"], + choices=[c for c in BACKEND_TO_FUNC_GENERATOR.keys() if c != "iree_boo"], action="append", default=[], required=False, @@ -112,11 +113,6 @@ def main(): # Check the reference backend ref_backends = meta_args.reference_backend - # TODO: Add ability to benchmark against torch-compile (inductor). - if "torch-compile" in ref_backends: - raise NotImplementedError( - "Comparing against torch-compiled reference not yet implemented." - ) # Setup a csv output file with headers. csv_stats = ALL_STATS @@ -149,17 +145,10 @@ def main(): ) output_num_bytes = signature.get_output_size() - profile_context = ( - profile(activities=[ProfilerActivity.CUDA]) - if timing_args.time - else nullcontext() - ) - for backend in backends: _func = BACKEND_TO_FUNC_GENERATOR[backend](signature) try: - with profile_context as prof: - run(_func, timing_args.iter, output_num_bytes, sample_inputs) + prof = run(_func, timing_args, output_num_bytes, sample_inputs) except Exception as exc: print(f">>> ERROR: {exc}") csv_file.write("N.A.," * len(csv_stats)) @@ -217,33 +206,80 @@ def get_aggregate_stats( def run( func: Callable, - iter: int, + timing_args: argparse.ArgumentParser, output_num_bytes: int, per_device_args: Sequence[tuple[torch.Tensor, ...]], -) -> None: - """Distributes `iter`-many applications of `func` to `per_device_args`.""" +) -> profile | None: + """Distributes `iter`-many applications of `func` to `per_device_args`. If + timing is requested, returns a torch profiler object that can be inspected + to recover time-related information.""" num_devices = len(per_device_args) - iter_per_device = iter // num_devices - rem_iter = iter % num_devices # This is a rough threshold: Mi300x 192 GB memory divided by 2. mem_bytes_threshold = 96 * (10**9) iter_thresh = int(mem_bytes_threshold // output_num_bytes) + assert ( + iter_thresh > 1 or not timing_args.time + ), "Cannot reliably profile if cleanup is needed after every step." + + # Cleanup is performed after every iter_thresh steps, including the + # initialization step and the cleanup steps themselves: + # num_cleanups = (iter // num_devices + num_cleanups + 1) // iter_thresh + # Solving which leads to the form below. + num_cleanups = (timing_args.iter // num_devices + 1) // (iter_thresh - 1) + + # The total number of iterations includes cleanups and the initial + # initialization operation that are not profiled, so the profiler records + # the expected number of iterations. When not profiling, just run as many + # times as requested. + total_num_iters = ( + timing_args.iter + (num_cleanups + 1) * num_devices + if timing_args.time + else timing_args.iter + ) + + def needs_cleanup(step: int) -> bool: + per_device_step = step // num_devices + return (per_device_step + 1) % iter_thresh == 0 + + def schedule_fn(step: int) -> torch.profiler.ProfilerAction: + """Scheduling function for the profiler. Ensures it doesn't capture the + first iteration and the cleanup iterations where additional overhead may + happen.""" + # Skip fist run on each device. + if step < num_devices: + return torch.profiler.ProfilerAction.NONE + + # Skip the step on which cleanup happens. + if needs_cleanup(step): + return torch.profiler.ProfilerAction.NONE + + # Save the results at the last iteration. + if step == total_num_iters: + return torch.profiler.ProfilerAction.RECORD_AND_SAVE + return torch.profiler.ProfilerAction.RECORD + + profile_context = ( + profile(activities=[ProfilerActivity.CUDA], schedule=schedule_fn) + if timing_args.time + else nullcontext() + ) results: tuple[torch.Tensor, ...] | torch.Tensor | None = None - for iter in range(iter_per_device + 1): - for device_idx, launch_args in enumerate(per_device_args): - if iter == iter_per_device and device_idx >= rem_iter: - break + with profile_context as prof: + for iter in range(total_num_iters): + device_idx = iter % num_devices + launch_args = per_device_args[device_idx] results = func(*launch_args) - if (iter + 1) % iter_thresh == 0: - print( - f">>>\tSynchronizing all devices on iter {iter} and collecting garbage." - ) - for i in range(num_devices): - torch.cuda.synchronize(torch.device(f"cuda:{i}")) - gc.collect() - - torch.cuda.synchronize() + if needs_cleanup(iter): + print( + f">>>\tSynchronizing all devices on iter {iter} and collecting garbage." + ) + for i in range(num_devices): + torch.cuda.synchronize(torch.device(f"cuda:{i}")) + gc.collect() + prof.step() + + torch.cuda.synchronize() if results is None: results = () if isinstance(results, torch.Tensor): @@ -252,12 +288,22 @@ def run( print( f">>>\tresult #{i} shape: {result.shape}; dtype: {result.dtype}; device type: {result.device.type}" ) - return + return prof if timing_args.time else None + + +def get_torch_compiled_module(signature: OpSignature, backend: str) -> Callable: + mod = signature.get_nn_module(use_custom=False) + return torch.compile(mod, dynamic=False, backend=backend) BACKEND_TO_FUNC_GENERATOR: dict[str, Callable[[OpSignature], Callable]] = { - "iree_boo": get_launchable, + "iree_boo_legacy": get_launchable, "torch": (lambda signature: signature.get_nn_module(use_custom=False)), + "inductor": partial(get_torch_compiled_module, backend="inductor"), + "iree_boo": partial(get_torch_compiled_module, backend="iree_boo"), + "iree_boo_inductor": partial( + get_torch_compiled_module, backend="iree_boo_inductor" + ), } From dc27b67277621d0cfd56d3be30f7723c7c45c308 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 5 Sep 2025 13:57:15 +0000 Subject: [PATCH 3/3] [boo] permute sample tensors for convolutions This is more representative of the real workload Signed-off-by: Alex Zinenko --- iree/turbine/kernel/boo/op_exports/conv.py | 39 +++++++++++++++++++--- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/iree/turbine/kernel/boo/op_exports/conv.py b/iree/turbine/kernel/boo/op_exports/conv.py index fede7930e..bc438deab 100644 --- a/iree/turbine/kernel/boo/op_exports/conv.py +++ b/iree/turbine/kernel/boo/op_exports/conv.py @@ -13,11 +13,12 @@ import torch -from .utils import Permutation +from .utils import Permutation, permute_layout from ..exports.signature import OpSignature, ModeBase from ..exports.parser import OpCLIParser from ....ops.conv_fwd import conv_2d_nhwc_fhwc, generic_conv from ....ops.insert_slice import insert_slice +from typing import Sequence __all__ = [ "Mode", @@ -350,19 +351,47 @@ def get(shape): return torch.ones(shape, dtype=self.dtype, device=device) * splat_value return torch.randn(shape, generator=gen, dtype=self.dtype, device=device) + def get_permuted( + shape: Sequence[int], order: Sequence[int] | Permutation + ) -> torch.Tensor: + tensor = get(shape) + items = order if not isinstance(order, Permutation) else order.items + if items == sorted(items): + return tensor + return permute_layout(tensor, order) + + # Argument tensors will have to be permuted accordingly to the layouts + # specified in the signature. + input_permutation = Permutation.get("NCHW", self.input_layout) + kernel_permutation = Permutation.get("NCHW", self.kernel_layout) + output_permutation = Permutation.get("NCHW", self.output_layout) + if self.mode == Mode.FORWARD: # (x, w, b) or (x, w) return ( - (get(self.input_shape), get(self.kernel_shape), get(out_channels)) + ( + get_permuted(self.input_shape, input_permutation), + get_permuted(self.kernel_shape, kernel_permutation), + get_permuted(out_channels, output_permutation), + ) if self.bias - else (get(self.input_shape), get(self.kernel_shape)) + else ( + get_permuted(self.input_shape, input_permutation), + get_permuted(self.kernel_shape, kernel_permutation), + ) ) if self.mode == Mode.WEIGHT_BACKWARD: # (dLdy, x) - return (get(self.output_shape), get(self.input_shape)) + return ( + get_permuted(self.output_shape, output_permutation), + get_permuted(self.input_shape, input_permutation), + ) if self.mode == Mode.INPUT_BACKWARD: # (dLdy, w) - return (get(self.output_shape), get(self.kernel_shape)) + return ( + get_permuted(self.output_shape, output_permutation), + get_permuted(self.kernel_shape, kernel_permutation), + ) raise ValueError(f"Unknown mode: {self.mode}") @property