Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import os | |
import PIL | |
from typing import List, Optional, Union | |
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput | |
from PIL import Image | |
from diffusers.utils import logging | |
VECTOR_DATA_FOLDER = "vector_data" | |
VECTOR_DATA_DICT = "vector_data" | |
logger = logging.get_logger(__name__) | |
def get_ddpm_inversion_scheduler( | |
scheduler, | |
step_function, | |
config, | |
timesteps, | |
save_timesteps, | |
latents, | |
x_ts, | |
x_ts_c_hat, | |
save_intermediate_results, | |
pipe, | |
x_0, | |
v1s_images, | |
v2s_images, | |
deltas_images, | |
v1_x0s, | |
v2_x0s, | |
deltas_x0s, | |
folder_name, | |
image_name, | |
time_measure_n, | |
): | |
def step( | |
model_output: torch.FloatTensor, | |
timestep: int, | |
sample: torch.FloatTensor, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
generator=None, | |
variance_noise: Optional[torch.FloatTensor] = None, | |
return_dict: bool = True, | |
): | |
# if scheduler.is_save: | |
# start = timer() | |
res_inv = step_save_latents( | |
scheduler, | |
model_output[:1, :, :, :], | |
timestep, | |
sample[:1, :, :, :], | |
eta, | |
use_clipped_model_output, | |
generator, | |
variance_noise, | |
return_dict, | |
) | |
# end = timer() | |
# print(f"Run Time Inv: {end - start}") | |
res_inf = step_use_latents( | |
scheduler, | |
model_output[1:, :, :, :], | |
timestep, | |
sample[1:, :, :, :], | |
eta, | |
use_clipped_model_output, | |
generator, | |
variance_noise, | |
return_dict, | |
) | |
# res = res_inv | |
res = (torch.cat((res_inv[0], res_inf[0]), dim=0),) | |
return res | |
# return res | |
scheduler.step_function = step_function | |
scheduler.is_save = True | |
scheduler._timesteps = timesteps | |
scheduler._save_timesteps = save_timesteps if save_timesteps else timesteps | |
scheduler._config = config | |
scheduler.latents = latents | |
scheduler.x_ts = x_ts | |
scheduler.x_ts_c_hat = x_ts_c_hat | |
scheduler.step = step | |
scheduler.save_intermediate_results = save_intermediate_results | |
scheduler.pipe = pipe | |
scheduler.v1s_images = v1s_images | |
scheduler.v2s_images = v2s_images | |
scheduler.deltas_images = deltas_images | |
scheduler.v1_x0s = v1_x0s | |
scheduler.v2_x0s = v2_x0s | |
scheduler.deltas_x0s = deltas_x0s | |
scheduler.clean_step_run = False | |
scheduler.x_0s = create_xts( | |
config.noise_shift_delta, | |
config.noise_timesteps, | |
config.clean_step_timestep, | |
None, | |
pipe.scheduler, | |
timesteps, | |
x_0, | |
no_add_noise=True, | |
) | |
scheduler.folder_name = folder_name | |
scheduler.image_name = image_name | |
scheduler.p_to_p = False | |
scheduler.p_to_p_replace = False | |
scheduler.time_measure_n = time_measure_n | |
return scheduler | |
def step_save_latents( | |
self, | |
model_output: torch.FloatTensor, | |
timestep: int, | |
sample: torch.FloatTensor, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
generator=None, | |
variance_noise: Optional[torch.FloatTensor] = None, | |
return_dict: bool = True, | |
): | |
# print(self._save_timesteps) | |
# timestep_index = map_timpstep_to_index[timestep] | |
# timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item() | |
timestep_index = self._save_timesteps.index(timestep) if not self.clean_step_run else -1 | |
next_timestep_index = timestep_index + 1 if not self.clean_step_run else -1 | |
u_hat_t = self.step_function( | |
model_output=model_output, | |
timestep=timestep, | |
sample=sample, | |
eta=eta, | |
use_clipped_model_output=use_clipped_model_output, | |
generator=generator, | |
variance_noise=variance_noise, | |
return_dict=False, | |
scheduler=self, | |
) | |
x_t_minus_1 = self.x_ts[next_timestep_index] | |
self.x_ts_c_hat.append(u_hat_t) | |
z_t = x_t_minus_1 - u_hat_t | |
self.latents.append(z_t) | |
z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs) | |
x_t_minus_1_predicted = u_hat_t + z_t | |
if not return_dict: | |
return (x_t_minus_1_predicted,) | |
return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None) | |
def step_use_latents( | |
self, | |
model_output: torch.FloatTensor, | |
timestep: int, | |
sample: torch.FloatTensor, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
generator=None, | |
variance_noise: Optional[torch.FloatTensor] = None, | |
return_dict: bool = True, | |
): | |
# timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item() | |
timestep_index = self._timesteps.index(timestep) if not self.clean_step_run else -1 | |
next_timestep_index = ( | |
timestep_index + 1 if not self.clean_step_run else -1 | |
) | |
z_t = self.latents[next_timestep_index] # + 1 because latents[0] is X_T | |
_, normalize_coefficient = normalize( | |
z_t[0] if self._config.breakdown == "x_t_hat_c_with_zeros" else z_t, | |
timestep_index, | |
self._config.max_norm_zs, | |
) | |
if normalize_coefficient == 0: | |
eta = 0 | |
# eta = normalize_coefficient | |
x_t_hat_c_hat = self.step_function( | |
model_output=model_output, | |
timestep=timestep, | |
sample=sample, | |
eta=eta, | |
use_clipped_model_output=use_clipped_model_output, | |
generator=generator, | |
variance_noise=variance_noise, | |
return_dict=False, | |
scheduler=self, | |
) | |
w1 = self._config.ws1[timestep_index] | |
w2 = self._config.ws2[timestep_index] | |
x_t_minus_1_exact = self.x_ts[next_timestep_index] | |
x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat) | |
x_t_c_hat: torch.Tensor = self.x_ts_c_hat[next_timestep_index] | |
if self._config.breakdown == "x_t_c_hat": | |
raise NotImplementedError("breakdown x_t_c_hat not implemented yet") | |
# x_t_c_hat = x_t_c_hat.expand_as(x_t_hat_c_hat) | |
x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat) | |
# if self._config.breakdown == "x_t_c_hat": | |
# v1 = x_t_hat_c_hat - x_t_c_hat | |
# v2 = x_t_c_hat - x_t_c | |
if ( | |
self._config.breakdown == "x_t_hat_c" | |
or self._config.breakdown == "x_t_hat_c_with_zeros" | |
): | |
zero_index_reconstruction = 1 if not self.time_measure_n else 0 | |
edit_prompts_num = ( | |
(model_output.size(0) - zero_index_reconstruction) // 3 | |
if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p | |
else (model_output.size(0) - zero_index_reconstruction) // 2 | |
) | |
x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction) | |
edit_images_indices = ( | |
edit_prompts_num + zero_index_reconstruction, | |
( | |
model_output.size(0) | |
if self._config.breakdown == "x_t_hat_c" | |
else zero_index_reconstruction + 2 * edit_prompts_num | |
), | |
) | |
x_t_hat_c = torch.zeros_like(x_t_hat_c_hat) | |
x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[ | |
x_t_hat_c_indices[0] : x_t_hat_c_indices[1] | |
] | |
v1 = x_t_hat_c_hat - x_t_hat_c | |
v2 = x_t_hat_c - normalize_coefficient * x_t_c | |
if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p: | |
path = os.path.join( | |
self.folder_name, | |
VECTOR_DATA_FOLDER, | |
self.image_name, | |
) | |
if not hasattr(self, VECTOR_DATA_DICT): | |
os.makedirs(path, exist_ok=True) | |
self.vector_data = dict() | |
x_t_0 = x_t_c_hat[1] | |
empty_prompt_indices = (1 + 2 * edit_prompts_num, 1 + 3 * edit_prompts_num) | |
x_t_hat_0 = x_t_hat_c_hat[empty_prompt_indices[0] : empty_prompt_indices[1]] | |
self.vector_data[timestep.item()] = dict() | |
self.vector_data[timestep.item()]["x_t_hat_c"] = x_t_hat_c[ | |
edit_images_indices[0] : edit_images_indices[1] | |
] | |
self.vector_data[timestep.item()]["x_t_hat_0"] = x_t_hat_0 | |
self.vector_data[timestep.item()]["x_t_c"] = x_t_c[0].expand_as(x_t_hat_0) | |
self.vector_data[timestep.item()]["x_t_0"] = x_t_0.expand_as(x_t_hat_0) | |
self.vector_data[timestep.item()]["x_t_hat_c_hat"] = x_t_hat_c_hat[ | |
edit_images_indices[0] : edit_images_indices[1] | |
] | |
self.vector_data[timestep.item()]["x_t_minus_1_noisy"] = x_t_minus_1_exact[ | |
0 | |
].expand_as(x_t_hat_0) | |
self.vector_data[timestep.item()]["x_t_minus_1_clean"] = self.x_0s[ | |
next_timestep_index | |
].expand_as(x_t_hat_0) | |
else: # no breakdown | |
v1 = x_t_hat_c_hat - normalize_coefficient * x_t_c | |
v2 = 0 | |
if self.save_intermediate_results and not self.p_to_p: | |
delta = v1 + v2 | |
v1_plus_x0 = self.x_0s[next_timestep_index] + v1 | |
v2_plus_x0 = self.x_0s[next_timestep_index] + v2 | |
delta_plus_x0 = self.x_0s[next_timestep_index] + delta | |
v1_images = decode_latents(v1, self.pipe) | |
self.v1s_images.append(v1_images) | |
v2_images = ( | |
decode_latents(v2, self.pipe) | |
if self._config.breakdown != "no_breakdown" | |
else [PIL.Image.new("RGB", (1, 1))] | |
) | |
self.v2s_images.append(v2_images) | |
delta_images = decode_latents(delta, self.pipe) | |
self.deltas_images.append(delta_images) | |
v1_plus_x0_images = decode_latents(v1_plus_x0, self.pipe) | |
self.v1_x0s.append(v1_plus_x0_images) | |
v2_plus_x0_images = ( | |
decode_latents(v2_plus_x0, self.pipe) | |
if self._config.breakdown != "no_breakdown" | |
else [PIL.Image.new("RGB", (1, 1))] | |
) | |
self.v2_x0s.append(v2_plus_x0_images) | |
delta_plus_x0_images = decode_latents(delta_plus_x0, self.pipe) | |
self.deltas_x0s.append(delta_plus_x0_images) | |
# print(f"v1 norm: {torch.norm(v1, dim=0).mean()}") | |
# if self._config.breakdown != "no_breakdown": | |
# print(f"v2 norm: {torch.norm(v2, dim=0).mean()}") | |
# print(f"v sum norm: {torch.norm(v1 + v2, dim=0).mean()}") | |
x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2 | |
if ( | |
self._config.breakdown == "x_t_hat_c" | |
or self._config.breakdown == "x_t_hat_c_with_zeros" | |
): | |
x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[ | |
edit_images_indices[0] : edit_images_indices[1] | |
] # update x_t_hat_c to be x_t_hat_c_hat | |
if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p: | |
x_t_minus_1[empty_prompt_indices[0] : empty_prompt_indices[1]] = ( | |
x_t_minus_1[edit_images_indices[0] : edit_images_indices[1]] | |
) | |
self.vector_data[timestep.item()]["x_t_minus_1_edited"] = x_t_minus_1[ | |
edit_images_indices[0] : edit_images_indices[1] | |
] | |
if timestep == self._timesteps[-1]: | |
torch.save( | |
self.vector_data, | |
os.path.join( | |
path, | |
f"{VECTOR_DATA_DICT}.pt", | |
), | |
) | |
# p_to_p_force_perfect_reconstruction | |
if not self.time_measure_n: | |
x_t_minus_1[0] = x_t_minus_1_exact[0] | |
if not return_dict: | |
return (x_t_minus_1,) | |
return DDIMSchedulerOutput( | |
prev_sample=x_t_minus_1, | |
pred_original_sample=None, | |
) | |
def create_xts( | |
noise_shift_delta, | |
noise_timesteps, | |
clean_step_timestep, | |
generator, | |
scheduler, | |
timesteps, | |
x_0, | |
no_add_noise=False, | |
): | |
if noise_timesteps is None: | |
noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1]) | |
noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps] | |
first_x_0_idx = len(noise_timesteps) | |
for i in range(len(noise_timesteps)): | |
if noise_timesteps[i] <= 0: | |
first_x_0_idx = i | |
break | |
noise_timesteps = noise_timesteps[:first_x_0_idx] | |
x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1) | |
noise = ( | |
torch.randn(x_0_expanded.size(), generator=generator, device="cpu").to( | |
x_0.device | |
) | |
if not no_add_noise | |
else torch.zeros_like(x_0_expanded) | |
) | |
x_ts = scheduler.add_noise( | |
x_0_expanded, | |
noise, | |
torch.IntTensor(noise_timesteps), | |
) | |
x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)] | |
x_ts += [x_0] * (len(timesteps) - first_x_0_idx) | |
x_ts += [x_0] | |
if clean_step_timestep > 0: | |
x_ts += [x_0] | |
return x_ts | |
def normalize( | |
z_t, | |
i, | |
max_norm_zs, | |
): | |
max_norm = max_norm_zs[i] | |
if max_norm < 0: | |
return z_t, 1 | |
norm = torch.norm(z_t) | |
if norm < max_norm: | |
return z_t, 1 | |
coeff = max_norm / norm | |
z_t = z_t * coeff | |
return z_t, coeff | |
def decode_latents(latent, pipe): | |
latent_img = pipe.vae.decode( | |
latent / pipe.vae.config.scaling_factor, return_dict=False | |
)[0] | |
return pipe.image_processor.postprocess(latent_img, output_type="pil") | |
def deterministic_ddim_step( | |
model_output: torch.FloatTensor, | |
timestep: int, | |
sample: torch.FloatTensor, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
generator=None, | |
variance_noise: Optional[torch.FloatTensor] = None, | |
return_dict: bool = True, | |
scheduler=None, | |
): | |
if scheduler.num_inference_steps is None: | |
raise ValueError( | |
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | |
) | |
prev_timestep = ( | |
timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps | |
) | |
# 2. compute alphas, betas | |
alpha_prod_t = scheduler.alphas_cumprod[timestep] | |
alpha_prod_t_prev = ( | |
scheduler.alphas_cumprod[prev_timestep] | |
if prev_timestep >= 0 | |
else scheduler.final_alpha_cumprod | |
) | |
beta_prod_t = 1 - alpha_prod_t | |
if scheduler.config.prediction_type == "epsilon": | |
pred_original_sample = ( | |
sample - beta_prod_t ** (0.5) * model_output | |
) / alpha_prod_t ** (0.5) | |
pred_epsilon = model_output | |
elif scheduler.config.prediction_type == "sample": | |
pred_original_sample = model_output | |
pred_epsilon = ( | |
sample - alpha_prod_t ** (0.5) * pred_original_sample | |
) / beta_prod_t ** (0.5) | |
elif scheduler.config.prediction_type == "v_prediction": | |
pred_original_sample = (alpha_prod_t**0.5) * sample - ( | |
beta_prod_t**0.5 | |
) * model_output | |
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | |
else: | |
raise ValueError( | |
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or" | |
" `v_prediction`" | |
) | |
# 4. Clip or threshold "predicted x_0" | |
if scheduler.config.thresholding: | |
pred_original_sample = scheduler._threshold_sample(pred_original_sample) | |
elif scheduler.config.clip_sample: | |
pred_original_sample = pred_original_sample.clamp( | |
-scheduler.config.clip_sample_range, | |
scheduler.config.clip_sample_range, | |
) | |
# 5. compute variance: "sigma_t(η)" -> see formula (16) | |
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | |
variance = scheduler._get_variance(timestep, prev_timestep) | |
std_dev_t = eta * variance ** (0.5) | |
if use_clipped_model_output: | |
# the pred_epsilon is always re-derived from the clipped x_0 in Glide | |
pred_epsilon = ( | |
sample - alpha_prod_t ** (0.5) * pred_original_sample | |
) / beta_prod_t ** (0.5) | |
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** ( | |
0.5 | |
) * pred_epsilon | |
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
prev_sample = ( | |
alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | |
) | |
return prev_sample | |
def deterministic_euler_step( | |
model_output: torch.FloatTensor, | |
timestep: Union[float, torch.FloatTensor], | |
sample: torch.FloatTensor, | |
eta, | |
use_clipped_model_output, | |
generator, | |
variance_noise, | |
return_dict, | |
scheduler, | |
): | |
""" | |
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
model_output (`torch.FloatTensor`): | |
The direct output from learned diffusion model. | |
timestep (`float`): | |
The current discrete timestep in the diffusion chain. | |
sample (`torch.FloatTensor`): | |
A current instance of a sample created by the diffusion process. | |
generator (`torch.Generator`, *optional*): | |
A random number generator. | |
return_dict (`bool`): | |
Whether or not to return a | |
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. | |
Returns: | |
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: | |
If return_dict is `True`, | |
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, | |
otherwise a tuple is returned where the first element is the sample tensor. | |
""" | |
if ( | |
isinstance(timestep, int) | |
or isinstance(timestep, torch.IntTensor) | |
or isinstance(timestep, torch.LongTensor) | |
): | |
raise ValueError( | |
( | |
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | |
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" | |
" one of the `scheduler.timesteps` as a timestep." | |
), | |
) | |
if scheduler.step_index is None: | |
scheduler._init_step_index(timestep) | |
sigma = scheduler.sigmas[scheduler.step_index] | |
# Upcast to avoid precision issues when computing prev_sample | |
sample = sample.to(torch.float32) | |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | |
if scheduler.config.prediction_type == "epsilon": | |
pred_original_sample = sample - sigma * model_output | |
elif scheduler.config.prediction_type == "v_prediction": | |
# * c_out + input * c_skip | |
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( | |
sample / (sigma**2 + 1) | |
) | |
elif scheduler.config.prediction_type == "sample": | |
raise NotImplementedError("prediction_type not implemented yet: sample") | |
else: | |
raise ValueError( | |
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`" | |
) | |
sigma_from = scheduler.sigmas[scheduler.step_index] | |
sigma_to = scheduler.sigmas[scheduler.step_index + 1] | |
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 | |
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 | |
# 2. Convert to an ODE derivative | |
derivative = (sample - pred_original_sample) / sigma | |
dt = sigma_down - sigma | |
prev_sample = sample + derivative * dt | |
# Cast sample back to model compatible dtype | |
prev_sample = prev_sample.to(model_output.dtype) | |
# upon completion increase step index by one | |
scheduler._step_index += 1 | |
return prev_sample | |
def deterministic_non_ancestral_euler_step( | |
model_output: torch.FloatTensor, | |
timestep: Union[float, torch.FloatTensor], | |
sample: torch.FloatTensor, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
s_churn: float = 0.0, | |
s_tmin: float = 0.0, | |
s_tmax: float = float("inf"), | |
s_noise: float = 1.0, | |
generator: Optional[torch.Generator] = None, | |
variance_noise: Optional[torch.FloatTensor] = None, | |
return_dict: bool = True, | |
scheduler=None, | |
): | |
""" | |
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
model_output (`torch.FloatTensor`): | |
The direct output from learned diffusion model. | |
timestep (`float`): | |
The current discrete timestep in the diffusion chain. | |
sample (`torch.FloatTensor`): | |
A current instance of a sample created by the diffusion process. | |
s_churn (`float`): | |
s_tmin (`float`): | |
s_tmax (`float`): | |
s_noise (`float`, defaults to 1.0): | |
Scaling factor for noise added to the sample. | |
generator (`torch.Generator`, *optional*): | |
A random number generator. | |
return_dict (`bool`): | |
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or | |
tuple. | |
Returns: | |
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: | |
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is | |
returned, otherwise a tuple is returned where the first element is the sample tensor. | |
""" | |
if ( | |
isinstance(timestep, int) | |
or isinstance(timestep, torch.IntTensor) | |
or isinstance(timestep, torch.LongTensor) | |
): | |
raise ValueError( | |
( | |
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | |
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" | |
" one of the `scheduler.timesteps` as a timestep." | |
), | |
) | |
if not scheduler.is_scale_input_called: | |
logger.warning( | |
"The `scale_model_input` function should be called before `step` to ensure correct denoising. " | |
"See `StableDiffusionPipeline` for a usage example." | |
) | |
if scheduler.step_index is None: | |
scheduler._init_step_index(timestep) | |
# Upcast to avoid precision issues when computing prev_sample | |
sample = sample.to(torch.float32) | |
sigma = scheduler.sigmas[scheduler.step_index] | |
gamma = ( | |
min(s_churn / (len(scheduler.sigmas) - 1), 2**0.5 - 1) | |
if s_tmin <= sigma <= s_tmax | |
else 0.0 | |
) | |
sigma_hat = sigma * (gamma + 1) | |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | |
# NOTE: "original_sample" should not be an expected prediction_type but is left in for | |
# backwards compatibility | |
if ( | |
scheduler.config.prediction_type == "original_sample" | |
or scheduler.config.prediction_type == "sample" | |
): | |
pred_original_sample = model_output | |
elif scheduler.config.prediction_type == "epsilon": | |
pred_original_sample = sample - sigma_hat * model_output | |
elif scheduler.config.prediction_type == "v_prediction": | |
# denoised = model_output * c_out + input * c_skip | |
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( | |
sample / (sigma**2 + 1) | |
) | |
else: | |
raise ValueError( | |
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`" | |
) | |
# 2. Convert to an ODE derivative | |
derivative = (sample - pred_original_sample) / sigma_hat | |
dt = scheduler.sigmas[scheduler.step_index + 1] - sigma_hat | |
prev_sample = sample + derivative * dt | |
# Cast sample back to model compatible dtype | |
prev_sample = prev_sample.to(model_output.dtype) | |
# upon completion increase step index by one | |
scheduler._step_index += 1 | |
return prev_sample | |
def deterministic_ddpm_step( | |
model_output: torch.FloatTensor, | |
timestep: Union[float, torch.FloatTensor], | |
sample: torch.FloatTensor, | |
eta, | |
use_clipped_model_output, | |
generator, | |
variance_noise, | |
return_dict, | |
scheduler, | |
): | |
""" | |
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
model_output (`torch.FloatTensor`): | |
The direct output from learned diffusion model. | |
timestep (`float`): | |
The current discrete timestep in the diffusion chain. | |
sample (`torch.FloatTensor`): | |
A current instance of a sample created by the diffusion process. | |
generator (`torch.Generator`, *optional*): | |
A random number generator. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. | |
Returns: | |
[`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: | |
If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a | |
tuple is returned where the first element is the sample tensor. | |
""" | |
t = timestep | |
prev_t = scheduler.previous_timestep(t) | |
if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [ | |
"learned", | |
"learned_range", | |
]: | |
model_output, predicted_variance = torch.split( | |
model_output, sample.shape[1], dim=1 | |
) | |
else: | |
predicted_variance = None | |
# 1. compute alphas, betas | |
alpha_prod_t = scheduler.alphas_cumprod[t] | |
alpha_prod_t_prev = ( | |
scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one | |
) | |
beta_prod_t = 1 - alpha_prod_t | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
current_alpha_t = alpha_prod_t / alpha_prod_t_prev | |
current_beta_t = 1 - current_alpha_t | |
# 2. compute predicted original sample from predicted noise also called | |
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf | |
if scheduler.config.prediction_type == "epsilon": | |
pred_original_sample = ( | |
sample - beta_prod_t ** (0.5) * model_output | |
) / alpha_prod_t ** (0.5) | |
elif scheduler.config.prediction_type == "sample": | |
pred_original_sample = model_output | |
elif scheduler.config.prediction_type == "v_prediction": | |
pred_original_sample = (alpha_prod_t**0.5) * sample - ( | |
beta_prod_t**0.5 | |
) * model_output | |
else: | |
raise ValueError( | |
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" | |
" `v_prediction` for the DDPMScheduler." | |
) | |
# 3. Clip or threshold "predicted x_0" | |
if scheduler.config.thresholding: | |
pred_original_sample = scheduler._threshold_sample(pred_original_sample) | |
elif scheduler.config.clip_sample: | |
pred_original_sample = pred_original_sample.clamp( | |
-scheduler.config.clip_sample_range, scheduler.config.clip_sample_range | |
) | |
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t | |
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
pred_original_sample_coeff = ( | |
alpha_prod_t_prev ** (0.5) * current_beta_t | |
) / beta_prod_t | |
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t | |
# 5. Compute predicted previous sample µ_t | |
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
pred_prev_sample = ( | |
pred_original_sample_coeff * pred_original_sample | |
+ current_sample_coeff * sample | |
) | |
return pred_prev_sample | |