πŸ‘‹ Welcome to sbijax!

πŸ‘‹ Welcome to sbijax!#

Simulation-based inference in JAX


Sbijax is a Python library for neural simulation-based inference and approximate Bayesian computation using JAX. It implements recent methods, such as Simulated Annealing ABC, Surjective Neural Likelihood Estimation, Neural Approximate Sufficient Statistics or Neural Posterior Score Estimation, as well as calibration and convergence diagnostics.

Caution

⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.

Example#

Sbijax implements a low-level, functional API in the idiom of dm-haiku and blackjax: every method is a factory that takes only the network and returns a record of pure functions, and training and sampling are free driver functions (train(), sample()). The prior and simulator define the data; the optimizer and sampler are injected at the driver that uses them. For example, neural likelihood estimation:

from jax import numpy as jnp, random as jr
from sbijax import nle, train, sample, simulate
from sbijax.mcmc import make_sampler, nuts
from sbijax.nn import make_maf
from tensorflow_probability.substrates.jax import distributions as tfd

prior = tfd.JointDistributionNamed(dict(
    theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))
), batch_ndims=0)

def simulator_fn(seed, theta):
    p = tfd.Normal(jnp.zeros_like(theta["theta"]), 0.1)
    y = theta["theta"] + p.sample(seed=seed)
    return y

estimator = nle(make_maf(2))

y_observed = jnp.array([-1.0, 1.0])
data = simulate(jr.key(1), prior, simulator_fn, n=10_000)
params, info = train(jr.key(2), estimator, data)
samples, _ = sample(
    jr.key(3), estimator, params, y_observed,
    sampler=make_sampler(nuts, prior=prior),
)

Installation#

You can install sbijax from PyPI using:

pip install sbijax

To install the latest GitHub <RELEASE>, just call the following on the command line:

pip install git+https://github.com/dirmeier/sbijax@<RELEASE>

See also the installation instructions for JAX, if you plan to use sbijax on GPU/TPU.

Contributing#

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled β€œgood first issue”.

In order to contribute:

  1. Clone sbijax and install uv from here,

  2. install all dependencies using `uv sync,

  3. create a new branch locally git checkout -b feature/my-new-feature or git checkout -b issue/fixes-bug,

  4. implement your contribution and ideally a test case,

  5. test it by calling make tests, make lints and make format on the (Unix) command line,

  6. submit a PR πŸ™‚

License#

sbijax is licensed under the Apache 2.0 License.