Skip to content

[refactor] Flux/Chroma single file implementation + Attention Dispatcher #11916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 17, 2025

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jul 14, 2025

This PR is a continuation of #11368.

Using set_attention_backend:

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, attention_backend

model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda")
transformer.set_attention_backend("sage_varlen")
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.text_encoder.to("cuda")
pipe.text_encoder_2.to("cuda")
pipe.vae.to("cuda")

prompt = "A cat holding a sign that says 'hello world'"

image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
image.save("output.png")

Using context manager attention_backend:

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, attention_backend

model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda")
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.text_encoder.to("cuda")
pipe.text_encoder_2.to("cuda")
pipe.vae.to("cuda")

prompt = "A cat holding a sign that says 'hello world'"

with attention_backend("_native_cudnn"):
    image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
image.save("output.png")

@a-r-r-o-w a-r-r-o-w requested a review from DN6 July 14, 2025 02:57
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
config.spatial_attention_block_skip_range = 2

for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking as a TODO for myself. We will no longer need _ATTENTION_CLASSES once everything is AttentionModuleMixin

@a-r-r-o-w a-r-r-o-w changed the title [refactor] Flux single file implementation [refactor] Flux single file implementation + Attention Dispatcher Jul 15, 2025
@a-r-r-o-w a-r-r-o-w mentioned this pull request Jul 15, 2025
Comment on lines +212 to +224
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)

hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yiyixuxu as well for apply_rotary_emb related changes, and because of the attention processor rewrite and minification

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarking the changes with unflatten/flatten/split_with_sizes, I see a consistent 2-4% speedup compared to main on A100

Comment on lines +90 to +99
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)

query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))

query = attn.norm_q(query)
key = attn.norm_k(key)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The permutation from BSHD to BHSD for torch native SDPA has been moved to the dispatcher. This is done because most attention backend implementations available expect a BSHD tensor, so if we do the permute/transpose here, we will have to undo it in the dispatcher (for say, flash attention). This slows down the inference. Instead, if we move the permute to torch-specific dispatch methods, we keep the original SDPA performance, while also speeding up other backends significantly

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarking only transformer with this change yields an average speedup of ~3% (over 5 runs) on non-pytorch attention backends (only tested xformers and flash-attn-2; but in theory it should speedup all backends that expect BSHD format) in comparison to not having this change (doing a permute and then undoing it within the backend [the older behaviour in previous dispatcher PR])

@@ -42,6 +42,309 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 I've moved these functions here instead of the attention processor. This is so we can re-use them in both IPA and normal processors. LMK if you have something else in mind

@a-r-r-o-w a-r-r-o-w changed the title [refactor] Flux single file implementation + Attention Dispatcher [refactor] Flux/Chroma single file implementation + Attention Dispatcher Jul 15, 2025
@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jul 15, 2025

I think we should have FeedForward and adaptive-layernorm specialized per model too (more true to the single-file implementation). LMK what you think and I'll update accordingly.

For Chroma (which is mostly Flux), we should also do a full single file implementation too instead of importing from Flux. We can add Copied from's

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍🏽 some very small things to take care of and we can merge.

@a-r-r-o-w a-r-r-o-w added the roadmap Add to current release roadmap label Jul 16, 2025
@a-r-r-o-w
Copy link
Member Author

Failing test is unrelated (WanVACE tests are very flaky for some reason)

