Skip to content

Fix Qwen-Image long prompt dimension mismatch error (issue #12083) #12087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
92 changes: 76 additions & 16 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,31 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(1024)
neg_index = torch.arange(1024).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
# Initialize with default size 1024, but allow dynamic expansion
self._current_max_len = 1024
pos_index = torch.arange(self._current_max_len)
neg_index = torch.arange(self._current_max_len).flip(0) * -1 - 1
self.register_buffer(
"pos_freqs",
torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
),
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
self.register_buffer(
"neg_freqs",
torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
),
)
self.rope_cache = {}

Expand All @@ -193,6 +201,53 @@ def rope_params(self, index, dim, theta=10000):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs

def _expand_pos_freqs_if_needed(self, required_len):
"""Expand pos_freqs and neg_freqs if required length exceeds current size"""
if required_len <= self._current_max_len:
return

# Calculate new size (use next power of 2 or round to nearest 512 for efficiency)
new_max_len = max(required_len, int((required_len + 511) // 512) * 512)

# Log warning about potential quality degradation for long prompts
if required_len > 512:
logger.warning(
f"QwenImage model was trained on prompts up to 512 tokens. "
f"Current prompt requires {required_len} tokens, which may lead to unpredictable behavior. "
f"Consider using shorter prompts for better results."
)

# Generate expanded indices
pos_index = torch.arange(new_max_len, device=self.pos_freqs.device)
neg_index = torch.arange(new_max_len, device=self.neg_freqs.device).flip(0) * -1 - 1

# Generate expanded frequency embeddings
new_pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
).to(device=self.pos_freqs.device, dtype=self.pos_freqs.dtype)

new_neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
).to(device=self.neg_freqs.device, dtype=self.neg_freqs.dtype)

# Update buffers
self.register_buffer("pos_freqs", new_pos_freqs)
self.register_buffer("neg_freqs", new_neg_freqs)
self._current_max_len = new_max_len

# Clear cache since dimensions changed
self.rope_cache = {}

def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
Expand Down Expand Up @@ -232,6 +287,11 @@ def forward(self, video_fhw, txt_seq_lens, device):
max_vid_index = max(height, width)

max_len = max(txt_seq_lens)

# Expand pos_freqs if needed to accommodate max_vid_index + max_len
required_len = max_vid_index + max_len
self._expand_pos_freqs_if_needed(required_len)

txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]

return vid_freqs, txt_freqs
Expand Down
78 changes: 77 additions & 1 deletion tests/pipelines/qwenimage/test_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
QwenImagePipeline,
QwenImageTransformer2DModel,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.testing_utils import CaptureLogger, enable_full_determinism, torch_device

from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
Expand Down Expand Up @@ -234,3 +234,79 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
expected_diff_max,
"VAE tiling should not affect the inference results",
)

def test_long_prompt_no_error(self):
# Test for issue #12083: long prompts should not cause dimension mismatch errors
device = torch_device
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)

# Create a long prompt that approaches but stays within limits
# This tests the original issue fix without triggering the warning
phrase = "A beautiful, detailed, high-resolution, photorealistic image showing "
long_prompt = phrase * 40 # Generates ~800 tokens, well within limits

# Verify token count for test clarity
tokenizer = components["tokenizer"]
token_count = len(tokenizer.encode(long_prompt))
required_len = 32 + token_count # height/width + tokens
# Should be large enough to test the fix but not trigger expansion warning
self.assertGreater(token_count, 500, f"Test prompt should be substantial (got {token_count} tokens)")
self.assertLess(required_len, 1024, f"Test should stay within limits (got {required_len})")

inputs = {
"prompt": long_prompt,
"generator": torch.Generator(device=device).manual_seed(0),
"num_inference_steps": 2,
"guidance_scale": 3.0,
"true_cfg_scale": 1.0,
"height": 32, # Small size for fast test
"width": 32, # Small size for fast test
"max_sequence_length": 1024, # Allow long sequence (max allowed)
"output_type": "pt",
}

# This should not raise a RuntimeError about tensor dimension mismatch
_ = pipe(**inputs)

def test_long_prompt_warning(self):
"""Test that long prompts trigger appropriate warning about training limitation"""
from diffusers.utils import logging

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)

# Create a long prompt that will exceed the RoPE expansion threshold
# The warning is triggered when required_len = max(height, width) + text_tokens > _current_max_len
# Since _current_max_len is 1024 and height=width=32, we need > 992 tokens
phrase = "A detailed photorealistic image showing many beautiful elements and complex artistic creative features with intricate designs."
long_prompt = phrase * 58 # Generates ~1045 tokens, ensuring required_len > 1024

# Verify we exceed the threshold (for test robustness)
tokenizer = components["tokenizer"]
token_count = len(tokenizer.encode(long_prompt))
required_len = 32 + token_count # height/width + tokens
self.assertGreater(required_len, 1024, f"Test prompt must exceed threshold (got {required_len})")

# Capture transformer logging
logger = logging.get_logger("diffusers.models.transformers.transformer_qwenimage")
logger.setLevel(logging.WARNING)

with CaptureLogger(logger) as cap_logger:
_ = pipe(
prompt=long_prompt,
generator=torch.Generator(device=torch_device).manual_seed(0),
num_inference_steps=2,
guidance_scale=3.0,
true_cfg_scale=1.0,
height=32, # Small size for fast test
width=32, # Small size for fast test
max_sequence_length=1024, # Allow long sequence
output_type="pt",
)

# Verify warning was logged about the 512-token training limitation
self.assertTrue("512 tokens" in cap_logger.out)
self.assertTrue("unpredictable behavior" in cap_logger.out)