Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2KleinBaseModularPipeline",
"Flux2KleinModularPipeline",
"Flux2ModularPipeline",
"FluxAutoBlocks",
Expand All @@ -431,8 +432,13 @@
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"Wan22AutoBlocks",
"WanAutoBlocks",
"Wan22Blocks",
"Wan22Image2VideoBlocks",
"Wan22Image2VideoModularPipeline",
"Wan22ModularPipeline",
"WanBlocks",
"WanImage2VideoAutoBlocks",
"WanImage2VideoModularPipeline",
"WanModularPipeline",
"ZImageAutoBlocks",
"ZImageModularPipeline",
Expand Down Expand Up @@ -1151,6 +1157,7 @@
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinBaseModularPipeline,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
FluxAutoBlocks,
Expand All @@ -1167,8 +1174,13 @@
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
Wan22AutoBlocks,
WanAutoBlocks,
Wan22Blocks,
Wan22Image2VideoBlocks,
Wan22Image2VideoModularPipeline,
Wan22ModularPipeline,
WanBlocks,
WanImage2VideoAutoBlocks,
WanImage2VideoModularPipeline,
WanModularPipeline,
ZImageAutoBlocks,
ZImageModularPipeline,
Expand Down
24 changes: 22 additions & 2 deletions src/diffusers/modular_pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,16 @@
"InsertableDict",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
_import_structure["wan"] = [
"WanBlocks",
"Wan22Blocks",
"WanImage2VideoAutoBlocks",
"Wan22Image2VideoBlocks",
"WanModularPipeline",
"Wan22ModularPipeline",
"WanImage2VideoModularPipeline",
"Wan22Image2VideoModularPipeline",
]
_import_structure["flux"] = [
"FluxAutoBlocks",
"FluxModularPipeline",
Expand All @@ -58,6 +67,7 @@
"Flux2KleinBaseAutoBlocks",
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
"Flux2KleinBaseModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
Expand Down Expand Up @@ -88,6 +98,7 @@
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinBaseModularPipeline,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
)
Expand All @@ -112,7 +123,16 @@
QwenImageModularPipeline,
)
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
from .wan import (
Wan22Blocks,
Wan22Image2VideoBlocks,
Wan22Image2VideoModularPipeline,
Wan22ModularPipeline,
WanBlocks,
WanImage2VideoAutoBlocks,
WanImage2VideoModularPipeline,
WanModularPipeline,
)
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
else:
import sys
Expand Down
8 changes: 6 additions & 2 deletions src/diffusers/modular_pipelines/flux2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
_import_structure["modular_pipeline"] = [
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
"Flux2KleinBaseModularPipeline",
]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand Down Expand Up @@ -101,7 +105,7 @@
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
)
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline
else:
import sys

Expand Down
49 changes: 18 additions & 31 deletions src/diffusers/modular_pipelines/flux2/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# limitations under the License.


from typing import Any, Dict, Optional

from ...loaders import Flux2LoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
Expand Down Expand Up @@ -59,46 +57,35 @@ def num_channels_latents(self):
return num_channels_latents


class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
class Flux2KleinModularPipeline(Flux2ModularPipeline):
"""
A ModularPipeline for Flux2-Klein.
A ModularPipeline for Flux2-Klein (distilled model).

> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""

default_blocks_name = "Flux2KleinBaseAutoBlocks"

def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinAutoBlocks"
else:
return "Flux2KleinBaseAutoBlocks"
default_blocks_name = "Flux2KleinAutoBlocks"

@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
def requires_unconditional_embeds(self):
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
return False

@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1

@property
def default_sample_size(self):
return 128
return requires_unconditional_embeds

@property
def vae_scale_factor(self):
vae_scale_factor = 8
if getattr(self, "vae", None) is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor

@property
def num_channels_latents(self):
num_channels_latents = 32
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
class Flux2KleinBaseModularPipeline(Flux2ModularPipeline):
"""
A ModularPipeline for Flux2-Klein (base model).

> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""

default_blocks_name = "Flux2KleinBaseAutoBlocks"

@property
def requires_unconditional_embeds(self):
Expand Down
75 changes: 58 additions & 17 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,61 @@


# map regular pipeline to modular pipeline class name


def _create_default_map_fn(pipeline_class_name: str):
"""Create a mapping function that always returns the same pipeline class."""

def _map_fn(config_dict=None):
return pipeline_class_name

return _map_fn


def _flux2_klein_map_fn(config_dict=None):
if config_dict is None:
return "Flux2KleinModularPipeline"

if "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinModularPipeline"
else:
return "Flux2KleinBaseModularPipeline"


def _wan_map_fn(config_dict=None):
if config_dict is None:
return "WanModularPipeline"

if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22ModularPipeline"
else:
return "WanModularPipeline"


def _wan_i2v_map_fn(config_dict=None):
if config_dict is None:
return "WanImage2VideoModularPipeline"

if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22Image2VideoModularPipeline"
else:
return "WanImage2VideoModularPipeline"


MODULAR_PIPELINE_MAPPING = OrderedDict(
[
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
("flux-kontext", "FluxKontextModularPipeline"),
("flux2", "Flux2ModularPipeline"),
("flux2-klein", "Flux2KleinModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
("qwenimage-layered", "QwenImageLayeredModularPipeline"),
("z-image", "ZImageModularPipeline"),
("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")),
("wan", _wan_map_fn),
("wan-i2v", _wan_i2v_map_fn),
("flux", _create_default_map_fn("FluxModularPipeline")),
("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")),
("flux2", _create_default_map_fn("Flux2ModularPipeline")),
("flux2-klein", _flux2_klein_map_fn),
("qwenimage", _create_default_map_fn("QwenImageModularPipeline")),
("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")),
("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")),
("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")),
("z-image", _create_default_map_fn("ZImageModularPipeline")),
]
)

Expand Down Expand Up @@ -366,7 +408,8 @@ def init_pipeline(
"""
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
"""
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
map_fn = MODULAR_PIPELINE_MAPPING.get(self.model_name, _create_default_map_fn("ModularPipeline"))
pipeline_class_name = map_fn()
diffusers_module = importlib.import_module("diffusers")
pipeline_class = getattr(diffusers_module, pipeline_class_name)

Expand Down Expand Up @@ -1545,7 +1588,7 @@ def __init__(
if modular_config_dict is not None:
blocks_class_name = modular_config_dict.get("_blocks_class_name")
else:
blocks_class_name = self.get_default_blocks_name(config_dict)
blocks_class_name = self.default_blocks_name
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name)
Expand Down Expand Up @@ -1617,9 +1660,6 @@ def default_call_parameters(self) -> Dict[str, Any]:
params[input_param.name] = input_param.default
return params

def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
return self.default_blocks_name

@classmethod
def _load_pipeline_config(
cls,
Expand Down Expand Up @@ -1715,7 +1755,8 @@ def from_pretrained(
logger.debug(" try to determine the modular pipeline class from model_index.json")
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
model_name = _get_model(standard_pipeline_class.__name__)
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline"))
pipeline_class_name = map_fn(config_dict)
diffusers_module = importlib.import_module("diffusers")
pipeline_class = getattr(diffusers_module, pipeline_class_name)
else:
Expand Down
36 changes: 18 additions & 18 deletions src/diffusers/modular_pipelines/wan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@

_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["decoders"] = ["WanImageVaeDecoderStep"]
_import_structure["encoders"] = ["WanTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"Wan22AutoBlocks",
"WanAutoBlocks",
"WanAutoImageEncoderStep",
"WanAutoVaeImageEncoderStep",
_import_structure["modular_blocks_wan"] = ["WanBlocks"]
_import_structure["modular_blocks_wan22"] = ["Wan22Blocks"]
_import_structure["modular_blocks_wan22_i2v"] = ["Wan22Image2VideoBlocks"]
_import_structure["modular_blocks_wan_i2v"] = ["WanImage2VideoAutoBlocks"]
_import_structure["modular_pipeline"] = [
"Wan22Image2VideoModularPipeline",
"Wan22ModularPipeline",
"WanImage2VideoModularPipeline",
"WanModularPipeline",
]
_import_structure["modular_pipeline"] = ["WanModularPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand All @@ -39,16 +39,16 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .decoders import WanImageVaeDecoderStep
from .encoders import WanTextEncoderStep
from .modular_blocks import (
ALL_BLOCKS,
Wan22AutoBlocks,
WanAutoBlocks,
WanAutoImageEncoderStep,
WanAutoVaeImageEncoderStep,
from .modular_blocks_wan import WanBlocks
from .modular_blocks_wan22 import Wan22Blocks
from .modular_blocks_wan22_i2v import Wan22Image2VideoBlocks
from .modular_blocks_wan_i2v import WanImage2VideoAutoBlocks
from .modular_pipeline import (
Wan22Image2VideoModularPipeline,
Wan22ModularPipeline,
WanImage2VideoModularPipeline,
WanModularPipeline,
)
from .modular_pipeline import WanModularPipeline
else:
import sys

Expand Down
Loading
Loading