|
| 1 | +# Copyright 2025 The IREE Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +# See https://llvm.org/LICENSE.txt for license information. |
| 5 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +import torch |
| 8 | +from .library import define_schema, register_impl, register_meta |
| 9 | +from ..layer_norm_exports.layer_norm import LayerNormSignature, Mode |
| 10 | +from ..driver.launch import get_launchable |
| 11 | +from ..runtime import LaunchableRuntimeCache |
| 12 | +from .utils import * |
| 13 | +from typing import Sequence |
| 14 | + |
| 15 | +__all__ = [ |
| 16 | + "boo_layer_norm", |
| 17 | +] |
| 18 | + |
| 19 | +# TODO(azinenko): can this be automated, pytorch doc says these can be inferred from type information? |
| 20 | +define_schema( |
| 21 | + "layer_norm", |
| 22 | + "(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float? eps) -> (Tensor, Tensor, Tensor)", |
| 23 | +) |
| 24 | +# "(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-5) -> Tensor") |
| 25 | + |
| 26 | +# TODO(azinenko,zjgarvey): this should eventually be generalized with non-boo registration. |
| 27 | + |
| 28 | + |
| 29 | +@register_impl("layer_norm") |
| 30 | +def _boo_layer_norm_impl( |
| 31 | + input: torch.Tensor, |
| 32 | + normalized_shape: Sequence[int], |
| 33 | + weight: torch.Tensor, |
| 34 | + bias: torch.Tensor, |
| 35 | + eps: float, |
| 36 | +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 37 | + signature = LayerNormSignature.get(input, normalized_shape, weight, bias, eps=eps) |
| 38 | + |
| 39 | + # TODO: support non-contiguous memory formats via permutations |
| 40 | + |
| 41 | + func_name = signature.func_name |
| 42 | + args = tuple( |
| 43 | + filter( |
| 44 | + lambda x: x is not None, |
| 45 | + map(lambda x: x.data if x is not None else None, (input, weight, bias)), |
| 46 | + ) |
| 47 | + ) |
| 48 | + cache_hit = LaunchableRuntimeCache.get(func_name) |
| 49 | + if cache_hit: |
| 50 | + return cache_hit(*args) |
| 51 | + |
| 52 | + layer_norm = get_launchable(signature) |
| 53 | + return layer_norm(*args) |
| 54 | + |
| 55 | + |
| 56 | +@register_meta("layer_norm") |
| 57 | +def _boo_layer_norm_meta( |
| 58 | + input: torch.Tensor, |
| 59 | + normalized_shape: Sequence[int], |
| 60 | + weight: torch.Tensor, |
| 61 | + bias: torch.Tensor, |
| 62 | + eps: float, |
| 63 | +) -> torch.Tensor: |
| 64 | + signature = LayerNormSignature.get(input, normalized_shape, weight, bias, eps=eps) |
| 65 | + |
| 66 | + # TODO: support non-contiguous memory formats via permutations |
| 67 | + |
| 68 | + return ( |
| 69 | + torch.empty_like(input), |
| 70 | + torch.empty( |
| 71 | + signature.aggregate_shape, dtype=signature.dtype, device=input.device |
| 72 | + ), |
| 73 | + torch.empty( |
| 74 | + signature.aggregate_shape, dtype=signature.dtype, device=input.device |
| 75 | + ), |
| 76 | + ) |
| 77 | + |
| 78 | + |
| 79 | +define_schema( |
| 80 | + "layer_norm_backward", |
| 81 | + "(Tensor grad_output, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor weight, Tensor bias, bool[3] mask) -> (Tensor?, Tensor?, Tensor?)", |
| 82 | +) |
| 83 | + |
| 84 | + |
| 85 | +@register_impl("layer_norm_backward") |
| 86 | +def _boo_layer_norm_backward_impl( |
| 87 | + grad_output: torch.Tensor, |
| 88 | + input: torch.Tensor, |
| 89 | + normalized_shape: int | Sequence[int] | torch.Size, |
| 90 | + mean: torch.Tensor, |
| 91 | + rstd: torch.Tensor, |
| 92 | + weight: torch.Tensor, |
| 93 | + bias: torch.Tensor, |
| 94 | + mask: tuple[bool, bool, bool], |
| 95 | +) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: |
| 96 | + |
| 97 | + input_grad: torch.Tensor | None = None |
| 98 | + weight_grad: torch.Tensor | None = None |
| 99 | + bias_grad: torch.Tensor | None = None |
| 100 | + |
| 101 | + # TODO(azinenko): it is unclear to me why convolution decided to implement |
| 102 | + # each derivative computation as a single kernel, but cargo-culting it here. |
| 103 | + |
| 104 | + def data_tuple(*args: torch.Tensor): |
| 105 | + return tuple(a.data for a in args) |
| 106 | + |
| 107 | + if mask[0]: |
| 108 | + signature = LayerNormSignature.get( |
| 109 | + input, normalized_shape, weight, bias, Mode.INPUT_BACKWARD |
| 110 | + ) |
| 111 | + launchable = get_launchable(signature) |
| 112 | + input_grad = launchable(*data_tuple(grad_output, input, weight, mean, rstd)) |
| 113 | + |
| 114 | + if mask[1]: |
| 115 | + signature = LayerNormSignature.get( |
| 116 | + input, normalized_shape, weight, bias, Mode.WEIGHT_BACKWARD |
| 117 | + ) |
| 118 | + launchable = get_launchable(signature) |
| 119 | + weight_grad = launchable(*data_tuple(grad_output, input, mean, rstd)) |
| 120 | + |
| 121 | + if mask[2]: |
| 122 | + signature = LayerNormSignature.get( |
| 123 | + input, normalized_shape, weight, bias, Mode.BIAS_BACKWARD |
| 124 | + ) |
| 125 | + launchable = get_launchable(signature) |
| 126 | + bias_grad = launchable(*data_tuple(grad_output)) |
| 127 | + |
| 128 | + return input_grad, weight_grad, bias_grad |
| 129 | + |
| 130 | + |
| 131 | +@register_meta("layer_norm_backward") |
| 132 | +def _boo_layer_norm_backward_meta( |
| 133 | + grad_output: torch.Tensor, |
| 134 | + input: torch.Tensor, |
| 135 | + normalized_shape: int | Sequence[int] | torch.Size, |
| 136 | + mean: torch.Tensor, |
| 137 | + rstd: torch.Tensor, |
| 138 | + weight: torch.Tensor, |
| 139 | + bias: torch.Tensor, |
| 140 | + mask: tuple[bool, bool, bool], |
| 141 | +) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: |
| 142 | + input_grad: torch.Tensor | None = None |
| 143 | + weight_grad: torch.Tensor | None = None |
| 144 | + bias_grad: torch.Tensor | None = None |
| 145 | + |
| 146 | + if mask[0]: |
| 147 | + input_grad = torch.empty_like(input) |
| 148 | + if mask[1]: |
| 149 | + weight_grad = torch.empty_like(weight) |
| 150 | + if mask[2]: |
| 151 | + bias_grad = torch.empty_like(bias) |
| 152 | + return input_grad, weight_grad, bias_grad |
| 153 | + |
| 154 | + |
| 155 | +def pytorch_layer_norm_backward(ctx, grad_output: torch.Tensor): |
| 156 | + """ATen/PyTorch fallback implementation for backward.""" |
| 157 | + |
| 158 | + input, weight, bias, mean, rstd = ctx.saved_tensors |
| 159 | + mask = tuple(ctx.needs_input_grad[0:3]) |
| 160 | + |
| 161 | + input_grad, weight_grad, bias_grad = torch.ops.aten.native_layer_norm_backward( |
| 162 | + grad_output, input, ctx.normalized_shape, mean, rstd, weight, bias, mask |
| 163 | + ) |
| 164 | + |
| 165 | + return input_grad, None, weight_grad, bias_grad, None |
| 166 | + |
| 167 | + |
| 168 | +class _BooLayerNorm(torch.autograd.Function): |
| 169 | + @staticmethod |
| 170 | + def forward( |
| 171 | + ctx: torch.autograd.function.FunctionCtx, |
| 172 | + input: torch.Tensor, |
| 173 | + normalized_shape: int | Sequence[int] | torch.Size, |
| 174 | + weight: torch.Tensor, |
| 175 | + bias: torch.Tensor, |
| 176 | + eps: float, |
| 177 | + ) -> torch.Tensor: |
| 178 | + result, mean, rstd = torch.ops.boo.layer_norm( |
| 179 | + input, normalized_shape, weight, bias, eps |
| 180 | + ) |
| 181 | + ctx.save_for_backward(input, weight, bias, mean, rstd) |
| 182 | + ctx.normalized_shape = normalized_shape |
| 183 | + return result |
| 184 | + |
| 185 | + @staticmethod |
| 186 | + def backward( |
| 187 | + ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor |
| 188 | + ) -> tuple[ |
| 189 | + torch.Tensor | None, None, torch.Tensor | None, torch.Tensor | None, None |
| 190 | + ]: |
| 191 | + if not is_boo_backward_enabled(): |
| 192 | + return pytorch_layer_norm_backward(ctx, grad_output) |
| 193 | + |
| 194 | + input, weight, bias, mean, rstd = ctx.saved_tensors |
| 195 | + # Note that the context contains grad flags for every forward argument |
| 196 | + # in order, including non-differentiable attributes like |
| 197 | + # `normalized_shape`. The indices below correspond to the positions of |
| 198 | + # input, weight and bias in the forward signature. |
| 199 | + mask = ( |
| 200 | + ctx.needs_input_grad[0], |
| 201 | + ctx.needs_input_grad[2], |
| 202 | + ctx.needs_input_grad[3], |
| 203 | + ) |
| 204 | + input_grad, weight_grad, bias_grad = torch.ops.boo.layer_norm_backward( |
| 205 | + grad_output, input, ctx.normalized_shape, mean, rstd, weight, bias, mask |
| 206 | + ) |
| 207 | + |
| 208 | + return input_grad, None, weight_grad, bias_grad, None |
| 209 | + |
| 210 | + |
| 211 | +def boo_layer_norm( |
| 212 | + input: torch.Tensor, |
| 213 | + normalized_shape: Sequence[int], |
| 214 | + weight: torch.Tensor | None = None, |
| 215 | + bias: torch.Tensor | None = None, |
| 216 | + eps: float = 1e-5, |
| 217 | +) -> torch.Tensor: |
| 218 | + use_autograd = torch._C.is_grad_enabled() and any( |
| 219 | + x is not None and x.requires_grad for x in (input, weight, bias) |
| 220 | + ) |
| 221 | + if use_autograd: |
| 222 | + return _BooLayerNorm.apply(input, normalized_shape, weight, bias, eps) |
| 223 | + result, _, _ = torch.ops.boo.layer_norm(input, normalized_shape, weight, bias, eps) |
| 224 | + return result |
0 commit comments