Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
dd241dc
initial conversion script
miguelmartin75 Jan 9, 2026
7e475bd
cosmos control net block
miguelmartin75 Jan 10, 2026
1b934ff
CosmosAttention
miguelmartin75 Jan 15, 2026
b40da24
base model conversion
miguelmartin75 Jan 15, 2026
cfedde1
wip
miguelmartin75 Jan 16, 2026
8222e9f
pipeline updates
miguelmartin75 Jan 17, 2026
9fefe1f
convert controlnet
miguelmartin75 Jan 20, 2026
2b67a31
pipeline: working without controls
miguelmartin75 Jan 20, 2026
5f2bab8
wip
miguelmartin75 Jan 20, 2026
97f10d8
debugging
miguelmartin75 Jan 21, 2026
cc6cf13
Almost working
miguelmartin75 Jan 22, 2026
4ba9945
temp
miguelmartin75 Jan 22, 2026
35e0653
control working
miguelmartin75 Jan 22, 2026
9da2e88
cleanup + detail on neg_encoder_hidden_states
miguelmartin75 Jan 23, 2026
b3852ac
convert edge
miguelmartin75 Jan 23, 2026
a16e81a
pos emb for control latents
miguelmartin75 Jan 24, 2026
cd65899
convert all chkpts
miguelmartin75 Jan 26, 2026
dfe99b8
resolve TODOs
miguelmartin75 Jan 27, 2026
aadf51a
remove prints
miguelmartin75 Jan 27, 2026
26b7ee5
Docs
miguelmartin75 Jan 27, 2026
d7f122d
add siglip image reference encoder
miguelmartin75 Jan 27, 2026
50f7e53
Add unit tests
miguelmartin75 Jan 30, 2026
c5c2456
controlnet: add duplicate layers
miguelmartin75 Jan 31, 2026
9a55923
Additional tests
miguelmartin75 Jan 31, 2026
2e2fea1
skip less
miguelmartin75 Jan 31, 2026
bf1f99d
skip less
miguelmartin75 Jan 31, 2026
910103f
remove image_ref
miguelmartin75 Feb 2, 2026
751fba4
minor
miguelmartin75 Feb 2, 2026
251b5c1
docs
miguelmartin75 Feb 2, 2026
44db782
remove skipped test in transfer
miguelmartin75 Feb 2, 2026
c1cfa9d
Don't crash process
miguelmartin75 Feb 2, 2026
9b8338c
formatting
miguelmartin75 Feb 2, 2026
b9dd0cb
revert some changes
miguelmartin75 Feb 2, 2026
d09cf24
remove skipped test
miguelmartin75 Feb 2, 2026
2cd7f23
make style
miguelmartin75 Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 283 additions & 15 deletions scripts/convert_cosmos_to_diffusers.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@
"ControlNetModel",
"ControlNetUnionModel",
"ControlNetXSAdapter",
"CosmosControlNetModel",
"CosmosTransformer3DModel",
"DiTTransformer2DModel",
"EasyAnimateTransformer3DModel",
Expand Down Expand Up @@ -477,6 +478,7 @@
"CogView4Pipeline",
"ConsisIDPipeline",
"Cosmos2_5_PredictBasePipeline",
"Cosmos2_5_TransferPipeline",
"Cosmos2TextToImagePipeline",
"Cosmos2VideoToWorldPipeline",
"CosmosTextToWorldPipeline",
Expand Down Expand Up @@ -981,6 +983,7 @@
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
CosmosControlNetModel,
CosmosTransformer3DModel,
DiTTransformer2DModel,
EasyAnimateTransformer3DModel,
Expand Down Expand Up @@ -1209,6 +1212,7 @@
CogView4Pipeline,
ConsisIDPipeline,
Cosmos2_5_PredictBasePipeline,
Cosmos2_5_TransferPipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["cache_utils"] = ["CacheMixin"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_cosmos"] = ["CosmosControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
_import_structure["controlnets.controlnet_hunyuan"] = [
"HunyuanDiT2DControlNetModel",
Expand Down Expand Up @@ -175,6 +176,7 @@
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
CosmosControlNetModel,
FluxControlNetModel,
FluxMultiControlNetModel,
HunyuanDiT2DControlNetModel,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/controlnets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

if is_torch_available():
from .controlnet import ControlNetModel, ControlNetOutput
from .controlnet_cosmos import CosmosControlNetModel
from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
from .controlnet_hunyuan import (
HunyuanControlNetOutput,
Expand Down
312 changes: 312 additions & 0 deletions src/diffusers/models/controlnets/controlnet_cosmos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput, is_torchvision_available, logging
from ..modeling_utils import ModelMixin
from ..transformers.transformer_cosmos import (
CosmosEmbedding,
CosmosLearnablePositionalEmbed,
CosmosPatchEmbed,
CosmosRotaryPosEmbed,
CosmosTransformerBlock,
)


if is_torchvision_available():
from torchvision import transforms

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


@dataclass
class CosmosControlNetOutput(BaseOutput):
"""
Output of [`CosmosControlNetModel`].

Args:
control_block_samples (`list[torch.Tensor]`):
List of control block activations to be injected into transformer blocks.
"""

control_block_samples: List[torch.Tensor]


class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
ControlNet for Cosmos Transfer2.5.

This model duplicates the shared embedding modules from the transformer (patch_embed, time_embed,
learnable_pos_embed, img_context_proj) to enable proper CPU offloading. The forward() method computes everything
internally from raw inputs.
"""

_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embed", "patch_embed_base", "time_embed"]
_no_split_modules = ["CosmosTransformerBlock"]
_keep_in_fp32_modules = ["learnable_pos_embed"]

@register_to_config
def __init__(
self,
n_controlnet_blocks: int = 4,
in_channels: int = 130,
latent_channels: int = 18, # base latent channels (latents + condition_mask) + padding_mask
model_channels: int = 2048,
num_attention_heads: int = 32,
attention_head_dim: int = 128,
mlp_ratio: float = 4.0,
text_embed_dim: int = 1024,
adaln_lora_dim: int = 256,
patch_size: Tuple[int, int, int] = (1, 2, 2),
max_size: Tuple[int, int, int] = (128, 240, 240),
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
extra_pos_embed_type: Optional[str] = None,
img_context_dim_in: Optional[int] = None,
img_context_dim_out: int = 2048,
use_crossattn_projection: bool = False,
crossattn_proj_in_channels: int = 1024,
encoder_hidden_states_channels: int = 1024,
):
super().__init__()

self.patch_embed = CosmosPatchEmbed(in_channels, model_channels, patch_size, bias=False)

self.patch_embed_base = CosmosPatchEmbed(latent_channels, model_channels, patch_size, bias=False)
self.time_embed = CosmosEmbedding(model_channels, model_channels)

self.learnable_pos_embed = None
if extra_pos_embed_type == "learnable":
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
hidden_size=model_channels,
max_size=max_size,
patch_size=patch_size,
)

self.img_context_proj = None
if img_context_dim_in is not None and img_context_dim_in > 0:
self.img_context_proj = nn.Sequential(
nn.Linear(img_context_dim_in, img_context_dim_out, bias=True),
nn.GELU(),
)

# Cross-attention projection for text embeddings (same as transformer)
self.crossattn_proj = None
if use_crossattn_projection:
self.crossattn_proj = nn.Sequential(
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
nn.GELU(),
)

# RoPE for both control and base latents
self.rope = CosmosRotaryPosEmbed(
hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
)

self.control_blocks = nn.ModuleList(
[
CosmosTransformerBlock(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=text_embed_dim,
mlp_ratio=mlp_ratio,
adaln_lora_dim=adaln_lora_dim,
qk_norm="rms_norm",
out_bias=False,
img_context=img_context_dim_in is not None and img_context_dim_in > 0,
before_proj=(block_idx == 0),
after_proj=True,
)
for block_idx in range(n_controlnet_blocks)
]
)

self.gradient_checkpointing = False

def _expand_conditioning_scale(self, conditioning_scale: Union[float, List[float]]) -> List[float]:
if isinstance(conditioning_scale, list):
scales = conditioning_scale
else:
scales = [conditioning_scale] * len(self.control_blocks)

if len(scales) < len(self.control_blocks):
logger.warning(
"Received %d control scales, but control network defines %d blocks. "
"Scales will be trimmed or repeated to match.",
len(scales),
len(self.control_blocks),
)
scales = (scales * len(self.control_blocks))[: len(self.control_blocks)]
return scales

def forward(
self,
controls_latents: torch.Tensor,
latents: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: Union[Optional[torch.Tensor], Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
condition_mask: torch.Tensor,
conditioning_scale: Union[float, List[float]] = 1.0,
padding_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
fps: Optional[int] = None,
return_dict: bool = True,
) -> Union[CosmosControlNetOutput, Tuple[List[torch.Tensor]]]:
"""
Forward pass for the ControlNet.

Args:
controls_latents: Control signal latents [B, C, T, H, W]
latents: Base latents from the noising process [B, C, T, H, W]
timestep: Diffusion timestep tensor
encoder_hidden_states: Tuple of (text_context, img_context) or text_context
condition_mask: Conditioning mask [B, 1, T, H, W]
conditioning_scale: Scale factor(s) for control outputs
padding_mask: Padding mask [B, 1, H, W] or None
attention_mask: Optional attention mask or None
fps: Frames per second for RoPE or None
return_dict: Whether to return a CosmosControlNetOutput or a tuple

Returns:
CosmosControlNetOutput or tuple of control tensors
"""
B, C, T, H, W = controls_latents.shape

# 1. Prepare control latents
control_hidden_states = controls_latents
vace_in_channels = self.config.in_channels - 1
if control_hidden_states.shape[1] < vace_in_channels - 1:
pad_C = vace_in_channels - 1 - control_hidden_states.shape[1]
control_hidden_states = torch.cat(
[
control_hidden_states,
torch.zeros(
(B, pad_C, T, H, W), dtype=control_hidden_states.dtype, device=control_hidden_states.device
),
],
dim=1,
)

control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)

padding_mask_resized = transforms.functional.resize(
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
control_hidden_states = torch.cat(
[control_hidden_states, padding_mask_resized.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
)

# 2. Prepare base latents (same processing as transformer.forward)
base_hidden_states = latents
if condition_mask is not None:
base_hidden_states = torch.cat([base_hidden_states, condition_mask], dim=1)

base_padding_mask = transforms.functional.resize(
padding_mask, list(base_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
base_hidden_states = torch.cat(
[base_hidden_states, base_padding_mask.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
)

# 3. Generate positional embeddings (shared for both)
image_rotary_emb = self.rope(control_hidden_states, fps=fps)
extra_pos_emb = self.learnable_pos_embed(control_hidden_states) if self.learnable_pos_embed else None

# 4. Patchify control latents
control_hidden_states = self.patch_embed(control_hidden_states)
control_hidden_states = control_hidden_states.flatten(1, 3)

# 5. Patchify base latents
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = T // p_t
post_patch_height = H // p_h
post_patch_width = W // p_w

base_hidden_states = self.patch_embed_base(base_hidden_states)
base_hidden_states = base_hidden_states.flatten(1, 3)

# 6. Time embeddings
if timestep.ndim == 1:
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep)
elif timestep.ndim == 5:
batch_size, _, num_frames, _, _ = latents.shape
assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
)
timestep_flat = timestep.flatten()
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep_flat)
temb, embedded_timestep = (
x.view(batch_size, post_patch_num_frames, 1, 1, -1)
.expand(-1, -1, post_patch_height, post_patch_width, -1)
.flatten(1, 3)
for x in (temb, embedded_timestep)
)
else:
raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}")

# 7. Process encoder hidden states
if isinstance(encoder_hidden_states, tuple):
text_context, img_context = encoder_hidden_states
else:
text_context = encoder_hidden_states
img_context = None

# Apply cross-attention projection to text context
if self.crossattn_proj is not None:
text_context = self.crossattn_proj(text_context)

# Apply cross-attention projection to image context (if provided)
if img_context is not None and self.img_context_proj is not None:
img_context = self.img_context_proj(img_context)

# Combine text and image context into a single tuple
if self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0:
processed_encoder_hidden_states = (text_context, img_context)
else:
processed_encoder_hidden_states = text_context

# 8. Prepare attention mask
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]

# 9. Run control blocks
scales = self._expand_conditioning_scale(conditioning_scale)
result = []
for block_idx, (block, scale) in enumerate(zip(self.control_blocks, scales)):
if torch.is_grad_enabled() and self.gradient_checkpointing:
control_hidden_states, control_proj = self._gradient_checkpointing_func(
block,
control_hidden_states,
processed_encoder_hidden_states,
embedded_timestep,
temb,
image_rotary_emb,
extra_pos_emb,
attention_mask,
None, # controlnet_residual
base_hidden_states,
block_idx,
)
else:
control_hidden_states, control_proj = block(
hidden_states=control_hidden_states,
encoder_hidden_states=processed_encoder_hidden_states,
embedded_timestep=embedded_timestep,
temb=temb,
image_rotary_emb=image_rotary_emb,
extra_pos_emb=extra_pos_emb,
attention_mask=attention_mask,
controlnet_residual=None,
latents=base_hidden_states,
block_idx=block_idx,
)
result.append(control_proj * scale)

if not return_dict:
return (result,)

return CosmosControlNetOutput(control_block_samples=result)
Loading