Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dit inference model #612

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added ppdiffusers/deploy/__init__.py
Empty file.
298 changes: 298 additions & 0 deletions ppdiffusers/deploy/custom_ddim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np
import paddle

from ppdiffusers.configuration_utils import ConfigMixin, register_to_config, deprecate
from ppdiffusers.utils import BaseOutput, logging
from ppdiffusers.schedulers import SchedulerMixin


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
class DDIMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.

Args:
prev_sample (`paddle.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`paddle.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: paddle.Tensor
pred_original_sample: Optional[paddle.Tensor] = None


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> paddle.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].

Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.


Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.

Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""

def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return paddle.to_tensor(betas)


class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance.

[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.

For more details, see the original paper: https://arxiv.org/abs/2010.02502

Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""

_deprecated_kwargs = ["predict_epsilon"]
order = 1

@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
**kwargs,
):
# message = (
# "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
# " DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
# )
# predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
# if predict_epsilon is not None:
# self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = paddle.to_tensor(trained_betas, dtype="float32")
elif beta_schedule == "linear":
self.betas = paddle.linspace(beta_start, beta_end, num_train_timesteps, dtype="float32")
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = paddle.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype="float32") ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

self.alphas = 1.0 - self.betas
self.alphas_cumprod = paddle.cumprod(self.alphas, 0)

# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = paddle.to_tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0

# setable values
self.num_inference_steps = None
self.timesteps = paddle.to_tensor(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))

def scale_model_input(self, sample: paddle.Tensor, timestep: Optional[int] = None) -> paddle.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.

Args:
sample (`paddle.Tensor`): input sample
timestep (`int`, optional): current timestep

Returns:
`paddle.Tensor`: scaled input sample
"""
return sample

def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

return variance

def set_timesteps(self, num_inference_steps: int):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = paddle.to_tensor(timesteps)
self.timesteps += self.config.steps_offset
self.cal_mul_params()

def cal_mul_params(self):
self.sample_mul_params = []
self.model_output_mul_params = []
for timestep in self.timesteps:
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
sample_params = alpha_prod_t_prev ** (0.5) / alpha_prod_t ** (0.5)
model_output_params = (beta_prod_t ** (0.5) / alpha_prod_t ** (0.5) * alpha_prod_t_prev ** (0.5) - (1 - alpha_prod_t_prev) ** (0.5))
self.sample_mul_params.append(sample_params.astype("float16"))
self.model_output_mul_params.append(model_output_params.astype("float16"))


def step(
self,
model_output: paddle.Tensor,
index: int,
timestep: int,
sample: paddle.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
variance_noise: Optional[paddle.Tensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:

if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

if self.config.prediction_type != "epsilon":
raise ValueError(
"prediction_type must be epsilon"
)

if eta > 0:
raise ValueError(
"eta must be 0"
)

prev_sample = self.sample_mul_params[index] * sample - self.model_output_mul_params[index] * model_output
if not return_dict:
return (prev_sample,)

return DDIMSchedulerOutput(prev_sample=prev_sample)

def add_noise(
self,
original_samples: paddle.Tensor,
noise: paddle.Tensor,
timesteps: paddle.Tensor,
) -> paddle.Tensor:
# Make sure alphas_cumprod and timestep have same dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.cast(original_samples.dtype)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

def get_velocity(self, sample: paddle.Tensor, noise: paddle.Tensor, timesteps: paddle.Tensor) -> paddle.Tensor:
# Make sure alphas_cumprod and timestep have same dtype as sample
self.alphas_cumprod = self.alphas_cumprod.cast(sample.dtype)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity

def __len__(self):
return self.config.num_train_timesteps
Loading