Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,9 @@ def multistep_dpm_solver_second_order_update(

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
) -> int:
"""
Find the index for a given timestep in the schedule.
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,9 @@ def ind_fn(t, b, c, d):

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
) -> int:
"""
Find the index for a given timestep in the schedule.
Expand Down
68 changes: 53 additions & 15 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,21 +245,42 @@ def __init__(
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
deprecate(
"algorithm_types dpmsolver and sde-dpmsolver",
"1.0.0",
deprecation_message,
)

if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
Expand Down Expand Up @@ -287,7 +308,12 @@ def __init__(
self.init_noise_sigma = 1.0

# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type not in [
"dpmsolver",
"dpmsolver++",
"sde-dpmsolver",
"sde-dpmsolver++",
]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
Expand Down Expand Up @@ -724,7 +750,7 @@ def convert_model_output(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
sample: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand All @@ -738,7 +764,7 @@ def convert_model_output(
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
sample (`torch.Tensor`, *optional*):
A current instance of a sample created by the diffusion process.

Returns:
Expand Down Expand Up @@ -822,7 +848,7 @@ def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
sample: Optional[torch.Tensor] = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
Expand All @@ -832,8 +858,10 @@ def dpm_solver_first_order_update(
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
sample (`torch.Tensor`, *optional*):
A current instance of a sample created by the diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor.

Returns:
`torch.Tensor`:
Expand All @@ -860,7 +888,10 @@ def dpm_solver_first_order_update(
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)

sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
sigma_t, sigma_s = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
Expand Down Expand Up @@ -891,7 +922,7 @@ def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
sample: Optional[torch.Tensor] = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
Expand All @@ -901,7 +932,7 @@ def multistep_dpm_solver_second_order_update(
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`):
sample (`torch.Tensor`, *optional*):
A current instance of a sample created by the diffusion process.

Returns:
Expand Down Expand Up @@ -1014,7 +1045,7 @@ def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
sample: Optional[torch.Tensor] = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
Expand All @@ -1024,8 +1055,10 @@ def multistep_dpm_solver_third_order_update(
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`):
sample (`torch.Tensor`, *optional*):
A current instance of a sample created by diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor.

Returns:
`torch.Tensor`:
Expand Down Expand Up @@ -1106,7 +1139,9 @@ def multistep_dpm_solver_third_order_update(
return x_t

def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
) -> int:
"""
Find the index for a given timestep in the schedule.
Expand Down Expand Up @@ -1216,7 +1251,10 @@ def step(
sample = sample.to(torch.float32)
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
model_output.shape,
generator=generator,
device=model_output.device,
dtype=torch.float32,
)
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
Expand Down
Loading
Loading