@@ -21,6 +21,7 @@ class Mode(ModeBase, IntEnum):
21
21
INPUT_BACKWARD = 1
22
22
WEIGHT_BACKWARD = 2
23
23
BIAS_BACKWARD = 3
24
+ FULL_BACKWARD = 4
24
25
25
26
26
27
class LayerNormSignature (OpSignature ):
@@ -160,13 +161,16 @@ def arrange_backward_launch_args(
160
161
input = forward_args [0 ]
161
162
# TODO: is this possible at this level?
162
163
weight = forward_args [1 ] if len (forward_args ) > 1 else None
164
+ bias = forward_args [2 ] if len (forward_args ) > 2 else None
163
165
_ , mean , rstd = forward_results
164
166
if self .mode == Mode .INPUT_BACKWARD :
165
167
return (input , weight , mean , rstd )
166
168
if self .mode == Mode .WEIGHT_BACKWARD :
167
169
return (input , mean , rstd )
168
170
if self .mode == Mode .BIAS_BACKWARD :
169
171
return ()
172
+ if self .mode == Mode .FULL_BACKWARD :
173
+ return (input , weight , bias , mean , rstd )
170
174
assert False , "Unsupported mode."
171
175
172
176
def as_init_kwargs (self ) -> dict [str , Any ]:
@@ -202,6 +206,9 @@ def get_nn_module(self, use_custom: bool) -> torch.nn.Module:
202
206
return LayerNormBackwardWeight (self )
203
207
if self .mode == Mode .BIAS_BACKWARD :
204
208
return LayerNormBackwardBias (self )
209
+ if self .mode == Mode .FULL_BACKWARD :
210
+ return LayerNormBackwardFull (self )
211
+ assert False , f"Unknown mode: { self .mode .name } ."
205
212
206
213
def get_sample_args (
207
214
self ,
@@ -247,6 +254,15 @@ def get(shape: Sequence[int]) -> torch.Tensor:
247
254
if self .mode == Mode .BIAS_BACKWARD :
248
255
# (dLdy,)
249
256
return (get (self .output_shape ),)
257
+ if self .mode == Mode .FULL_BACKWARD :
258
+ return (
259
+ get (self .output_shape ),
260
+ get (self .input_shape ),
261
+ get (self .normalized_shape ) if self .elementwise_affine else None ,
262
+ get (self .normalized_shape ) if self .bias else None ,
263
+ get (self .aggregate_shape ).to (dtype = self .forwarded_args_dtype ),
264
+ get (self .aggregate_shape ).to (dtype = self .forwarded_args_dtype ),
265
+ )
250
266
raise ValueError (f"Unknown mode: { self .mode } " )
251
267
252
268
@@ -374,6 +390,69 @@ def forward(self, grad_output: torch.Tensor) -> torch.Tensor:
374
390
return grad_output .sum (dim = self .keep_dim )
375
391
376
392
393
+ class LayerNormBackwardFull (torch .nn .Module ):
394
+ """Module computing, as its forward computation, the gradients of the input,
395
+ weights, and bias of the layer normalization given the gradient of its
396
+ output."""
397
+
398
+ def __init__ (self , signature : LayerNormSignature , * , use_aten = True ):
399
+ super ().__init__ ()
400
+ self .use_aten = use_aten
401
+ self .normalized_shape = signature .normalized_shape
402
+ self .need_bias = signature .bias
403
+ self .need_weight = signature .elementwise_affine
404
+ self .normalized_dim = list (
405
+ range (len (signature .input_shape ))[- len (self .normalized_shape ) :]
406
+ )
407
+ self .keep_dim = list (
408
+ range (len (signature .input_shape ))[: - len (signature .normalized_shape )]
409
+ )
410
+
411
+ def forward (
412
+ self ,
413
+ grad_output : torch .Tensor ,
414
+ input : torch .Tensor ,
415
+ weight : torch .Tensor | None ,
416
+ bias : torch .Tensor | None ,
417
+ mean : torch .Tensor ,
418
+ rstd : torch .Tensor ,
419
+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
420
+ assert self .need_weight != (
421
+ weight is None
422
+ ), "Weight must be provided if its gradient is requested."
423
+ assert self .need_bias != (
424
+ bias is None
425
+ ), "Bias must be provided if its gradient is requested."
426
+ if self .use_aten :
427
+ return torch .ops .aten .native_layer_norm_backward (
428
+ grad_output ,
429
+ input ,
430
+ self .normalized_shape ,
431
+ mean ,
432
+ rstd ,
433
+ weight ,
434
+ bias ,
435
+ (True , self .need_weight , self .need_bias ),
436
+ )
437
+
438
+ # Recompute norm instead of saving it. Judging by the signature, this is the same
439
+ # decision as ATen.
440
+ norm = (input - mean ) * rstd
441
+ dnorm = grad_output * weight if weight is not None else grad_output
442
+ dx = (
443
+ dnorm
444
+ - dnorm .mean (dim = self .normalized_dim , keepdim = True )
445
+ - norm * (dnorm * norm ).mean (dim = self .normalized_dim , keepdim = True )
446
+ ) * rstd
447
+ dw = None
448
+ if self .need_weight :
449
+ dw = (grad_output * norm ).sum (self .keep_dim )
450
+ db = None
451
+ if self .need_bias :
452
+ db = grad_output .sum (dim = self .keep_dim )
453
+ return dx , dw , db
454
+
455
+
377
456
def _parse_shape (shape : str ) -> list [int ]:
378
457
for symbol in shape :
379
458
assert symbol in "0123456789x" , "Unsupported shape syntax."
@@ -419,6 +498,8 @@ def get_signature(args: argparse.Namespace) -> LayerNormSignature:
419
498
mode = Mode .WEIGHT_BACKWARD
420
499
case 4 :
421
500
mode = Mode .BIAS_BACKWARD
501
+ case 5 :
502
+ mode = Mode .FULL_BACKWARD
422
503
case _:
423
504
raise ValueError (f"Unsupported mode { args .forw } ." )
424
505
0 commit comments