In JAX and PyTorch

Many areas of science and engineering encounter data defined on the sphere. Modelling and analysis of such data often requires the spherical counterpart to the Fourier transform — the spherical harmonic transform. We provide a brief overview of the spherical harmonic transform and present a new differentiable algorithm tailored towards acceleration on GPUs [1]. This algorithm is implemented in the recently released S2FFT python package, which supports both JAX and PyTorch.

[Image created by authors.]

Increasingly often we are interested in analysing data that lives on the sphere. The diversity in applications is remarkable, ranging from quantum chemistry, biomedical imaging, climate physics and geophysics, to the wider cosmos.

The most well-known areas in which one encounters data on the sphere are within the physical sciences, particularly within atmospheric science, geophysical modelling, and astrophysics.

Examples of the most widely known cases of spherical data, such as the Earth (left) and artist impression of astronomical observations (right). [Earth image sourced from Wikipedia; astrophysics image sourced from Wikipedia.]

These problems are naturally spherical as observations are made at each point on the surface of a sphere: the surface of the Earth for geophysics and the sky for astrophysics. Other examples come from applications like computer graphics and vision, where 360° panoramic cameras capture the world around you in every direction.

In many cases the spherical nature of the problem at hand is fairly easy to see; however, this is not always the case. Perhaps surprisingly, spherical data is quite frequently encountered within the biological disciplines, though the spherical aspect is often much less obvious! Since we are often concerned about local directions in biological studies, such as the direction water diffuses within the brain, we encounter spherical data.

Diffusion tensor imaging of neuronal connections in the human brain. Within each voxel neurons are free to travel in any direction, so the problem is naturally spherical. [Animation by Alfred Anwander, CC-BY licence.]

Given the prevalence of such data, it isn’t surprising that many spherical analysis techniques have been developed. A frequency analysis of the data can be insightful, often to afford a statistical summary or an effective representation for further analysis or modelling. Recently geometric deep learning techniques have proven highly effective for the analysis of data on complex domains [2–6], particularly for highly complex problems such as molecular modelling and protein interactions (see our prior post on A Brief Introduction to Geometric Deep Learning).

So we have data on the sphere and a variety of techniques by which spherical data may be analysed, but we need mathematical tools to do so. Specifically, we need to know how to decompose spherical data into frequencies efficiently.

The Fourier transforms provides a frequency decomposition that is often used to calculate statistical correlations within data. Many physical systems may also be described more straightforwardly in frequency space, as each frequency may evolve independently.

To extend the standard Fourier transform to the sphere, we need the meeting of minds of two 17th century French mathematicians: Joseph Fourier and Adrien-Marie Legendre.

Joseph Fourier (left) and Adrien-Marie Legendre (right). Tragically, the caricature of Legendre is the only known image of him. [Fourier image sourced from Wikipedia. Legendre image sourced from Wikipedia.]

First, let’s consider how to decompose Euclidean data into its various frequencies. Such a transformation of the data was first derived by Joseph Fourier and is given by

which is found almost everywhere and is a staple of undergraduate physics for a reason! This works by projecting our data f(x) onto a set of trigonometric functions, called a basis. One can do effectively the same thing on the sphere, but the basis functions are now given by the spherical harmonics Yₗₘ:

where (θ, ϕ) are the usual spherical polar co-ordinates.

Spherical harmonic basis functions (real component). [Sourced from Wikipedia.]

The spherical harmonics (shown above) can be broken down further into the product of an exponential and Legendre polynomials — à la Adrien-Marie Legendre — as

And so the spherical harmonic transform can be written as a Fourier transform followed by an associated Legendre transform. The real difficulty comes in evaluating the Legendre part of the transform: it is either computationally expensive or memory hungry, depending on the method one chooses.

The growth of differentiable programming is opening up many new types of analysis. In particular, many applications require spherical transforms that are differentiable.

Machine learning models on the sphere require differentiable transforms so that models may be trained by gradient-based optimisation algorithms, i.e. through back-propagation.

Emerging physics-enhanced machine learning approaches [7] for hybrid data-driven and model-based approaches [8] also require differentiable physics models, which in many cases themselves require differentiable spherical transforms.

With this in mind it is clear that for modern applications an efficient algorithm for the spherical harmonic transform is necessary but not enough. Differentiability is key.

This is all well and good, but how does one efficiently evaluate the spherical harmonic transform? A variety of algorithms have been developed, with some great software packages. However for modern applications we need one that is differentiable, can run on hardware accelerators like GPUs, and is computationally scalable.

By redesigning the core algorithms from the ground up (as described in depth in our corresponding paper [1]), we recently developed a python package called S2FFT that should fit the bill.

S2FFT is implemented in JAX, a differentiable programming language developed by Google, and also includes a PyTorch frontend.

S2FFT is a Python package implementing differentiable and accelerated spherical harmonic transforms, with interfaces in JAX and PyTorch. [Image created by authors.]

S2FFT provides two operating modes: precompute the associated Legendre functions, which are then accessed at run time; or compute them on-the-fly during the transform. The pre-compute approach is just about as fast as you can get, but the memory required to store all Legendre function values scales cubicly with resolution, which can be a problem! The second approach we provide instead recursively computes Legendre terms on-the-fly, and so can be scaled to very high resolutions.

In addition, S2FFT also supports a hybrid automatic and manual differentiation approach so that gradients can be computed efficiently.

The package is designed to support multiple different sampling schemes on the sphere. At launch, equiangular (McEwen & Wiaux [9], Driscoll & Healy [10]), Gauss-Legendre, and HEALPix [11] sampling schemes are supported, although others may easily be added in future.

Different sampling schemes on the sphere supported by S2FFT. [Original figure created by authors.]

The S2FFT package is available on PyPi so anyone can install it straightforwardly by running:

pip install s2fft

Or to pick up PyTorch support by running:

pip install "s2fft[torch]"

From here the top-level transforms can be called simply by

import s2fft

# Compute forward spherical harmonic transform
flm = s2fft.forward_jax(f, L)

# Compute inverse spherical harmonic transform
f = s2fft.inverse_jax(flm, L)

These functions can be picked up out of the box and integrated as layers within existing models, both in JAX and PyTorch, with full support for both forward and reverse mode differentiation.

With researchers becoming increasingly interested in differentiable programming for scientific applications, there is a critical need for modern software packages that implement the foundational mathematical methods on which science is often based, like the spherical harmonic transform.

We hope S2FFT will be of great use in coming years and are excited to see what people will use it for!

[1] Price & McEwen, Differentiable and accelerated spherical harmonic and Wigner transforms, arxiv:2311.14670 (2023).

[2] Bronstein, Bruna, Cohen, Velickovic, Geometric Deep Learning: Grids, Groups, Graphs, Geodesics, and Gauges, arXix:2104.13478 (2021).

[3] Ocampo, Price & McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023).

[4] Cobb, Wallis, Mavor-Parker, Marignier, Price, d’Avezac, McEwen, Efficient Generalised Spherical CNNs, ICLR (2021).

[5] Cohen, Geiger, Koehler, Welling, Spherical CNNs, ICLR (2018).

[6] Jumper et al., Highly accurate protein structure prediction with AlphaFold, Nature (2021).

[7] Karniadakis et al, Physics-informed machine learning, Nature Reviews Physics (2021).

[8] Campagne et al., Jax-cosmo: An end-to-end differentiable and GPU accelerated cosmology library, arXiv:2302.05163 (2023).

[9] McEwen & Wiaux, A novel sampling theorem on the sphere, IEEE TSP (2012).

[10] Driscoll & Healy, Computing Fourier Transforms and Convolutions on the 2-Sphere, AAM (1994).

[11] Gorski et al., HEALPix: a Framework for High Resolution Discretization, and Fast Analysis of Data Distributed on the Sphere, ApJ (2005).

Leave a Reply