FURAX: a modular JAX toolbox for solving inverse problems in science

Pierre Chanial, Wassim Kabalan, Simon Biquart


FURAX Library



  • Motivations and Goals: Why and for what FURAX ?

  • FURAX Bulding Blocks: Presentation of the FURAX PyTrees and Operators.

  • Optimizations: High-level algebraic reductions with FURAX.

  • CMB Applications: From map-making to component separation.

Motivations and Goals


  • Modularity, extensibility, simplicity: Easy to experiment new ideas, Fail fast approach
  • JAX: Differentiation, Just In Time (JIT) compilation, run the same code anywhere — on CPUs and GPUs, laptops and super-computers
  • Framework for robust B-mode analysis
  • Able to handle SO- and S4-like data sets volumes, Compatibility with TOAST
  • Non-ideal optical components
  • 1st steps: “max-L” and “template” map-making (following MAPPRAISER’s formalism)
  • Multi-GPU parallelization (soon)

You can try it: pip install furax, bearing in mind that the library is actively developed and currently in a refactoring phase.

FURAX Building Blocks

FURAX PyTrees



FURAX relies on PyTrees to represent the data. For example, for component separation analysis, we can write the generalized sky as a nested PyTree

sky = {
  'cmb': HealpixLandscape(NSIDE, 'IQU').normal(key1),
  'dust': HealpixLandscape(NSIDE, 'IQU').normal(key2),
  'synchrotron': HealpixLandscape(NSIDE, 'IQU').normal(key3),
}

HealpixLandscape(NSIDE, 'IQU') returns an instance of StokesIQUPyTree, which has the attributes i, q and u that store the JAX arrays of the Stokes components.

Also available are StokesIPyTree, StokesQUPyTree and StokesIQUVPyTree.


PyTree are then used by the FURAX Operators:

# Given an acquisition H:
tod = H(sky)

FURAX Operators


The base class AbstractLinearOperator provides a default implementation for the usual linear algebra operations.


Operation FURAX Comment
Addition A + B
Composition A @ B
Multiplication by scalar k * A Returns the composition of a HomothetyOperator and A
Transpose A.T Through JAX autodiff, but can be overriden
Inverse A.I By default, the CG solver is used, but it can be overriden or configured using a context manager
Block Assembly BlockColumnOperator([A, B]) BlockDiagonalOperator([A, B]) BlockRowOperator([A, B]) Handle any PyTree of Operators: Block*Operator({'a': A, 'b': B})
Flattened dense matrix A.as_matrix()
Algebraic reduction A.reduce()

FURAX Operators



Generic Operator Description
IdentityOperator
HomothetyOperator
DiagonalOperator
BroadcastDiagonalOperator Non-square operator for broadcasting
TensorOperator For dense matrix operations
IndexOperator Can be used for projecting skies onto time-ordered series
MoveAxisOperator
ReshapeOperator
RavelOperator
SymmetricBandToeplitzOperator Methods: direct, FFT, overlap and save
Block*Operator Block assembly operators (column, diagonal, row)
Applied Operator Description
QURotationOperator
HWPOperator Ideal HWP
LinearPolarizerOperator Ideal linear polarizer
CMBOperator Parametrized CMB SED
DustOperator Parametrized dust SED
SynchrotronOperator Parametrized synchrotron SED

Multi-level Optimizations

JAX GPU Compilation Chain






From the Python code to the GPU-native code

XLA simplifications


Mathematical identities

  • \(a\times 0 = a - a = 0\)
  • \(a - 0 = a\times 1 = a / 1 = a^1 = a\)
  • \(a^{-1} = 1/a\), \(a^{1/2} = \sqrt{a}\)
  • \(-(-x) = x\)
  • \((-a)(-b) = ab\)
  • \(ac + bc = (a+b)c\)
  • \(a / const = a \times (1 / const)\)
  • \((a + c1) + (b + c2) = a + b + (c1 + c2)\)
  • \((a / b) / (c / d) = ad / bc\)
  • \(\ln e^x = x\)
  • \(\exp a \exp b = \exp(a+b)\)
  • \(a / \exp b = a \exp(-b)\)

Array manipulations

  • slicing
  • reshaping
  • broadcasting
  • transposition
  • bitcast
  • copies

Ol’ Digger’s tricks

  • \(a^2 = a \times a\), \(a^3 = a \times a \times a\)
  • \(a / b = a\)>>\(\log_2 b\) if b is a power of 2
  • \(a \mod b = a \& (b - 1)\) if b is a power of 2

and many more (see xla/hlo/transforms/simplifiers/algebraic_simplifier.cc)

Dead Code Elimination (DCE)

import jax
import jax.numpy as jnp

@jax.jit
def func_dce(x):
    unused = jnp.sin(x)
    y = jnp.exp(x)
    return y[0]

