Pierre Chanial, Wassim Kabalan, Simon Biquart
You can try it: pip install furax
, bearing in mind that the library is actively developed and currently in a refactoring phase.
FURAX relies on PyTrees to represent the data. For example, for component separation analysis, we can write the generalized sky as a nested PyTree
sky = {
'cmb': HealpixLandscape(NSIDE, 'IQU').normal(key1),
'dust': HealpixLandscape(NSIDE, 'IQU').normal(key2),
'synchrotron': HealpixLandscape(NSIDE, 'IQU').normal(key3),
}
HealpixLandscape(NSIDE, 'IQU')
returns an instance of StokesIQUPyTree
, which has the attributes i
, q
and u
that store the JAX arrays of the Stokes components.
Also available are StokesIPyTree
, StokesQUPyTree
and StokesIQUVPyTree
.
PyTree are then used by the FURAX Operators:
The base class AbstractLinearOperator
provides a default implementation for the usual linear algebra operations.
Operation | FURAX | Comment |
---|---|---|
Addition | A + B |
|
Composition | A @ B |
|
Multiplication by scalar | k * A |
Returns the composition of a HomothetyOperator and A |
Transpose | A.T |
Through JAX autodiff, but can be overriden |
Inverse | A.I |
By default, the CG solver is used, but it can be overriden or configured using a context manager |
Block Assembly | BlockColumnOperator([A, B]) BlockDiagonalOperator([A, B]) BlockRowOperator([A, B]) |
Handle any PyTree of Operators: Block*Operator({'a': A, 'b': B}) |
Flattened dense matrix | A.as_matrix() |
|
Algebraic reduction | A.reduce() |
Generic Operator | Description |
---|---|
IdentityOperator |
|
HomothetyOperator |
|
DiagonalOperator |
|
BroadcastDiagonalOperator |
Non-square operator for broadcasting |
TensorOperator |
For dense matrix operations |
IndexOperator |
Can be used for projecting skies onto time-ordered series |
MoveAxisOperator |
|
ReshapeOperator |
|
RavelOperator |
|
SymmetricBandToeplitzOperator |
Methods: direct, FFT, overlap and save |
Block*Operator |
Block assembly operators (column, diagonal, row) |
Applied Operator | Description |
---|---|
QURotationOperator |
|
HWPOperator |
Ideal HWP |
LinearPolarizerOperator |
Ideal linear polarizer |
CMBOperator |
Parametrized CMB SED |
DustOperator |
Parametrized dust SED |
SynchrotronOperator |
Parametrized synchrotron SED |
Mathematical identities
Array manipulations
Ol’ Digger’s tricks
and many more (see xla/hlo/transforms/simplifiers/algebraic_simplifier.cc)
import jax
import jax.numpy as jnp
@jax.jit
def func_cse(theta):
a = jnp.sin(theta)
b = jnp.sin(theta) + 1
return a + b
XLA Common Subexpression Elimination (CSE)
Definition: CSE identifies and eliminates duplicate computations within a function to optimize performance.
Example in Code:
jnp.sin(theta)
computed twice.a
and b
.Benefits:
import jax
import jax.numpy as jnp
def rot(x, y, theta):
rotated_x = x * jnp.cos(theta) - y * jnp.sin(theta)
rotated_y = x * jnp.sin(theta) + y * jnp.cos(theta)
return rotated_x, rotated_y
@jax.jit
def func(x, y, theta1, theta2):
return rot(x, y, theta=theta1 + theta2)
import jax
import jax.numpy as jnp
def rot(x, y, theta):
rotated_x = x * jnp.cos(theta) - y * jnp.sin(theta)
rotated_y = x * jnp.sin(theta) + y * jnp.cos(theta)
return rotated_x, rotated_y
@jax.jit
def func(x, y, theta1, theta2):
x, y = rot(x, y, theta=theta1)
return rot(x, y, theta=theta2)
Given this modeling of the acquisition, using an ideal linear polarizer and an ideal half wave plate: \[ \mathbf{H} = \mathbf{C}_{\textrm{LP}} \, \mathbf{R}_{2\theta} \, \mathbf{R}_{-2\phi} \, \mathbf{C}_{\textrm{HWP}} \, \mathbf{R}_{2\phi} \, \mathbf{R}_{2\psi} \, \mathbf{P} \] with
FURAX reduces this expression to:
\[ \mathbf{H} = \mathbf{C}_{\textrm{LP}} \, \mathbf{R}_{-2\theta + 4\phi + 2\psi}\, \mathbf{P} \]
When the time-time noise covariance matrix \(\mathbf{N}\) is diagonal and \(\mathbf{P}\) is a “one-to-one” intensity projection matrix:
\[ \mathbf{P} = \begin{bmatrix} 0 & \cdots & 0 & 1 & 0 & \cdots & 0 \\ 0 & 1 & 0 & \cdots & 0 & 0 & \cdots & 0 \\ \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots \\ 1 & 0 & \cdots & 0 & 0 & \cdots & 0 \\ \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots \\ 0 & 0 & \cdots & 0 & 1 & 0 & \cdots & 0 \\ \end{bmatrix}, \]
the product \(\mathbf{P}^\top \mathbf{N}^{-1} \mathbf{P}\) is diagonal and can therefore be easily inverted (for pixel-pixel covariance matrix or preconditioning).
Each term is related to the number of noise-weighted times a pixel of the map has been observed.
For IQU maps, the product is block diagonal, with 3x3 blocks that can also be easily inverted.
By adding a rule for this operation, we’ve seen an performance improvement by more than a factor of 10 in the forward application (WIP, currently only for \(\mathbf{N}\) scalar).
Operation | Reduction |
---|---|
BlockDiagonalOperator([D1, D2]) @ BlockColumnOperator([C1, C2]) |
BlockColumnOperator([D1 @ C1, D2 @ C2]) |
BlockRowOperator([R1, R2]) @ BlockDiagonalOperator([D1, D2]) |
BlockRowOperator([R1 @ D1, R2 @ D2]) |
BlockRowOperator([R1, R2]) @ BlockColumnOperator([C1, C2]) |
R1 @ C1 + R2 @ C2 |
Given two observations
\[ \mathbf{P} = \begin{bmatrix} \mathbf{P}_1 \\ \mathbf{P}_2 \end{bmatrix}, \quad \mathbf{N}^{-1} = \begin{bmatrix} \mathbf{N}_1^{-1} & 0 \\ 0 & \mathbf{N}_2^{-1} \end{bmatrix}, \]
The combination is reduced to \[ \mathbf{P}^\top \mathbf{N}^{-1} \mathbf{P} = \mathbf{P}_1^\top \mathbf{N}_1^{-1} \mathbf{P}_1^{\,} + \mathbf{P}_2^\top \mathbf{N}_2^{-1} \mathbf{P}_2^{\,}. \]
Credits: SciPol Ph.D students
\[ d = \mathbf{P}s + n \]
\[ \widehat{s} = (\mathbf{P}^\top \mathbf{N}^{-1} \mathbf{P})^{-1} \mathbf{P}^\top \mathbf{N}^{-1} d \]
\[ d_{\nu, i, t} = \int_{\textrm{BP}_\nu} d\nu' \mathbf{M}^{(\gamma)}_{\nu', i, t, p} \mathbf{A}^{(\beta)}_{\nu', t, c, p} s_{c, p} + n_{\nu, i, t} \]
Credit: Simon Biquart
h = bandpass @ pol @ rot @ hwp @ proj @ mixing
tod = h(sky)
solution = ((h.T @ invN @ h).I @ h.T @ invN)(gap_fill(key, tod))
Credit: Simon Biquart
polarizer = LinearPolarizerOperator.create(detector_angles)
hwp = HWPOperator.create(hwp_angles)
rotation = QURotationOperator.create(telescope_angle)
sampling = SamplingOperator(pixel_indices)
h = polarizer @ hwp @ rotation @ sampling
invN = SymmetricBandToeplitzOperator(noise_correlations)
L = (h.T @ invN @ h).I @ h.T @ invN
estimate = L(data)
FURAX Map-Making
Can be extended & complexified easily!. - non-ideal components - parametric data model
Credit: Simon Biquart
Credit: Ema Tsang King Sang
Component separation methods
Credit: Wassim Kabalan
Credit: Wassim Kabalan
from furax.seds import \
CMBOperator,DustOperator, SynchrotronOperator
def make_A(TEMP_D , BETA_D , BETA_S):
cmb = CMBOperator(nu, in_structure)
dust = DustOperator(nu , TEMP_D , BETA_D)
synchrotron = SynchrotronOperator(nu , BETA_S)
mixed_sed = BlockRowOperator({
'cmb': cmb,
'dust': dust,
'synchrotron': synchrotron,
})
return mixed_sed
Easy to evaluate and extend
Credit: Wassim Kabalan
For the Simons Observatory, characterize the observed atmospheric template from the recorded time-ordered data to separate the atmosphere from the sky signal we are after.
\[ d_{\text{atm}} = \mathbf{A}(\text{PWV}) \mathbf{P}(w_x, w_y) s_{\text{atm}} + n \]
Model Parameters
Atmospheric time-ordered data recorded by our telescope at two frequencies: \(d_{\text{atm}}\)
Mixing matrix \(\mathbf{A} = \begin{bmatrix} a(\text{PWV}_{1}) & 0 \\ 0 & a(\text{PWV}_{2}) \end{bmatrix}\)
Pointing matrix \(\mathbf{P} = \begin{bmatrix} P_{1} \\ P_{2} \end{bmatrix}\)
Atmospheric template: \(s_{\text{atm}}\)
Noise: \(n\), of covariance matrix \(\mathbf{N} = \begin{bmatrix} \sigma_{1}^{2} & 0 \\ 0 & \sigma_{2}^{2} \end{bmatrix}\)
\[ \boxed{ \langle \delta \mathcal{S}_\text{spec}(w_x, w_y, \text{PWV} \mid \vec{d}_{\text{atm}}) \rangle = \vec{d}_{\text{atm}}^\top \cdot \mathbf{AP} \Big[ (\mathbf{AP})^\top \mathbf{N}^{-1} (\mathbf{AP}) \Big]^{-1} (\mathbf{AP})^\top \mathbf{N}^{-1} d_{\text{atm}} } \]
Credit: Amalia Villarrubia Aguilar
def average_spectral_likelihood_noise(d_atm, w_vec, PWV):
# POINTING matrix (shape: (n_detectors * N_obs * n_freq) x N_pix)
P_single_freq = pointing_matrix_single_freq(detector_pointings_t, t_obs, n_detectors, z_atm, d_pix_sim, N_pix_sim, w_vec)
P = BlockColumnOperator({‘93’: P_single_freq, ‘145’: P_single_freq})
# MIXING matrix: (shape: (n_detectors * N_obs * n_freq) x (n_detectors * N_obs * n_freq))
A_block_structure = StokesIPyTree.structure_for((n_detectors,N_obs))
A_93 = HomothetyOperator(atm_emission(PWV, mu_93)/atm_emission(PWV, mu_93), in_structure=A_block_structure)
A_145 = HomothetyOperator(atm_emission(PWV, mu_145)/atm_emission(PWV, mu_93), in_structure=A_block_structure)
A = BlockDiagonalOperator({‘93’: A_93, ‘145’: A_145})
# COMPOSITION matrix: mixing matrix @ pointing matrix (shape: (n_detectors * N_obs * n_freq) x N_pix)
C = A @ P
# NOISE covariance matrix (shape: (n_det x N_obs x n_freq) x (n_det x N_obs x n_freq))
N_block_structure = StokesIPyTree.structure_for((n_detectors,N_obs))
N_93 = HomothetyOperator(noise_variance_93, in_structure=N_block_structure)
N_145 = HomothetyOperator(noise_variance_145, in_structure=N_block_structure)
N = BlockDiagonalOperator({‘93’: N_93, ‘145’: N_145})
# Spectral likelihood computation:
core_op = (C.T @ N.I @ C).I
full_op = N.I @ C @ core_op @ C.T @ N.I
S = - StokesIPyTree(d_atm) @ StokesIPyTree(full_op(d_atm))
return S
Credit: Amalia Villarrubia Aguilar
Gridding process
➜ This spectral likelihood is minimised through gridding: we compute \(\langle \delta \mathcal{S}_\text{spec}(w_x, w_y \mid \text{PWV}_{\text{sim}}) \rangle\)
for 22,500 different combinations of \((w_x, w_y)\).
Credit: Amalia Villarrubia Aguilar