-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Cosmos Transfer2.5 inference pipeline: general/{seg, depth, blur, edge} #13066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Cosmos Transfer2.5 inference pipeline: general/{seg, depth, blur, edge} #13066
Conversation
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! The overall structure looks good. I left some minor comments.
One question before I can review further: Are the base transformer weights the same across the different control variants?
This helps us understand whether splitting the controlnet from the transformer makes sense (i.e., can users mix and match?), and also helps me understand whether the controlnet is required for this pipeline etc
| --save_pipeline | ||
|
|
||
| # seg | ||
| transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh does each variant come with its own base transformer?
in diffusers we typically split controlnet from the base model is so that user can mix and match, it this something possible with cosmos?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each variant should have the same weights as the base transformer, I will double check this, but I split out the controlnet and save the pipeline (saves base transformer + controlnet), such that the pipeline can be loaded directly from a model_id/revision.
I will look into only loading the controlnet from the converted script.
| raise AttributeError("Could not access latents of provided encoder_output") | ||
|
|
||
|
|
||
| def transfer2_5_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we inline this inside the __call__ method? We typically only create separate methods for operations users might need to run standalone to pre-compute things (like encode_prompt, encode_video, etc.). It's also easier to read when you don't have to jump around the file.
| transformer: CosmosTransformer3DModel, | ||
| vae: AutoencoderKLWan, | ||
| scheduler: UniPCMultistepScheduler, | ||
| controlnet: CosmosControlNetModel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is controlnet optional here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I will change the typehint
What does this PR do?
This PR introduces Cosmos Transfer2.5 inference pipeline, which extends the existing code in transformer_cosmos.py and introduces a new controlnet class for cosmos. The conversion script is updated to convert the checkpoints too.
I've intentionally split the controlnet from the base predict model to match the rest of the diffusers codebase. To do this, I have had to duplicate some layers/weights from the base model (relating to the patch & timestep embeddings), but I believe SD3 does this.
Similar to predict2.5, I have added documentation and unit tests.
Additional PRs will be submitted for the following features (in order of priority):
In addition, unfortunately, the guardrails safety model is too aggressive: it currently flags "not safe" for the examples we have on cosmos-transfer2.5 (e.g. edge example for 93 frames is flagged). This guardrail model needs to be updated, but this work is ~orthogonal of this PR.
Who can review?
Core library: