sbijax#

The top-level module, sbijax, contains all implemented methods for neural simulation-based inference and approximate Bayesian inference as well as functionality for visualization and other utility.

CMPE(model_fns, network[, t_max, t_min])

Consistency model posterior estimation.

FMPE(model_fns, density_estimator)

Flow matching posterior estimation.

NPE(model_fns, density_estimator[, ...])

Neural posterior estimation.

NLE(model_fns, network)

Neural likelihood estimation.

SNLE(model_fns, network)

Surjective neural likelihood estimation.

NRE(model_fns, classifier[, num_classes, gamma])

Neural ratio estimation.

SMCABC(model_fns, summary_fn, distance_fn)

Sequential Monte Carlo approximate Bayesian computation.

NASS(model_fns, summary_net)

Neural approximate summary statistics.

NASSS(model_fns, summary_net)

Neural approximate slice sufficient statistics.

plot_ess(inference_data)

Effective sample size plot.

plot_loss_profile(losses[, axes])

Visualize the training and validation loss profile.

plot_rank(inference_data)

Rank statistics plots.

plot_rhat_and_ress(inference_data[, axes])

Split-$hat{R}$ and relative effective sample size plot.

plot_posterior(inference_data)

Posterior histogram plot.

plot_trace(inference_data)

MCMC trace plot.

as_inference_data(samples, observed)

Convert a PyTree to an inference data object.

inference_data_as_dictionary(inference_data)

Convert inference data to a PyTree.

Posterior estimation#

class sbijax.CMPE(model_fns, network, t_max=50.0, t_min=0.001)[source]#

Consistency model posterior estimation.

Implements the CMPE algorithm introduced in Schmitt et al. [2023].

Parameters:

model_fns

a tuple of callables. The first element needs to be a

function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.

network: a consistency model t_min: minimal time point for ODE integration t_max: maximal time point for ODE integration

Examples

>>> from sbijax import CMPE
>>> from sbijax.nn import make_cm
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
...     dict(theta=tfd.Normal(0.0, 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed)
>>> fns = prior, s
>>> neural_network = make_cm(1)
>>> model = CMPE(fns, neural_network)

References

Schmitt, Marvin, et al. “Consistency Models for Scalable and Fast Simulation-Based Inference”. arXiv preprint arXiv:2312.05440, 2023.

fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, n_early_stopping_delta=0.001, **kwargs)#

Fit the model.

Parameters:
  • rng_key (PRNGKey) – a jax random key

  • data (Any) – data set obtained from calling simulate_data_and_possibly_append

  • optimizer (GradientTransformation) – an optax optimizer object

  • n_iter (int) – maximal number of training iterations per round

  • batch_size (int) – batch size used for training the model

  • percentage_data_as_validation_set (float) – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience (int) – number of iterations of no improvement of training the flow before stopping optimisation

  • **kwargs – optional keyword arguments

  • n_early_stopping_delta (float)

Returns:

a tuple of parameters and a tuple of the training information

sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_samples – number of samples to draw

Returns:

returns an array of samples from the posterior distribution of dimension (n_samples times p)

simulate_data(rng_key, *, params=None, observable=None, n_simulations=1000, **kwargs)#

Simulate data from the posterior or prior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using ‘observable;

  • observable – an observation. Needs to be gfiven if posterior draws are desired

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

a NamedTuple of two axis, y and theta

simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#

Simulate data and paarameters from the prior or posterior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

returns a NamedTuple of two axis, y and theta

class sbijax.FMPE(model_fns, density_estimator)[source]#

Flow matching posterior estimation.

Implements the FMPE algorithm introduced in Wildberger et al. [2023].

Parameters:
  • model_fns – a tuple of callables. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.

  • density_estimator – a continuous normalizing flow model

Examples

>>> from sbijax import FMPE
>>> from sbijax.nn import make_cnf
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
...     dict(theta=tfd.Normal(0.0, 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed)
>>> fns = prior, s
>>> neural_network = make_cnf(1)
>>> model = FMPE(fns, neural_network)

References

Wildberger, Jonas, et al. “Flow Matching for Scalable Simulation-Based Inference.” Advances in Neural Information Processing Systems, 2024.

fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, n_early_stopping_delta=0.001, **kwargs)[source]#

Fit the model.

Parameters:
  • rng_key (PRNGKey) – a jax random key

  • data (Any) – data set obtained from calling simulate_data_and_possibly_append

  • optimizer (GradientTransformation) – an optax optimizer object

  • n_iter (int) – maximal number of training iterations per round

  • batch_size (int) – batch size used for training the model

  • percentage_data_as_validation_set (float) – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience (int) – number of iterations of no improvement of training the flow before stopping optimisation

  • **kwargs – optional keyword arguments

  • n_early_stopping_delta (float)

Returns:

a tuple of parameters and a tuple of the training information

sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)[source]#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_samples – number of samples to draw

Returns:

returns an array of samples from the posterior distribution of dimension (n_samples times p)

simulate_data(rng_key, *, params=None, observable=None, n_simulations=1000, **kwargs)#

Simulate data from the posterior or prior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using ‘observable;

  • observable – an observation. Needs to be gfiven if posterior draws are desired

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

a NamedTuple of two axis, y and theta

simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#

Simulate data and paarameters from the prior or posterior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

returns a NamedTuple of two axis, y and theta

class sbijax.NPE(model_fns, density_estimator, num_atoms=10, use_event_space_bijections=True)[source]#

Neural posterior estimation.

Implements the method introduced in Greenberg et al. [2019]. In the literature, the method is usually referred to as APT or NPE-C, but here we refer to it simply as NPE.

Parameters:
  • model_fns – a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.

  • density_estimator – a (neural) conditional density estimator to model the posterior distribution

  • num_atoms – number of atomic atoms

Examples

>>> from sbijax import NPE
>>> from sbijax.nn import make_maf
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
...     dict(theta=tfd.Normal(0.0, 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed)
>>> fns = prior, s
>>> neural_network = make_maf(1)
>>> model = NPE(fns, neural_network)

References

Greenberg, David, et al. “Automatic posterior transformation for likelihood-free inference.” International Conference on Machine Learning, 2019.

fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#

Fit an SNP model.

Parameters:
  • rng_key – a jax random key

  • data – data set obtained from calling simulate_data_and_possibly_append

  • optimizer – an optax optimizer object

  • n_iter – maximal number of training iterations per round

  • batch_size – batch size used for training the model

  • percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation

Returns:

a tuple of parameters and a tuple of the training information

sample_posterior(rng_key, params, observable, *, n_samples=4000, check_proposal_probs=True, **kwargs)[source]#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_samples – number of samples to draw

  • check_proposal_probs – check if the proposal draws have finite density and only accept a proposal if it is. This is convenient to turn off if the density estimator has to learn a density of a constrained variable, and the RejectionMC takes a long time to draw valid samples.

Returns:

returns an array of samples from the posterior distribution of dimension (n_samples times p)

simulate_data(rng_key, *, params=None, observable=None, n_simulations=1000, **kwargs)#

Simulate data from the posterior or prior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using ‘observable;

  • observable – an observation. Needs to be gfiven if posterior draws are desired

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

a NamedTuple of two axis, y and theta

simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#

Simulate data and paarameters from the prior or posterior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

returns a NamedTuple of two axis, y and theta

Likelihood estimation#

class sbijax.NLE(model_fns, network)[source]#

Neural likelihood estimation.

Implements the method introduced in Papamakarios et al. [2019].

Parameters:
  • model_fns – a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.

  • density_estimator – a (neural) conditional density estimator to model the likelihood function

Examples

>>> from sbijax import NLE
>>> from sbijax.nn import make_mdn
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(0.0, 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed)
>>> fns = prior, s
>>> neural_network = make_mdn(1, 5)
>>> model = NLE(fns, neural_network)

References

Papamakarios, George, et al. “Sequential neural likelihood: Fast likelihood-free inference with autoregressive flows.” International Conference on Artificial Intelligence and Statistics, 2019.

fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#

Fit the model.

Parameters:
  • rng_key – a jax random key

  • data – data set obtained from calling simulate_data_and_possibly_append

  • optimizer – an optax optimizer object

  • n_iter – maximal number of training iterations per round

  • batch_size – batch size used for training the model

  • percentage_data_as_validation_set – percentage of the simulated data that is used for valitation and early stopping

  • n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation

  • **kwargs – additional keyword arguments (not used for NLE)

Returns:

a tuple of parameters and a tuple of the training information

simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Simulate data from the prior or posterior.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • n_chains – number of MCMC chains

  • n_samples – number of sa les to draw in total

  • n_warmup – number of draws to discarded

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

returns a NamedTuple with two elements, y and theta

simulate_data(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Simulate data from the posterior or prior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using ‘observable;

  • observable – an observation. Needs to be gfiven if posterior draws are desired

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

a NamedTuple of two axis, y and theta

sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_chains – number of MCMC chains

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics

class sbijax.SNLE(model_fns, network)[source]#

Surjective neural likelihood estimation.

Implements the method introduced in Dirmeier et al. [2023]. SNLE is particularly useful when dealing with high-dimensional data since it reduces its dimensionality using dimensionality reduction.

Parameters:
  • model_fns – a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.

  • density_estimator – a (neural) conditional density estimator to model the likelihood function

Examples

>>> from jax import numpy as jnp
>>> from sbijax import SNLE
>>> from sbijax.nn import make_maf
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(jnp.zeros(5), 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(
...     theta["theta"], 1.0).sample(seed=seed, sample_shape=(2,)
... ).reshape(-1, 10)
>>> fns = prior, s
>>> neural_network = make_maf(10, n_layer_dimensions=[10, 10, 5, 5, 5])
>>> model = SNLE(fns, neural_network)

References

Dirmeier, Simon, et al. “Simulation-based inference using surjective sequential neural likelihood estimation.” arXiv preprint arXiv:2308.01054, 2023.

fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)#

Fit the model.

Parameters:
  • rng_key – a jax random key

  • data – data set obtained from calling simulate_data_and_possibly_append

  • optimizer – an optax optimizer object

  • n_iter – maximal number of training iterations per round

  • batch_size – batch size used for training the model

  • percentage_data_as_validation_set – percentage of the simulated data that is used for valitation and early stopping

  • n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation

  • **kwargs – additional keyword arguments (not used for NLE)

Returns:

a tuple of parameters and a tuple of the training information

sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_chains – number of MCMC chains

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics

simulate_data(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)#

Simulate data from the posterior or prior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using ‘observable;

  • observable – an observation. Needs to be gfiven if posterior draws are desired

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

a NamedTuple of two axis, y and theta

simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)#

Simulate data from the prior or posterior.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • n_chains – number of MCMC chains

  • n_samples – number of sa les to draw in total

  • n_warmup – number of draws to discarded

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

returns a NamedTuple with two elements, y and theta

Likelihood-ratio estimation#

class sbijax.NRE(model_fns, classifier, num_classes=10, gamma=1.0)[source]#

Neural ratio estimation.

Implements the method by Miller et al. [2022]. The original publication calls the method as CNRE or NRE-C, but here, we refer to it as NRE.

Parameters:
  • model_fns (tuple[Callable, Callable]) – a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.

  • classifier (Transformed) – a neural network for classification

  • num_classes (int) – number of classes to classify against

  • gamma (float) – relative weight of classes

Examples

>>> from sbijax import NRE
>>> from sbijax.nn import make_resnet
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
...     dict(theta=tfd.Normal(0.0, 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed)
>>> fns = prior, s
>>> neural_network = make_resnet()
>>> model = NRE(fns, neural_network)

References

Miller, Benjamin K., et al. “Contrastive neural ratio estimation.” Advances in Neural Information Processing Systems, 2022.

fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=25, n_early_stopping_delta=0.001, **kwargs)[source]#

Fit the model.

Parameters:
  • rng_key (Array) – a jax random key

  • data (NamedTuple) – data set obtained from calling simulate_data_and_possibly_append

  • optimizer (GradientTransformation) – an optax optimizer object

  • n_iter (int) – maximal number of training iterations per round

  • batch_size (int) – batch size used for training the model

  • percentage_data_as_validation_set (float) – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience (int) – number of iterations of no improvement of training the flow before stopping optimisation

  • n_early_stopping_delta – minimal value for improvement for early stopping

Returns:

a tuple of parameters and a tuple of the training information

simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Simulate data from the prior or posterior.

Simulate new parameters and observables from the prior or posterior (when params and data given). If a data argument is provided, append the new samples to the data set and return the old+new data.

Parameters:
  • rng_key (Array) – a jax random key

  • params (Mapping[str, Mapping[str, Array]] | None) – a dictionary of neural network parameters

  • observable (Array) – an observation

  • data (tuple | None) – existing data set or None

  • n_simulations (int) – number of newly simulated data

  • n_chains (int) – number of MCMC chains

  • n_samples (int) – number of sa les to draw in total

  • n_warmup (int) – number of draws to discarded

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

returns a NamedTuple of two axis, y and theta

sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_chains – number of MCMC chains

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

returns an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics

simulate_data(rng_key, *, params=None, observable=None, n_simulations=1000, **kwargs)#

Simulate data from the posterior or prior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using ‘observable;

  • observable – an observation. Needs to be gfiven if posterior draws are desired

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

a NamedTuple of two axis, y and theta

Approximate Bayesian computation#

class sbijax.SMCABC(model_fns, summary_fn, distance_fn)[source]#

Sequential Monte Carlo approximate Bayesian computation.

Implements the algorithm from Beaumont et al. [2009].

Parameters:
  • model_fns – a tuple of callables. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.

  • summary_fn – summary function

  • distance_fn – distance function

Examples

>>> from sbijax import SMCABC
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
...     dict(theta=tfd.Normal(0.0, 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed)
>>> fns = prior, s
>>> summary_fn = lambda x: x
>>> distance_fn = lambda x, y: jax.vmap(lambda z: jnp.linalg.norm(z))(x - y)
>>> model = SMCABC(fns, summary_fn, distance_fn)

References

Beaumont, Mark A, et al. “Adaptive approximate Bayesian computation”. Biometrika, 2009.

sample_posterior(rng_key, observable, n_rounds=10, n_particles=10000, eps_step=0.825, ess_min=2000, cov_scale=1.0)[source]#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random

  • n_rounds – max number of SMC rounds

  • observable – the observation to condition on

  • n_rounds – number of rounds of SMC

  • n_particles – number of n_particles to draw for each parameter

  • eps_step – decay of initial epsilon per simulation round

  • ess_min – minimal effective sample size

  • cov_scale – scaling of the transition kernel covariance

Returns:

an array of samples from the posterior distribution of dimension (n_samples times p)

Summary statistics#

class sbijax.NASS(model_fns, summary_net)[source]#

Neural approximate summary statistics.

Implements the NASS algorithm introduced in Chen et al. [2023]. NASS can be used to automatically summary statistics of a data set. With the learned summaries, inferential algorithms like NLE or SMCABC can be used to infer posterior distributions.

Parameters:
  • model_fns – a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.

  • summary_net – a SNASSNet object

Examples

>>> from sbijax import NASS
>>> from sbijax.nn import make_nass_net
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(jnp.zeros(5), 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(
...     theta["theta"], 1.0).sample(seed=seed, sample_shape=(2,)
... ).reshape(-1, 10)
>>> fns = prior, s
>>> neural_network = make_nass_net([64, 64, 5], [64, 64, 1])
>>> model = NASS(fns, neural_network)

References

Chen, Yanzhi et al. “Neural Approximate Sufficient Statistics for Implicit Models”. ICLR, 2021

fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#

Fit the model to data.

Parameters:
  • rng_key – a jax random key

  • data – data set obtained from calling simulate_data_and_possibly_append

  • optimizer – an optax optimizer object

  • n_iter – maximal number of training iterations per round

  • batch_size – batch size used for training the model

  • percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation

  • **kwargs – additional keyword arguments not used for NASS)

Returns:

tuple of parameters and a tuple of the training information

class sbijax.NASSS(model_fns, summary_net)[source]#

Neural approximate slice sufficient statistics.

Implements the NASSS algorithm introduced in Chen et al. [2021]. NASS can be used to automatically summary statistics of a data set. With the learned summaries, inferential algorithms like NLE or SMCABC can be used to infer posterior distributions.

Parameters:
  • model_fns – a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.

  • summary_net – a (neural) conditional density estimator to model the likelihood function of summary statistics, i.e., the modelled dimensionality is that of the summaries

  • summary_net – a SNASSSNet object

Examples

>>> from jax import numpy as jnp
>>> from sbijax import NASSS
>>> from sbijax.nn import make_nasss_net
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(jnp.zeros(5), 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(
...     theta["theta"], 1.0).sample(seed=seed, sample_shape=(2,)
... ).reshape(-1, 10)
>>> fns = prior, s
>>> neural_network = make_nasss_net([64, 64, 5], [64, 64, 1], [64, 64, 1])
>>> model = NASSS(fns, neural_network)

References

Yanzhi Chen et al. “Is Learning Summary Statistics Necessary for Likelihood-free Inference”. ICML, 2023

fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)#

Fit the model to data.

Parameters:
  • rng_key – a jax random key

  • data – data set obtained from calling simulate_data_and_possibly_append

  • optimizer – an optax optimizer object

  • n_iter – maximal number of training iterations per round

  • batch_size – batch size used for training the model

  • percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation

  • **kwargs – additional keyword arguments not used for NASS)

Returns:

tuple of parameters and a tuple of the training information

Visualization#

sbijax.plot_ess(inference_data)[source]#

Effective sample size plot.

Parameters:

inference_data (DataTree) – an inference data object received from calling sample_posterior of an SBI algorithm

Returns:

the same array of matplotlib axes with added plots

sbijax.plot_loss_profile(losses, axes=None)[source]#

Visualize the training and validation loss profile.

Return type:

Axes

Parameters:
  • losses (Array) – a jax.Array of training and validation losses

  • axes (Axes) – a matplotlib axes

Returns:

the same array of matplotlib axes with added plots

sbijax.plot_rank(inference_data)[source]#

Rank statistics plots.

Parameters:

inference_data (DataTree) – an inference data object received from calling sample_posterior of an SBI algorithm

Returns:

the same array of matplotlib axes with added plots

sbijax.plot_rhat_and_ress(inference_data, axes=None)[source]#

Split-\(\hat{R}\) and relative effective sample size plot.

Return type:

ndarray[Axes]

Parameters:
  • inference_data (DataTree) – an inference data object received from calling sample_posterior of an SBI algorithm

  • axes (ndarray[Axes]) – an array of matplotlib axes

Returns:

the same array of matplotlib axes with added plots

sbijax.plot_posterior(inference_data)[source]#

Posterior histogram plot.

Parameters:
  • inference_data (DataTree) – an inference data object received from calling sample_posterior of an SBI algorithm

  • axes – an array of matplotlib axes

Returns:

the same array of matplotlib axes with added plots

sbijax.plot_trace(inference_data)[source]#

MCMC trace plot.

Parameters:
  • inference_data (DataTree) – an inference data object received from calling sample_posterior of an SBI algorithm

  • axes – an array of matplotlib axes

  • **kwargs – additional parameters passed to Arviz

Returns:

the same array of matplotlib axes with added plots

Utility#

sbijax.as_inference_data(samples, observed)[source]#

Convert a PyTree to an inference data object.

Return type:

DataTree

Parameters:
  • samples (Any) – a PyTree of posterior samples

  • observed (Array) – a jax.Array representing the observed data

Returns:

an inference data object

sbijax.inference_data_as_dictionary(inference_data)[source]#

Convert inference data to a PyTree.

Return type:

Any

Parameters:

inference_data (DataTree) – the posterior variable of an inference data object

Returns:

a PyTree