Note: found out a more detailed code go-over tool, annotated deep learning implementations .
- However, it seems that as a researcher, it is better to read the paper and read the code yourself before looking for illustrations from this site.
In this post I go over the implementation of the DDIM sampling method (for Implicit models) and also some variants of the implicit diffusion model paradigm. It is advised to first take a read at the ddpm post about the general training setup of the diffusion model; DDIM changes the inference / sampling process to make it more efficient.
The general idea from DDIM is that reverse sampling doesn’t have to be completely in sync (stepwise) with the forward Markovian diffusion process: the objective of KL holds as long as the marginal distributions agree; for this an additional variable $\sigma_t$ is introduced for the variational distribution. To make this possible, the inference process now has a more complex form (equation 7 in DDIM):
$$\begin{align} q(x_{t-1} \mid x_{t}, x_{0}) = \mathcal{N}\left( \sqrt{\alpha_{t-1}}x_0 + \sqrt{1-\alpha_{t-1}-\sigma_{t}^2}\cdot \frac{x_{t} - \sqrt{\alpha_{t}}x_0}{\sqrt{1-\alpha_{t}}} , \sigma_{t}^2 I\right) \end{align}$$
Then the posterior diffusion process has the following decomposition (equatino 12 in DDIM):
$$\begin{align} x_{t-1} = \sqrt{\alpha_{t-1}} \underbrace{\left( \frac{x_{t} - \sqrt{1-\alpha_{t}}\epsilon_{\theta}^{(t)}(x_{t})}{\sqrt{\alpha_{t}}} \right)}_{\text{predicted }x_0} + \underbrace{\sqrt{1 - \alpha_{t-1} - \sigma_{t}^2} \cdot \epsilon_{\theta}^{(t)} (x_t)}_{\text{direction pointing to }x_t} + \underbrace{\sigma_{t} \epsilon_{t}}_{\text{random noise}} \end{align}$$
It was mentioned in the paper that:
- If $\epsilon_t = 0$ for all $t$, then forward process becomes deterministic (Deep Implicit Model)
- If $\epsilon_t = \sqrt{(1-\alpha_{t-1})/(1-\alpha_t)} \sqrt{1-\alpha_t/\alpha_{t-1}}$ for all $t$, then forward process is Markovian and we have DDPM.
Now the sampling process is proposed to be only an ascending subsequence of ${0, \cdots, T}$ in the original case. Suppose the subsequence has time stamps $\tau_i$, then $\sigma$ is designed to be (Equation 16 in DDIM):
$$\begin{align} \sigma_{\tau_i}(\eta) = \eta \sqrt{\frac{1-\alpha_{\tau_{i-1}}}{1-\alpha_{\tau_i}}} \sqrt{1- \frac{\alpha_{\tau_i}}{\alpha_{\tau_{i-1}}}} \end{align}$$
with $\eta$ an hyperparameter. This is the crucial part for DDPM implementation. This equation corresponds to L717 in the DDIM implementation under denoising_diffusion_pytorch.py
.
# DDIM sampling, L691-731
# sampling_timesteps < total_timesteps in the case of DDIM
@torch.inference_mode()
def ddim_sample(self, shape, return_all_timesteps = False):
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img = torch.randn(shape, device = device)
imgs = [img]
x_start = None
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
self_cond = x_start if self.self_condition else None
# obtain predicted x_0 and predicted noise in equation 12 (equation 2 in our case)
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)
if time_next < 0:
img = x_start
imgs.append(img)
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
# compute sigma from equation 16 (equation 3 in our case)
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
# compute term before predicted noise
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(img)
# direct implementation of equation 12 (equation 2 in our case)
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
imgs.append(img)
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
ret = self.unnormalize(ret)
return ret
Some crucial implementation mappings:
Everything regarding
times
before the for loop deals with obtaining a subsequence of the time steps.For equation 12 in DDIM (equation 2 in our case), the “predicted $x_0$” and $\epsilon\theta^{(t)}(x_t)$_ term are obtained through:
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)
Then
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
directly maps to equation 2 itself.