-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[refactor] Flux/Chroma single file implementation + Attention Dispatcher #11916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>
@@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt | |||
config.spatial_attention_block_skip_range = 2 | |||
|
|||
for name, submodule in module.named_modules(): | |||
if not isinstance(submodule, _ATTENTION_CLASSES): | |||
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Marking as a TODO for myself. We will no longer need _ATTENTION_CLASSES
once everything is AttentionModuleMixin
if image_rotary_emb is not None: | ||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) | ||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) | ||
|
||
hidden_states = dispatch_attention_fn( | ||
query, | ||
key, | ||
value, | ||
attn_mask=attention_mask, | ||
dropout_p=0.0, | ||
is_causal=False, | ||
backend=self._attention_backend, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @yiyixuxu as well for apply_rotary_emb related changes, and because of the attention processor rewrite and minification
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmarking the changes with unflatten/flatten/split_with_sizes, I see a consistent 2-4% speedup compared to main on A100
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( | ||
attn, hidden_states, encoder_hidden_states | ||
) | ||
|
||
query = query.unflatten(-1, (attn.heads, -1)) | ||
key = key.unflatten(-1, (attn.heads, -1)) | ||
value = value.unflatten(-1, (attn.heads, -1)) | ||
|
||
query = attn.norm_q(query) | ||
key = attn.norm_k(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The permutation from BSHD
to BHSD
for torch native SDPA has been moved to the dispatcher. This is done because most attention backend implementations available expect a BSHD
tensor, so if we do the permute/transpose here, we will have to undo it in the dispatcher (for say, flash attention). This slows down the inference. Instead, if we move the permute to torch-specific dispatch methods, we keep the original SDPA performance, while also speeding up other backends significantly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmarking only transformer with this change yields an average speedup of ~3% (over 5 runs) on non-pytorch attention backends (only tested xformers and flash-attn-2; but in theory it should speedup all backends that expect BSHD format) in comparison to not having this change (doing a permute and then undoing it within the backend [the older behaviour in previous dispatcher PR])
@@ -42,6 +42,309 @@ | |||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |||
|
|||
|
|||
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 I've moved these functions here instead of the attention processor. This is so we can re-use them in both IPA and normal processors. LMK if you have something else in mind
I think we should have FeedForward and adaptive-layernorm specialized per model too (more true to the single-file implementation). LMK what you think and I'll update accordingly. For Chroma (which is mostly Flux), we should also do a full single file implementation too instead of importing from Flux. We can add Copied from's |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 👍🏽 some very small things to take care of and we can merge.
Failing test is unrelated (WanVACE tests are very flaky for some reason) |
…her (huggingface#11916) * update * update * add coauthor Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com> * improve test * handle ip adapter params correctly * fix chroma qkv fusion test * fix fastercache implementation * fix more tests * fight more tests * add back set_attention_backend * update * update * make style * make fix-copies * make ip adapter processor compatible with attention dispatcher * refactor chroma as well * remove rmsnorm assert * minify and deprecate npu/xla processors --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
…her (huggingface#11916) * update * update * add coauthor Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com> * improve test * handle ip adapter params correctly * fix chroma qkv fusion test * fix fastercache implementation * fix more tests * fight more tests * add back set_attention_backend * update * update * make style * make fix-copies * make ip adapter processor compatible with attention dispatcher * refactor chroma as well * remove rmsnorm assert * minify and deprecate npu/xla processors --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
@a-r-r-o-w @DN6 @tolgacangoz I wanted to also see if it would still compile but got the following error logs:
compiling with both
and
yields the same result after
|
This PR is a continuation of #11368.
Using
set_attention_backend
:Using context manager
attention_backend
: