Quick Overview of Main Result

The work PDE Refiner is about ensuring stable long time rollouts of dynamical systems. In this work, they identified empirically that the neglect of non-dominant spatial frequency information (high frequency modes in PDE solution) is the primary reason for the inability of existing models for long temporal rollouts. To address this issue, they force the model to learn these modes by actively injecting noises at each step and predicting noise, just like the mechanism in diffuion models. (they also claim the additional benefits of data augmentation and uncertainty quantification)

In additional to the standard MSE loss in operator learning:

$$ L_1(u,t) = \mid \mid u(t) - NO(\hat{u}(t), u(t-\Delta t)) \mid \mid^2 $$

They added a Refinement process, where they repeatedly add noise to the prediction and predict noise (noise prediction) for $K$ iteratinos. $\forall k = 1, \cdots K$:

$$ L_2(u,t) = \mathbb{E}_{k\sim U(0,K)} \mathbb{E}_{\epsilon^k \sim \mathcal{N}(0,1)} [\mid\mid \epsilon_k - NO(u(t) + \sigma_k \epsilon_k, u(t-\Delta t), k) \mid\mid^2] $$

where the noises were scheduled to be exponentially decreasing, with the minimum frequency being determined by the underlying dynamical process; for the KS equation, this frequency is at $\sigma_{min} = 2\cdot 10^{-7}$.

Implementation

The pseudocode from Appendix C of the paper provides an elucidating pseudocode (added some comments):

Pseudocode One

class PDERefiner:
    def __init__(self, num_steps, min_noise_std):
        self.num_steps = num_steps
        self.min_noise_std = min_noise_std
        self.neural_operator = MyNetwork(...)

    # one step training: sample k ~ Uniform(0,K)
    # k = 0 => standard operator learning MSE loss 
    # k > 0 => predict noise from noised u_t 
    def train_step(self, u_t, u_prev):
        k = randint(0, self.num_steps + 1)
        if k == 0:
            pred = self.neural_operator(zeros_like(u_t), u_prev, k)
            target = u_t
        else:
            noise_std = self.min_noise_std ** (k / self.num_steps)
            noise = randn_like(u_t)
            u_t_noised = u_t + noise * noise_std
            pred = self.neural_operator(u_t_noised, u_prev, k)
            target = noise
        loss = mse(pred, target)
        return loss

    # prediction: first predict u_t
    # then predict noise K times and denoise
    def predict_next_solution(self, u_prev):
        u_hat_t = self.neural_operator(zeros_like(u_prev), u_prev, 0)
        for k in range(1, self.num_steps + 1):
            noise_std = self.min_noise_std ** (k / self.num_steps)
            noise = randn_like(u_t)
            u_hat_t_noised = u_hat_t + noise * noise_std
            pred = self.neural_operator(u_hat_t_noised, u_prev, k)
            u_hat_t = u_hat_t_noised - pred * noise_std
        return u_hat_t

    # rollout: iteratively call the predict_next_solution
    def rollout(self, u_initial, timesteps):
        trajectory = [u_initial]
        for t in range(timesteps):
            u_hat_t = self.predict_next_solution(trajectory[-1])
        trajectory.append(u_hat_t)
        return trajectory

Straightfowardly, train_step corresponds to the two losses defined. Notice that these two losses are not summed but are compmuted in an either/or fashion. predict_next_step predicts both the natural target and performs a forward noising + backward denoising process $K$ times.

Pseudocode Two

The connection with diffusion model can be readily seen from the implementation. This fact is also illustrated in the paper section 3.3 and Appendix C. This part of the code looks like the following:

from diffusers.schedulers import DDPMScheduler

