-
Notifications
You must be signed in to change notification settings - Fork 6.2k
enable compilation in qwen image. #12061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ee89f79
5620f87
365b8c1
dfc6018
431cd77
6a5fcec
d2241b9
ebd8289
4ffc993
bc4016d
20e30cb
0305b5a
6eaca43
e32a353
c630005
fdfc758
385f819
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
# limitations under the License. | ||
|
||
|
||
import functools | ||
import math | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
|
@@ -162,15 +163,15 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): | |
self.axes_dim = axes_dim | ||
pos_index = torch.arange(1024) | ||
neg_index = torch.arange(1024).flip(0) * -1 - 1 | ||
self.pos_freqs = torch.cat( | ||
pos_freqs = torch.cat( | ||
[ | ||
self.rope_params(pos_index, self.axes_dim[0], self.theta), | ||
self.rope_params(pos_index, self.axes_dim[1], self.theta), | ||
self.rope_params(pos_index, self.axes_dim[2], self.theta), | ||
], | ||
dim=1, | ||
) | ||
self.neg_freqs = torch.cat( | ||
neg_freqs = torch.cat( | ||
[ | ||
self.rope_params(neg_index, self.axes_dim[0], self.theta), | ||
self.rope_params(neg_index, self.axes_dim[1], self.theta), | ||
|
@@ -179,6 +180,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): | |
dim=1, | ||
) | ||
self.rope_cache = {} | ||
self.register_buffer("pos_freqs", pos_freqs, persistent=False) | ||
self.register_buffer("neg_freqs", neg_freqs, persistent=False) | ||
|
||
# 是否使用 scale rope | ||
self.scale_rope = scale_rope | ||
|
@@ -198,33 +201,17 @@ def forward(self, video_fhw, txt_seq_lens, device): | |
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: | ||
txt_length: [bs] a list of 1 integers representing the length of the text | ||
""" | ||
if self.pos_freqs.device != device: | ||
self.pos_freqs = self.pos_freqs.to(device) | ||
self.neg_freqs = self.neg_freqs.to(device) | ||
|
||
Comment on lines
-201
to
-204
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recompilation trigger one. |
||
if isinstance(video_fhw, list): | ||
video_fhw = video_fhw[0] | ||
frame, height, width = video_fhw | ||
rope_key = f"{frame}_{height}_{width}" | ||
|
||
if rope_key not in self.rope_cache: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recompilation trigger two. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this again because there is something special happening on the first run? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
seq_lens = frame * height * width | ||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) | ||
if self.scale_rope: | ||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) | ||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) | ||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) | ||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) | ||
|
||
else: | ||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) | ||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) | ||
|
||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) | ||
self.rope_cache[rope_key] = freqs.clone().contiguous() | ||
vid_freqs = self.rope_cache[rope_key] | ||
if not torch.compiler.is_compiling(): | ||
if rope_key not in self.rope_cache: | ||
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width) | ||
vid_freqs = self.rope_cache[rope_key] | ||
else: | ||
vid_freqs = self._compute_video_freqs(frame, height, width) | ||
|
||
if self.scale_rope: | ||
max_vid_index = max(height // 2, width // 2) | ||
|
@@ -236,6 +223,25 @@ def forward(self, video_fhw, txt_seq_lens, device): | |
|
||
return vid_freqs, txt_freqs | ||
|
||
@functools.lru_cache(maxsize=None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove the WDYT about maybe putting maxsize=128 or something here so that long running services that use diffusers don't accidentally die with OOM (probably very unlikely though) @sayakpaul? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
def _compute_video_freqs(self, frame, height, width): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: we need to remove frame (can be done in future PR) |
||
seq_lens = frame * height * width | ||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
|
||
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) | ||
if self.scale_rope: | ||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) | ||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) | ||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) | ||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) | ||
else: | ||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) | ||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) | ||
|
||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) | ||
return freqs.clone().contiguous() | ||
|
||
|
||
class QwenDoubleStreamAttnProcessor2_0: | ||
""" | ||
|
@@ -482,6 +488,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro | |
_supports_gradient_checkpointing = True | ||
_no_split_modules = ["QwenImageTransformerBlock"] | ||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"] | ||
_repeated_blocks = ["QwenImageTransformerBlock"] | ||
|
||
@register_to_config | ||
def __init__( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1711,6 +1711,11 @@ def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5 | |
if not self.model_class._supports_group_offloading: | ||
pytest.skip("Model does not support group offloading.") | ||
|
||
if self.model_class.__name__ == "QwenImageTransformer2DModel": | ||
pytest.skip( | ||
"QwenImageTransformer2DModel doesn't support group offloading with disk. Needs to be investigated." | ||
) | ||
|
||
Comment on lines
+1714
to
+1718
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will investigate in a follow-up. |
||
def _has_generator_arg(model): | ||
sig = inspect.signature(model.forward) | ||
params = sig.parameters | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# coding=utf-8 | ||
# Copyright 2025 HuggingFace Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import torch | ||
|
||
from diffusers import QwenImageTransformer2DModel | ||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device | ||
|
||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin | ||
|
||
|
||
enable_full_determinism() | ||
|
||
|
||
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase): | ||
model_class = QwenImageTransformer2DModel | ||
main_input_name = "hidden_states" | ||
# We override the items here because the transformer under consideration is small. | ||
model_split_percents = [0.7, 0.6, 0.6] | ||
|
||
# Skip setting testing with default: AttnProcessor | ||
uses_custom_attn_processor = True | ||
|
||
@property | ||
def dummy_input(self): | ||
return self.prepare_dummy_input() | ||
|
||
@property | ||
def input_shape(self): | ||
return (16, 16) | ||
|
||
@property | ||
def output_shape(self): | ||
return (16, 16) | ||
|
||
def prepare_dummy_input(self, height=4, width=4): | ||
batch_size = 1 | ||
num_latent_channels = embedding_dim = 16 | ||
sequence_length = 7 | ||
vae_scale_factor = 4 | ||
|
||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) | ||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) | ||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) | ||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) | ||
orig_height = height * 2 * vae_scale_factor | ||
orig_width = width * 2 * vae_scale_factor | ||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size | ||
|
||
return { | ||
"hidden_states": hidden_states, | ||
"encoder_hidden_states": encoder_hidden_states, | ||
"encoder_hidden_states_mask": encoder_hidden_states_mask, | ||
"timestep": timestep, | ||
"img_shapes": img_shapes, | ||
"txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(), | ||
} | ||
|
||
def prepare_init_args_and_inputs_for_common(self): | ||
init_dict = { | ||
"patch_size": 2, | ||
"in_channels": 16, | ||
"out_channels": 4, | ||
"num_layers": 2, | ||
"attention_head_dim": 16, | ||
"num_attention_heads": 3, | ||
"joint_attention_dim": 16, | ||
"guidance_embeds": False, | ||
"axes_dims_rope": (8, 4, 4), | ||
} | ||
|
||
inputs_dict = self.dummy_input | ||
return init_dict, inputs_dict | ||
|
||
def test_gradient_checkpointing_is_applied(self): | ||
expected_set = {"QwenImageTransformer2DModel"} | ||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | ||
|
||
|
||
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): | ||
model_class = QwenImageTransformer2DModel | ||
|
||
def prepare_init_args_and_inputs_for_common(self): | ||
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common() | ||
|
||
def prepare_dummy_input(self, height, width): | ||
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is most likely not equivalent. When registered as buffer, if the model is loaded in bf16, the precision of these will bf16 instead of fp32. Doing RoPE in bf16 may harm image quality, so we need to be careful here. Not sure what's best to do here -- maybe for now we can put the rope layer in
_keep_modules_in_fp32
?This recompilation related problem seems to have become too common with RoPE. Maybe we need to rethink the design a bit.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for the record, sharing the recompilation error we get without the buffer implementation:
But I agree with your first point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The recompilation message here says that on the first compilation
'self._modules['pos_embed'].neg_freqs'
was a CPU tensor, and on second it became a CUDA tensor. Does that match your expectation? If yes, is it possible to change that somehow. If there is something special happening on the first invocation, you can put compile on the second invocation onwards.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is, IMO, an involved user-experience we should probably avoid.