-
Notifications
You must be signed in to change notification settings - Fork 67
[boo] add a combined backward layer norm #1108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
case 5: | ||
mode = Mode.FULL_BACKWARD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do these match the MIOpen interface? for conv at least, forw
is actually a bitset (1 << 0
is forward, 1 << 1
is input backward, 1 << 2
is weight backward, so all backwards would be 0b110 == 6
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The interface simply ignores the -F
flag right now and cannot compute backward at all. I can update all of these to be bit flags pending that if desired.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can leave it as is for now/revisit later, I don't have a strong opinion. But since aten.native_group_norm_backward
uses a mask to specify the outputs to compute, I suspect that reflecting that in the signature here could simplify some code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update to bit-style flag.
bias is None | ||
), "Bias must be provided if its gradient is requested." | ||
if self.use_aten: | ||
return torch.ops.aten.native_layer_norm_backward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the torch.compile
integration I assume torch.ops.aten.native_layer_norm_backward
is what we'll actually see. Is there any reason to not use it for all cases? I'm a bit concerned that anything else won't accurately reflect the actual model performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_aten
defaults to True
, so we will use this. The override is needed because apparently integrations may have a flag to use layernorm as rmsnorm, which does less computation and needs to be reflected in the impl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_aten
defaults toTrue
, so we will use this
right, can the aten op be used for the other backwards variations too?
The override is needed because apparently integrations may have a flag to use layernorm as rmsnorm, which does less computation and needs to be reflected in the impl.
I don't really follow how this would work, how do you see the use_aten
option being used? Is it intended to be plumbed through to the driver at some point? And if so, it seems to me like it should be its own command, rather than an option to layernorm.
As for the implementation here, what would we see when going through torch.compile
? I assume it'll be some combination of aten ops, could we use that as the implementation instead?
(let me know if you'd like to have a quick chat about this, there may be some context I'm missing)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, can the aten op be used for the other backwards variations too?
good point, yes we could use that too
I don't really follow how this would work, how do you see the use_aten option being used? Is it intended to be plumbed through to the driver at some point?
So far I was just changing the default and rerunning the benchmark script to generate a second CSV table and then align columns to get aten and non-aten results side-by-side, but we can expose this through the driver indeed.
As for the implementation here, what would we see when going through torch.compile? I assume it'll be some combination of aten ops, could we use that as the implementation instead?
This is part of core aten IR according to https://docs.pytorch.org/docs/stable/torch.compiler_ir.html, so we should see it as is. Somewhere in the pipeline torch-mlir would expand it to the ops similar from the manual implementation below.
(let me know if you'd like to have a quick chat about this, there may be some context I'm missing)
Time zones are hard... The context is just me pushing out the bits I needed to get layernorm numbers for the sizes of reference and discovering fun edge cases. Your comments make sense to me though, I just haven't thought about all the things :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm okay I would strongly prefer that we avoid adding the manual implementation, but if it's necessary to reproduce some of the numbers we need then that's fine (though I think it should be plumbed through to the driver in that case).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It currently is needed for numbers. We can later investigate the difference in IRs between this and torch-mlir lowering of the aten native. I plumbed the flag for what is added here and made it systematic in #1115.
b3065b6
to
19a0dcb
Compare
0b5b671
to
6993e80
Compare
Add an op export for combined computaiton of all gradients in layer norm. This may be more efficient than executing them one by one in some cases and requires separate testing. Signed-off-by: Alex Zinenko <git@ozinenko.com>
- plumb `use_aten` through the driver - use bitmask-style values for the mode enum Signed-off-by: Alex Zinenko <git@ozinenko.com>
6993e80
to
8581686
Compare
Add an op export for combined computaiton of all gradients in layer norm. This may be more efficient than executing them one by one in some cases and requires separate testing.