Note: found out a more detailed code go-over tool, annotated deep learning implementations .
- However, it seems that as a researcher, it is better to read the paper and read the code yourself before looking for illustrations from this site.
In this post I go over the implementation of the DDIM sampling method (for Implicit models) and also some variants of the implicit diffusion model paradigm. It is advised to first take a read at the ddpm post about the general training setup of the diffusion model; DDIM changes the inference / sampling process to make it more efficient.
The general idea from DDIM is that reverse sampling doesn’t have to be completely in sync (stepwise) with the forward Markovian diffusion process: the objective of KL holds as long as the marginal distributions agree; for this an additional variable is introduced for the variational distribution. To make this possible, the inference process now has a more complex form (equation 7 in DDIM):
Then the posterior diffusion process has the following decomposition (equatino 12 in DDIM):
It was mentioned in the paper that:
- If for all , then forward process becomes deterministic (Deep Implicit Model)
- If for all , then forward process is Markovian and we have DDPM.
Now the sampling process is proposed to be only an ascending subsequence of in the original case. Suppose the subsequence has time stamps , then is designed to be (Equation 16 in DDIM):
with an hyperparameter. This is the crucial part for DDPM implementation. This equation corresponds to L717 in the DDIM implementation under denoising_diffusion_pytorch.py
.
# DDIM sampling, L691-731
# sampling_timesteps < total_timesteps in the case of DDIM
@torch.inference_mode()
def ddim_sample(self, shape, return_all_timesteps = False):
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img = torch.randn(shape, device = device)
imgs = [img]
x_start = None
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
self_cond = x_start if self.self_condition else None
# obtain predicted x_0 and predicted noise in equation 12 (equation 2 in our case)
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)
if time_next < 0:
img = x_start
imgs.append(img)
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
# compute sigma from equation 16 (equation 3 in our case)
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
# compute term before predicted noise
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(img)
# direct implementation of equation 12 (equation 2 in our case)
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
imgs.append(img)
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
ret = self.unnormalize(ret)
return ret
Some crucial implementation mappings:
Everything regarding
times
before the for loop deals with obtaining a subsequence of the time steps.For equation 12 in DDIM (equation 2 in our case), the “predicted ” and $\epsilon\theta^{(t)}(x_t)$_ term are obtained through:
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)
Then
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
directly maps to equation 2 itself.