Note: found out a more detailed code go-over tool, annotated deep learning implementations .

  • However, it seems that as a researcher, it is better to read the paper and read the code yourself before looking for illustrations from this site.

In this post I go over the implementation of the DDIM sampling method (for Implicit models) and also some variants of the implicit diffusion model paradigm. It is advised to first take a read at the ddpm post about the general training setup of the diffusion model; DDIM changes the inference / sampling process to make it more efficient.

The general idea from DDIM is that reverse sampling doesn’t have to be completely in sync (stepwise) with the forward Markovian diffusion process: the objective of KL holds as long as the marginal distributions agree; for this an additional variable σt\sigma_t is introduced for the variational distribution. To make this possible, the inference process now has a more complex form (equation 7 in DDIM):

q(xt1xt,x0)=N(αt1x0+1αt1σt2xtαtx01αt,σt2I)\begin{align} q(x_{t-1} \mid x_{t}, x_{0}) = \mathcal{N}\left( \sqrt{\alpha_{t-1}}x_0 + \sqrt{1-\alpha_{t-1}-\sigma_{t}^2}\cdot \frac{x_{t} - \sqrt{\alpha_{t}}x_0}{\sqrt{1-\alpha_{t}}} , \sigma_{t}^2 I\right) \end{align}

Then the posterior diffusion process has the following decomposition (equatino 12 in DDIM):

xt1=αt1(xt1αtϵθ(t)(xt)αt)predicted x0+1αt1σt2ϵθ(t)(xt)direction pointing to xt+σtϵtrandom noise\begin{align} x_{t-1} = \sqrt{\alpha_{t-1}} \underbrace{\left( \frac{x_{t} - \sqrt{1-\alpha_{t}}\epsilon_{\theta}^{(t)}(x_{t})}{\sqrt{\alpha_{t}}} \right)}_{\text{predicted }x_0} + \underbrace{\sqrt{1 - \alpha_{t-1} - \sigma_{t}^2} \cdot \epsilon_{\theta}^{(t)} (x_t)}_{\text{direction pointing to }x_t} + \underbrace{\sigma_{t} \epsilon_{t}}_{\text{random noise}} \end{align}

It was mentioned in the paper that:

  • If ϵt=0\epsilon_t = 0 for all tt, then forward process becomes deterministic (Deep Implicit Model)
  • If ϵt=(1αt1)/(1αt)1αt/αt1\epsilon_t = \sqrt{(1-\alpha_{t-1})/(1-\alpha_t)} \sqrt{1-\alpha_t/\alpha_{t-1}} for all tt, then forward process is Markovian and we have DDPM.

Now the sampling process is proposed to be only an ascending subsequence of 0,,T{0, \cdots, T} in the original case. Suppose the subsequence has time stamps τi\tau_i, then σ\sigma is designed to be (Equation 16 in DDIM):

στi(η)=η1ατi11ατi1ατiατi1\begin{align} \sigma_{\tau_i}(\eta) = \eta \sqrt{\frac{1-\alpha_{\tau_{i-1}}}{1-\alpha_{\tau_i}}} \sqrt{1- \frac{\alpha_{\tau_i}}{\alpha_{\tau_{i-1}}}} \end{align}

with η\eta an hyperparameter. This is the crucial part for DDPM implementation. This equation corresponds to L717 in the DDIM implementation under denoising_diffusion_pytorch.py.

# DDIM sampling, L691-731
# sampling_timesteps < total_timesteps in the case of DDIM 

@torch.inference_mode()
def ddim_sample(self, shape, return_all_timesteps = False):
    batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective

    times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1)   # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
    times = list(reversed(times.int().tolist()))
    time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]

    img = torch.randn(shape, device = device)
    imgs = [img]

    x_start = None

    for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
        time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
        self_cond = x_start if self.self_condition else None
        # obtain predicted x_0 and predicted noise in equation 12 (equation 2 in our case)
        pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)

        if time_next < 0:
            img = x_start
            imgs.append(img)
            continue

        alpha = self.alphas_cumprod[time]
        alpha_next = self.alphas_cumprod[time_next]

        # compute sigma from equation 16 (equation 3 in our case)
        sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
        # compute term before predicted noise 
        c = (1 - alpha_next - sigma ** 2).sqrt()

        noise = torch.randn_like(img)

        # direct implementation of equation 12 (equation 2 in our case)
        img = x_start * alpha_next.sqrt() + \
                c * pred_noise + \
                sigma * noise

        imgs.append(img)

    ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)

    ret = self.unnormalize(ret)
    return ret

Some crucial implementation mappings:

  1. Everything regarding times before the for loop deals with obtaining a subsequence of the time steps.

  2. For equation 12 in DDIM (equation 2 in our case), the “predicted x0x_0” and $\epsilon\theta^{(t)}(x_t)$_ term are obtained through:

pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)

Then

img = x_start * alpha_next.sqrt() + \
            c * pred_noise + \
            sigma * noise

directly maps to equation 2 itself.