The discrete version of diffusion models (DDPM, DDIM) are easier to understand and implement, but the same may not be said about their continuous counterpart. Here I attempt to take a dive into the official implementation of the score based generative model through stochastic differential equations paper and try to map the implementation with the paper. This would help me (and potentially others) to adapt and modify the code base for my own use. The Implementation in this case is from the PyTorch version and with focus on the essential part of the paper (continuous DDPM training + sampling), used in my project on learning SODE system with diffusion models
Training
Training for score-based generative model, with the particular case of continuous SDE, can be found in equation 7 in the original paper:
$$ \theta^\ast = \text{argmin}_{\theta} \mathbb{E}_t { \lambda(t) \mathbb{E}_{x(0)}\mathbb{E}_{x(t)\mid x(0)} [\mid\mid s_{\theta}(x(t),t) - \nabla_{x(t)}\log p_{0t}(x(t)\mid x(0)) \mid\mid_2^2] } $$
where $x(t)$ is the intermediate result after $t$ step diffusion and $\lambda(t)$ is a positive weighting function, usually inversely proportional to the expectation of the conditional score function,
$$ \lambda(t) \propto 1/\mathbb{E}\left[ \mid\mid \nabla_{x(t)} \log p_{0t}(x(t)\mid x(0)) \mid\mid \right] $$
In the code, this requires then the computation of the score function of the defined SDE and the score network, which in the basic form, takes in $x(t)$ and time step $t$. The ground truth score can be evaluated analytically. For the case of Variance-Preserving SDE (the continuous version of the DDPM model), the SDE is the following (equation 11):
$$\begin{aligned} dx = -\frac{1}{2} \beta(t) x dt + \sqrt{\beta(t)} dw \end{aligned}$$
which results in the distribution of the transition kernel to be governed by two ODEs (from section 5.5 in Applied Stochastic Differential Equations ):
$$\begin{aligned} &p(x(t)\mid x(0)) = \mathcal{N}(x(t) \mid m(t), \sigma(t)) \\ &\frac{dm}{dt} = -\frac{1}{2}\beta(t) m\\ &\frac{d\sigma}{dt} = -\beta(t) \sigma + \beta(t)\\ \end{aligned}$$
Due to the linearity of the forward process (it is in fact an Orstein-Uhlenbeck Process). Since $\beta(t)$ is pre-determined and scheduled, we can solve the two ODEs analytically, which results in the following code in the sdelib.VPSDE
submodule:
def marginal_prob(self, x, t):
log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
mean = torch.exp(log_mean_coeff[:, None]) * x
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
return mean, std
Knowing the transition kernel, computing $\nabla_{x(t)}\log P(x(t)\mid x(0))$ is straightforward due to the Gaussian assumption. This results in the loss function computation: losses.get_sde_loss_fn.loss_fn
:
reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
z = torch.randn_like(batch)
mean, std = sde.marginal_prob(batch, t)
perturbed_data = mean + std[:, None, None, None] * z
score = score_fn(perturbed_data, t)
if not likelihood_weighting:
losses = torch.square(score * std[:, None, None, None] + z)
losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)
else:
g2 = sde.sde(torch.zeros_like(batch), t)[1] ** 2
losses = torch.square(score + z / std[:, None, None, None])
losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * g2
loss = torch.mean(losses)
return loss
Notice that in the implementation, the author made extensive use of the decorator pattern (without explicitly stating decorators), this functional approach makes it natural for them to convert the codebase into Jax later on.
Sampling
Sampling for scored-based diffusion model starts from:
- A trained score network that supposedly outputs ground-truth score values.
- A random noise of the same shape as data.
- A sampler for SDE or the probability ODE.
The probability flow ODE (equation 13) is the following:
$$\begin{aligned} dx = \left[ f(x,t) - \frac{1}{2}g(t)^2 \nabla_x \log p_t(x) \right] dt \end{aligned}$$
where the score is the trained score network. This formulation allows the use of powerful black-box ODE samplers from scipy
. The sampling logic for probability flow ODE is in sampling.get_ode_sampler
:
def get_ode_sampler(sde, shape, inverse_scaler,
denoise=False, rtol=1e-5, atol=1e-5,
method='RK45', eps=1e-3, device='cuda'):
"""Probability flow ODE sampler with the black-box ODE solver.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
shape: A sequence of integers. The expected shape of a single sample.
inverse_scaler: The inverse data normalizer.
denoise: If `True`, add one-step denoising to final samples.
rtol: A `float` number. The relative tolerance level of the ODE solver.
atol: A `float` number. The absolute tolerance level of the ODE solver.
method: A `str`. The algorithm used for the black-box ODE solver.
See the documentation of `scipy.integrate.solve_ivp`.
eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
device: PyTorch device.
Returns:
A sampling function that returns samples and the number of function evaluations during sampling.
"""
def denoise_update_fn(model, x):
score_fn = get_score_fn(sde, model, train=False, continuous=True)
# Reverse diffusion predictor for denoising
predictor_obj = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
vec_eps = torch.ones(x.shape[0], device=x.device) * eps
_, x = predictor_obj.update_fn(x, vec_eps)
return x
def drift_fn(model, x, t):
"""Get the drift function of the reverse-time SDE."""
score_fn = get_score_fn(sde, model, train=False, continuous=True)
rsde = sde.reverse(score_fn, probability_flow=True)
return rsde.sde(x, t)[0]
def ode_sampler(model, z=None):
"""The probability flow ODE sampler with black-box ODE solver.
Args:
model: A score model.
z: If present, generate samples from latent code `z`.
Returns:
samples, number of function evaluations.
"""
with torch.no_grad():
# Initial sample
if z is None:
# If not represent, sample the latent code from the prior distibution of the SDE.
x = sde.prior_sampling(shape).to(device)
else:
x = z
def ode_func(t, x):
x = from_flattened_numpy(x, shape).to(device).type(torch.float32)
vec_t = torch.ones(shape[0], device=x.device) * t
drift = drift_fn(model, x, vec_t)
return to_flattened_numpy(drift)
# Black-box ODE solver for the probability flow ODE
solution = integrate.solve_ivp(ode_func, (sde.T, eps), to_flattened_numpy(x),
rtol=rtol, atol=atol, method=method)
nfe = solution.nfev
x = torch.tensor(solution.y[:, -1]).reshape(shape).to(device).type(torch.float32)
# Denoising is equivalent to running one predictor step without adding noise
if denoise:
x = denoise_update_fn(model, x)
x = inverse_scaler(x)
return x, nfe
return ode_sampler
where the crucial parts are ode_func
, which implements the equation above; the denoise_update_fn
is used to implement the numerical scheme in ODE solve known as the Predictor-Corrector Samplers, in Appendix G of the paper.
The input to the sampling function is just a random noise z
(generated on-the-fly by default), as the sde.prior_sampling
function generates just random noise.