Getting started#

Sbijax is a Python package for neural simulation-based inference and approximate Bayesian computation. Here we demonstrate its core functionality using a simple Gaussian model.

Interactive online version of this notebook:

Open In Colab

[1]:
import jax
import sbijax
%matplotlib inline
import matplotlib.pyplot as plt

Model definition#

To do approximate inference using sbijax, a user first has to define a prior model and a simulator function which can be used to generate synthetic data. We will be using a simple bivariate Gaussian as an example with the following generative model:

\begin{align} \mu &\sim \mathcal{N}_2(0, I)\\ \sigma &\sim \mathcal{N}^+(1)\\ y & \sim \mathcal{N}_2(\mu, \sigma^2 I) \end{align}

Using TensorFlow Probability, the prior model and simulator are implemented like this:

[2]:
from jax import numpy as jnp, random as jr
from tensorflow_probability.substrates.jax import distributions as tfd

def prior_fn():
    prior = tfd.JointDistributionNamed(dict(
        mean=tfd.Normal(jnp.zeros(2), 1.0),
        scale=tfd.HalfNormal(jnp.ones(1)),
    ), batch_ndims=0)
    return prior


def simulator_fn(seed: jr.PRNGKey, theta: dict[str, jax.Array]):
    p = tfd.Normal(jnp.zeros_like(theta["mean"]), 1.0)
    y = theta["mean"] + theta["scale"] * p.sample(seed=seed)
    return y

Algorithm definition#

Having defined a model of interest, i.e., the prior and simulator functions, we construct an inferential method. Here, we use flow matching posterior estimation (FMPE) which is an algorithm that directly aims to learn an approximation to the posterior distribution (without intermediate steps like neural likelihood or neural likelihood-ratio estimation methods).

FMPE requires a continuous normalizing flow as a neural network model. We can construct it using the functionality provided by sbijax.

[3]:
from sbijax.nn import make_cnf

n_dim_theta = 3
n_layers, hidden_size = 2, 128
neural_network = make_cnf(n_dim_theta, n_layers, hidden_size)

The neural network model in FMPE targets the posterior distribution and hence needs to target a space with its dimensionaliy. One can either find the dimensionality by inspecting the prior model (above, two for the mean and one for the scale), or use one of JAX internal methods, ravel_pytree, for this:

[4]:
prior_draw = prior_fn().sample(seed=jr.PRNGKey(1))
prior_draw
[4]:
{'scale': Array([0.9881485], dtype=float32),
 'mean': Array([ 1.2772286 , -0.66140693], dtype=float32)}
[5]:
len(jax.flatten_util.ravel_pytree(prior_draw)[0])
[5]:
3

We can then construct the method itself. It takes as arguments a tuple of prior and simulator functions and the neural network.

[6]:
from sbijax import FMPE

fns = prior_fn, simulator_fn
model = FMPE(fns, neural_network)

Training and inference#

Inference is then as easy as simulating some data, fitting the data to the model, a sampling from the approximate posterior. The data set is a dictionary of dictionaries (a PyTree in JAX lingo). It contains samples for the simulator function, called y, and parameter samples from the prior model, called theta.

[7]:
data, _ = model.simulate_data(
    jr.PRNGKey(1),
    n_simulations=10_000,
)
data
[7]:
{'y': Array([[ 0.07692909,  0.7882271 ],
        [-1.2418504 , -0.25333643],
        [ 1.1943562 , -2.124853  ],
        ...,
        [ 1.3316323 ,  0.5488601 ],
        [ 4.862982  , -4.1227694 ],
        [-0.00955033,  0.989019  ]], dtype=float32),
 'theta': {'scale': Array([[0.9232396 ],
         [0.36471593],
         [0.6795394 ],
         ...,
         [0.11454558],
         [1.260745  ],
         [0.5012804 ]], dtype=float32),
  'mean': Array([[ 0.30212572,  0.67478853],
         [-0.963459  ,  0.086253  ],
         [ 0.39044896, -2.2378268 ],
         ...,
         [ 1.3652769 ,  0.6302657 ],
         [ 2.7543025 , -1.9100804 ],
         [-0.16545922,  0.6475435 ]], dtype=float32)}}

We then fit the model using the typical flow matching loss.

[8]:
params, losses = model.fit(jr.PRNGKey(2), data=data)
  6%|███████████████████▍                                                                                                                                                                                                                                                                                            | 64/1000 [00:41<10:05,  1.55it/s]

Finally, we sample from the posterior distribution for a specific observation \(y_{\text{obs}}\).

[9]:
y_obs = jnp.array([-1.0, 1.0])
inference_results, diagnostics = model.sample_posterior(
    jr.PRNGKey(3), params, y_obs, n_samples=1_000
)
print(inference_results)
<xarray.DataTree>
Group: /
├── Group: /posterior
│       Dimensions:    (chain: 1, draw: 1000, mean_dim: 2, scale_dim: 1)
│       Coordinates:
│         * chain      (chain) int64 8B 0
│         * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
│         * mean_dim   (mean_dim) int64 16B 0 1
│         * scale_dim  (scale_dim) int64 8B 0
│       Data variables:
│           mean       (chain, draw, mean_dim) float32 8kB ...
│           scale      (chain, draw, scale_dim) float32 4kB ...
│       Attributes:
│           created_at:                 2026-03-19T11:02:23.342542+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.0.0
│           creation_library_language:  Python
└── Group: /observed_data
        Dimensions:  (chain: 2)
        Coordinates:
          * chain    (chain) int64 16B 0 1
        Data variables:
            y        (chain) float32 8B ...
        Attributes:
            created_at:                 2026-03-19T11:02:23.343127+00:00
            creation_library:           ArviZ
            creation_library_version:   1.0.0
            creation_library_language:  Python

Visualization#

Sbijax provides basic functionality to analyse posterior draws. We show some visualizations below.

[10]:
sbijax.plot_loss_profile(losses)
plt.show()
../_images/notebooks_getting_started_19_0.png
[11]:
sbijax.plot_posterior(inference_results)
plt.show()
../_images/notebooks_getting_started_20_0.png

Session info#

[12]:
import session_info

session_info.show(html=False)
-----
arviz_plots                 1.0.0
haiku                       0.0.16
jax                         0.8.1
jaxlib                      0.8.1
matplotlib                  3.10.8
sbijax                      0.3.6
session_info                v1.0.1
tensorflow_probability      0.26.0-dev20260318
xarray                      2026.2.0
-----
IPython             9.11.0
jupyter_client      8.8.0
jupyter_core        5.9.1
jupyterlab          4.5.6
notebook            7.5.5
-----
Python 3.12.10 (main, May 30 2025, 05:53:56) [Clang 20.1.4 ]
macOS-26.2-arm64-arm-64bit
-----
Session information updated at 2026-03-19 12:02