Using the picasso trained predictors#

This notebook shows how one can use the trained models to make predictions of gas thermodynamics from halo properties. For a full documentation of the predictor objects and their methods, see picasso.predictors: From halo properties to gas properties.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from astropy.table import Table
from astropy.cosmology import FlatLambdaCDM

from picasso import predictors
from picasso.utils.plots import NFW

import seaborn as sns
sns.set_style("darkgrid")
sns.set_theme("notebook")

benchmark = True

We will use the minimal_576 trained model, which takes as input halo mass and concentration:

predictor = predictors.minimal_576
print(predictor.input_names)
['log M200', 'c200']

Predicting gas model parameters#

First, we want to compute predictions for the model parameter vector, \(\vartheta_{\rm gas}\). To do so, we simply need the vector of scalar halo properties \(\vartheta_{\rm halo}\). We’ll use some pre-stored data (containing four halos from the simulations presented in Kéruzoré+24) and write the input vector:

halos = Table.read("../data/halos.hdf5")
logM200c = jnp.log10(halos["M200c"])
c200c = jnp.array(halos["c200c"])
theta_halo = jnp.array([logM200c, c200c]).T
print(theta_halo.shape)
(4, 2)

We can then use the predictor.predict_model_parameters() function to predict \(\vartheta_{\rm gas}\). For a single halo:

theta_gas_0 = predictor.predict_model_parameters(theta_halo[0])
print(theta_gas_0)
[3.2217349e+03 1.9191902e+02 1.1346726e+00 0.0000000e+00 3.5946820e-07
 1.1811828e-02 2.1121846e-01 1.6479003e+00]

The predictor.predict_model_parameters() function can also be used for several halos at a time:

theta_gas = predictor.predict_model_parameters(theta_halo)
print(f"{theta_gas=}")
print(f"{theta_gas.shape=}")
theta_gas=Array([[3.2217329e+03, 1.9191902e+02, 1.1346726e+00, 0.0000000e+00,
        3.5946820e-07, 1.1811828e-02, 2.1121846e-01, 1.6479003e+00],
       [1.1516166e+03, 6.4102837e+01, 1.1361028e+00, 0.0000000e+00,
        2.6019134e-07, 4.2736135e-02, 3.2087338e-01, 1.0452865e+00],
       [1.1092018e+03, 6.1245892e+01, 1.1370699e+00, 0.0000000e+00,
        1.8584475e-07, 3.9753534e-02, 3.3282000e-01, 1.0353637e+00],
       [8.4998859e+02, 4.8176445e+01, 1.1426771e+00, 0.0000000e+00,
        1.3825047e-07, 4.1373007e-02, 3.5185230e-01, 9.5307523e-01]],      dtype=float32)
theta_gas.shape=(4, 8)
fig, axs = plt.subplots(2, 4, figsize=(13, 8))
axs = axs.flatten()
for ax, q in zip(axs, theta_gas.T):
    for i in range(4):
        ax.plot([q[i], q[i]], [0, 1])
    ax.set_yticklabels([])
    ax.set_ylim(0, 1)
    
axs[0].set_xlabel("$\\rho_0$")
axs[1].set_xlabel("$P_0$")
axs[2].set_xlabel("$\\Gamma_0$")
axs[3].set_xlabel("$c_\\gamma$")
axs[4].set_xlabel("$\\theta_0$")
axs[5].set_xlabel("$A_{\\rm nt}$")
axs[6].set_xlabel("$B_{\\rm nt}$")
axs[7].set_xlabel("$C_{\\rm nt}$")
fig.tight_layout()
../_images/5d527b7427814d33fd82f05286759c37ea39488f7a771b6253b8305505665c1c.png

It can also be just-in-time compiled:

if benchmark:
    predict_jit = jax.jit(predictor.predict_model_parameters)
    print("Not jitted:", end=" ")
    %timeit _ = predictor.predict_model_parameters(theta_halo)
    print("jitted:", end=" ")
    _ = predict_jit(theta_halo)
    %timeit _ = predict_jit(theta_halo)
Not jitted: 
8.01 ms ± 391 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
jitted: 
8.19 µs ± 40.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Predicting gas thermodynamics#

With a prediction for \(\vartheta_{\rm gas}\), we can use picasso.polytrop and picasso.nonthermal to predict gas thermodynamics (see Using the picasso analytical gas model). PicassoPredictor objects also offers a wrapper function that predicts all thermodynamic properties directly from an input vector \(\vartheta_{\rm halo}\) and a potential distribution. Assuming the halos above are NFW, we can predict their potential profiles:

r_R500c = jnp.logspace(-1, 0.5, 51)

phi = []
for i in range(4):
    nfw_i = NFW(halos["M200c"][i], halos["c200c"][i], "200c", z=0.0, cosmo=FlatLambdaCDM(70.0, 0.3))
    phi_i = nfw_i.potential(r_R500c * halos["R500c"][i])
    phi.append(phi_i - nfw_i.potential(1e-6))

