@@ -44,6 +44,7 @@ def __init__(
44
44
bias : bool = True ,
45
45
dtype = torch .bfloat16 ,
46
46
mode : str | Mode = Mode .FORWARD ,
47
+ forwarded_args_dtype : torch .dtype | None = None ,
47
48
):
48
49
if (
49
50
len (normalized_shape ) > len (input_shape )
@@ -60,6 +61,7 @@ def __init__(
60
61
self .bias = bias
61
62
self .dtype = dtype
62
63
self .mode = Mode .parse (mode )
64
+ self .forwarded_args_dtype = forwarded_args_dtype or dtype
63
65
64
66
@property
65
67
def output_shape (self ) -> list [int ]:
@@ -116,6 +118,7 @@ def func_name(self) -> str:
116
118
"layer_norm" ,
117
119
f"{ len (self .normalized_shape )} d" ,
118
120
str (self .dtype ).removeprefix ("torch." ),
121
+ str (self .forwarded_args_dtype ).removeprefix ("torch." ),
119
122
self .mode .name .lower (),
120
123
"x" .join (str (i ) for i in self .input_shape ),
121
124
"w" if self .elementwise_affine is not None else "" ,
@@ -175,6 +178,7 @@ def as_init_kwargs(self) -> dict[str, Any]:
175
178
"bias" : self .bias ,
176
179
"dtype" : self .dtype ,
177
180
"mode" : self .Mode ,
181
+ "forwarded_args_dtype" : self .forwarded_args_dtype ,
178
182
}
179
183
180
184
def get_output_size (self ) -> int :
@@ -229,16 +233,16 @@ def get(shape: Sequence[int]) -> torch.Tensor:
229
233
get (self .output_shape ),
230
234
get (self .input_shape ),
231
235
get (self .normalized_shape ),
232
- get (self .aggregate_shape ),
233
- get (self .aggregate_shape ),
236
+ get (self .aggregate_shape ). to ( dtype = self . forwarded_args_dtype ) ,
237
+ get (self .aggregate_shape ). to ( dtype = self . forwarded_args_dtype ) ,
234
238
)
235
239
if self .mode == Mode .WEIGHT_BACKWARD :
236
240
# (dLdy, input, mean, rstd)
237
241
return (
238
242
get (self .output_shape ),
239
243
get (self .input_shape ),
240
- get (self .aggregate_shape ),
241
- get (self .aggregate_shape ),
244
+ get (self .aggregate_shape ). to ( dtype = self . forwarded_args_dtype ) ,
245
+ get (self .aggregate_shape ). to ( dtype = self . forwarded_args_dtype ) ,
242
246
)
243
247
if self .mode == Mode .BIAS_BACKWARD :
244
248
# (dLdy,)
@@ -253,6 +257,7 @@ def __init__(self, signature: LayerNormSignature):
253
257
super ().__init__ ()
254
258
self .normalized_shape = signature .normalized_shape
255
259
self .eps = signature .eps
260
+ self .forwarded_args_dtype = signature .forwarded_args_dtype
256
261
257
262
def forward (
258
263
self ,
@@ -272,9 +277,14 @@ def forward(
272
277
# torch.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
273
278
#
274
279
# wrapper hides. We want those too so we can save them for backward.
275
- return torch .ops .aten .native_layer_norm (
280
+ output , mean , rstd = torch .ops .aten .native_layer_norm (
276
281
input , self .normalized_shape , weight , bias , self .eps
277
282
)
283
+ return (
284
+ output ,
285
+ mean .to (dtype = self .forwarded_args_dtype ),
286
+ rstd .to (dtype = self .forwarded_args_dtype ),
287
+ )
278
288
279
289
280
290
class LayerNormBackwardInput (torch .nn .Module ):
@@ -285,6 +295,7 @@ def __init__(self, signature: LayerNormSignature):
285
295
super ().__init__ ()
286
296
self .normalized_shape = signature .normalized_shape
287
297
self .eps = signature .eps
298
+ self .dtype = signature .dtype
288
299
289
300
self .normalized_dim = list (
290
301
range (len (signature .input_shape ))[- len (self .normalized_shape ) :]
@@ -303,6 +314,8 @@ def forward(
303
314
) -> torch .Tensor :
304
315
# Recompute norm instead of saving it. Judging by the signature, this is the same
305
316
# decision as ATen.
317
+ mean = mean .to (dtype = self .dtype )
318
+ rstd = rstd .to (dtype = self .dtype )
306
319
norm = (input - mean ) * rstd
307
320
dnorm = grad_output * weight if weight is not None else grad_output
308
321
dx = (
@@ -321,6 +334,8 @@ def __init__(self, signature: LayerNormSignature):
321
334
super ().__init__ ()
322
335
self .normalized_shape = signature .normalized_shape
323
336
self .eps = signature .eps
337
+ self .dtype = signature .dtype
338
+
324
339
self .normalized_dim = list (
325
340
range (len (signature .input_shape ))[- len (self .normalized_shape ) :]
326
341
)
@@ -337,6 +352,8 @@ def forward(
337
352
):
338
353
# Recompute norm instead of saving it. Judging by the signature, this is the same
339
354
# decision as ATen.
355
+ mean = mean .to (dtype = self .dtype )
356
+ rstd = rstd .to (dtype = self .dtype )
340
357
norm = (input - mean ) * rstd
341
358
return (grad_output * norm ).sum (self .keep_dim )
342
359
0 commit comments