In this post, I study and review spectral methods in PyTorch, with special focus on visualizing the frequency modes, both for the Discrete Fourier Transform (DFT) / Fast Fourier Transform (FFT) on a regular grid and the Spherical Fourier Transform (SFT).
Discrete and Fast Fourier Transform for 2D and 3D signals
Mathematically, these two implement the standard Fourier transform in $\mathbb{C}^d$; In torch, there are a family of transform functions dealing with different scenarios:
fft, ifft
are Fourier transform in $\mathbb{C}$.fftn, ifftn
are Fourier transform in $\mathbb{C}^n$.rfft, irfft, rfftn, irfftn
are transforms restricted to $\mathbb{R}^n$ in time domain.hfft, ihfft
, real-valued in frequency domain, known as Hermitian FFT.
For our purpose, let’s take a closer look at the input and output of function torch.fft.rfftn
. Here our input (in the context of video, PDEs, and other spatio-temporal signals), usually takes the form $(B, C, H, W)$ where $B$ is the batch size, $C$ is the number of features on a grid, and $H,W$ are the height of width of a grid. In this case, the fft.rfft2
function can be used, which performs a 2-dimensional DFT on the input’s last two dimensions (equivalently, can do rfftn(dim=[-2,-1])
). In this case, on $H$ and $W$, two spatial dimensions. Example:
import torch
import torch.fft as fft
x = torch.randn((32, 3, 64, 64)) # (b,c,h,w)
x_hat = fft.rfftn(x, dim=[-2,-1]) # rfft on h,w, (b, c, h, w') = (32, 3, 64, 33)
It changes the size of the last dimension, and now for real valued Fourier transform, one can obtain the magnitude of frequency modes. For example:
freq_mode_x0 = torch.abs(x_hat[0]) # (3, 64, 33)
freq_mode_x0_c1 = torch.abs(x_hat[0, 1, :]) # (64, 33)
Contains the frequency modes of the three channels and the first channel, respectively. The frequency mode in 2-D spatial Fourier transform is known as spatial frequency ($\lambda$) while the frequency in 1-D (usually this only concerns time) is known as angular frequency. A special quantity of interest is called the wave number, which is $v = 1/\lambda$.
If we have a spatio-temporal signal, such as the entire rollout of a PDE simulation, then a 3D Fourier transform can be applied to $(T, H, W)$ where $T$ is the number of time-steps, e.g.:
X = torch.randn((32, 3, 10, 64, 64)) # (b, c, t, h, w)
X_hat = fft.rfftn(X, dim=[-3,-2,-1]) # rfft on h,w,t => (32, 3, 10, 64, 33)
While it is very straightforward to plot wavenumber/frequency against amplitude in the 1-D Fourier transform, plotting 2D and higher-dimensional spectrum is not, and I will talk about how to do that below.
The Spectrum Plot for 2D and 3D signals
A popular plot to visualize the result of Fourier transform is called spectrum plot, which can be useful to visualize the behavior of Fourier modes through training. This is important for PDE modeling and especially monitoring the behavior of neural operators applied to turbulent dynamics, because most of the time, these systems only have stable behavior with respect to the Fourier spectrum (also known as Fourier statistics). This is the type of plot that was featured in works that study the training and inference stability of Fourier Neural Operators, such as PDE-Refiner , iFNO , and in Qin, et al. , which analyzes FNO from a spectral perspective.
Note that in PDE-Refiner’s plots, they only provide visualization for the 1-D KS equation, and this is why they can plot wave number against amplitude (since it is a 1-D FFT). However, for the more general case of n-dimensional spatial data, we have to take a look at Qin et al., which proposes a radial energy spectrum plot, a more general way to plot power spectrum over higher-dimensional PDEs (in Qin, et al., they provided the spectrum plot for the 3D Navier-Stokes equation).
In Appendix A of Qin, et al., they described the steps to produce plot as:
- For the spatial signal $X$, Compute its 2D FFT of $\hat{X}$.
- Compute the wave numbers: for each entry in $\hat{X}$, compute the distance of this entry from the center of $\hat{X}$, using
torch.fft.fftshift
. - Bin the wave numbers: divide wave numbers into bins, each bin representing a range of wave numbers.
- Compute energy in each bin: for each bin, sum the squared magnitude of the Fourier coefficients that fall within the corresponding wave number range. Since each entry in $\hat{X}$ has its corresponding bin, this can be done by computing on $\hat{X}$.
Translating to PyTorch code, with the special consideration that data has shape $(b, c, h, w)$ where $b$ is the batch size and $c$ is the number of channels, and we only plot the power spectrum average over batch and select one of the channels.
def plot_spectrum_channel(X, i, n_bins):
# for 1st channel and for average over minibatch
X = torch.mean(X, dim=0) # (c, h, w)
X_hat = torch.fft.fftn(X, dim=(-2, -1)) # (h, w')
X_hat_centered = torch.fft.fftshift(X_hat, dim=(-2, -1))[i, :, :] #(h, w')
wavenumbers = torch.abs(X_hat - X_hat_centered).flatten() # (hw')
amplitudes = torch.abs(X_hat).flatten() # (hw')
energy = amplitudes ** 2 # energy is squared amplitude
# compute bins
bins = torch.arange(0, torch.max(wavenumbers), torch.max(wavenumbers)//n_bins)
indices = torch.bucketize(wavenumbers, bins)
radial_spectrum_profile = torch.zeros(len(bins))
for i in range(1, len(bins)):
mask = indices == i
if torch.any(mask):
radial_profile[i] = torch.sum(energy[mask])
plt.figure(figsize=(10, 5))
plt.plot(bins.numpy(), radial_profile.numpy(), marker='o')
plt.xlabel('Wave number (radial frequency)')
plt.ylabel('Spectral Energy')
plt.title('Spectrum Energy Plot')
plt.grid(True)
Spherical Harmonics Transform Visualization
In this case, we are studying the output of the Spherical Fourier Neural Operator (SFNO), in this work .
First we need to understand the output from the SFNO and ask the following question:
when we talk about visualization of the spherical harmonics, what information do we want to extract?
First, we know that spherical harmonics are defined in the spherical coordinate $(\theta, \phi)$, so a sphere is required to visualize them.
The RealSHT Module
The spherical harmonics transform (SHT) is an extension of the Fourier Transform (FT) on $S^2$. To understand how to visualize it, first we must study the formula and the code
Some basic intuition about SHT and spectral analysis in general is that, function defined on the domain must be embedded into a Hilbert space, which admits an inner product and an orthogonal basis that respects the symmetry of the underlying geometric domain.
In the case of the standard Fourier transform, it is the sine and cosine basis (also called Fourier basis) and can be written in complex form as waves $\exp(i \langle i, x\rangle)$, while for the SHT, it is the Legendre basis $P_l^m(\cos \theta)$, where $(\theta, \phi)$ are coordinate components of the spherical/polar coordinates and $(l,m)$ are integers associated with the polynomial. In the case of earth, we have $0\leq l \leq L, -l\leq m \leq l$, where $L$ is the number of maximal zonal modes (in the paper SFNO, this is the resolution at latitude level, e.g., size of $h$ in the grid). The spherical harmonics are defined as:
$$\begin{align} & Y_l^m (\theta,\phi) = (-1)^m c_l^m P_l^m(\cos \theta) e^{im\phi} = \hat{P}_l^m (\cos(\theta)) e^{im\phi} \\ & c_l^{m} := \sqrt{\frac{2l+1}{4\pi} \frac{(l-m)!}{(l+m)!}} \end{align}$$
which forms the northnormal basis for $L^2(S^2)$, a Hilbert space defined on the sphere. To define the Spherical Harmonic Transform (SHT), first in the SFNO paper, the input and the output are defined: for the SHT operator $\mathcal{F}: u \rightarrow \hat{u}$, where $u\in L^2(S^2)$, we have
$$\begin{align} & u(\theta, \phi) = \sum_{l \in \mathbb{N}}\sum_{|m|\leq l} \hat{u}(l,m) Y_l^m (\theta, \phi) \\ & \hat{u}(l,m) = \int_{S^2} \bar{Y_l^m} \cdot u \underbrace{\sin \theta d\theta d\phi}_{d\Omega} \end{align}$$
where $d\Omega$ is the Lebesgue measure on $S^2$. In actual implementation, however, this integral itself must be approximated with some discrete operations, similar to how DFT and FFT need to be done for the standard Fourier transform. The discussion on the actual implementation for SFT is covered in Appendix B of the SFNO paper, where the actual Fourier transform is defined as:
$$\begin{align} & \hat{u} (\cdots, l, m) = \sum_{j=0}^{H-1} P(l,m,j) FT(u(\cdots, j,k),k)(j,m) \ & P(l,m,j) = (-1)^m c_l^m P_l^m (\cos \theta_j) w(\theta_j) \end{align}$$
So for the actual implementation, Fourier transform is taken for the last two dimensions ($\theta, \phi$) and a weighted sum with respect to the Legendre polynomial (basis) is then computed.
For the inverse transform, the formula is given in the same Appendix section:
$$\begin{align} & u(\cdots, j,k) = FT^{-1}\left[ \sum_{l=0}^L \hat{P}(i,m,j)\hat{u}(\cdots, l, m), m\right] (j,k) \\ & \hat{P}(l,m,j) = (-1)^m c_l^m P_l^m (\cos \theta_j) \end{align}$$
For the code implementation, it is in torch._harmonics.sht.py
:
# RealSHT
def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.nlat)
assert(x.shape[-1] == self.nlon)
# apply real fft in the longitudinal direction
x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
# do the Legendre-Gauss quadrature
x = torch.view_as_real(x)
# distributed contraction: fork
out_shape = list(x.size())
out_shape[-3] = self.lmax
out_shape[-2] = self.mmax
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
# contraction
xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], self.weights.to(x.dtype) )
xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], self.weights.to(x.dtype) )
x = torch.view_as_complex(xout)
return x
# RealISHT
def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) )
im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) )
xs = torch.stack((rl, im), -1)
# apply the inverse (real) FFT
x = torch.view_as_complex(xs)
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
return x
# calling convention for the two classes
from torch_harmonics import RealSHT, InverseRealSHT
nlat = 60
nlon = 2*nlat
lmax = mmax = nlat
sht = RealSHT(nlat, nlon, lmax=lmax, mmax=mmax)
isht = InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax)
The forward RealSHT takes in an input of shape $(b, c, h, w)$, where $(h,w)$ should equal to nlat
and nlon
for the constructor, and returns a shape of $(b, c, h, h)$ where the last dimension corresponds to the coefficients (similar to the standard Fourier transform).
Note that in the terminology of spherical harmonic, $(h,w)$ would be $(l,m)$, hence the transform is $(b,c,l,m) \rightarrow (b,c,l,l)$.
From a software perspective, the calling convention of SHT vs. FFT are close, with the SHT having additional constraint on the input grid shape must agree with nlat, nlon.
For a matrix of size $(l,m)$, We can interpret the entries $(i,j)$ from the transformed matrix, where $0 \leq i,j \leq l$, as the spherical harmonics coefficients $\hat{u}(l,m)$, which determines the amplitude of contribution from the spherical harmonics basis $Y_l^m$. We can therefore perform similar analysis for power spectrum as what we have done for FFT.
Power Spectrum of SHT
Based on the observation above, to visualize the spectrum after SHT, we can follow similar steps as what we have done for FFT, only here the notion of wave number would be different. For the spectrum plot over spherical harmonics, we are not really trying to plot the radial wave number, but x-axis would simply be the degree of Legendre polynomial $l$ (degree of spherical harmonics). This results in a natural binning scheme, of treating each row as a bin and compute row-wise energy (sum of squares of magnitudes). Therefore, to obtain this plot:
- For matrix $X$, Apply 2D SHT to obtain $\hat{X}$
- Compute a row-wise sum of squares, e.g., $e_i = \sum_{j}^m |X_{ij}|^2$, where magnitude is taken to be the absolute value of a complex number.
- Plot $i$ against $e_i$ computed above.
Here again, we talk about only the mean of a mini-batch, with one channel.
def plot_spectrum_channel_spherical(X, i):
X = torch.mean(X, dim=0) #(c, l, l)
X_hat = sht(X) # (c,l,l)
magnitude_spectrum = torch.abs(X_hat[i,:,:]) # (l, l)
degrees = [i for i in range(1,magnitude_spectrum.shape[-1]+1)]
energy = torch.sum(magnitude_spectrum ** 2, dim=-1)
plt.figure(figsize=(10, 5))
plt.plot(degrees, energy, marker='o')
plt.xlabel('degree of spherical harmonics')
plt.ylabel('Spectral Energy')
plt.title('Spectrum Energy Plot')
plt.grid(True)
Instead of using wave number in the case of Fourier transform, here the x-axis is the degree value, which corresponds to the degree of the Legendre polynomial (and the number of latitude on spatial grid). To convert the degrees into wave number, Jean’s formula can be used (although visualization wouldn’t be pretty due to the periodic nature of polar coordinates)
$$ \lambda(l) = \frac{2\pi}{l + 1/2} $$
Possible Future works: One immediate future direction of interest is to expand the spherical FNO’s 2D convolution into a 3D convolution, like what had been done with FNO, where the third dimension is temporal. It is clear in FNO since dimension of $t$ is treated then same as the dimension of $\mathbb{R}^2$. However, with SFNO, this convolution would be defined as over $S^2 \times \mathbb{R}$ rather than $\mathbb{R}^3$. We then need to define a Hilbert space over this product topology and define the spectral basis and transform accordingly.