Compiled StableHLO representation
import jax
import jax.numpy as jnp

@jax.jit
def func_full(x):
    return jnp.exp(x)

Full computation vs DCE

XLA Common Subexpression Elimination (CSE)

import jax
import jax.numpy as jnp

@jax.jit
def func_cse(theta):
    a = jnp.sin(theta)
    b = jnp.sin(theta) + 1
    return a + b

XLA Common Subexpression Elimination (CSE)

  • Definition: CSE identifies and eliminates duplicate computations within a function to optimize performance.

  • Example in Code:

    • Without CSE: jnp.sin(theta) computed twice.
    • After CSE: Shared computation across a and b.
  • Benefits:

    • Reduces redundant computation.
    • Enhances runtime efficiency and memory usage.

Compiled StableHLO representation

JAX GPU Compilation Chain with FURAX






From the Python code to the GPU-native code

FURAX Algebraic Reductions: Composition of Rotations

import jax
import jax.numpy as jnp

def rot(x, y, theta):
    rotated_x = x * jnp.cos(theta) - y * jnp.sin(theta)
    rotated_y = x * jnp.sin(theta) + y * jnp.cos(theta)
    return rotated_x, rotated_y

@jax.jit
def func(x, y, theta1, theta2):
    return rot(x, y, theta=theta1 + theta2)

Compiled StableHLO representation

FURAX Algebraic Reductions: Composition of Rotations

import jax
import jax.numpy as jnp

def rot(x, y, theta):
    rotated_x = x * jnp.cos(theta) - y * jnp.sin(theta)
    rotated_y = x * jnp.sin(theta) + y * jnp.cos(theta)
    return rotated_x, rotated_y

@jax.jit
def func(x, y, theta1, theta2):
    x, y = rot(x, y, theta=theta1)
    return rot(x, y, theta=theta2)

Compiled StableHLO representation

FURAX Algebraic Reductions: Instrument Acquisition


Given this modeling of the acquisition, using an ideal linear polarizer and an ideal half wave plate: \[ \mathbf{H} = \mathbf{C}_{\textrm{LP}} \, \mathbf{R}_{2\theta} \, \mathbf{R}_{-2\phi} \, \mathbf{C}_{\textrm{HWP}} \, \mathbf{R}_{2\phi} \, \mathbf{R}_{2\psi} \, \mathbf{P} \] with

  • \(\theta\): detector polarization angle
  • \(\phi\): HWP rotation angle
  • \(\psi\): telescope rotation angle

FURAX reduces this expression to:

\[ \mathbf{H} = \mathbf{C}_{\textrm{LP}} \, \mathbf{R}_{-2\theta + 4\phi + 2\psi}\, \mathbf{P} \]

FURAX Algebraic Reductions: Pointing Matrix


When the time-time noise covariance matrix \(\mathbf{N}\) is diagonal and \(\mathbf{P}\) is a “one-to-one” intensity projection matrix:

\[ \mathbf{P} = \begin{bmatrix} 0 & \cdots & 0 & 1 & 0 & \cdots & 0 \\ 0 & 1 & 0 & \cdots & 0 & 0 & \cdots & 0 \\ \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots \\ 1 & 0 & \cdots & 0 & 0 & \cdots & 0 \\ \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots \\ 0 & 0 & \cdots & 0 & 1 & 0 & \cdots & 0 \\ \end{bmatrix}, \]



  • the product \(\mathbf{P}^\top \mathbf{N}^{-1} \mathbf{P}\) is diagonal and can therefore be easily inverted (for pixel-pixel covariance matrix or preconditioning).

  • Each term is related to the number of noise-weighted times a pixel of the map has been observed.

  • For IQU maps, the product is block diagonal, with 3x3 blocks that can also be easily inverted.

  • By adding a rule for this operation, we’ve seen an performance improvement by more than a factor of 10 in the forward application (WIP, currently only for \(\mathbf{N}\) scalar).

FURAX Algebraic Reductions: Block Assembly



Operation Reduction
BlockDiagonalOperator([D1, D2]) @ BlockColumnOperator([C1, C2]) BlockColumnOperator([D1 @ C1, D2 @ C2])
BlockRowOperator([R1, R2]) @ BlockDiagonalOperator([D1, D2]) BlockRowOperator([R1 @ D1, R2 @ D2])
BlockRowOperator([R1, R2]) @ BlockColumnOperator([C1, C2]) R1 @ C1 + R2 @ C2



Practical use case:

Given two observations

\[ \mathbf{P} = \begin{bmatrix} \mathbf{P}_1 \\ \mathbf{P}_2 \end{bmatrix}, \quad \mathbf{N}^{-1} = \begin{bmatrix} \mathbf{N}_1^{-1} & 0 \\ 0 & \mathbf{N}_2^{-1} \end{bmatrix}, \]