class PDERefinerDiffusion:
    def __init__(self, num_steps, min_noise_std):
        betas = [min_noise_std ** (k / num_steps)
                 for k in reversed(range(num_steps + 1))]
        self.scheduler = DDPMScheduler(num_train_timesteps=num_steps + 1,
                                       trained_betas=betas, prediction_type='v_prediction', clip_sample=False)
        self.num_steps = num_steps
        self.neural_operator = MyNetwork(...)

    def train_step(self, u_t, u_prev):
        k = randint(0, self.num_steps + 1)
        # The scheduler uses t=K for first step prediction, and t=0 for minimum noise.
        # To be consistent with the presentation in the paper, we keep k and the
        # scheduler time separate. However, one can also use the scheduler time step
        # as k directly and acts as conditional input to the neural operator.
        scheduler_t = self.num_steps - k
        noise_factor = self.scheduler.alphas_cumprod[scheduler_t]
        signal_factor = 1 - noise_factor
        noise = randn_like(u_t)
        u_t_noised = self.scheduler.add_noise(u_t, noise, scheduler_t)
        pred = self.neural_operator(u_t_noised, u_prev, k)
        target = (noise_factor ** 0.5) * noise - (signal_factor ** 0.5) * u_t
        loss = mse(pred, target)
        return loss
        
    # almost identiy as code ONE 
    def predict_next_solution(self, u_prev):
        u_hat_t_noised = randn_like(u_prev)
        for scheduler_t in self.scheduler.timesteps:
            k = self.num_steps - scheduler_t
            pred = self.neural_operator(u_hat_t_noised, u_prev, k)
            out = self.scheduler.step(pred, scheduler_t, u_hat_t_noised)
            u_hat_t_noised = out.prev_sample
            u_hat_t = u_hat_t_noised
        return u_hat_t

    # almost identical as code ONE 
    def rollout(self, u_initial, timesteps):
        trajectory = [u_initial]
        for t in range(timesteps):
            u_hat_t = self.predict_next_solution(trajectory[-1])
            trajectory.append(u_hat_t)
        return trajectory

In this case, they used a more advanced noising-denoising process from DDPM, using the diffusers package from Huggingface. DDPM scheduler allows the choice of prediction type; in their official implementation, instead of noise, they used v_prediction, which is proved to be more powerful in video diffusion models .

The official implementation for this work is at PDE Arena , with the main script at pdeareana.models.pderefiner, below L94-99

self.scheduler = DDPMScheduler(
            num_train_timesteps=num_refinement_steps + 1,
            trained_betas=betas,
            prediction_type="v_prediction",
            clip_sample=False,
        )

Official Implementation

In the official implementation, we again take a look at the three methods above, especially train_step:

def train_step(self, batch):
    x, y, cond = batch
    if self.hparams.predict_difference:
        # Predict difference to next step instead of next step directly.
        y = (y - x[:, -1:]) / self.hparams.difference_weight
    k = torch.randint(0, self.scheduler.config.num_train_timesteps, (x.shape[0],), device=x.device)
    noise_factor = self.scheduler.alphas_cumprod.to(x.device)[k]
    noise_factor = noise_factor.view(-1, *[1 for _ in range(x.ndim - 1)])
    signal_factor = 1 - noise_factor
    noise = torch.randn_like(y)
    y_noised = self.scheduler.add_noise(y, noise, k)
    x_in = torch.cat([x, y_noised], axis=1)
    pred = self.model(x_in, time=k * self.time_multiplier, z=cond)
    target = (noise_factor**0.5) * noise - (signal_factor**0.5) * y
    loss = self.train_criterion(pred, target)
    return loss, pred, target

This is basically pseudocode 2 with some additional codes to deal with the shape of input batched tensor and a choice to perform residual prediction (the actual practice in their experiment of Kolmogorov 2D flow, in section 4.3).

(Conditional) Neural Operator

The Neural operator in their case takes in a conditional value (in this case, the step in diffusion/noising process):

self.model([x_in, time = ..., z=cond])

As explained in Appendix D.1, Model Architecture, the conditioning parameter $\Delta t$ and $\Delta x$ are embedded into feature vector space via sinusoidal embeddings, just like in the DiT paper (see this post ). Instead of using a transformer model, though, their backbone is a UNet. This may suggest a further improvement in their modeling process:

  • Does latent diffusion process help with this type PDE-refiner training?
  • Does transformer model improve the UNet model in this case?

These two are being examined in my climate project at this moment, which requires long rollouts.