Using the picasso analytical gas model#

This notebook shows how one can use the picasso.polytrop and picasso.nonthermal modules to compute gas properties from a gravitational potential distribution. For a full documentation of the functions available in both modules, see picasso.polytrop: Polytropic gas model and picasso.nonthermal: Non-thermal pressure support.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from astropy.table import Table
from astropy.cosmology import FlatLambdaCDM

from picasso import polytrop, nonthermal
from picasso.utils.plots import NFW

import seaborn as sns
sns.set_style("darkgrid")
sns.set_theme("notebook")

benchmark = True

The analytical gas model#

The polytropic gas model can be written as (see Kéruzoré+24):

\[ \rho(\phi, \, r) = \rho_0 \theta^{\Gamma(r) / (\Gamma(r) - 1)}(\phi), \\[10pt] P(\phi, \, r) = P_0 \theta^{1 / (\Gamma(r) - 1)}(\phi), \]

where \(\phi\) is the halo’s gravitational potential, and $\( \theta(\phi) = 1 - \theta_0 (\phi - \phi_0), \)$

The gas polytropic index, \(\Gamma\), is allowed to vary with radius as:

\[ \Gamma(r) = \begin{cases} \begin{aligned} & \; 1 + (\Gamma_0 - 1) \frac{1}{1 + e^{-x}} & c_\Gamma > 0; \\ & \; \Gamma_0 & c_\Gamma = 0; \\ & \; \Gamma_0 + (\Gamma_0 - 1) \left(1 - \frac{1}{1 + e^{x}}\right) & c_\Gamma < 0, \\ \end{aligned} \end{cases} \]

with \(x \equiv r / (c_\gamma R_{500c})\).

This model has five parameters: \((\rho_0, P_0)\) are the central value of gas density and pressure, \(\Gamma_0\) is the asymptotic value of the polytropic index as \(r \rightarrow \infty\), \(c_\gamma\) is the polytropic concentration (\(c_\gamma = 0\) implies \(\Gamma(r) = \Gamma_0\)), and \(\theta_0\) is a shape parameter.

We further write the fraction of non-thermal pressure as a power-law of radius, plus a constant plateau:

\[ f_{\rm nt}(r) = A_{\rm nt} + (B_{\rm nt} - A_{\rm nt}) \left(\frac{r}{2R_{500c}}\right)^{C_{\rm nt}} \]

This adds three parameters to our gas model: \(A_{\rm nt}\) is the central value of non-thermal pressure fraction, \(B_{\rm nt}\) is the non-thermal pressure fraction at \(r=2R_{500c}\), and \(C_{\rm nt}\) is the power law evolution with radius.

Halo potential and gas model parameters#

Making predictions of gas properties using this model requires two ingredients: a potential distribution, and a vector containing values for the eight parameters of the gas model, \(\vartheta_{\rm gas}\). The picasso model is based on using a neural network to predict the latter (see Using the picasso trained predictors); here, we are interested in using the gas model independently, assuming we independently obtained a prediction of the vector parameter.

We will use simple NFW halos to make predictions. We’ll use some pre-stored data (containing mass and concentration values for four halos from the simulations presented in Kéruzoré+24) and compute their potential profiles:

halos = Table.read("../data/halos.hdf5")
r_R500c = jnp.logspace(-1, 0.5, 51)

phi = []
for i in range(4):
    nfw_i = NFW(halos["M200c"][i], halos["c200c"][i], "200c", z=0.0, cosmo=FlatLambdaCDM(70.0, 0.3))
    phi_i = nfw_i.potential(r_R500c * halos["R500c"][i])
    phi.append(phi_i - nfw_i.potential(1e-6))
phi = jnp.array(phi)

For simplicity, we will use fixed values for the parameter vector \(\vartheta_{\rm gas}\):

# rho_0, P_0, Gamma_0, c_gamma, theta_0, A_nt, B_nt, C_nt
theta_gas = jnp.array([3.22e3, 1.91e2, 1.134, 0.0, 3.594e-7, 1.18e-2, 2.11e-1, 1.647])

Polytropic model: density and total pressure#

First, focusing on one halo, we can use polytrop.rho_P_g to compute density and total pressure:

rho_g, P_tot = polytrop.rho_P_g(phi[0], r_R500c, *theta_gas[:5])

fig, axs = plt.subplots(1, 2, figsize=(8, 4))
for ax, q in zip(axs, [rho_g, P_tot]):
    ax.loglog(r_R500c, q)
    ax.set_xlabel("$r / R_{500c}$")
axs[0].set_ylabel("$\\rho_{\\rm g} / 500 \\rho_{\\rm crit.}$")
axs[1].set_ylabel("$P_{\\rm tot} / P_{500c}$")
fig.tight_layout()
../_images/68a0d80bb18087d8539e58db844edd7be9746e5bc6f31f5b20efc193ff032ce4.png