The combination is reduced to \[ \mathbf{P}^\top \mathbf{N}^{-1} \mathbf{P} = \mathbf{P}_1^\top \mathbf{N}_1^{-1} \mathbf{P}_1^{\,} + \mathbf{P}_2^\top \mathbf{N}_2^{-1} \mathbf{P}_2^{\,}. \]

CMB Applications





Credits: SciPol Ph.D students

Maximum-Likelihood Map-Making


Classic data model

\[ d = \mathbf{P}s + n \]

  • \(d\): time-ordered data
  • \(\mathbf{P}\): pointing matrix ( telescope scanning)
  • \(s\): discretized sky signal
  • \(n\): stochastic contribution (noise)


Optimal (GLS) solution:

\[ \widehat{s} = (\mathbf{P}^\top \mathbf{N}^{-1} \mathbf{P})^{-1} \mathbf{P}^\top \mathbf{N}^{-1} d \]

Generalized parametric data model

\[ d_{\nu, i, t} = \int_{\textrm{BP}_\nu} d\nu' \mathbf{M}^{(\gamma)}_{\nu', i, t, p} \mathbf{A}^{(\beta)}_{\nu', t, c, p} s_{c, p} + n_{\nu, i, t} \]

  • \(\mathbf{M}\) includes HWP parameters, band passes, beam properties, gains, …
  • \(\mathbf{A}\) includes the frequency modeling of CMB, astrophysical foregrounds, atmosphere, ground
  • \(\mathbf{H} = \mathbf{MA}\) is the generalized pointing operator

Credit: Simon Biquart

Maximum-Likelihood Map-Making


\(N^{-1}\) Block symmetric band Toeplitz

Furax example

h = bandpass @ pol @ rot @ hwp @ proj @ mixing
tod = h(sky)
solution = ((h.T @ invN @ h).I @ h.T @ invN)(gap_fill(key, tod))

Credit: Simon Biquart

Writing this with FURAX tools


polarizer = LinearPolarizerOperator.create(detector_angles)
hwp = HWPOperator.create(hwp_angles)
rotation = QURotationOperator.create(telescope_angle)
sampling = SamplingOperator(pixel_indices)
h = polarizer @ hwp @ rotation @ sampling
invN = SymmetricBandToeplitzOperator(noise_correlations)
L = (h.T @ invN @ h).I @ h.T @ invN
estimate = L(data)




FURAX Map-Making

Can be extended & complexified easily!. - non-ideal components - parametric data model

Credit: Simon Biquart

Quantifying Biases from Non-Ideal Half Wave Plates


Non-ideal Half Wave Plate

  • Made of several stacked layers
  • Takes into account the transmission and reflection of the incident electromagnetic field at the layers boundaries

Furax modeling

h = pol @ bandpass @ hwp @ mixing_matrix @ projection
sol = ((h.T @ h).I @ h.T)(tod)

Credit: Ema Tsang King Sang

Component Separation

  • The Cosmic Microwave Background (CMB) signal is obscured by various foregrounds, making it challenging to detect the true cosmological information.
  • Dust: Emission from galactic dust adds significant noise to the CMB, particularly affecting polarization measurements.
  • Synchrotron Radiation: Electrons spiraling in the galaxy’s magnetic fields produce synchrotron radiation, another major contaminant.

Component separation methods

  • Blind Methods: Like SMICA (Spectral Matching Independent Component Analysis)
  • Parametric Methods: Like FGbuster (Foreground Buster)

Credit: Wassim Kabalan

On-going development of parametric component separation within the FURAX framework

Does everything fgbuster does but “better”

  • Uses FURAX linear algebra operators to efficiently represent the mixing matrix
  • Is written in JAX, is hardware accelerated
  • Provides easy access to gradients

Beyond fgbuster

  • Automatic cluster detection for spectral index parameters
  • Flexible likelihood model (can be extended to include more components and different objective functions)

Credit: Wassim Kabalan

On-going development of parametric component separation within the FURAX framework


Creating a Mixing Matrix Operator for SED evaluation

from furax.seds import \
  CMBOperator,DustOperator, SynchrotronOperator

def make_A(TEMP_D , BETA_D , BETA_S):
  cmb = CMBOperator(nu, in_structure)
  dust = DustOperator(nu , TEMP_D , BETA_D)
  synchrotron = SynchrotronOperator(nu , BETA_S)

  mixed_sed = BlockRowOperator({
          'cmb': cmb,
          'dust': dust,
          'synchrotron': synchrotron,
      })
  return mixed_sed

Trivial construction of a spectral likelihood functions

import jax

@jax.jit
def negative_log_prob(params, d):
  A = make_A(params['TEMP_D'], 
             params['BETA_D'], 
             params['BETA_S'])

  x = (A.T @ invN)(d)
  l = jax.tree.map(lambda a, b: a @ b, x, (A.T @ invN @ A).I(x))\
                       S
  summed_log_prob = jax.tree.reduce(operator.add, l)

  return -summed_log_prob


Easy to evaluate and extend

  • The likelihood function is readily available and can be easily extended to include more components or different objective functions.
  • Easy access to gradients and hessians for optimization.

Credit: Wassim Kabalan

Atmosphere Decontamination Using Time-Domain Component Separation

Science Goal

For the Simons Observatory, characterize the observed atmospheric template from the recorded time-ordered data to separate the atmosphere from the sky signal we are after.

Data Model

\[ d_{\text{atm}} = \mathbf{A}(\text{PWV}) \mathbf{P}(w_x, w_y) s_{\text{atm}} + n \]

Model Parameters

  • Wind velocity: \(\vec{w} = (w_x, w_y)\)
  • Precipitable Water Vapour (PWV): ~Amplitude of atmospheric fluctuations.
  • Estimate parameters by minimizing the spectral likelihood.
  • Atmospheric time-ordered data recorded by our telescope at two frequencies: \(d_{\text{atm}}\)

  • Mixing matrix \(\mathbf{A} = \begin{bmatrix} a(\text{PWV}_{1}) & 0 \\ 0 & a(\text{PWV}_{2}) \end{bmatrix}\)

  • Pointing matrix \(\mathbf{P} = \begin{bmatrix} P_{1} \\ P_{2} \end{bmatrix}\)

  • Atmospheric template: \(s_{\text{atm}}\)

  • Noise: \(n\), of covariance matrix \(\mathbf{N} = \begin{bmatrix} \sigma_{1}^{2} & 0 \\ 0 & \sigma_{2}^{2} \end{bmatrix}\)

Spectral Likelihood

\[ \boxed{ \langle \delta \mathcal{S}_\text{spec}(w_x, w_y, \text{PWV} \mid \vec{d}_{\text{atm}}) \rangle = \vec{d}_{\text{atm}}^\top \cdot \mathbf{AP} \Big[ (\mathbf{AP})^\top \mathbf{N}^{-1} (\mathbf{AP}) \Big]^{-1} (\mathbf{AP})^\top \mathbf{N}^{-1} d_{\text{atm}} } \]

Credit: Amalia Villarrubia Aguilar

Atmosphere Decontamination Using Time-Domain Component Separation

Spectral likelihood computation using FURAX

def average_spectral_likelihood_noise(d_atm, w_vec, PWV):
    # POINTING matrix (shape: (n_detectors * N_obs * n_freq) x N_pix)
    P_single_freq = pointing_matrix_single_freq(detector_pointings_t, t_obs, n_detectors, z_atm, d_pix_sim, N_pix_sim, w_vec)
    P = BlockColumnOperator({‘93’: P_single_freq, ‘145’: P_single_freq})
    # MIXING matrix: (shape: (n_detectors * N_obs * n_freq) x (n_detectors * N_obs * n_freq))
    A_block_structure = StokesIPyTree.structure_for((n_detectors,N_obs))
    A_93 = HomothetyOperator(atm_emission(PWV, mu_93)/atm_emission(PWV, mu_93), in_structure=A_block_structure)
    A_145 = HomothetyOperator(atm_emission(PWV, mu_145)/atm_emission(PWV, mu_93), in_structure=A_block_structure)
    A = BlockDiagonalOperator({‘93’: A_93, ‘145’: A_145})
    # COMPOSITION matrix: mixing matrix @ pointing matrix (shape: (n_detectors * N_obs * n_freq) x N_pix)
    C = A @ P
    # NOISE covariance matrix (shape: (n_det x N_obs x n_freq) x (n_det x N_obs x n_freq))
    N_block_structure = StokesIPyTree.structure_for((n_detectors,N_obs))
    N_93 = HomothetyOperator(noise_variance_93, in_structure=N_block_structure)
    N_145 = HomothetyOperator(noise_variance_145, in_structure=N_block_structure)
    N = BlockDiagonalOperator({‘93’: N_93, ‘145’: N_145})
    # Spectral likelihood computation:
    core_op = (C.T @ N.I @ C).I
    full_op = N.I @ C @ core_op @ C.T @ N.I
    S = - StokesIPyTree(d_atm) @ StokesIPyTree(full_op(d_atm))
    return S
  • Easily extensible
  • Easy access to gradients

Credit: Amalia Villarrubia Aguilar

Atmosphere Decontamination Using Time-Domain Component Separation


Spectral likelihood minimization

Spectral likelihood gridding



Gridding process

➜ This spectral likelihood is minimised through gridding: we compute \(\langle \delta \mathcal{S}_\text{spec}(w_x, w_y \mid \text{PWV}_{\text{sim}}) \rangle\)
for 22,500 different combinations of \((w_x, w_y)\).

Credit: Amalia Villarrubia Aguilar