From 31ed0097063aeb1f43b57384d9c11477515390f0 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 2 Feb 2026 15:48:06 +0100 Subject: [PATCH 1/2] update --- .../test_models_transformer_qwenimage.py | 257 ++++++++++++------ 1 file changed, 169 insertions(+), 88 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index e6b19377b14f..40c70d6b050d 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -13,49 +13,86 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import torch from diffusers import QwenImageTransformer2DModel from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + BitsAndBytesTesterMixin, + ContextParallelTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = QwenImageTransformer2DModel - main_input_name = "hidden_states" - # We override the items here because the transformer under consideration is small. - model_split_percents = [0.7, 0.6, 0.6] - - # Skip setting testing with default: AttnProcessor - uses_custom_attn_processor = True - +class QwenImageTransformerTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - return self.prepare_dummy_input() + def model_class(self): + return QwenImageTransformer2DModel @property - def input_shape(self): + def output_shape(self) -> tuple[int, int]: return (16, 16) @property - def output_shape(self): + def input_shape(self) -> tuple[int, int]: return (16, 16) - def prepare_dummy_input(self, height=4, width=4): + @property + def model_split_percents(self) -> list: + # We override the items here because the transformer under consideration is small. + return [0.7, 0.6, 0.6] + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def uses_custom_attn_processor(self) -> bool: + # Skip setting testing with default: AttnProcessor + return True + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int]]: + return { + "patch_size": 2, + "in_channels": 16, + "out_channels": 4, + "num_layers": 2, + "attention_head_dim": 16, + "num_attention_heads": 4, # Must be divisible by 2 for Ulysses context parallel + "joint_attention_dim": 16, + "guidance_embeds": False, + "axes_dims_rope": (8, 4, 4), + } + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: batch_size = 1 num_latent_channels = embedding_dim = 16 - sequence_length = 7 + sequence_length = 8 # Must be divisible by 2 for context parallel tests vae_scale_factor = 4 - hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) orig_height = height * 2 * vae_scale_factor @@ -70,29 +107,12 @@ def prepare_dummy_input(self, height=4, width=4): "img_shapes": img_shapes, } - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "patch_size": 2, - "in_channels": 16, - "out_channels": 4, - "num_layers": 2, - "attention_head_dim": 16, - "num_attention_heads": 3, - "joint_attention_dim": 16, - "guidance_embeds": False, - "axes_dims_rope": (8, 4, 4), - } - - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_gradient_checkpointing_is_applied(self): - expected_set = {"QwenImageTransformer2DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) +class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin): def test_infers_text_seq_len_from_mask(self): """Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors.""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) # Test 1: Contiguous mask with padding at the end (only first 2 tokens valid) @@ -104,24 +124,24 @@ def test_infers_text_seq_len_from_mask(self): ) # Verify rope_text_seq_len is returned as an int (for torch.compile compatibility) - self.assertIsInstance(rope_text_seq_len, int) + assert isinstance(rope_text_seq_len, int) # Verify per_sample_len is computed correctly (max valid position + 1 = 2) - self.assertIsInstance(per_sample_len, torch.Tensor) - self.assertEqual(int(per_sample_len.max().item()), 2) + assert isinstance(per_sample_len, torch.Tensor) + assert int(per_sample_len.max().item()) == 2 # Verify mask is normalized to bool dtype - self.assertTrue(normalized_mask.dtype == torch.bool) - self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values + assert normalized_mask.dtype == torch.bool + assert normalized_mask.sum().item() == 2 # Only 2 True values # Verify rope_text_seq_len is at least the sequence length - self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1]) + assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1] # Test 2: Verify model runs successfully with inferred values inputs["encoder_hidden_states_mask"] = normalized_mask with torch.no_grad(): output = model(**inputs) - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] # Test 3: Different mask pattern (padding at beginning) encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone() @@ -133,21 +153,22 @@ def test_infers_text_seq_len_from_mask(self): ) # Max valid position is 6 (last token), so per_sample_len should be 7 - self.assertEqual(int(per_sample_len2.max().item()), 7) - self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values + assert int(per_sample_len2.max().item()) == 7 + assert normalized_mask2.sum().item() == 4 # 4 True values # Test 4: No mask provided (None case) rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], None ) - self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1]) - self.assertIsInstance(rope_text_seq_len_none, int) - self.assertIsNone(per_sample_len_none) - self.assertIsNone(normalized_mask_none) + assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1] + assert isinstance(rope_text_seq_len_none, int) + assert per_sample_len_none is None + assert normalized_mask_none is None def test_non_contiguous_attention_mask(self): """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) # Create a non-contiguous mask pattern: valid, padding, valid, padding, etc. @@ -160,21 +181,22 @@ def test_non_contiguous_attention_mask(self): inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], encoder_hidden_states_mask ) - self.assertEqual(int(per_sample_len.max().item()), 5) - self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1]) - self.assertIsInstance(inferred_rope_len, int) - self.assertTrue(normalized_mask.dtype == torch.bool) + assert int(per_sample_len.max().item()) == 5 + assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1] + assert isinstance(inferred_rope_len, int) + assert normalized_mask.dtype == torch.bool inputs["encoder_hidden_states_mask"] = normalized_mask with torch.no_grad(): output = model(**inputs) - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] def test_txt_seq_lens_deprecation(self): """Test that passing txt_seq_lens raises a deprecation warning.""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) # Prepare inputs with txt_seq_lens (deprecated parameter) @@ -186,18 +208,24 @@ def test_txt_seq_lens_deprecation(self): inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens # Test that deprecation warning is raised - with self.assertWarns(FutureWarning) as warning_context: + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") with torch.no_grad(): output = model(**inputs_with_deprecated) - # Verify the warning message mentions the deprecation - warning_message = str(warning_context.warning) - self.assertIn("txt_seq_lens", warning_message) - self.assertIn("deprecated", warning_message) - self.assertIn("encoder_hidden_states_mask", warning_message) + # Verify a FutureWarning was raised + future_warnings = [x for x in w if issubclass(x.category, FutureWarning)] + assert len(future_warnings) > 0, "Expected FutureWarning to be raised" + + # Verify the warning message mentions the deprecation + warning_message = str(future_warnings[0].message) + assert "txt_seq_lens" in warning_message + assert "deprecated" in warning_message # Verify the model still works correctly despite the deprecation - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] def test_layered_model_with_mask(self): """Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model).""" @@ -208,7 +236,7 @@ def test_layered_model_with_mask(self): "out_channels": 4, "num_layers": 2, "attention_head_dim": 16, - "num_attention_heads": 3, + "num_attention_heads": 4, # Must be divisible by 2 for Ulysses context parallel "joint_attention_dim": 16, "axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16) "use_layer3d_rope": True, # Enable layered RoPE @@ -220,11 +248,11 @@ def test_layered_model_with_mask(self): # Verify the model uses QwenEmbedLayer3DRope from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope - self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope) + assert isinstance(model.pos_embed, QwenEmbedLayer3DRope) # Test single generation with layered structure batch_size = 1 - text_seq_len = 7 + text_seq_len = 8 img_h, img_w = 4, 4 layers = 4 @@ -262,24 +290,69 @@ def test_layered_model_with_mask(self): additional_t_cond=addition_t_cond, ) - self.assertEqual(output.sample.shape[1], hidden_states.shape[1]) + assert output.sample.shape[1] == hidden_states.shape[1] + + +class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for QwenImage Transformer.""" + + +class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin): + """Training tests for QwenImage Transformer.""" + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"QwenImageTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) -class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = QwenImageTransformer2DModel +class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin): + """Attention processor tests for QwenImage Transformer.""" - def prepare_init_args_and_inputs_for_common(self): - return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common() - def prepare_dummy_input(self, height, width): - return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width) +class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin): + """Context Parallel inference tests for QwenImage Transformer.""" + + +class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin): + """LoRA adapter tests for QwenImage Transformer.""" + + +class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin): + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + """Override to support dynamic height/width for compilation tests.""" + batch_size = 1 + num_latent_channels = embedding_dim = 16 + sequence_length = 8 # Must be divisible by 2 for context parallel tests + vae_scale_factor = 4 + + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size - def test_torch_compile_recompilation_and_graph_break(self): - super().test_torch_compile_recompilation_and_graph_break() + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } def test_torch_compile_with_and_without_mask(self): """Test that torch.compile works with both None mask and padding mask.""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) model.eval() model.compile(mode="default", fullgraph=True) @@ -300,13 +373,13 @@ def test_torch_compile_with_and_without_mask(self): ): output_no_mask_2 = model(**inputs_no_mask) - self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1]) - self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1] + assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1] # Test 2: Run with all-ones mask (should behave like None) inputs_all_ones = inputs.copy() # Keep the all-ones mask - self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item()) + assert inputs_all_ones["encoder_hidden_states_mask"].all().item() # First run to allow compilation with torch.no_grad(): @@ -320,8 +393,8 @@ def test_torch_compile_with_and_without_mask(self): ): output_all_ones_2 = model(**inputs_all_ones) - self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1]) - self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1] + assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1] # Test 3: Run with actual padding mask (has zeros) inputs_with_padding = inputs.copy() @@ -342,8 +415,16 @@ def test_torch_compile_with_and_without_mask(self): ): output_with_padding_2 = model(**inputs_with_padding) - self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1]) - self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1] + assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1] # Verify that outputs are different (mask should affect results) - self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)) + assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3) + + +class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin): + """BitsAndBytes quantization tests for QwenImage Transformer.""" + + +class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin): + """TorchAO quantization tests for QwenImage Transformer.""" From ffdfe289830c14692f9956a92959421c9d69e279 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 3 Feb 2026 06:05:12 +0100 Subject: [PATCH 2/2] update --- .../test_models_transformer_qwenimage.py | 48 ++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 40c70d6b050d..094c1dd11055 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -25,6 +25,7 @@ BaseModelTesterConfig, BitsAndBytesTesterMixin, ContextParallelTesterMixin, + LoraHotSwappingForModelTesterMixin, LoraTesterMixin, MemoryTesterMixin, ModelTesterMixin, @@ -146,15 +147,15 @@ def test_infers_text_seq_len_from_mask(self): # Test 3: Different mask pattern (padding at beginning) encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone() encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding - encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid + encoder_hidden_states_mask2[:, 3:] = 1 # Last 5 tokens are valid (seq_len=8) rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], encoder_hidden_states_mask2 ) - # Max valid position is 6 (last token), so per_sample_len should be 7 - assert int(per_sample_len2.max().item()) == 7 - assert normalized_mask2.sum().item() == 4 # 4 True values + # Max valid position is 7 (last token), so per_sample_len should be 8 + assert int(per_sample_len2.max().item()) == 8 + assert normalized_mask2.sum().item() == 5 # 5 True values # Test 4: No mask provided (None case) rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( @@ -166,14 +167,14 @@ def test_infers_text_seq_len_from_mask(self): assert normalized_mask_none is None def test_non_contiguous_attention_mask(self): - """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" + """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0, 0])""" init_dict = self.get_init_dict() inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) # Create a non-contiguous mask pattern: valid, padding, valid, padding, etc. encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() - # Pattern: [True, False, True, False, True, False, False] + # Pattern: [True, False, True, False, True, False, False, False] (seq_len=8) encoder_hidden_states_mask[:, 1] = 0 encoder_hidden_states_mask[:, 3] = 0 encoder_hidden_states_mask[:, 5:] = 0 @@ -317,6 +318,41 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM """LoRA adapter tests for QwenImage Transformer.""" +class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): + """LoRA hot-swapping tests for QwenImage Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + """Override to support dynamic height/width for LoRA hotswap tests.""" + batch_size = 1 + num_latent_channels = embedding_dim = 16 + sequence_length = 8 + vae_scale_factor = 4 + + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + + class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin): @property def different_shapes_for_compilation(self):