@a-r-r-o-w a-r-r-o-w merged commit 18c8f10 into main Jul 17, 2025
58 of 100 checks passed
@a-r-r-o-w a-r-r-o-w deleted the to-single-file/flux branch July 17, 2025 12:00
@github-project-automation github-project-automation bot moved this from In Progress to Done in Diffusers Roadmap 0.35 Jul 17, 2025
tolgacangoz pushed a commit to tolgacangoz/diffusers that referenced this pull request Jul 17, 2025
…her (huggingface#11916)

* update

* update

* add coauthor

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>

* improve test

* handle ip adapter params correctly

* fix chroma qkv fusion test

* fix fastercache implementation

* fix more tests

* fight more tests

* add back set_attention_backend

* update

* update

* make style

* make fix-copies

* make ip adapter processor compatible with attention dispatcher

* refactor chroma as well

* remove rmsnorm assert

* minify and deprecate npu/xla processors

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
tolgacangoz pushed a commit to tolgacangoz/diffusers that referenced this pull request Jul 18, 2025
…her (huggingface#11916)

* update

* update

* add coauthor

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>

* improve test

* handle ip adapter params correctly

* fix chroma qkv fusion test

* fix fastercache implementation

* fix more tests

* fight more tests

* add back set_attention_backend

* update

* update

* make style

* make fix-copies

* make ip adapter processor compatible with attention dispatcher

* refactor chroma as well

* remove rmsnorm assert

* minify and deprecate npu/xla processors

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
@okaris
Copy link
Contributor

okaris commented Jul 18, 2025

@a-r-r-o-w @DN6 @tolgacangoz
I got a bit excited about this PR and wanted to give it a go. I love the syntax, both the setter function and the context, great work!

I wanted to also see if it would still compile but got the following error logs:

[t+28s648ms]   0%|          | 0/30 [00:00<?, ?it/s]/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:1601: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
[t+28s650ms]   torch._dynamo.utils.warn_once(msg)
[t+28s666ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0] Graph break from `Tensor.item()`, consider setting:
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0] or:
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0] to include these operations in the captured graph.
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0] Graph break: from user code at:
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 733, in forward
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     encoder_hidden_states, hidden_states = block(
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 456, in forward
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     attention_outputs = self.attn(
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 343, in forward
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 117, in __call__
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     hidden_states = dispatch_attention_fn(
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/attention_dispatch.py", line 241, in dispatch_attention_fn
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     return backend_fn(**kwargs)
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/attention_dispatch.py", line 962, in _sage_varlen_attention
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     _prepare_for_flash_attn_or_sage_varlen(
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/attention_dispatch.py", line 351, in _prepare_for_flash_attn_or_sage_varlen
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/_dynamo/polyfills/__init__.py", line 253, in getattr_and_trace
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     return fn(*args[2:], **kwargs)
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/models/attention_dispatch.py", line 321, in _prepare_for_flash_attn_or_sage_varlen_without_mask
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]     max_seqlen_q = seqlens_q.max().item()
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]
[t+28s667ms] W0718 20:32:55.420000 1 /inferencesh/cache/gpu/uv/archive-v0/bZ7CJbyNBkaHsGR2jn6Vv/torch/_dynamo/variables/tensor.py:1048] [0/0]
[t+28s669ms]   0%|          | 0/30 [00:00<?, ?it/s]
[t+28s670ms] [ERROR] Traceback (most recent call last):
[t+28s671ms]   File "/server/tasks.py", line 50, in run_task
[t+28s671ms]     output = await result
[t+28s671ms]              ^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/src/inference.py", line 248, in run
[t+28s671ms]     result = self.pipeline(
[t+28s671ms]              ^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[t+28s671ms]     return func(*args, **kwargs)
[t+28s671ms]            ^^^^^^^^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_kontext.py", line 1063, in __call__
[t+28s671ms]     noise_pred = self.transformer(
[t+28s671ms]                  ^^^^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 411, in __call__
[t+28s671ms]     return super().__call__(*args, **kwargs)
[t+28s671ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[t+28s671ms]     return self._call_impl(*args, **kwargs)
[t+28s671ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+28s671ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[t+28s671ms]     return forward_call(*args, **kwargs)
[t+28s671ms]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+28s672ms]   File "/inferencesh/apps/gpu/31s80r2z5hw6jskcs081xh8wdb/venv/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 812, in compile_wrapper
[t+28s672ms]     raise e.with_traceback(None) from e.__cause__  # User compiler error
[t+28s672ms]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[t+28s672ms] torch._dynamo.exc.Unsupported: Unsupported Tensor.item() call with capture_scalar_outputs=False
[t+28s672ms]   Explanation: Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False.

compiling with both

self.pipeline.transformer.compile_repeated_blocks(fullgraph=True)

and

self.pipeline.transformer.to(memory_format=torch.channels_last)
self.pipeline.transformer = torch.compile(
  self.pipeline.transformer, mode="max-autotune", fullgraph=True
)

yields the same result

after

self.pipeline.transformer.set_attention_backend("sage_varlen")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Development

Successfully merging this pull request may close these issues.

4 participants