diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 52ec30c536bd..9554a914a966 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -415,6 +415,7 @@ "Flux2AutoBlocks", "Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks", + "Flux2KleinBaseModularPipeline", "Flux2KleinModularPipeline", "Flux2ModularPipeline", "FluxAutoBlocks", @@ -431,8 +432,13 @@ "QwenImageModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", - "Wan22AutoBlocks", - "WanAutoBlocks", + "Wan22Blocks", + "Wan22Image2VideoBlocks", + "Wan22Image2VideoModularPipeline", + "Wan22ModularPipeline", + "WanBlocks", + "WanImage2VideoAutoBlocks", + "WanImage2VideoModularPipeline", "WanModularPipeline", "ZImageAutoBlocks", "ZImageModularPipeline", @@ -1151,6 +1157,7 @@ Flux2AutoBlocks, Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, + Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline, FluxAutoBlocks, @@ -1167,8 +1174,13 @@ QwenImageModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, - Wan22AutoBlocks, - WanAutoBlocks, + Wan22Blocks, + Wan22Image2VideoBlocks, + Wan22Image2VideoModularPipeline, + Wan22ModularPipeline, + WanBlocks, + WanImage2VideoAutoBlocks, + WanImage2VideoModularPipeline, WanModularPipeline, ZImageAutoBlocks, ZImageModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 823a3d263ea9..94b87c61c234 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -45,7 +45,16 @@ "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] - _import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"] + _import_structure["wan"] = [ + "WanBlocks", + "Wan22Blocks", + "WanImage2VideoAutoBlocks", + "Wan22Image2VideoBlocks", + "WanModularPipeline", + "Wan22ModularPipeline", + "WanImage2VideoModularPipeline", + "Wan22Image2VideoModularPipeline", + ] _import_structure["flux"] = [ "FluxAutoBlocks", "FluxModularPipeline", @@ -58,6 +67,7 @@ "Flux2KleinBaseAutoBlocks", "Flux2ModularPipeline", "Flux2KleinModularPipeline", + "Flux2KleinBaseModularPipeline", ] _import_structure["qwenimage"] = [ "QwenImageAutoBlocks", @@ -88,6 +98,7 @@ Flux2AutoBlocks, Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, + Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline, ) @@ -112,7 +123,16 @@ QwenImageModularPipeline, ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline - from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline + from .wan import ( + Wan22Blocks, + Wan22Image2VideoBlocks, + Wan22Image2VideoModularPipeline, + Wan22ModularPipeline, + WanBlocks, + WanImage2VideoAutoBlocks, + WanImage2VideoModularPipeline, + WanModularPipeline, + ) from .z_image import ZImageAutoBlocks, ZImageModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/flux2/__init__.py b/src/diffusers/modular_pipelines/flux2/__init__.py index 220ec0c4ab65..74907a9af806 100644 --- a/src/diffusers/modular_pipelines/flux2/__init__.py +++ b/src/diffusers/modular_pipelines/flux2/__init__.py @@ -55,7 +55,11 @@ "Flux2VaeEncoderSequentialStep", ] _import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"] - _import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"] + _import_structure["modular_pipeline"] = [ + "Flux2ModularPipeline", + "Flux2KleinModularPipeline", + "Flux2KleinBaseModularPipeline", + ] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -101,7 +105,7 @@ Flux2KleinAutoBlocks, Flux2KleinBaseAutoBlocks, ) - from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline + from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py index 29fbeba07c24..31ba5aec7cfb 100644 --- a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py @@ -13,8 +13,6 @@ # limitations under the License. -from typing import Any, Dict, Optional - from ...loaders import Flux2LoraLoaderMixin from ...utils import logging from ..modular_pipeline import ModularPipeline @@ -59,46 +57,35 @@ def num_channels_latents(self): return num_channels_latents -class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): +class Flux2KleinModularPipeline(Flux2ModularPipeline): """ - A ModularPipeline for Flux2-Klein. + A ModularPipeline for Flux2-Klein (distilled model). > [!WARNING] > This is an experimental feature and is likely to change in the future. """ - default_blocks_name = "Flux2KleinBaseAutoBlocks" - - def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: - if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]: - return "Flux2KleinAutoBlocks" - else: - return "Flux2KleinBaseAutoBlocks" + default_blocks_name = "Flux2KleinAutoBlocks" @property - def default_height(self): - return self.default_sample_size * self.vae_scale_factor + def requires_unconditional_embeds(self): + if hasattr(self.config, "is_distilled") and self.config.is_distilled: + return False - @property - def default_width(self): - return self.default_sample_size * self.vae_scale_factor + requires_unconditional_embeds = False + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 - @property - def default_sample_size(self): - return 128 + return requires_unconditional_embeds - @property - def vae_scale_factor(self): - vae_scale_factor = 8 - if getattr(self, "vae", None) is not None: - vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - return vae_scale_factor - @property - def num_channels_latents(self): - num_channels_latents = 32 - if getattr(self, "transformer", None): - num_channels_latents = self.transformer.config.in_channels // 4 - return num_channels_latents +class Flux2KleinBaseModularPipeline(Flux2ModularPipeline): + """ + A ModularPipeline for Flux2-Klein (base model). + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2KleinBaseAutoBlocks" @property def requires_unconditional_embeds(self): diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 98ede73c21fe..952efa70ed04 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -52,19 +52,61 @@ # map regular pipeline to modular pipeline class name + + +def _create_default_map_fn(pipeline_class_name: str): + """Create a mapping function that always returns the same pipeline class.""" + + def _map_fn(config_dict=None): + return pipeline_class_name + + return _map_fn + + +def _flux2_klein_map_fn(config_dict=None): + if config_dict is None: + return "Flux2KleinModularPipeline" + + if "is_distilled" in config_dict and config_dict["is_distilled"]: + return "Flux2KleinModularPipeline" + else: + return "Flux2KleinBaseModularPipeline" + + +def _wan_map_fn(config_dict=None): + if config_dict is None: + return "WanModularPipeline" + + if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None: + return "Wan22ModularPipeline" + else: + return "WanModularPipeline" + + +def _wan_i2v_map_fn(config_dict=None): + if config_dict is None: + return "WanImage2VideoModularPipeline" + + if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None: + return "Wan22Image2VideoModularPipeline" + else: + return "WanImage2VideoModularPipeline" + + MODULAR_PIPELINE_MAPPING = OrderedDict( [ - ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), - ("wan", "WanModularPipeline"), - ("flux", "FluxModularPipeline"), - ("flux-kontext", "FluxKontextModularPipeline"), - ("flux2", "Flux2ModularPipeline"), - ("flux2-klein", "Flux2KleinModularPipeline"), - ("qwenimage", "QwenImageModularPipeline"), - ("qwenimage-edit", "QwenImageEditModularPipeline"), - ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), - ("qwenimage-layered", "QwenImageLayeredModularPipeline"), - ("z-image", "ZImageModularPipeline"), + ("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")), + ("wan", _wan_map_fn), + ("wan-i2v", _wan_i2v_map_fn), + ("flux", _create_default_map_fn("FluxModularPipeline")), + ("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")), + ("flux2", _create_default_map_fn("Flux2ModularPipeline")), + ("flux2-klein", _flux2_klein_map_fn), + ("qwenimage", _create_default_map_fn("QwenImageModularPipeline")), + ("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")), + ("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")), + ("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")), + ("z-image", _create_default_map_fn("ZImageModularPipeline")), ] ) @@ -366,7 +408,8 @@ def init_pipeline( """ create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub. """ - pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__) + map_fn = MODULAR_PIPELINE_MAPPING.get(self.model_name, _create_default_map_fn("ModularPipeline")) + pipeline_class_name = map_fn() diffusers_module = importlib.import_module("diffusers") pipeline_class = getattr(diffusers_module, pipeline_class_name) @@ -1545,7 +1588,7 @@ def __init__( if modular_config_dict is not None: blocks_class_name = modular_config_dict.get("_blocks_class_name") else: - blocks_class_name = self.get_default_blocks_name(config_dict) + blocks_class_name = self.default_blocks_name if blocks_class_name is not None: diffusers_module = importlib.import_module("diffusers") blocks_class = getattr(diffusers_module, blocks_class_name) @@ -1617,9 +1660,6 @@ def default_call_parameters(self) -> Dict[str, Any]: params[input_param.name] = input_param.default return params - def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: - return self.default_blocks_name - @classmethod def _load_pipeline_config( cls, @@ -1715,7 +1755,8 @@ def from_pretrained( logger.debug(" try to determine the modular pipeline class from model_index.json") standard_pipeline_class = _get_pipeline_class(cls, config=config_dict) model_name = _get_model(standard_pipeline_class.__name__) - pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__) + map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline")) + pipeline_class_name = map_fn(config_dict) diffusers_module = importlib.import_module("diffusers") pipeline_class = getattr(diffusers_module, pipeline_class_name) else: diff --git a/src/diffusers/modular_pipelines/wan/__init__.py b/src/diffusers/modular_pipelines/wan/__init__.py index 73f67c9afed2..284b6c9fa436 100644 --- a/src/diffusers/modular_pipelines/wan/__init__.py +++ b/src/diffusers/modular_pipelines/wan/__init__.py @@ -21,16 +21,16 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["decoders"] = ["WanImageVaeDecoderStep"] - _import_structure["encoders"] = ["WanTextEncoderStep"] - _import_structure["modular_blocks"] = [ - "ALL_BLOCKS", - "Wan22AutoBlocks", - "WanAutoBlocks", - "WanAutoImageEncoderStep", - "WanAutoVaeImageEncoderStep", + _import_structure["modular_blocks_wan"] = ["WanBlocks"] + _import_structure["modular_blocks_wan22"] = ["Wan22Blocks"] + _import_structure["modular_blocks_wan22_i2v"] = ["Wan22Image2VideoBlocks"] + _import_structure["modular_blocks_wan_i2v"] = ["WanImage2VideoAutoBlocks"] + _import_structure["modular_pipeline"] = [ + "Wan22Image2VideoModularPipeline", + "Wan22ModularPipeline", + "WanImage2VideoModularPipeline", + "WanModularPipeline", ] - _import_structure["modular_pipeline"] = ["WanModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -39,16 +39,16 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .decoders import WanImageVaeDecoderStep - from .encoders import WanTextEncoderStep - from .modular_blocks import ( - ALL_BLOCKS, - Wan22AutoBlocks, - WanAutoBlocks, - WanAutoImageEncoderStep, - WanAutoVaeImageEncoderStep, + from .modular_blocks_wan import WanBlocks + from .modular_blocks_wan22 import Wan22Blocks + from .modular_blocks_wan22_i2v import Wan22Image2VideoBlocks + from .modular_blocks_wan_i2v import WanImage2VideoAutoBlocks + from .modular_pipeline import ( + Wan22Image2VideoModularPipeline, + Wan22ModularPipeline, + WanImage2VideoModularPipeline, + WanModularPipeline, ) - from .modular_pipeline import WanModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index e2f8d3e7d88b..719ba4c21148 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -280,7 +280,7 @@ class WanAdditionalInputsStep(ModularPipelineBlocks): def __init__( self, - image_latent_inputs: List[str] = ["first_frame_latents"], + image_latent_inputs: List[str] = ["image_condition_latents"], additional_batch_inputs: List[str] = [], ): """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" @@ -294,20 +294,16 @@ def __init__( Args: image_latent_inputs (List[str], optional): Names of image latent tensors to process. In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be - a single string or list of strings. Defaults to ["first_frame_latents"]. + a single string or list of strings. Defaults to ["image_condition_latents"]. additional_batch_inputs (List[str], optional): Names of additional conditional input tensors to expand batch size. These tensors will only have their batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. Defaults to []. Examples: - # Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep() - - # Configure to process multiple image latent inputs - WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"]) - - # Configure to process image latents and additional batch inputs WanAdditionalInputsStep( - image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"] + # Configure to process image_condition_latents (default behavior) WanAdditionalInputsStep() # Configure to + process image latents and additional batch inputs WanAdditionalInputsStep( + image_latent_inputs=["image_condition_latents"], additional_batch_inputs=["image_embeds"] ) """ if not isinstance(image_latent_inputs, list): @@ -557,81 +553,3 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe self.set_block_state(state, block_state) return components, state - - -class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks): - model_name = "wan" - - @property - def description(self) -> str: - return "step that prepares the masked first frame latents and add it to the latent condition" - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]), - InputParam("num_frames", type_hint=int), - ] - - def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape - - mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) - mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0 - - first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = torch.repeat_interleave( - first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal - ) - mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view( - batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width - ) - mask_lat_size = mask_lat_size.transpose(1, 2) - mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device) - block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1) - - self.set_block_state(state, block_state) - return components, state - - -class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks): - model_name = "wan" - - @property - def description(self) -> str: - return "step that prepares the masked latents with first and last frames and add it to the latent condition" - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]), - InputParam("num_frames", type_hint=int), - ] - - def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - - batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape - - mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) - mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0 - - first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = torch.repeat_interleave( - first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal - ) - mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view( - batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width - ) - mask_lat_size = mask_lat_size.transpose(1, 2) - mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device) - block_state.first_last_frame_latents = torch.concat( - [mask_lat_size, block_state.first_last_frame_latents], dim=1 - ) - - self.set_block_state(state, block_state) - return components, state diff --git a/src/diffusers/modular_pipelines/wan/decoders.py b/src/diffusers/modular_pipelines/wan/decoders.py index 7cec318c1706..181f3aae1d58 100644 --- a/src/diffusers/modular_pipelines/wan/decoders.py +++ b/src/diffusers/modular_pipelines/wan/decoders.py @@ -29,7 +29,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class WanImageVaeDecoderStep(ModularPipelineBlocks): +class WanVaeDecoderStep(ModularPipelineBlocks): model_name = "wan" @property diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index 2da36f52da87..7f44b0230d78 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -89,52 +89,10 @@ def inputs(self) -> List[InputParam]: description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( - "first_frame_latents", + "image_condition_latents", required=True, type_hint=torch.Tensor, - description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.", - ), - InputParam( - "dtype", - required=True, - type_hint=torch.dtype, - description="The dtype of the model inputs. Can be generated in input step.", - ), - ] - - @torch.no_grad() - def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): - block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to( - block_state.dtype - ) - return components, block_state - - -class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks): - model_name = "wan" - - @property - def description(self) -> str: - return ( - "step within the denoising loop that prepares the latent input for the denoiser. " - "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " - "object (e.g. `WanDenoiseLoopWrapper`)" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", - ), - InputParam( - "first_last_frame_latents", - required=True, - type_hint=torch.Tensor, - description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.", + description="The image condition latents to use for the denoising process. Can be generated in prepare_first_frame_latents/prepare_first_last_frame_latents step.", ), InputParam( "dtype", @@ -147,7 +105,7 @@ def inputs(self) -> List[InputParam]: @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): block_state.latent_model_input = torch.cat( - [block_state.latents, block_state.first_last_frame_latents], dim=1 + [block_state.latents, block_state.image_condition_latents], dim=1 ).to(block_state.dtype) return components, block_state @@ -584,29 +542,3 @@ def description(self) -> str: " - `WanLoopAfterDenoiser`\n" "This block supports image-to-video tasks for Wan2.2." ) - - -class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper): - block_classes = [ - WanFLF2VLoopBeforeDenoiser, - WanLoopDenoiser( - guider_input_fields={ - "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), - "encoder_hidden_states_image": "image_embeds", - } - ), - WanLoopAfterDenoiser, - ] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - - @property - def description(self) -> str: - return ( - "Denoise step that iteratively denoise the latents. \n" - "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" - "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" - " - `WanFLF2VLoopBeforeDenoiser`\n" - " - `WanLoopDenoiser`\n" - " - `WanLoopAfterDenoiser`\n" - "This block supports FLF2V tasks for wan2.1." - ) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index 4fd69c6ca6ab..22b62a601d34 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -468,7 +468,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe return components, state -class WanVaeImageEncoderStep(ModularPipelineBlocks): +class WanVaeEncoderStep(ModularPipelineBlocks): model_name = "wan" @property @@ -493,7 +493,7 @@ def inputs(self) -> List[InputParam]: InputParam("resized_image", type_hint=PIL.Image.Image, required=True), InputParam("height"), InputParam("width"), - InputParam("num_frames"), + InputParam("num_frames", type_hint=int, default=81), InputParam("generator"), ] @@ -564,7 +564,51 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe return components, state -class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks): +class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "step that prepares the masked first frame latents and add it to the latent condition" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]), + InputParam("num_frames", required=True), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape + + mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device) + block_state.image_condition_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1) + + self.set_block_state(state, block_state) + return components, state + + +class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks): model_name = "wan" @property @@ -590,7 +634,7 @@ def inputs(self) -> List[InputParam]: InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True), InputParam("height"), InputParam("width"), - InputParam("num_frames"), + InputParam("num_frames", type_hint=int, default=81), InputParam("generator"), ] @@ -667,3 +711,49 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe self.set_block_state(state, block_state) return components, state + + +class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "step that prepares the masked latents with first and last frames and add it to the latent condition" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]), + InputParam("num_frames", type_hint=int, required=True), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape + + mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device) + block_state.image_condition_latents = torch.concat( + [mask_lat_size, block_state.first_last_frame_latents], dim=1 + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py deleted file mode 100644 index 905111bcf42d..000000000000 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ /dev/null @@ -1,474 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ...utils import logging -from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict -from .before_denoise import ( - WanAdditionalInputsStep, - WanPrepareFirstFrameLatentsStep, - WanPrepareFirstLastFrameLatentsStep, - WanPrepareLatentsStep, - WanSetTimestepsStep, - WanTextInputStep, -) -from .decoders import WanImageVaeDecoderStep -from .denoise import ( - Wan22DenoiseStep, - Wan22Image2VideoDenoiseStep, - WanDenoiseStep, - WanFLF2VDenoiseStep, - WanImage2VideoDenoiseStep, -) -from .encoders import ( - WanFirstLastFrameImageEncoderStep, - WanFirstLastFrameVaeImageEncoderStep, - WanImageCropResizeStep, - WanImageEncoderStep, - WanImageResizeStep, - WanTextEncoderStep, - WanVaeImageEncoderStep, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# wan2.1 -# wan2.1: text2vid -class WanCoreDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - WanTextInputStep, - WanSetTimestepsStep, - WanPrepareLatentsStep, - WanDenoiseStep, - ] - block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] - - @property - def description(self): - return ( - "denoise block that takes encoded conditions and runs the denoising process.\n" - + "This is a sequential pipeline blocks:\n" - + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanSetTimestepsStep` is used to set the timesteps\n" - + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - + " - `WanDenoiseStep` is used to denoise the latents\n" - ) - - -# wan2.1: image2video -## image encoder -class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks): - model_name = "wan" - block_classes = [WanImageResizeStep, WanImageEncoderStep] - block_names = ["image_resize", "image_encoder"] - - @property - def description(self): - return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings" - - -## vae encoder -class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks): - model_name = "wan" - block_classes = [WanImageResizeStep, WanVaeImageEncoderStep] - block_names = ["image_resize", "vae_encoder"] - - @property - def description(self): - return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" - - -## denoise -class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - WanTextInputStep, - WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]), - WanSetTimestepsStep, - WanPrepareLatentsStep, - WanPrepareFirstFrameLatentsStep, - WanImage2VideoDenoiseStep, - ] - block_names = [ - "input", - "additional_inputs", - "set_timesteps", - "prepare_latents", - "prepare_first_frame_latents", - "denoise", - ] - - @property - def description(self): - return ( - "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" - + "This is a sequential pipeline blocks:\n" - + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" - + " - `WanSetTimestepsStep` is used to set the timesteps\n" - + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - + " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n" - + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" - ) - - -# wan2.1: FLF2v - - -## image encoder -class WanFLF2VImageEncoderStep(SequentialPipelineBlocks): - model_name = "wan" - block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep] - block_names = ["image_resize", "last_image_resize", "image_encoder"] - - @property - def description(self): - return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings" - - -## vae encoder -class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks): - model_name = "wan" - block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep] - block_names = ["image_resize", "last_image_resize", "vae_encoder"] - - @property - def description(self): - return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions" - - -## denoise -class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - WanTextInputStep, - WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]), - WanSetTimestepsStep, - WanPrepareLatentsStep, - WanPrepareFirstLastFrameLatentsStep, - WanFLF2VDenoiseStep, - ] - block_names = [ - "input", - "additional_inputs", - "set_timesteps", - "prepare_latents", - "prepare_first_last_frame_latents", - "denoise", - ] - - @property - def description(self): - return ( - "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" - + "This is a sequential pipeline blocks:\n" - + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" - + " - `WanSetTimestepsStep` is used to set the timesteps\n" - + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - + " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n" - + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" - ) - - -# wan2.1: auto blocks -## image encoder -class WanAutoImageEncoderStep(AutoPipelineBlocks): - block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep] - block_names = ["flf2v_image_encoder", "image2video_image_encoder"] - block_trigger_inputs = ["last_image", "image"] - - @property - def description(self): - return ( - "Image Encoder step that encode the image to generate the image embeddings" - + "This is an auto pipeline block that works for image2video tasks." - + " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided." - + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided." - + " - if `last_image` or `image` is not provided, step will be skipped." - ) - - -## vae encoder -class WanAutoVaeImageEncoderStep(AutoPipelineBlocks): - block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep] - block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"] - block_trigger_inputs = ["last_image", "image"] - - @property - def description(self): - return ( - "Vae Image Encoder step that encode the image to generate the image latents" - + "This is an auto pipeline block that works for image2video tasks." - + " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided." - + " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided." - + " - if `last_image` or `image` is not provided, step will be skipped." - ) - - -## denoise -class WanAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [ - WanFLF2VCoreDenoiseStep, - WanImage2VideoCoreDenoiseStep, - WanCoreDenoiseStep, - ] - block_names = ["flf2v", "image2video", "text2video"] - block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None] - - @property - def description(self) -> str: - return ( - "Denoise step that iteratively denoise the latents. " - "This is a auto pipeline block that works for text2video and image2video tasks." - " - `WanCoreDenoiseStep` (text2video) for text2vid tasks." - " - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks." - + " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n" - + " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n" - ) - - -# auto pipeline blocks -class WanAutoBlocks(SequentialPipelineBlocks): - block_classes = [ - WanTextEncoderStep, - WanAutoImageEncoderStep, - WanAutoVaeImageEncoderStep, - WanAutoDenoiseStep, - WanImageVaeDecoderStep, - ] - block_names = [ - "text_encoder", - "image_encoder", - "vae_encoder", - "denoise", - "decode", - ] - - @property - def description(self): - return ( - "Auto Modular pipeline for text-to-video using Wan.\n" - + "- for text-to-video generation, all you need to provide is `prompt`" - ) - - -# wan22 -# wan2.2: text2vid - - -## denoise -class Wan22CoreDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - WanTextInputStep, - WanSetTimestepsStep, - WanPrepareLatentsStep, - Wan22DenoiseStep, - ] - block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] - - @property - def description(self): - return ( - "denoise block that takes encoded conditions and runs the denoising process.\n" - + "This is a sequential pipeline blocks:\n" - + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanSetTimestepsStep` is used to set the timesteps\n" - + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - + " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n" - ) - - -# wan2.2: image2video -## denoise -class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - WanTextInputStep, - WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]), - WanSetTimestepsStep, - WanPrepareLatentsStep, - WanPrepareFirstFrameLatentsStep, - Wan22Image2VideoDenoiseStep, - ] - block_names = [ - "input", - "additional_inputs", - "set_timesteps", - "prepare_latents", - "prepare_first_frame_latents", - "denoise", - ] - - @property - def description(self): - return ( - "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" - + "This is a sequential pipeline blocks:\n" - + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" - + " - `WanSetTimestepsStep` is used to set the timesteps\n" - + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - + " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n" - + " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n" - ) - - -class Wan22AutoDenoiseStep(AutoPipelineBlocks): - block_classes = [ - Wan22Image2VideoCoreDenoiseStep, - Wan22CoreDenoiseStep, - ] - block_names = ["image2video", "text2video"] - block_trigger_inputs = ["first_frame_latents", None] - - @property - def description(self) -> str: - return ( - "Denoise step that iteratively denoise the latents. " - "This is a auto pipeline block that works for text2video and image2video tasks." - " - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks." - " - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks." - + " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n" - + " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n" - ) - - -class Wan22AutoBlocks(SequentialPipelineBlocks): - block_classes = [ - WanTextEncoderStep, - WanAutoVaeImageEncoderStep, - Wan22AutoDenoiseStep, - WanImageVaeDecoderStep, - ] - block_names = [ - "text_encoder", - "vae_encoder", - "denoise", - "decode", - ] - - @property - def description(self): - return ( - "Auto Modular pipeline for text-to-video using Wan2.2.\n" - + "- for text-to-video generation, all you need to provide is `prompt`" - ) - - -# presets for wan2.1 and wan2.2 -# YiYi Notes: should we move these to doc? -# wan2.1 -TEXT2VIDEO_BLOCKS = InsertableDict( - [ - ("text_encoder", WanTextEncoderStep), - ("input", WanTextInputStep), - ("set_timesteps", WanSetTimestepsStep), - ("prepare_latents", WanPrepareLatentsStep), - ("denoise", WanDenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - -IMAGE2VIDEO_BLOCKS = InsertableDict( - [ - ("image_resize", WanImageResizeStep), - ("image_encoder", WanImage2VideoImageEncoderStep), - ("vae_encoder", WanImage2VideoVaeImageEncoderStep), - ("input", WanTextInputStep), - ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])), - ("set_timesteps", WanSetTimestepsStep), - ("prepare_latents", WanPrepareLatentsStep), - ("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep), - ("denoise", WanImage2VideoDenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - - -FLF2V_BLOCKS = InsertableDict( - [ - ("image_resize", WanImageResizeStep), - ("last_image_resize", WanImageCropResizeStep), - ("image_encoder", WanFLF2VImageEncoderStep), - ("vae_encoder", WanFLF2VVaeImageEncoderStep), - ("input", WanTextInputStep), - ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])), - ("set_timesteps", WanSetTimestepsStep), - ("prepare_latents", WanPrepareLatentsStep), - ("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep), - ("denoise", WanFLF2VDenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - -AUTO_BLOCKS = InsertableDict( - [ - ("text_encoder", WanTextEncoderStep), - ("image_encoder", WanAutoImageEncoderStep), - ("vae_encoder", WanAutoVaeImageEncoderStep), - ("denoise", WanAutoDenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - -# wan2.2 presets - -TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict( - [ - ("text_encoder", WanTextEncoderStep), - ("input", WanTextInputStep), - ("set_timesteps", WanSetTimestepsStep), - ("prepare_latents", WanPrepareLatentsStep), - ("denoise", Wan22DenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - -IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict( - [ - ("image_resize", WanImageResizeStep), - ("vae_encoder", WanImage2VideoVaeImageEncoderStep), - ("input", WanTextInputStep), - ("set_timesteps", WanSetTimestepsStep), - ("prepare_latents", WanPrepareLatentsStep), - ("denoise", Wan22DenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - -AUTO_BLOCKS_WAN22 = InsertableDict( - [ - ("text_encoder", WanTextEncoderStep), - ("vae_encoder", WanAutoVaeImageEncoderStep), - ("denoise", Wan22AutoDenoiseStep), - ("decode", WanImageVaeDecoderStep), - ] -) - -# presets all blocks (wan and wan22) - - -ALL_BLOCKS = { - "wan2.1": { - "text2video": TEXT2VIDEO_BLOCKS, - "image2video": IMAGE2VIDEO_BLOCKS, - "flf2v": FLF2V_BLOCKS, - "auto": AUTO_BLOCKS, - }, - "wan2.2": { - "text2video": TEXT2VIDEO_BLOCKS_WAN22, - "image2video": IMAGE2VIDEO_BLOCKS_WAN22, - "auto": AUTO_BLOCKS_WAN22, - }, -} diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks_wan.py b/src/diffusers/modular_pipelines/wan/modular_blocks_wan.py new file mode 100644 index 000000000000..d01a86ca09b5 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/modular_blocks_wan.py @@ -0,0 +1,83 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from .before_denoise import ( + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + WanDenoiseStep, +) +from .encoders import ( + WanTextEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. DENOISE +# ==================== + + +# inputs(text) -> set_timesteps -> prepare_latents -> denoise +class WanCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [ + WanTextInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanDenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return ( + "denoise block that takes encoded conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanDenoiseStep` is used to denoise the latents\n" + ) + + +# ==================== +# 2. BLOCKS (Wan2.1 text2video) +# ==================== + + +class WanBlocks(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [ + WanTextEncoderStep, + WanCoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = ["text_encoder", "denoise", "decode"] + + @property + def description(self): + return ( + "Modular pipeline blocks for Wan2.1.\n" + + "- `WanTextEncoderStep` is used to encode the text\n" + + "- `WanCoreDenoiseStep` is used to denoise the latents\n" + + "- `WanVaeDecoderStep` is used to decode the latents to images" + ) diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks_wan22.py b/src/diffusers/modular_pipelines/wan/modular_blocks_wan22.py new file mode 100644 index 000000000000..21164422f3d9 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/modular_blocks_wan22.py @@ -0,0 +1,88 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from .before_denoise import ( + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + Wan22DenoiseStep, +) +from .encoders import ( + WanTextEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. DENOISE +# ==================== + +# inputs(text) -> set_timesteps -> prepare_latents -> denoise + + +class Wan22CoreDenoiseStep(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [ + WanTextInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + Wan22DenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return ( + "denoise block that takes encoded conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n" + ) + + +# ==================== +# 2. BLOCKS (Wan2.2 text2video) +# ==================== + + +class Wan22Blocks(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [ + WanTextEncoderStep, + Wan22CoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "denoise", + "decode", + ] + + @property + def description(self): + return ( + "Modular pipeline for text-to-video using Wan2.2.\n" + + " - `WanTextEncoderStep` encodes the text\n" + + " - `Wan22CoreDenoiseStep` denoes the latents\n" + + " - `WanVaeDecoderStep` decodes the latents to video frames\n" + ) diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py b/src/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py new file mode 100644 index 000000000000..3db1c8fa837b --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py @@ -0,0 +1,117 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from .before_denoise import ( + WanAdditionalInputsStep, + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + Wan22Image2VideoDenoiseStep, +) +from .encoders import ( + WanImageResizeStep, + WanPrepareFirstFrameLatentsStep, + WanTextEncoderStep, + WanVaeEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. VAE ENCODER +# ==================== + + +class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep] + block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"] + + @property + def description(self): + return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" + + +# ==================== +# 2. DENOISE +# ==================== + + +# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents) +class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [ + WanTextInputStep, + WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + Wan22Image2VideoDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "denoise", + ] + + @property + def description(self): + return ( + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n" + ) + + +# ==================== +# 3. BLOCKS (Wan2.2 Image2Video) +# ==================== + + +class Wan22Image2VideoBlocks(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [ + WanTextEncoderStep, + WanImage2VideoVaeEncoderStep, + Wan22Image2VideoCoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "vae_encoder", + "denoise", + "decode", + ] + + @property + def description(self): + return ( + "Modular pipeline for image-to-video using Wan2.2.\n" + + " - `WanTextEncoderStep` encodes the text\n" + + " - `WanImage2VideoVaeEncoderStep` encodes the image\n" + + " - `Wan22Image2VideoCoreDenoiseStep` denoes the latents\n" + + " - `WanVaeDecoderStep` decodes the latents to video frames\n" + ) diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py b/src/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py new file mode 100644 index 000000000000..d07ab8ecf473 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py @@ -0,0 +1,203 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from .before_denoise import ( + WanAdditionalInputsStep, + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + WanImage2VideoDenoiseStep, +) +from .encoders import ( + WanFirstLastFrameImageEncoderStep, + WanFirstLastFrameVaeEncoderStep, + WanImageCropResizeStep, + WanImageEncoderStep, + WanImageResizeStep, + WanPrepareFirstFrameLatentsStep, + WanPrepareFirstLastFrameLatentsStep, + WanTextEncoderStep, + WanVaeEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# ==================== +# 1. IMAGE ENCODER +# ==================== + + +# wan2.1 I2V (first frame only) +class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanImageEncoderStep] + block_names = ["image_resize", "image_encoder"] + + @property + def description(self): + return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings" + + +# wan2.1 FLF2V (first and last frame) +class WanFLF2VImageEncoderStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep] + block_names = ["image_resize", "last_image_resize", "image_encoder"] + + @property + def description(self): + return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings" + + +# wan2.1 Auto Image Encoder +class WanAutoImageEncoderStep(AutoPipelineBlocks): + block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep] + block_names = ["flf2v_image_encoder", "image2video_image_encoder"] + block_trigger_inputs = ["last_image", "image"] + model_name = "wan-i2v" + + @property + def description(self): + return ( + "Image Encoder step that encode the image to generate the image embeddings" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided." + + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided." + + " - if `last_image` or `image` is not provided, step will be skipped." + ) + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# wan2.1 I2V (first frame only) +class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep] + block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"] + + @property + def description(self): + return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" + + +# wan2.1 FLF2V (first and last frame) +class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [ + WanImageResizeStep, + WanImageCropResizeStep, + WanFirstLastFrameVaeEncoderStep, + WanPrepareFirstLastFrameLatentsStep, + ] + block_names = ["image_resize", "last_image_resize", "vae_encoder", "prepare_first_last_frame_latents"] + + @property + def description(self): + return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions" + + +# wan2.1 Auto Vae Encoder +class WanAutoVaeEncoderStep(AutoPipelineBlocks): + model_name = "wan-i2v" + block_classes = [WanFLF2VVaeEncoderStep, WanImage2VideoVaeEncoderStep] + block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"] + block_trigger_inputs = ["last_image", "image"] + + @property + def description(self): + return ( + "Vae Image Encoder step that encode the image to generate the image latents" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided." + + " - `WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided." + + " - if `last_image` or `image` is not provided, step will be skipped." + ) + + +# ==================== +# 3. DENOISE (inputs -> set_timesteps -> prepare_latents -> denoise) +# ==================== + + +# wan2.1 I2V core denoise (support both I2V and FLF2V) +# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents) +class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [ + WanTextInputStep, + WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanImage2VideoDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "denoise", + ] + + @property + def description(self): + return ( + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" + ) + + +# ==================== +# 4. BLOCKS (Wan2.1 Image2Video) +# ==================== + + +# wan2.1 Image2Video Auto Blocks +class WanImage2VideoAutoBlocks(SequentialPipelineBlocks): + model_name = "wan-i2v" + block_classes = [ + WanTextEncoderStep, + WanAutoImageEncoderStep, + WanAutoVaeEncoderStep, + WanImage2VideoCoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "image_encoder", + "vae_encoder", + "denoise", + "decode", + ] + + @property + def description(self): + return ( + "Auto Modular pipeline for image-to-video using Wan.\n" + + "- for I2V workflow, all you need to provide is `image`" + + "- for FLF2V workflow, all you need to provide is `last_image` and `image`" + ) diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py index 930b25e4b905..0e52026a51bf 100644 --- a/src/diffusers/modular_pipelines/wan/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py @@ -13,8 +13,6 @@ # limitations under the License. -from typing import Any, Dict, Optional - from ...loaders import WanLoraLoaderMixin from ...pipelines.pipeline_utils import StableDiffusionMixin from ...utils import logging @@ -30,19 +28,12 @@ class WanModularPipeline( WanLoraLoaderMixin, ): """ - A ModularPipeline for Wan. + A ModularPipeline for Wan2.1 text2video. > [!WARNING] > This is an experimental feature and is likely to change in the future. """ - default_blocks_name = "WanAutoBlocks" - - # override the default_blocks_name in base class, which is just return self.default_blocks_name - def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: - if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None: - return "Wan22AutoBlocks" - else: - return "WanAutoBlocks" + default_blocks_name = "WanBlocks" @property def default_height(self): @@ -118,3 +109,33 @@ def num_train_timesteps(self): if hasattr(self, "scheduler") and self.scheduler is not None: num_train_timesteps = self.scheduler.config.num_train_timesteps return num_train_timesteps + + +class WanImage2VideoModularPipeline(WanModularPipeline): + """ + A ModularPipeline for Wan2.1 image2video (both I2V and FLF2V). + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "WanImage2VideoAutoBlocks" + + +class Wan22ModularPipeline(WanModularPipeline): + """ + A ModularPipeline for Wan2.2 text2video. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Wan22Blocks" + + +class Wan22Image2VideoModularPipeline(Wan22ModularPipeline): + """ + A ModularPipeline for Wan2.2 image2video. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Wan22Image2VideoBlocks" diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 5ee44190e23b..963ce19c324d 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -246,7 +246,7 @@ AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict( [ - ("wan", WanImageToVideoPipeline), + ("wan-i2v", WanImageToVideoPipeline), ] ) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a23f852616c0..9251c4c33b4d 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -47,6 +47,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Flux2KleinBaseModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Flux2KleinModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -287,7 +302,82 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class Wan22AutoBlocks(metaclass=DummyObject): +class Wan22Blocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Wan22Image2VideoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Wan22Image2VideoModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Wan22ModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class WanBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class WanImage2VideoAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -302,7 +392,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class WanAutoBlocks(metaclass=DummyObject): +class WanImage2VideoModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs):