Skip to content

Fix Qwen-Image long prompt dimension mismatch error (issue #12083) #12087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

robin-ede
Copy link

Fix Qwen-Image long prompt dimension mismatch error

  • Add dynamic expansion capability to QwenEmbedRope pos_freqs buffer
  • Expand buffer when max_vid_index + max_len exceeds current size
  • Prevent RuntimeError when text prompts exceed 1024 tokens with large images
  • Add comprehensive test case for long prompt scenarios
  • Maintain backward compatibility with existing functionality

Fixes: #12083

What does this PR do?

Fixes a critical bug in Qwen-Image where long text prompts (>1024 tokens) with large images cause RuntimeError: The size of tensor a (1024) must match the size of tensor b (983).

Problem

QwenEmbedRope had a fixed 1024-length buffer for positional frequencies. Large images (1024×1024) + long prompts required accessing pos_freqs[1024:2048] from a 1024-element buffer.

Solution

Added dynamic buffer expansion that automatically resizes pos_freqs when needed:

  • Only expands when required (memory efficient)
  • Uses register_buffer() for proper tensor management
  • Maintains backward compatibility and performance

Changes

  • Added _expand_pos_freqs_if_needed() method
  • Modified forward() to check expansion requirements
  • Added test case for long prompt scenarios

Before: pos_freqs[1024:2048] → IndexError
After: Auto-expands buffer → Success

Fixes #12083

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you read our philosophy doc (important for complex PRs)?
  • Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
    • Note: This is a bug fix with internal implementation changes only. No public API changes require documentation updates.
  • Did you write any new necessary tests?
    • Added: test_long_prompt_no_error() in tests/pipelines/qwenimage/test_qwenimage.py

Implementation Details

Architecture Analysis

Our solution follows established patterns in the diffusers codebase:

  1. Buffer Management: Uses register_buffer() like other components (modeling_ctx_clip.py, embeddings.py)
  2. Dynamic Computation: Mirrors pattern in get_1d_rotary_pos_embed() which computes frequencies on-demand
  3. Memory Alignment: Rounds to 512-token boundaries following PyTorch optimization practices

Performance Impact

  • Memory: Minimal - only expands when needed (one-time cost)
  • Speed: Negligible - expansion happens once then cached
  • Quality: Zero impact - identical mathematical operations, just larger buffer

Backward Compatibility

  • API: No changes to public interface
  • Behavior: Existing short prompts work exactly as before
  • Performance: Same performance characteristics for existing use cases

Who can review?

This PR affects:

The fix is focused on the QwenImage transformer implementation and follows established PyTorch patterns for dynamic buffer management.

@sayakpaul
Copy link
Member

@naykun could you also give this a look?

Comment on lines 165 to 182
pos_index = torch.arange(self._current_max_len)
neg_index = torch.arange(self._current_max_len).flip(0) * -1 - 1
self.register_buffer('pos_freqs', torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
))
self.register_buffer('neg_freqs', torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)
))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some changes seem to be overlapping with #12061

@naykun
Copy link
Contributor

naykun commented Aug 7, 2025

Hi @sayakpaul , the solution looks good for addressing the runtime error. However, I'd like to point out that the Qwen image model is not trained on prompts longer than 512 tokens, so extremely long prompts may lead to unpredictable behavior. Perhaps we should add a warning to highlight this limitation.

@sayakpaul
Copy link
Member

Tremendous suggestion! @robin-ede can we incorporate this and modify the test so that we verify we raise the warning?

Here's an example of how we test warnings:

diffusers/tests/lora/utils.py

Lines 1811 to 1820 in 0611631

logger = logging.get_logger("diffusers.utils.peft_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
# Since the missing key won't contain the adapter name ("default_0").
# Also strip out the component prefix (such as "unet." from `missing_key`).
component = list({k.split(".")[0] for k in state_dict})[0]
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))

@robin-ede
Copy link
Author

Tremendous suggestion! @robin-ede can we incorporate this and modify the test so that we verify we raise the warning?

Here's an example of how we test warnings:

diffusers/tests/lora/utils.py

Lines 1811 to 1820 in 0611631

logger = logging.get_logger("diffusers.utils.peft_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
# Since the missing key won't contain the adapter name ("default_0").
# Also strip out the component prefix (such as "unet." from `missing_key`).
component = list({k.split(".")[0] for k in state_dict})[0]
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))

Yea for sure! I'll get this done in a bit.

robin-ede and others added 5 commits August 7, 2025 07:52
…e#12083)

- Add dynamic expansion capability to QwenEmbedRope pos_freqs buffer
- Expand buffer when max_vid_index + max_len exceeds current size
- Prevent RuntimeError when text prompts exceed 1024 tokens with large images
- Add comprehensive test case for long prompt scenarios
- Maintain backward compatibility with existing functionality

Fixes: huggingface#12083
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
- Add warning when prompts exceed 512 tokens (model's training limit)

- Warn users about potential unpredictable behavior with long prompts

- Add comprehensive test with CaptureLogger to verify warning system

- Follow established diffusers warning patterns for consistency
- Move CaptureLogger import to top level following established patterns

- Use logging.WARNING constant instead of hardcoded value

- Simplify device handling to match other QwenImage tests

- Remove redundant variable assignments and comments
- Fix whitespace and string quote consistency

- Add trailing commas where appropriate

- Clean up formatting per diffusers code standards
@robin-ede robin-ede force-pushed the fix/qwen-image-long-prompt-issue-12083 branch from f250165 to 35cb2c8 Compare August 7, 2025 13:25
@robin-ede
Copy link
Author

Should be fixed! @sayakpaul

@robin-ede robin-ede requested a review from sayakpaul August 7, 2025 13:40
- Fix test_long_prompt_warning to properly trigger the 512-token warning
- Replace inefficient wall-of-text approach with elegant hardcoded multiplier
- Use precise token counting to ensure required_len > _current_max_len threshold
- Add runtime assertion for test robustness and maintainability
- Fix max_sequence_length validation error in test_long_prompt_no_error
- Replace character counting with actual token counting for accuracy
- Use multiplier that generates ~521 tokens (well within limits)
- Add runtime assertions to verify token count assumptions
- Ensure test validates the original fix without triggering warnings
- Make test intent clearer with proper token-based thresholds
@sayakpaul
Copy link
Member

@bot /style

Copy link
Contributor

github-actions bot commented Aug 7, 2025

Style bot fixed some files and pushed the changes.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Qwen-Image long prompt will cause error
5 participants