Skip to content

Conversation

ftynse
Copy link
Contributor

@ftynse ftynse commented Jul 31, 2025

Add an op export for combined computaiton of all gradients in layer norm. This may be more efficient than executing them one by one in some cases and requires separate testing.

Comment on lines 501 to 502
case 5:
mode = Mode.FULL_BACKWARD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do these match the MIOpen interface? for conv at least, forw is actually a bitset (1 << 0 is forward, 1 << 1 is input backward, 1 << 2 is weight backward, so all backwards would be 0b110 == 6)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interface simply ignores the -F flag right now and cannot compute backward at all. I can update all of these to be bit flags pending that if desired.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can leave it as is for now/revisit later, I don't have a strong opinion. But since aten.native_group_norm_backward uses a mask to specify the outputs to compute, I suspect that reflecting that in the signature here could simplify some code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update to bit-style flag.

bias is None
), "Bias must be provided if its gradient is requested."
if self.use_aten:
return torch.ops.aten.native_layer_norm_backward(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the torch.compile integration I assume torch.ops.aten.native_layer_norm_backward is what we'll actually see. Is there any reason to not use it for all cases? I'm a bit concerned that anything else won't accurately reflect the actual model performance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_aten defaults to True, so we will use this. The override is needed because apparently integrations may have a flag to use layernorm as rmsnorm, which does less computation and needs to be reflected in the impl.

Copy link
Member

@rkayaith rkayaith Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_aten defaults to True, so we will use this

right, can the aten op be used for the other backwards variations too?

The override is needed because apparently integrations may have a flag to use layernorm as rmsnorm, which does less computation and needs to be reflected in the impl.

I don't really follow how this would work, how do you see the use_aten option being used? Is it intended to be plumbed through to the driver at some point? And if so, it seems to me like it should be its own command, rather than an option to layernorm.
As for the implementation here, what would we see when going through torch.compile? I assume it'll be some combination of aten ops, could we use that as the implementation instead?

(let me know if you'd like to have a quick chat about this, there may be some context I'm missing)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, can the aten op be used for the other backwards variations too?

good point, yes we could use that too

I don't really follow how this would work, how do you see the use_aten option being used? Is it intended to be plumbed through to the driver at some point?

So far I was just changing the default and rerunning the benchmark script to generate a second CSV table and then align columns to get aten and non-aten results side-by-side, but we can expose this through the driver indeed.

As for the implementation here, what would we see when going through torch.compile? I assume it'll be some combination of aten ops, could we use that as the implementation instead?

This is part of core aten IR according to https://docs.pytorch.org/docs/stable/torch.compiler_ir.html, so we should see it as is. Somewhere in the pipeline torch-mlir would expand it to the ops similar from the manual implementation below.

(let me know if you'd like to have a quick chat about this, there may be some context I'm missing)

Time zones are hard... The context is just me pushing out the bits I needed to get layernorm numbers for the sizes of reference and discovering fun edge cases. Your comments make sense to me though, I just haven't thought about all the things :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm okay I would strongly prefer that we avoid adding the manual implementation, but if it's necessary to reproduce some of the numbers we need then that's fine (though I think it should be plumbed through to the driver in that case).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It currently is needed for numbers. We can later investigate the difference in IRs between this and torch-mlir lowering of the aten native. I plumbed the flag for what is added here and made it systematic in #1115.

@ftynse ftynse force-pushed the users/ftynse/explicit-dtype branch from b3065b6 to 19a0dcb Compare August 6, 2025 13:26
@ftynse ftynse force-pushed the users/ftynse/combined-layernorm-backward branch from 0b5b671 to 6993e80 Compare August 6, 2025 21:40
Base automatically changed from users/ftynse/explicit-dtype to main September 2, 2025 12:33
Add an op export for combined computaiton of all gradients in layer
norm. This may be more efficient than executing them one by one in some
cases and requires separate testing.

Signed-off-by: Alex Zinenko <git@ozinenko.com>
- plumb `use_aten` through the driver
- use bitmask-style values for the mode enum

Signed-off-by: Alex Zinenko <git@ozinenko.com>
@ftynse ftynse force-pushed the users/ftynse/combined-layernorm-backward branch from 6993e80 to 8581686 Compare September 2, 2025 12:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants