sbijax#

The top-level module, sbijax, contains all implemented methods for neural simulation-based inference and approximate Bayesian inference as well as diagnostics and other utility.

Every method is a factory function that takes only the network and returns a record of pure functions, following the low-level functional idiom of dm-haiku and blackjax. Training and sampling are free driver functions (train(), sample()); the optimizer is injected at train and, for likelihood/ratio methods, the sampler (which carries the prior) at sample:

est = nle(make_maf(2))
params, info = train(key, est, data, optimizer=optax.adam(3e-4))
samples, info = sample(
    key, est, params, y_observed, sampler=make_sampler(nuts, prior=prior)
)

See Design philosophy for the full design and Migration guide: 0.3 → 0.4 for moving from the class-based API.

npe(network, *[, num_atoms])

Construct a neural posterior estimator.

fmpe(network)

Construct a flow-matching posterior objective.

npse(network)

Construct a neural posterior score estimator.

nle(network)

Construct a neural likelihood objective.

snle(network)

Construct a surjective neural likelihood objective.

nre(network, *[, num_classes, gamma])

Construct a neural ratio objective.

sabc(prior, simulator, *[, summary_fn, ...])

Construct a simulated annealing ABC sampler.

smcabc(prior, simulator, summary_fn, distance_fn)

Construct a sequential Monte Carlo ABC sampler.

nass(network)

Construct a neural approximate sufficient statistics summary network.

nasss(network)

Construct a neural approximate slice sufficient statistics summary network.

summarized_estimator(estimator, summary_net, ...)

Adapt an objective to operate on learned summaries.

train(rng_key, objective, data, *[, ...])

Train any objective's TrainFns with early stopping.

sample(rng_key, objective, params, observable, *)

Draw posterior samples from a trained objective.

run_sequential(rng_key, objective, prior, ...)

Run multi-round sequential inference.

simulate(rng_key, prior, simulator, *[, ...])

Draw a dataset of parameters and observations.

stack(data, new_data)

Append two datasets along the sample axis.

sbc(rng_key, objective, params, prior, ...)

Compute simulation-based calibration ranks for a fitted objective.

Data pipeline#

sbijax.simulate(rng_key, prior, simulator, *, proposal=None, n)[source]#

Draw a dataset of parameters and observations.

Parameters are drawn from proposal if given, otherwise from prior, and observations are produced by running simulator on them.

Parameters:
  • rng_key – a jax random key

  • prior – a distribution to draw parameters from when proposal is None

  • simulator – a callable (rng_key, theta) -> y

  • proposal – an optional callable (rng_key, n) -> theta used to draw parameters instead of the prior (e.g. a fitted posterior in a sequential round)

  • n – the number of parameter/observation pairs to draw

Returns:

a dictionary with keys y and theta

sbijax.stack(data, new_data)[source]#

Append two datasets along the sample axis.

Parameters:
  • data – a dataset as returned by simulate()

  • new_data – a second dataset with the same structure

Returns:

a dataset whose leaves are the concatenation of both inputs

Posterior estimation#

sbijax.npe(network, *, num_atoms=10)[source]#

Construct a neural posterior estimator.

In round 0 use the returned ObjectiveFns directly for amortized maximum-likelihood training. For round > 0 (i.e. when training on data simulated from a fitted posterior rather than the prior), call obj.extra(prior) to obtain the atomic proposal-posterior objective of Greenberg et al. [2019].

Parameters:
  • network – a conditional density estimator with log_prob and sample methods (e.g. from make_maf())

  • num_atoms – the number of atoms in the contrastive proposal-posterior loss used by extra(prior)

Returns:

an ObjectiveFns; its extra field is a callable (prior) -> ObjectiveFns for sequential rounds

sbijax.fmpe(network)[source]#

Construct a flow-matching posterior objective.

Parameters:

network – a continuous normalizing flow with loss and sample methods

Returns:

an ObjectiveFns

sbijax.npse(network)[source]#

Construct a neural posterior score estimator.

Parameters:

network – a score network with loss, sample and log_prob methods (e.g. sbijax.experimental.nn.make_score_model())

Returns:

an ObjectiveFns

Likelihood estimation#

sbijax.nle(network)[source]#

Construct a neural likelihood objective.

Parameters:

network – a conditional density estimator exposing a log_prob method

Returns:

an ObjectiveFns

sbijax.snle(network)[source]#

Construct a surjective neural likelihood objective.

Parameters:

network – a surjective conditional density estimator with a log_prob method that reduces the dimensionality of the data

Returns:

an ObjectiveFns

Likelihood-ratio estimation#

sbijax.nre(network, *, num_classes=10, gamma=1.0)[source]#

Construct a neural ratio objective.

Parameters:
  • network – a classifier network mapping concat(y, theta) to logits

  • num_classes – number of contrastive classes

  • gamma – relative weight of the contrastive classes

Returns:

an ObjectiveFns

Approximate Bayesian computation#

sbijax.sabc(prior, simulator, *, summary_fn=<function <lambda>>, distance_fn=<function abs_distance>)[source]#

Construct a simulated annealing ABC sampler.

Parameters:
  • prior – a tfd distribution serving as the prior over parameters

  • simulator – a callable (rng_key, theta) -> y

  • summary_fn – maps simulated data to summary statistics

  • distance_fn – distance between simulated and observed summaries

Returns:

an ABCSampler

sbijax.smcabc(prior, simulator, summary_fn, distance_fn)[source]#

Construct a sequential Monte Carlo ABC sampler.

Parameters:
  • prior – a tfd distribution serving as the prior over parameters

  • simulator – a callable (rng_key, theta) -> y

  • summary_fn – maps simulated data to summary statistics

  • distance_fn – distance between simulated and observed summaries

Returns:

an ABCSampler

Summary statistics#

sbijax.nass(network)[source]#

Construct a neural approximate sufficient statistics summary network.

Parameters:

network – a NASS summary network with forward, summary and critic methods

Returns:

a SummaryFns

sbijax.nasss(network)[source]#

Construct a neural approximate slice sufficient statistics summary network.

Parameters:

network – a NASSS summary network with forward, summary, secondary_summary and critic methods

Returns:

a SummaryFns

A summary network is chained into a downstream estimator with:

sbijax.summarized_estimator(estimator, summary_net, summary_params)[source]#

Adapt an objective to operate on learned summaries.

Given a pre-fitted summary network, returns an ObjectiveFns whose training summarizes each batch before delegating to the wrapped estimator and whose sample_fn summarizes the observation before sampling. Because it returns an ObjectiveFns record it stays conformant and is driven by the generic train() / sampling helpers.

Fit the summary network first, then wrap the estimator:

sn = nass(make_nass_net(2, [64, 64]))
sn_params, _ = train(key, sn, data)
est = summarized_estimator(nle(make_maf(2)), sn, sn_params)
params, info = train(key, est, data)           # trains on summaries
samples, _ = est.sample_fn(key, params, y_observed, sampler=sampler)
Parameters:
  • estimator – the downstream ObjectiveFns consuming the summaries

  • summary_net – a fitted SummaryFns

  • summary_params – the summary network’s fitted parameters

Returns:

an ObjectiveFns

Training and sampling#

Trainable objectives are trained and sampled with the two free drivers. The sampler for likelihood/ratio methods is built with sbijax.mcmc.make_sampler().

sbijax.train(rng_key, objective, data, *, optimizer=None, info=None, n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, n_early_stopping_delta=0.001)[source]#

Train any objective’s TrainFns with early stopping.

Parameters:
  • rng_key – a jax random key

  • objective – an ObjectiveFns/SummaryFns (anything with train)

  • data – a {"y", "theta"} dataset pytree

  • optimizer – an optax optimizer, bound into the primitives here

  • info – previous round’s Info (None for round 0)

  • n_iter – number of epochs

  • batch_size – minibatch size

  • percentage_data_as_validation_set – validation split fraction

  • n_early_stopping_patience – early-stopping patience

  • n_early_stopping_delta – minimum early-stopping improvement

Returns:

a tuple (params, Info)

sbijax.sample(rng_key, objective, params, observable, *, sampler=None, **kwargs)[source]#

Draw posterior samples from a trained objective.

A one-line dispatch to objective.sample_fn kept for symmetry with sbijax.train().

Parameters:
  • rng_key – a jax random key

  • objective – an ObjectiveFns

  • params – the trained parameters

  • observable – the observation to condition on

  • sampler – a sampler from make_sampler() (required for MCMC methods, ignored by amortized methods)

  • **kwargs – forwarded to sample_fn

Returns:

(samples, info)

Sequential inference#

sbijax.run_sequential(rng_key, objective, prior, simulator, observable, *, n_rounds, n_simulations_per_round, sampler=None, proposal_fn=None, **train_kwargs)[source]#

Run multi-round sequential inference.

Round 0 simulates from the prior; later rounds simulate from a proposal built from the current posterior, append, and refit. NPE switches to its atomic objective (objective.extra(prior)) in rounds > 0; proposal-invariant methods reuse objective.

Parameters:
  • rng_key – a jax random key

  • objective – an ObjectiveFns

  • prior – the prior distribution

  • simulator – (rng_key, theta) -> y

  • observable – the observation to condition on

  • n_rounds – number of simulate/append/refit rounds

  • n_simulations_per_round – pairs drawn each round

  • sampler – a sampler (from make_sampler) for MCMC-based objectives

  • proposal_fn – optional (objective, params, observable, sampler) -> ((rng, n) -> theta); defaults to sampling the fitted posterior

  • **train_kwargs – forwarded to train each round

Returns:

(params, Info) from the final round

Diagnostics#

sbijax.sbc(rng_key, objective, params, prior, simulator, *, sampler=None, n_simulations=100, n_posterior_samples=1000, **sample_kwargs)[source]#

Compute simulation-based calibration ranks for a fitted objective.

For each of n_simulations draws theta* ~ prior and y* ~ simulator(theta*), draws n_posterior_samples from the posterior given y* and ranks each dimension of theta* among them. Calibrated posteriors yield ranks uniform on [0, n_posterior_samples].

Parameters:
  • rng_key – a jax random key

  • objective – an ObjectiveFns returned by a factory such as npe()

  • params – the fitted parameters

  • prior – the prior distribution

  • simulator – a callable (rng_key, theta) -> y

  • sampler – a sampler from make_sampler(); required for MCMC methods, ignored by amortized methods

  • n_simulations – number of calibration draws

  • n_posterior_samples – posterior draws per calibration draw

  • **sample_kwargs – forwarded to sample

Returns:

an integer array of shape (n_simulations, n_dims) of ranks