Note: found out a more detailed code go-over tool, annotated deep learning implementations .

  • However, it seems that as a researcher, it is better to read the paper and read the code yourself before looking for illustrations from this site.

In this post I go over the implementation of the DDIM sampling method (for Implicit models) and also some variants of the implicit diffusion model paradigm. It is advised to first take a read at the ddpm post about the general training setup of the diffusion model; DDIM changes the inference / sampling process to make it more efficient.

The general idea from DDIM is that reverse sampling doesn’t have to be completely in sync (stepwise) with the forward Markovian diffusion process: the objective of KL holds as long as the marginal distributions agree; for this an additional variable $\sigma_t$ is introduced for the variational distribution. To make this possible, the inference process now has a more complex form (equation 7 in DDIM):

$$\begin{align} q(x_{t-1} \mid x_{t}, x_{0}) = \mathcal{N}\left( \sqrt{\alpha_{t-1}}x_0 + \sqrt{1-\alpha_{t-1}-\sigma_{t}^2}\cdot \frac{x_{t} - \sqrt{\alpha_{t}}x_0}{\sqrt{1-\alpha_{t}}} , \sigma_{t}^2 I\right) \end{align}$$

Then the posterior diffusion process has the following decomposition (equatino 12 in DDIM):

$$\begin{align} x_{t-1} = \sqrt{\alpha_{t-1}} \underbrace{\left( \frac{x_{t} - \sqrt{1-\alpha_{t}}\epsilon_{\theta}^{(t)}(x_{t})}{\sqrt{\alpha_{t}}} \right)}_{\text{predicted }x_0} + \underbrace{\sqrt{1 - \alpha_{t-1} - \sigma_{t}^2} \cdot \epsilon_{\theta}^{(t)} (x_t)}_{\text{direction pointing to }x_t} + \underbrace{\sigma_{t} \epsilon_{t}}_{\text{random noise}} \end{align}$$

It was mentioned in the paper that:

  • If $\epsilon_t = 0$ for all $t$, then forward process becomes deterministic (Deep Implicit Model)
  • If $\epsilon_t = \sqrt{(1-\alpha_{t-1})/(1-\alpha_t)} \sqrt{1-\alpha_t/\alpha_{t-1}}$ for all $t$, then forward process is Markovian and we have DDPM.

Now the sampling process is proposed to be only an ascending subsequence of ${0, \cdots, T}$ in the original case. Suppose the subsequence has time stamps $\tau_i$, then $\sigma$ is designed to be (Equation 16 in DDIM):

$$\begin{align} \sigma_{\tau_i}(\eta) = \eta \sqrt{\frac{1-\alpha_{\tau_{i-1}}}{1-\alpha_{\tau_i}}} \sqrt{1- \frac{\alpha_{\tau_i}}{\alpha_{\tau_{i-1}}}} \end{align}$$

with $\eta$ an hyperparameter. This is the crucial part for DDPM implementation. This equation corresponds to L717 in the DDIM implementation under denoising_diffusion_pytorch.py.

# DDIM sampling, L691-731
# sampling_timesteps < total_timesteps in the case of DDIM 

@torch.inference_mode()
def ddim_sample(self, shape, return_all_timesteps = False):
    batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective

    times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1)   # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
    times = list(reversed(times.int().tolist()))
    time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]

    img = torch.randn(shape, device = device)
    imgs = [img]

    x_start = None

    for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
        time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
        self_cond = x_start if self.self_condition else None
        # obtain predicted x_0 and predicted noise in equation 12 (equation 2 in our case)
        pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)

        if time_next < 0:
            img = x_start
            imgs.append(img)
            continue

        alpha = self.alphas_cumprod[time]
        alpha_next = self.alphas_cumprod[time_next]

        # compute sigma from equation 16 (equation 3 in our case)
        sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
        # compute term before predicted noise 
        c = (1 - alpha_next - sigma ** 2).sqrt()

        noise = torch.randn_like(img)

        # direct implementation of equation 12 (equation 2 in our case)
        img = x_start * alpha_next.sqrt() + \
                c * pred_noise + \
                sigma * noise

        imgs.append(img)

    ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)

    ret = self.unnormalize(ret)
    return ret

Some crucial implementation mappings:

  1. Everything regarding times before the for loop deals with obtaining a subsequence of the time steps.

  2. For equation 12 in DDIM (equation 2 in our case), the “predicted $x_0$” and $\epsilon\theta^{(t)}(x_t)$_ term are obtained through:

pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)

Then

img = x_start * alpha_next.sqrt() + \
            c * pred_noise + \
            sigma * noise

directly maps to equation 2 itself.