Then, we can make predictions of gas thermodynamics for one halo:

rho_g, P_tot, P_th, f_nt = predictor.predict_gas_model(theta_halo[0], phi[0], r_R500c, r_R500c / 2)

Or for all halos at the same time (this function uses jax.vmap to vectorize the predictions):

r_R500c_v = jnp.outer(jnp.ones(4), r_R500c)
phi_v = jnp.array(phi)
rho_g, P_tot, P_th, f_nt = predictor.predict_gas_model(theta_halo, phi_v, r_R500c_v, r_R500c_v / 2)
fig, axs = plt.subplots(1, 4, figsize=(13, 4))
for ax, q in zip(axs, [rho_g, P_tot, P_th, f_nt]):
    ax.loglog(r_R500c_v.T, q.T)
    ax.set_xlabel("$r / R_{500c}$")

axs[0].set_ylabel("$\\rho_{\\rm g} / 500 \\rho_{\\rm crit.}$")
axs[1].set_ylabel("$P_{\\rm tot} / P_{500c}$")
axs[2].set_ylabel("$P_{\\rm th} / P_{500c}$")
axs[3].set_ylabel("$f_{\\rm nt}$")
fig.tight_layout()
../_images/4163e035d32d427edf8bab47ddafaff779a2835271e9e7a68b5a4f89bdebebee.png

Again, these functions can be just-in-time compiled:

if benchmark:
    predict_jit = jax.jit(predictor.predict_gas_model)

    print("1 halo, not jitted:", end=" ")
    %timeit _ = predictor.predict_gas_model(theta_halo[0], phi[0], r_R500c, r_R500c / 2)
    print("1 halo, jitted:", end=" ")
    _ = predict_jit(theta_halo[0], phi[0], r_R500c, r_R500c / 2)
    %timeit _ = predict_jit(theta_halo[0], phi[0], r_R500c, r_R500c / 2)

    print("4 halo, not jitted:", end=" ")
    %timeit _ = predictor.predict_gas_model(theta_halo, phi_v, r_R500c_v, r_R500c_v / 2)
    print("4 halo, jitted:", end=" ")
    _ = predict_jit(theta_halo, phi_v, r_R500c_v, r_R500c_v / 2)
    %timeit _ = predict_jit(theta_halo, phi_v, r_R500c_v, r_R500c_v / 2)
1 halo, not jitted: 
9.02 ms ± 342 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1 halo, jitted: 
51 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
4 halo, not jitted: 
14.7 ms ± 526 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4 halo, jitted: 
28.6 µs ± 212 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Differentiating predictions#

Since the model prediction functions described above are entirely implemented in JAX, they are fully differentiable. Let’s start by defining an untrained predictor, such that the prediction functions take as input a dict of neural network parameters \(\vartheta_{\rm net}\).

model = predictors.PicassoPredictor(
    predictors.FlaxRegMLP(2, 8, [8,], ["selu", "selu", "sigmoid"]),
    predictor._transform_x,
    predictor._transform_y,
    input_names=predictor.input_names
)

For demonstration purposes, we use an MLP with 2 input features, one hidden layer with 8 features, and 8 output features. flax allows us to initialize \(\vartheta_{\rm net}\) easily:

