halox: Dark matter halos in JAX

Published:

The cosmology ecosystem has excellent tools for halo modeling (colossus, pyCCL, halotools), but almost none of them are differentiable or GPU-accelerated. This matters increasingly as the field moves toward gradient-based inference methods (Hamiltonian Monte Carlo, variational inference) and ML-integrated pipelines - workflows where you need to compute gradients through your physical model, and where GPU throughput is the main bottleneck.

halox fills this gap: it is a JAX-native library implementing the core quantities needed for halo-based cosmology, making them JIT-compilable, GPU-accelerated, and fully differentiable out of the box:

  • NFW radial profiles - density and enclosed mass for Navarro-Frenk-White dark matter halos
  • Halo mass function - the abundance of halos as a function of mass, redshift, and cosmological parameters
  • Halo bias - the clustering bias between halos and the underlying dark matter field

Because halox functions are written in JAX, they compose naturally with the rest of the JAX ecosystem: they can be vmapped over batches of halos or cosmologies, JIT-compiled for GPU, and differentiated with jax.grad - making them drop-in components for HMC samplers, ML training loops, or sensitivity analyses.

All implementations are validated against established reference libraries (Astropy, colossus) across wide ranges of halo masses, redshifts, and cosmological parameters. These tests run automatically on every commit via a CI/CD pipeline, with results visualized in the online documentation.

Github: fkeruzore/halox Documentation: halox.readthedocs.io