The function can easily be compiled just-in-time:

rho_P_g = jax.jit(polytrop.rho_P_g)
rho_g, P_tot = rho_P_g(phi[0], r_R500c, *theta_gas[:5])

if benchmark:
    print("Not jitted:", end=" ")
    %timeit _ = polytrop.rho_P_g(phi[0], r_R500c, *theta_gas[:5])
    print("jitted:", end=" ")
    %timeit _ = rho_P_g(phi[0], r_R500c, *theta_gas[:5])
Not jitted: 
349 µs ± 2.29 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
jitted: 
79.4 µs ± 1.51 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Non-thermal pressure fraction and thermal pressure#

Similarly, the non-thermal pressure fraction can be computed using nonthermal.f_nt_generic, and be used to derive the thermal pressure:

f_nt = nonthermal.f_nt_generic(r_R500c / 2, *theta_gas[5:])
P_th = P_tot * (1 - f_nt)

fig, axs = plt.subplots(1, 2, figsize=(8, 4))
for ax, q in zip(axs, [P_th, f_nt]):
    ax.loglog(r_R500c, q)
    ax.set_xlabel("$r / R_{500c}$")
axs[0].set_ylabel("$P_{\\rm th} / P_{500c}$")
axs[1].set_ylabel("$f_{\\rm nt}$")
fig.tight_layout()
../_images/e0c939ddfa0f0a65d1c08c63d80919c13659c83bc937c074451ed4cee717838d.png

These can also be compiled:

@jax.jit
def f_nt_and_P_th(*args):
    f_nt = nonthermal.f_nt_generic(*args)
    return f_nt, P_tot * (1 - f_nt)

_ = f_nt_and_P_th(r_R500c / 2, *theta_gas[5:])

Batch predictions#

The picasso.polytrop and picasso.nonthermal modules can also be used to make predictions for several halos at a time:

theta_gas_v = jnp.array([
    [3.22e3, 1.91e2, 1.134, 0.0, 3.594e-7, 1.18e-2, 2.11e-1, 1.647],
    [1.15e3, 6.41e1, 1.136, 0.0, 2.601e-7, 4.27e-2, 3.20e-1, 1.045],
    [1.10e3, 6.12e1, 1.137, 0.0, 1.858e-7, 3.97e-2, 3.32e-1, 1.035],
    [8.49e2, 4.81e1, 1.142, 0.0, 1.382e-7, 4.13e-2, 3.51e-1, 0.953]
])

r_R500c_v = jnp.outer(jnp.ones(len(halos)), r_R500c)  # same radii for all halos

We have to be a bit smart about array shapes here (or, alternatively, one may want to use jax.vmap):

def thermodynamics(phi, theta_gas, r_pol, r_fnt):
    rho_g, P_tot = polytrop.rho_P_g(
        phi, r_pol,
        theta_gas[..., 0, None],
        theta_gas[..., 1, None],
        theta_gas[..., 2, None],
        theta_gas[..., 3, None],
        theta_gas[..., 4, None],
    )
    f_nt = nonthermal.f_nt_generic(
        r_fnt,
        theta_gas[..., 5, None],
        theta_gas[..., 6, None],
        theta_gas[..., 7, None],
    )
    return rho_g, P_tot, P_tot * (1 - f_nt), f_nt

rho_g, P_tot, P_th, f_nt = thermodynamics(phi, theta_gas_v, r_R500c, r_R500c_v / 2)

fig, axs = plt.subplots(1, 4, figsize=(13, 4))
for ax, q in zip(axs, [rho_g, P_tot, P_th, f_nt]):
    ax.loglog(r_R500c_v.T, q.T)
    ax.set_xlabel("$r / R_{500c}$")

axs[0].set_ylabel("$\\rho_{\\rm g} / 500 \\rho_{\\rm crit.}$")
axs[1].set_ylabel("$P_{\\rm tot} / P_{500c}$")
axs[2].set_ylabel("$P_{\\rm th} / P_{500c}$")
axs[3].set_ylabel("$f_{\\rm nt}$")
fig.tight_layout()
../_images/c36d5bb82e0eed6569e6b63bc7739e57ff8c3bb87b7d5c19d3f7c168bf903007.png

Again, these functions can be compiled:

if benchmark:
    print("Not jitted:", end=" ")
    %timeit _ = thermodynamics(phi, theta_gas_v, r_R500c, r_R500c_v / 2)

    thermodynamics = jax.jit(thermodynamics)
    print("jitted:", end=" ")
    %timeit _ = thermodynamics(phi, theta_gas_v, r_R500c, r_R500c_v / 2)
Not jitted: 
1.19 ms ± 28.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
jitted: 
21.2 µs ± 156 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)