17
17
class Mode (ModeBase , IntEnum ):
18
18
"""Mode selector for layer normalization, with each gradient being its own mode."""
19
19
20
- FORWARD = 0
21
- INPUT_BACKWARD = 1
22
- WEIGHT_BACKWARD = 2
23
- BIAS_BACKWARD = 3
24
- FULL_BACKWARD = 4
20
+ FORWARD = 1
21
+ INPUT_BACKWARD = 2
22
+ WEIGHT_BACKWARD = 4
23
+ BIAS_BACKWARD = 8
24
+ FULL_BACKWARD = INPUT_BACKWARD | WEIGHT_BACKWARD | BIAS_BACKWARD
25
25
26
26
27
27
class LayerNormSignature (OpSignature ):
@@ -46,6 +46,7 @@ def __init__(
46
46
dtype = torch .bfloat16 ,
47
47
mode : str | Mode = Mode .FORWARD ,
48
48
forwarded_args_dtype : torch .dtype | None = None ,
49
+ use_aten : bool = True ,
49
50
):
50
51
if (
51
52
len (normalized_shape ) > len (input_shape )
@@ -63,6 +64,7 @@ def __init__(
63
64
self .dtype = dtype
64
65
self .mode = Mode .parse (mode )
65
66
self .forwarded_args_dtype = forwarded_args_dtype or dtype
67
+ self .use_aten = use_aten
66
68
67
69
@property
68
70
def output_shape (self ) -> list [int ]:
@@ -124,6 +126,7 @@ def func_name(self) -> str:
124
126
"x" .join (str (i ) for i in self .input_shape ),
125
127
"w" if self .elementwise_affine is not None else "" ,
126
128
"b" if self .bias is not None else "" ,
129
+ "aten" if self .use_aten else "" ,
127
130
]
128
131
return "_" .join (name_items )
129
132
@@ -183,6 +186,7 @@ def as_init_kwargs(self) -> dict[str, Any]:
183
186
"dtype" : self .dtype ,
184
187
"mode" : self .Mode ,
185
188
"forwarded_args_dtype" : self .forwarded_args_dtype ,
189
+ "use_aten" : self .use_aten ,
186
190
}
187
191
188
192
def get_output_size (self ) -> int :
@@ -395,9 +399,9 @@ class LayerNormBackwardFull(torch.nn.Module):
395
399
weights, and bias of the layer normalization given the gradient of its
396
400
output."""
397
401
398
- def __init__ (self , signature : LayerNormSignature , * , use_aten = True ):
402
+ def __init__ (self , signature : LayerNormSignature ):
399
403
super ().__init__ ()
400
- self .use_aten = use_aten
404
+ self .use_aten = signature . use_aten
401
405
self .normalized_shape = signature .normalized_shape
402
406
self .need_bias = signature .bias
403
407
self .need_weight = signature .elementwise_affine
@@ -438,12 +442,14 @@ def forward(
438
442
# Recompute norm instead of saving it. Judging by the signature, this is the same
439
443
# decision as ATen.
440
444
norm = (input - mean ) * rstd
445
+ # norm = norm.to(dtype=input.dtype)
441
446
dnorm = grad_output * weight if weight is not None else grad_output
442
447
dx = (
443
448
dnorm
444
449
- dnorm .mean (dim = self .normalized_dim , keepdim = True )
445
450
- norm * (dnorm * norm ).mean (dim = self .normalized_dim , keepdim = True )
446
451
) * rstd
452
+ # dx = dx.to(dtype=input.dtype)
447
453
dw = None
448
454
if self .need_weight :
449
455
dw = (grad_output * norm ).sum (self .keep_dim )
@@ -489,19 +495,10 @@ def get_signature(args: argparse.Namespace) -> LayerNormSignature:
489
495
), "Can only normalize one trailing dimension for now (MIOpen limitation)."
490
496
normalized_shape = shape [args .normalized_dim :]
491
497
492
- match args .forw :
493
- case 1 :
494
- mode = Mode .FORWARD
495
- case 2 :
496
- mode = Mode .INPUT_BACKWARD
497
- case 3 :
498
- mode = Mode .WEIGHT_BACKWARD
499
- case 4 :
500
- mode = Mode .BIAS_BACKWARD
501
- case 5 :
502
- mode = Mode .FULL_BACKWARD
503
- case _:
504
- raise ValueError (f"Unsupported mode { args .forw } ." )
498
+ try :
499
+ mode = Mode (args .forw )
500
+ except Exception as e :
501
+ raise ValueError (f"Unsupported mode { args .forw } ." ) from e
505
502
506
503
return LayerNormSignature (
507
504
input_shape = shape ,
@@ -511,6 +508,7 @@ def get_signature(args: argparse.Namespace) -> LayerNormSignature:
511
508
bias = True ,
512
509
dtype = _DTypeCommandDispatcher .get_dtype (args .command ),
513
510
mode = mode ,
511
+ use_aten = args .use_aten ,
514
512
)
515
513
516
514
def get_miopen_parser () -> argparse .ArgumentParser :
@@ -519,7 +517,11 @@ def get_miopen_parser() -> argparse.ArgumentParser:
519
517
"command" , default = "layernorm" , choices = _DTypeCommandDispatcher .choices ()
520
518
)
521
519
parser .add_argument (
522
- "--forw" , "-F" , type = int , default = 1 , help = "Run only forward LayerNorm"
520
+ "--forw" ,
521
+ "-F" ,
522
+ type = int ,
523
+ default = 1 ,
524
+ help = "Kind of kernel to run, not compatible with MIOpen (1 forward, 2 backward input, 4 backward weight, 8 backward bias, 14 full backward)" ,
523
525
)
524
526
parser .add_argument (
525
527
"--input" ,
@@ -539,6 +541,12 @@ def get_miopen_parser() -> argparse.ArgumentParser:
539
541
parser .add_argument (
540
542
"--normalized_dim" , "-o" , type = int , default = 3 , help = "Normalized dim"
541
543
)
544
+ parser .add_argument (
545
+ "--use-aten" ,
546
+ type = bool ,
547
+ default = True ,
548
+ help = "Use core ATen op instead of a manual implementation" ,
549
+ )
542
550
return parser
543
551
544
552
@classmethod
0 commit comments