Using the picasso trained predictors#
This notebook shows how one can use the trained models to make predictions of gas thermodynamics from halo properties. For a full documentation of the predictor objects and their methods, see picasso.predictors: From halo properties to gas properties.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from astropy.table import Table
from astropy.cosmology import FlatLambdaCDM
from picasso import predictors
from picasso.utils.plots import NFW
import seaborn as sns
sns.set_style("darkgrid")
sns.set_theme("notebook")
benchmark = True
We will use the minimal_576
trained model, which takes as input halo mass and concentration:
predictor = predictors.minimal_576
print(predictor.input_names)
['log M200', 'c200']
Predicting gas model parameters#
First, we want to compute predictions for the model parameter vector, \(\vartheta_{\rm gas}\). To do so, we simply need the vector of scalar halo properties \(\vartheta_{\rm halo}\). We’ll use some pre-stored data (containing four halos from the simulations presented in Kéruzoré+24) and write the input vector:
halos = Table.read("../data/halos.hdf5")
logM200c = jnp.log10(halos["M200c"])
c200c = jnp.array(halos["c200c"])
theta_halo = jnp.array([logM200c, c200c]).T
print(theta_halo.shape)
(4, 2)
We can then use the predictor.predict_model_parameters()
function to predict \(\vartheta_{\rm gas}\). For a single halo:
theta_gas_0 = predictor.predict_model_parameters(theta_halo[0])
print(theta_gas_0)
[3.2217349e+03 1.9191902e+02 1.1346726e+00 0.0000000e+00 3.5946820e-07
1.1811828e-02 2.1121846e-01 1.6479003e+00]
The predictor.predict_model_parameters()
function can also be used for several halos at a time:
theta_gas = predictor.predict_model_parameters(theta_halo)
print(f"{theta_gas=}")
print(f"{theta_gas.shape=}")
theta_gas=Array([[3.2217329e+03, 1.9191902e+02, 1.1346726e+00, 0.0000000e+00,
3.5946820e-07, 1.1811828e-02, 2.1121846e-01, 1.6479003e+00],
[1.1516166e+03, 6.4102837e+01, 1.1361028e+00, 0.0000000e+00,
2.6019134e-07, 4.2736135e-02, 3.2087338e-01, 1.0452865e+00],
[1.1092018e+03, 6.1245892e+01, 1.1370699e+00, 0.0000000e+00,
1.8584475e-07, 3.9753534e-02, 3.3282000e-01, 1.0353637e+00],
[8.4998859e+02, 4.8176445e+01, 1.1426771e+00, 0.0000000e+00,
1.3825047e-07, 4.1373007e-02, 3.5185230e-01, 9.5307523e-01]], dtype=float32)
theta_gas.shape=(4, 8)
fig, axs = plt.subplots(2, 4, figsize=(13, 8))
axs = axs.flatten()
for ax, q in zip(axs, theta_gas.T):
for i in range(4):
ax.plot([q[i], q[i]], [0, 1])
ax.set_yticklabels([])
ax.set_ylim(0, 1)
axs[0].set_xlabel("$\\rho_0$")
axs[1].set_xlabel("$P_0$")
axs[2].set_xlabel("$\\Gamma_0$")
axs[3].set_xlabel("$c_\\gamma$")
axs[4].set_xlabel("$\\theta_0$")
axs[5].set_xlabel("$A_{\\rm nt}$")
axs[6].set_xlabel("$B_{\\rm nt}$")
axs[7].set_xlabel("$C_{\\rm nt}$")
fig.tight_layout()

It can also be just-in-time compiled:
if benchmark:
predict_jit = jax.jit(predictor.predict_model_parameters)
print("Not jitted:", end=" ")
%timeit _ = predictor.predict_model_parameters(theta_halo)
print("jitted:", end=" ")
_ = predict_jit(theta_halo)
%timeit _ = predict_jit(theta_halo)
Not jitted:
8.01 ms ± 391 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
jitted:
8.19 µs ± 40.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Predicting gas thermodynamics#
With a prediction for \(\vartheta_{\rm gas}\), we can use picasso.polytrop
and picasso.nonthermal
to predict gas thermodynamics (see Using the picasso analytical gas model).
PicassoPredictor
objects also offers a wrapper function that predicts all thermodynamic properties directly from an input vector \(\vartheta_{\rm halo}\) and a potential distribution.
Assuming the halos above are NFW, we can predict their potential profiles:
r_R500c = jnp.logspace(-1, 0.5, 51)
phi = []
for i in range(4):
nfw_i = NFW(halos["M200c"][i], halos["c200c"][i], "200c", z=0.0, cosmo=FlatLambdaCDM(70.0, 0.3))
phi_i = nfw_i.potential(r_R500c * halos["R500c"][i])
phi.append(phi_i - nfw_i.potential(1e-6))
Then, we can make predictions of gas thermodynamics for one halo:
rho_g, P_tot, P_th, f_nt = predictor.predict_gas_model(theta_halo[0], phi[0], r_R500c, r_R500c / 2)
Or for all halos at the same time (this function uses jax.vmap
to vectorize the predictions):
r_R500c_v = jnp.outer(jnp.ones(4), r_R500c)
phi_v = jnp.array(phi)
rho_g, P_tot, P_th, f_nt = predictor.predict_gas_model(theta_halo, phi_v, r_R500c_v, r_R500c_v / 2)
fig, axs = plt.subplots(1, 4, figsize=(13, 4))
for ax, q in zip(axs, [rho_g, P_tot, P_th, f_nt]):
ax.loglog(r_R500c_v.T, q.T)
ax.set_xlabel("$r / R_{500c}$")
axs[0].set_ylabel("$\\rho_{\\rm g} / 500 \\rho_{\\rm crit.}$")
axs[1].set_ylabel("$P_{\\rm tot} / P_{500c}$")
axs[2].set_ylabel("$P_{\\rm th} / P_{500c}$")
axs[3].set_ylabel("$f_{\\rm nt}$")
fig.tight_layout()

