Getting started#
Sbijax is a Python package for neural simulation-based inference and approximate Bayesian computation. Here we demonstrate its core functionality using a simple Gaussian model.
Interactive online version of this notebook:
[1]:
import jax
import sbijax
%matplotlib inline
import matplotlib.pyplot as plt
Model definition#
To do approximate inference using sbijax, a user first has to define a prior model and a simulator function which can be used to generate synthetic data. We will be using a simple bivariate Gaussian as an example with the following generative model:
\begin{align} \mu &\sim \mathcal{N}_2(0, I)\\ \sigma &\sim \mathcal{N}^+(1)\\ y & \sim \mathcal{N}_2(\mu, \sigma^2 I) \end{align}
Using TensorFlow Probability, the prior model and simulator are implemented like this:
[2]:
from jax import numpy as jnp, random as jr
from tensorflow_probability.substrates.jax import distributions as tfd
def prior_fn():
prior = tfd.JointDistributionNamed(dict(
mean=tfd.Normal(jnp.zeros(2), 1.0),
scale=tfd.HalfNormal(jnp.ones(1)),
), batch_ndims=0)
return prior
def simulator_fn(seed: jr.PRNGKey, theta: dict[str, jax.Array]):
p = tfd.Normal(jnp.zeros_like(theta["mean"]), 1.0)
y = theta["mean"] + theta["scale"] * p.sample(seed=seed)
return y
Algorithm definition#
Having defined a model of interest, i.e., the prior and simulator functions, we construct an inferential method. Here, we use flow matching posterior estimation (FMPE) which is an algorithm that directly aims to learn an approximation to the posterior distribution (without intermediate steps like neural likelihood or neural likelihood-ratio estimation methods).
FMPE requires a continuous normalizing flow as a neural network model. We can construct it using the functionality provided by sbijax.
[3]:
from sbijax.nn import make_cnf
n_dim_theta = 3
n_layers, hidden_size = 2, 128
neural_network = make_cnf(n_dim_theta, n_layers, hidden_size)
The neural network model in FMPE targets the posterior distribution and hence needs to target a space with its dimensionaliy. One can either find the dimensionality by inspecting the prior model (above, two for the mean and one for the scale), or use one of JAX internal methods, ravel_pytree, for this:
[4]:
prior_draw = prior_fn().sample(seed=jr.PRNGKey(1))
prior_draw
[4]:
{'scale': Array([0.9881485], dtype=float32),
'mean': Array([ 1.2772286 , -0.66140693], dtype=float32)}
[5]:
len(jax.flatten_util.ravel_pytree(prior_draw)[0])
[5]:
3
We can then construct the method itself. It takes as arguments a tuple of prior and simulator functions and the neural network.
[6]:
from sbijax import FMPE
fns = prior_fn, simulator_fn
model = FMPE(fns, neural_network)
Training and inference#
Inference is then as easy as simulating some data, fitting the data to the model, a sampling from the approximate posterior. The data set is a dictionary of dictionaries (a PyTree in JAX lingo). It contains samples for the simulator function, called y, and parameter samples from the prior model, called theta.
[7]:
data, _ = model.simulate_data(
jr.PRNGKey(1),
n_simulations=10_000,
)
data
[7]:
{'y': Array([[ 0.07692909, 0.7882271 ],
[-1.2418504 , -0.25333643],
[ 1.1943562 , -2.124853 ],
...,
[ 1.3316323 , 0.5488601 ],
[ 4.862982 , -4.1227694 ],
[-0.00955033, 0.989019 ]], dtype=float32),
'theta': {'scale': Array([[0.9232396 ],
[0.36471593],
[0.6795394 ],
...,
[0.11454558],
[1.260745 ],
[0.5012804 ]], dtype=float32),
'mean': Array([[ 0.30212572, 0.67478853],
[-0.963459 , 0.086253 ],
[ 0.39044896, -2.2378268 ],
...,
[ 1.3652769 , 0.6302657 ],
[ 2.7543025 , -1.9100804 ],
[-0.16545922, 0.6475435 ]], dtype=float32)}}
We then fit the model using the typical flow matching loss.
[8]:
params, losses = model.fit(jr.PRNGKey(2), data=data)
6%|███████████████████▍ | 64/1000 [00:41<10:05, 1.55it/s]
Finally, we sample from the posterior distribution for a specific observation \(y_{\text{obs}}\).
[9]:
y_obs = jnp.array([-1.0, 1.0])
inference_results, diagnostics = model.sample_posterior(
jr.PRNGKey(3), params, y_obs, n_samples=1_000
)
print(inference_results)
<xarray.DataTree>
Group: /
├── Group: /posterior
│ Dimensions: (chain: 1, draw: 1000, mean_dim: 2, scale_dim: 1)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
│ * mean_dim (mean_dim) int64 16B 0 1
│ * scale_dim (scale_dim) int64 8B 0
│ Data variables:
│ mean (chain, draw, mean_dim) float32 8kB ...
│ scale (chain, draw, scale_dim) float32 4kB ...
│ Attributes:
│ created_at: 2026-03-19T11:02:23.342542+00:00
│ creation_library: ArviZ
│ creation_library_version: 1.0.0
│ creation_library_language: Python
└── Group: /observed_data
Dimensions: (chain: 2)
Coordinates:
* chain (chain) int64 16B 0 1
Data variables:
y (chain) float32 8B ...
Attributes:
created_at: 2026-03-19T11:02:23.343127+00:00
creation_library: ArviZ
creation_library_version: 1.0.0
creation_library_language: Python
Visualization#
Sbijax provides basic functionality to analyse posterior draws. We show some visualizations below.
[10]:
sbijax.plot_loss_profile(losses)
plt.show()
[11]:
sbijax.plot_posterior(inference_results)
plt.show()
Session info#
[12]:
import session_info
session_info.show(html=False)
-----
arviz_plots 1.0.0
haiku 0.0.16
jax 0.8.1
jaxlib 0.8.1
matplotlib 3.10.8
sbijax 0.3.6
session_info v1.0.1
tensorflow_probability 0.26.0-dev20260318
xarray 2026.2.0
-----
IPython 9.11.0
jupyter_client 8.8.0
jupyter_core 5.9.1
jupyterlab 4.5.6
notebook 7.5.5
-----
Python 3.12.10 (main, May 30 2025, 05:53:56) [Clang 20.1.4 ]
macOS-26.2-arm64-arm-64bit
-----
Session information updated at 2026-03-19 12:02