theta_nn = model.mlp.init(jax.random.PRNGKey(44), jnp.ones(model.mlp.X_DIM))
print(f"{theta_nn=}")
theta_nn={'params': {'input': {'kernel': Array([[ 0.03842447, -0.06375846],
       [ 1.4234861 , -0.33939242]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}, 'dense1': {'kernel': Array([[-0.042185  , -0.7779425 ,  1.1562326 , -0.2546578 , -0.5690918 ,
         0.56547344,  0.03781918, -0.47313142],
       [ 1.0343004 ,  0.09225573,  0.04856016, -0.36422998, -0.2005813 ,
        -1.1775992 ,  0.5434312 ,  1.2453555 ]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}, 'output': {'kernel': Array([[ 7.8989536e-02,  7.0438385e-01,  1.0798373e-01, -3.7004817e-01,
        -7.8189468e-01,  3.6114028e-01,  7.4631339e-01,  8.4025212e-02],
       [-2.6573351e-04,  3.7197936e-01, -6.9597310e-01, -1.7070572e-01,
        -6.5567094e-01,  1.4083256e-01, -6.4938289e-01, -5.2993375e-01],
       [-6.3911790e-01,  8.5570492e-02, -2.9292062e-01, -7.5498748e-01,
         8.8445395e-02,  8.0396719e-02,  1.2768888e-01,  2.5018033e-01],
       [-1.9778068e-01,  5.1007611e-01,  5.9675574e-01, -4.7253093e-01,
         1.9520928e-01,  5.7592619e-02, -8.0524437e-02, -2.7744612e-01],
       [-1.5618019e-01,  2.4103884e-01, -5.6260284e-02,  1.4608471e-02,
         3.0849382e-01, -4.8495537e-01, -7.0060837e-01,  2.7153450e-01],
       [ 3.9369041e-01,  6.4290190e-01, -1.5023410e-01, -3.0567136e-01,
         6.6145092e-01, -2.0600930e-01,  1.0309051e-01,  4.7914842e-01],
       [ 5.4775792e-01, -1.3708287e-01,  2.3802257e-01, -6.9439328e-01,
        -2.7095252e-01,  4.1744873e-02, -9.0726621e-02,  5.4637486e-01],
       [-3.5675672e-01, -5.7854915e-01,  3.7428567e-01,  3.4006098e-01,
         4.1651636e-02, -3.5362524e-01, -1.4708665e-01, -1.8629548e-01]],      dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}}}

Now, we can write a loss function. For example, let’s write a mean square error function that compares the predictions of model parameters, \(\vartheta_{\rm gas}\), with the ones obtained above from the trained network:

def loss_fn_predict_parameters(theta_nn):
    # ... insert your code here
    diff = model.predict_model_parameters(theta_halo, theta_nn) - theta_gas
    return jnp.mean(diff ** 2)

We can use jax.value_and_grad to compute, for an input \(\vartheta_{\rm net}\), the value of the loss function and its gradients with respect to the components of \(\vartheta_{\rm net}\):

loss_fn_and_grads = jax.jit(jax.value_and_grad(loss_fn_predict_parameters))
loss, grads = loss_fn_and_grads(theta_nn)
print(f"{loss=}")
print(f"{grads=}")
loss=Array(165127.05, dtype=float32)
grads={'params': {'dense1': {'bias': Array([ 100206.84 ,   21614.707, -244034.03 ,  -91656.766,  -93063.48 ,
        189916.38 ,  340279.06 , -251444.38 ], dtype=float32), 'kernel': Array([[ 26503.604 ,  11982.22  ,  73294.48  ,  52443.27  ,  10371.293 ,
        -19767.486 , -75011.875 , -36658.6   ],
       [-15600.043 ,  -5187.1123,  -2184.661 ,  -8046.8896,   4345.722 ,
         -9154.322 ,  -4911.113 ,  30567.984 ]], dtype=float32)}, 'input': {'bias': Array([  12953.48, -330695.5 ], dtype=float32), 'kernel': Array([[ -21917.758, -193794.06 ],
       [  41602.273,  -53411.805]], dtype=float32)}, 'output': {'bias': Array([ 3.7065769e+05,  5.4172406e+04,  9.2098862e-04, -2.6323080e-02,
        1.1586923e-13, -1.4256586e-04,  9.6003921e-04,  3.2200915e-01],      dtype=float32), 'kernel': Array([[-1.57931260e+04, -2.42841211e+04, -3.35681165e-04,
         1.20395236e-02, -4.37351469e-14,  4.27405794e-05,
        -1.55504246e-03, -1.23349264e-01],
       [ 2.88966172e+04, -3.66464922e+04, -4.81067575e-04,
         1.83119718e-02, -6.31182389e-14,  5.76461171e-05,
        -2.74758250e-03, -1.79110184e-01],
       [-1.26881102e+05,  4.14549062e+04,  4.95018146e-04,
        -2.08275542e-02,  6.62612800e-14, -5.16625478e-05,
         3.81032354e-03,  1.89285502e-01],
       [ 4.18628008e+04, -6.02361523e+03, -6.02374239e-05,
         3.07562714e-03, -8.34886664e-15,  4.33183868e-06,
        -7.30396772e-04, -2.43033525e-02],
       [ 5.35147617e+04, -2.44235117e+04, -3.03071662e-04,
         1.22588547e-02, -4.01963009e-14,  3.36148078e-05,
        -2.08784547e-03, -1.14595592e-01],
       [-6.37317422e+04,  3.84080586e+04,  4.85661119e-04,
        -1.91901531e-02,  6.43300178e-14, -5.52192469e-05,
         3.12481867e-03,  1.82762071e-01],
       [-9.27105957e+03, -1.06851982e+04, -1.48742183e-04,
         5.28801745e-03, -1.93708625e-14,  1.90776464e-05,
        -6.66230044e-04, -5.45667335e-02],
       [-3.21776738e+04, -4.26220898e+04, -5.92439668e-04,
         2.11625788e-02, -7.70025116e-14,  7.59991526e-05,
        -2.69982382e-03, -2.17325643e-01]], dtype=float32)}}}

This loss function can then be optimized using, e.g., optax (See the optax docs here, in particular the tutorial to optimize the parameters of a flax model here).