Again, these functions can be just-in-time compiled:
if benchmark:
predict_jit = jax.jit(predictor.predict_gas_model)
print("1 halo, not jitted:", end=" ")
%timeit _ = predictor.predict_gas_model(theta_halo[0], phi[0], r_R500c, r_R500c / 2)
print("1 halo, jitted:", end=" ")
_ = predict_jit(theta_halo[0], phi[0], r_R500c, r_R500c / 2)
%timeit _ = predict_jit(theta_halo[0], phi[0], r_R500c, r_R500c / 2)
print("4 halo, not jitted:", end=" ")
%timeit _ = predictor.predict_gas_model(theta_halo, phi_v, r_R500c_v, r_R500c_v / 2)
print("4 halo, jitted:", end=" ")
_ = predict_jit(theta_halo, phi_v, r_R500c_v, r_R500c_v / 2)
%timeit _ = predict_jit(theta_halo, phi_v, r_R500c_v, r_R500c_v / 2)
1 halo, not jitted:
9.02 ms ± 342 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1 halo, jitted:
51 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
4 halo, not jitted:
14.7 ms ± 526 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4 halo, jitted:
28.6 µs ± 212 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Differentiating predictions#
Since the model prediction functions described above are entirely implemented in JAX, they are fully differentiable. Let’s start by defining an untrained predictor, such that the prediction functions take as input a dict of neural network parameters \(\vartheta_{\rm net}\).
model = predictors.PicassoPredictor(
predictors.FlaxRegMLP(2, 8, [8,], ["selu", "selu", "sigmoid"]),
predictor._transform_x,
predictor._transform_y,
input_names=predictor.input_names
)
For demonstration purposes, we use an MLP with 2 input features, one hidden layer with 8 features, and 8 output features.
flax
allows us to initialize \(\vartheta_{\rm net}\) easily:
theta_nn = model.mlp.init(jax.random.PRNGKey(44), jnp.ones(model.mlp.X_DIM))
print(f"{theta_nn=}")
theta_nn={'params': {'input': {'kernel': Array([[ 0.03842447, -0.06375846],
[ 1.4234861 , -0.33939242]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}, 'dense1': {'kernel': Array([[-0.042185 , -0.7779425 , 1.1562326 , -0.2546578 , -0.5690918 ,
0.56547344, 0.03781918, -0.47313142],
[ 1.0343004 , 0.09225573, 0.04856016, -0.36422998, -0.2005813 ,
-1.1775992 , 0.5434312 , 1.2453555 ]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}, 'output': {'kernel': Array([[ 7.8989536e-02, 7.0438385e-01, 1.0798373e-01, -3.7004817e-01,
-7.8189468e-01, 3.6114028e-01, 7.4631339e-01, 8.4025212e-02],
[-2.6573351e-04, 3.7197936e-01, -6.9597310e-01, -1.7070572e-01,
-6.5567094e-01, 1.4083256e-01, -6.4938289e-01, -5.2993375e-01],
[-6.3911790e-01, 8.5570492e-02, -2.9292062e-01, -7.5498748e-01,
8.8445395e-02, 8.0396719e-02, 1.2768888e-01, 2.5018033e-01],
[-1.9778068e-01, 5.1007611e-01, 5.9675574e-01, -4.7253093e-01,
1.9520928e-01, 5.7592619e-02, -8.0524437e-02, -2.7744612e-01],
[-1.5618019e-01, 2.4103884e-01, -5.6260284e-02, 1.4608471e-02,
3.0849382e-01, -4.8495537e-01, -7.0060837e-01, 2.7153450e-01],
[ 3.9369041e-01, 6.4290190e-01, -1.5023410e-01, -3.0567136e-01,
6.6145092e-01, -2.0600930e-01, 1.0309051e-01, 4.7914842e-01],
[ 5.4775792e-01, -1.3708287e-01, 2.3802257e-01, -6.9439328e-01,
-2.7095252e-01, 4.1744873e-02, -9.0726621e-02, 5.4637486e-01],
[-3.5675672e-01, -5.7854915e-01, 3.7428567e-01, 3.4006098e-01,
4.1651636e-02, -3.5362524e-01, -1.4708665e-01, -1.8629548e-01]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}}}
Now, we can write a loss function. For example, let’s write a mean square error function that compares the predictions of model parameters, \(\vartheta_{\rm gas}\), with the ones obtained above from the trained network:
def loss_fn_predict_parameters(theta_nn):
# ... insert your code here
diff = model.predict_model_parameters(theta_halo, theta_nn) - theta_gas
return jnp.mean(diff ** 2)
We can use jax.value_and_grad
to compute, for an input \(\vartheta_{\rm net}\), the value of the loss function and its gradients with respect to the components of \(\vartheta_{\rm net}\):
loss_fn_and_grads = jax.jit(jax.value_and_grad(loss_fn_predict_parameters))
loss, grads = loss_fn_and_grads(theta_nn)
print(f"{loss=}")
print(f"{grads=}")
loss=Array(165127.05, dtype=float32)
grads={'params': {'dense1': {'bias': Array([ 100206.84 , 21614.707, -244034.03 , -91656.766, -93063.48 ,
189916.38 , 340279.06 , -251444.38 ], dtype=float32), 'kernel': Array([[ 26503.604 , 11982.22 , 73294.48 , 52443.27 , 10371.293 ,
-19767.486 , -75011.875 , -36658.6 ],
[-15600.043 , -5187.1123, -2184.661 , -8046.8896, 4345.722 ,
-9154.322 , -4911.113 , 30567.984 ]], dtype=float32)}, 'input': {'bias': Array([ 12953.48, -330695.5 ], dtype=float32), 'kernel': Array([[ -21917.758, -193794.06 ],
[ 41602.273, -53411.805]], dtype=float32)}, 'output': {'bias': Array([ 3.7065769e+05, 5.4172406e+04, 9.2098862e-04, -2.6323080e-02,
1.1586923e-13, -1.4256586e-04, 9.6003921e-04, 3.2200915e-01], dtype=float32), 'kernel': Array([[-1.57931260e+04, -2.42841211e+04, -3.35681165e-04,
1.20395236e-02, -4.37351469e-14, 4.27405794e-05,
-1.55504246e-03, -1.23349264e-01],
[ 2.88966172e+04, -3.66464922e+04, -4.81067575e-04,
1.83119718e-02, -6.31182389e-14, 5.76461171e-05,
-2.74758250e-03, -1.79110184e-01],
[-1.26881102e+05, 4.14549062e+04, 4.95018146e-04,
-2.08275542e-02, 6.62612800e-14, -5.16625478e-05,
3.81032354e-03, 1.89285502e-01],
[ 4.18628008e+04, -6.02361523e+03, -6.02374239e-05,
3.07562714e-03, -8.34886664e-15, 4.33183868e-06,
-7.30396772e-04, -2.43033525e-02],
[ 5.35147617e+04, -2.44235117e+04, -3.03071662e-04,
1.22588547e-02, -4.01963009e-14, 3.36148078e-05,
-2.08784547e-03, -1.14595592e-01],
[-6.37317422e+04, 3.84080586e+04, 4.85661119e-04,
-1.91901531e-02, 6.43300178e-14, -5.52192469e-05,
3.12481867e-03, 1.82762071e-01],
[-9.27105957e+03, -1.06851982e+04, -1.48742183e-04,
5.28801745e-03, -1.93708625e-14, 1.90776464e-05,
-6.66230044e-04, -5.45667335e-02],
[-3.21776738e+04, -4.26220898e+04, -5.92439668e-04,
2.11625788e-02, -7.70025116e-14, 7.59991526e-05,
-2.69982382e-03, -2.17325643e-01]], dtype=float32)}}}
This loss function can then be optimized using, e.g., optax
(See the optax
docs here, in particular the tutorial to optimize the parameters of a flax
model here).