sbijax#
The top-level module, sbijax, contains all implemented methods for neural
simulation-based inference and approximate Bayesian inference as well as
functionality for visualization and other utility.
|
Consistency model posterior estimation. |
|
Flow matching posterior estimation. |
|
Neural posterior estimation. |
|
Neural likelihood estimation. |
|
Surjective neural likelihood estimation. |
|
Neural ratio estimation. |
|
Sequential Monte Carlo approximate Bayesian computation. |
|
Neural approximate summary statistics. |
|
Neural approximate slice sufficient statistics. |
|
Effective sample size plot. |
|
Visualize the training and validation loss profile. |
|
Rank statistics plots. |
|
Split-$hat{R}$ and relative effective sample size plot. |
|
Posterior histogram plot. |
|
MCMC trace plot. |
|
Convert a PyTree to an inference data object. |
|
Convert inference data to a PyTree. |
Posterior estimation#
- class sbijax.CMPE(model_fns, network, t_max=50.0, t_min=0.001)[source]#
Consistency model posterior estimation.
Implements the CMPE algorithm introduced in Schmitt et al. [2023].
- Parameters:
model_fns –
- a tuple of callables. The first element needs to be a
function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.
network: a consistency model t_min: minimal time point for ODE integration t_max: maximal time point for ODE integration
Examples
>>> from sbijax import CMPE >>> from sbijax.nn import make_cm >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(0.0, 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed) >>> fns = prior, s >>> neural_network = make_cm(1) >>> model = CMPE(fns, neural_network)
References
Schmitt, Marvin, et al. “Consistency Models for Scalable and Fast Simulation-Based Inference”. arXiv preprint arXiv:2312.05440, 2023.
- fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, n_early_stopping_delta=0.001, **kwargs)#
Fit the model.
- Parameters:
rng_key (PRNGKey) – a jax random key
data (Any) – data set obtained from calling simulate_data_and_possibly_append
optimizer (GradientTransformation) – an optax optimizer object
n_iter (int) – maximal number of training iterations per round
batch_size (int) – batch size used for training the model
percentage_data_as_validation_set (float) – percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience (int) – number of iterations of no improvement of training the flow before stopping optimisation
**kwargs – optional keyword arguments
n_early_stopping_delta (float)
- Returns:
a tuple of parameters and a tuple of the training information
- sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_samples – number of samples to draw
- Returns:
returns an array of samples from the posterior distribution of dimension (n_samples times p)
- simulate_data(rng_key, *, params=None, observable=None, n_simulations=1000, **kwargs)#
Simulate data from the posterior or prior and append.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using ‘observable;
observable – an observation. Needs to be gfiven if posterior draws are desired
n_simulations – number of newly simulated data
kwargs – dictionary of ey value pairs passed to sample_posterior
- Returns:
a NamedTuple of two axis, y and theta
- simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#
Simulate data and paarameters from the prior or posterior and append.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters
observable – an observation
data – existing data set
n_simulations – number of newly simulated data
kwargs – dictionary of ey value pairs passed to sample_posterior
- Returns:
returns a NamedTuple of two axis, y and theta
- class sbijax.FMPE(model_fns, density_estimator)[source]#
Flow matching posterior estimation.
Implements the FMPE algorithm introduced in Wildberger et al. [2023].
- Parameters:
model_fns – a tuple of callables. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.
density_estimator – a continuous normalizing flow model
Examples
>>> from sbijax import FMPE >>> from sbijax.nn import make_cnf >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(0.0, 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed) >>> fns = prior, s >>> neural_network = make_cnf(1) >>> model = FMPE(fns, neural_network)
References
Wildberger, Jonas, et al. “Flow Matching for Scalable Simulation-Based Inference.” Advances in Neural Information Processing Systems, 2024.
- fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, n_early_stopping_delta=0.001, **kwargs)[source]#
Fit the model.
- Parameters:
rng_key (PRNGKey) – a jax random key
data (Any) – data set obtained from calling simulate_data_and_possibly_append
optimizer (GradientTransformation) – an optax optimizer object
n_iter (int) – maximal number of training iterations per round
batch_size (int) – batch size used for training the model
percentage_data_as_validation_set (float) – percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience (int) – number of iterations of no improvement of training the flow before stopping optimisation
**kwargs – optional keyword arguments
n_early_stopping_delta (float)
- Returns:
a tuple of parameters and a tuple of the training information
- sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)[source]#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_samples – number of samples to draw
- Returns:
returns an array of samples from the posterior distribution of dimension (n_samples times p)
- simulate_data(rng_key, *, params=None, observable=None, n_simulations=1000, **kwargs)#
Simulate data from the posterior or prior and append.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using ‘observable;
observable – an observation. Needs to be gfiven if posterior draws are desired
n_simulations – number of newly simulated data
kwargs – dictionary of ey value pairs passed to sample_posterior
- Returns:
a NamedTuple of two axis, y and theta
- simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#
Simulate data and paarameters from the prior or posterior and append.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters
observable – an observation
data – existing data set
n_simulations – number of newly simulated data
kwargs – dictionary of ey value pairs passed to sample_posterior
- Returns:
returns a NamedTuple of two axis, y and theta
- class sbijax.NPE(model_fns, density_estimator, num_atoms=10, use_event_space_bijections=True)[source]#
Neural posterior estimation.
Implements the method introduced in Greenberg et al. [2019]. In the literature, the method is usually referred to as APT or NPE-C, but here we refer to it simply as NPE.
- Parameters:
model_fns – a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.
density_estimator – a (neural) conditional density estimator to model the posterior distribution
num_atoms – number of atomic atoms
Examples
>>> from sbijax import NPE >>> from sbijax.nn import make_maf >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(0.0, 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed) >>> fns = prior, s >>> neural_network = make_maf(1) >>> model = NPE(fns, neural_network)
References
Greenberg, David, et al. “Automatic posterior transformation for likelihood-free inference.” International Conference on Machine Learning, 2019.
- fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#
Fit an SNP model.
- Parameters:
rng_key – a jax random key
data – data set obtained from calling simulate_data_and_possibly_append
optimizer – an optax optimizer object
n_iter – maximal number of training iterations per round
batch_size – batch size used for training the model
percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation
- Returns:
a tuple of parameters and a tuple of the training information
- sample_posterior(rng_key, params, observable, *, n_samples=4000, check_proposal_probs=True, **kwargs)[source]#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_samples – number of samples to draw
check_proposal_probs – check if the proposal draws have finite density and only accept a proposal if it is. This is convenient to turn off if the density estimator has to learn a density of a constrained variable, and the RejectionMC takes a long time to draw valid samples.
- Returns:
returns an array of samples from the posterior distribution of dimension (n_samples times p)
- simulate_data(rng_key, *, params=None, observable=None, n_simulations=1000, **kwargs)#
Simulate data from the posterior or prior and append.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using ‘observable;
observable – an observation. Needs to be gfiven if posterior draws are desired
n_simulations – number of newly simulated data
kwargs – dictionary of ey value pairs passed to sample_posterior
- Returns:
a NamedTuple of two axis, y and theta
- simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#
Simulate data and paarameters from the prior or posterior and append.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters
observable – an observation
data – existing data set
n_simulations – number of newly simulated data
kwargs – dictionary of ey value pairs passed to sample_posterior
- Returns:
returns a NamedTuple of two axis, y and theta
Likelihood estimation#
- class sbijax.NLE(model_fns, network)[source]#
Neural likelihood estimation.
Implements the method introduced in Papamakarios et al. [2019].
- Parameters:
model_fns – a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.
density_estimator – a (neural) conditional density estimator to model the likelihood function
Examples
>>> from sbijax import NLE >>> from sbijax.nn import make_mdn >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(0.0, 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed) >>> fns = prior, s >>> neural_network = make_mdn(1, 5) >>> model = NLE(fns, neural_network)
References
Papamakarios, George, et al. “Sequential neural likelihood: Fast likelihood-free inference with autoregressive flows.” International Conference on Artificial Intelligence and Statistics, 2019.
- fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#
Fit the model.
- Parameters:
rng_key – a jax random key
data – data set obtained from calling simulate_data_and_possibly_append
optimizer – an optax optimizer object
n_iter – maximal number of training iterations per round
batch_size – batch size used for training the model
percentage_data_as_validation_set – percentage of the simulated data that is used for valitation and early stopping
n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation
**kwargs – additional keyword arguments (not used for NLE)
- Returns:
a tuple of parameters and a tuple of the training information
- simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Simulate data from the prior or posterior.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters
observable – an observation
data – existing data set
n_simulations – number of newly simulated data
n_chains – number of MCMC chains
n_samples – number of sa les to draw in total
n_warmup – number of draws to discarded
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
returns a NamedTuple with two elements, y and theta
- simulate_data(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Simulate data from the posterior or prior and append.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using ‘observable;
observable – an observation. Needs to be gfiven if posterior draws are desired
n_simulations – number of newly simulated data
kwargs – dictionary of ey value pairs passed to sample_posterior
- Returns:
a NamedTuple of two axis, y and theta
- sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_chains – number of MCMC chains
n_samples – number of samples per chain
n_warmup – number of samples to discard
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics
- class sbijax.SNLE(model_fns, network)[source]#
Surjective neural likelihood estimation.
Implements the method introduced in Dirmeier et al. [2023]. SNLE is particularly useful when dealing with high-dimensional data since it reduces its dimensionality using dimensionality reduction.
- Parameters:
model_fns – a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.
density_estimator – a (neural) conditional density estimator to model the likelihood function
Examples
>>> from jax import numpy as jnp >>> from sbijax import SNLE >>> from sbijax.nn import make_maf >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(5), 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal( ... theta["theta"], 1.0).sample(seed=seed, sample_shape=(2,) ... ).reshape(-1, 10) >>> fns = prior, s >>> neural_network = make_maf(10, n_layer_dimensions=[10, 10, 5, 5, 5]) >>> model = SNLE(fns, neural_network)
References
Dirmeier, Simon, et al. “Simulation-based inference using surjective sequential neural likelihood estimation.” arXiv preprint arXiv:2308.01054, 2023.
- fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)#
Fit the model.
- Parameters:
rng_key – a jax random key
data – data set obtained from calling simulate_data_and_possibly_append
optimizer – an optax optimizer object
n_iter – maximal number of training iterations per round
batch_size – batch size used for training the model
percentage_data_as_validation_set – percentage of the simulated data that is used for valitation and early stopping
n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation
**kwargs – additional keyword arguments (not used for NLE)
- Returns:
a tuple of parameters and a tuple of the training information
- sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_chains – number of MCMC chains
n_samples – number of samples per chain
n_warmup – number of samples to discard
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics
- simulate_data(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)#
Simulate data from the posterior or prior and append.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using ‘observable;
observable – an observation. Needs to be gfiven if posterior draws are desired
n_simulations – number of newly simulated data
kwargs – dictionary of ey value pairs passed to sample_posterior
- Returns:
a NamedTuple of two axis, y and theta
- simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)#
Simulate data from the prior or posterior.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters
observable – an observation
data – existing data set
n_simulations – number of newly simulated data
n_chains – number of MCMC chains
n_samples – number of sa les to draw in total
n_warmup – number of draws to discarded
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
returns a NamedTuple with two elements, y and theta
Likelihood-ratio estimation#
- class sbijax.NRE(model_fns, classifier, num_classes=10, gamma=1.0)[source]#
Neural ratio estimation.
Implements the method by Miller et al. [2022]. The original publication calls the method as CNRE or NRE-C, but here, we refer to it as NRE.
- Parameters:
model_fns (tuple[Callable, Callable]) – a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.
classifier (Transformed) – a neural network for classification
num_classes (int) – number of classes to classify against
gamma (float) – relative weight of classes
Examples
>>> from sbijax import NRE >>> from sbijax.nn import make_resnet >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(0.0, 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed) >>> fns = prior, s >>> neural_network = make_resnet() >>> model = NRE(fns, neural_network)
References
Miller, Benjamin K., et al. “Contrastive neural ratio estimation.” Advances in Neural Information Processing Systems, 2022.
- fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=25, n_early_stopping_delta=0.001, **kwargs)[source]#
Fit the model.
- Parameters:
rng_key (Array) – a jax random key
data (NamedTuple) – data set obtained from calling simulate_data_and_possibly_append
optimizer (GradientTransformation) – an optax optimizer object
n_iter (int) – maximal number of training iterations per round
batch_size (int) – batch size used for training the model
percentage_data_as_validation_set (float) – percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience (int) – number of iterations of no improvement of training the flow before stopping optimisation
n_early_stopping_delta – minimal value for improvement for early stopping
- Returns:
a tuple of parameters and a tuple of the training information
- simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Simulate data from the prior or posterior.
Simulate new parameters and observables from the prior or posterior (when params and data given). If a data argument is provided, append the new samples to the data set and return the old+new data.
- Parameters:
rng_key (Array) – a jax random key
params (Mapping[str, Mapping[str, Array]] | None) – a dictionary of neural network parameters
observable (Array) – an observation
data (tuple | None) – existing data set or None
n_simulations (int) – number of newly simulated data
n_chains (int) – number of MCMC chains
n_samples (int) – number of sa les to draw in total
n_warmup (int) – number of draws to discarded
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
returns a NamedTuple of two axis, y and theta
- sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_chains – number of MCMC chains
n_samples – number of samples per chain
n_warmup – number of samples to discard
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
returns an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics
- simulate_data(rng_key, *, params=None, observable=None, n_simulations=1000, **kwargs)#
Simulate data from the posterior or prior and append.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using ‘observable;
observable – an observation. Needs to be gfiven if posterior draws are desired
n_simulations – number of newly simulated data
kwargs – dictionary of ey value pairs passed to sample_posterior
- Returns:
a NamedTuple of two axis, y and theta
Approximate Bayesian computation#
- class sbijax.SMCABC(model_fns, summary_fn, distance_fn)[source]#
Sequential Monte Carlo approximate Bayesian computation.
Implements the algorithm from Beaumont et al. [2009].
- Parameters:
model_fns – a tuple of callables. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.
summary_fn – summary function
distance_fn – distance function
Examples
>>> from sbijax import SMCABC >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(0.0, 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed) >>> fns = prior, s >>> summary_fn = lambda x: x >>> distance_fn = lambda x, y: jax.vmap(lambda z: jnp.linalg.norm(z))(x - y) >>> model = SMCABC(fns, summary_fn, distance_fn)
References
Beaumont, Mark A, et al. “Adaptive approximate Bayesian computation”. Biometrika, 2009.
- sample_posterior(rng_key, observable, n_rounds=10, n_particles=10000, eps_step=0.825, ess_min=2000, cov_scale=1.0)[source]#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random
n_rounds – max number of SMC rounds
observable – the observation to condition on
n_rounds – number of rounds of SMC
n_particles – number of n_particles to draw for each parameter
eps_step – decay of initial epsilon per simulation round
ess_min – minimal effective sample size
cov_scale – scaling of the transition kernel covariance
- Returns:
an array of samples from the posterior distribution of dimension (n_samples times p)
Summary statistics#
- class sbijax.NASS(model_fns, summary_net)[source]#
Neural approximate summary statistics.
Implements the NASS algorithm introduced in Chen et al. [2023]. NASS can be used to automatically summary statistics of a data set. With the learned summaries, inferential algorithms like NLE or SMCABC can be used to infer posterior distributions.
- Parameters:
model_fns – a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.
summary_net – a SNASSNet object
Examples
>>> from sbijax import NASS >>> from sbijax.nn import make_nass_net >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(5), 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal( ... theta["theta"], 1.0).sample(seed=seed, sample_shape=(2,) ... ).reshape(-1, 10) >>> fns = prior, s >>> neural_network = make_nass_net([64, 64, 5], [64, 64, 1]) >>> model = NASS(fns, neural_network)
References
Chen, Yanzhi et al. “Neural Approximate Sufficient Statistics for Implicit Models”. ICLR, 2021
- fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#
Fit the model to data.
- Parameters:
rng_key – a jax random key
data – data set obtained from calling simulate_data_and_possibly_append
optimizer – an optax optimizer object
n_iter – maximal number of training iterations per round
batch_size – batch size used for training the model
percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation
**kwargs – additional keyword arguments not used for NASS)
- Returns:
tuple of parameters and a tuple of the training information
- class sbijax.NASSS(model_fns, summary_net)[source]#
Neural approximate slice sufficient statistics.
Implements the NASSS algorithm introduced in Chen et al. [2021]. NASS can be used to automatically summary statistics of a data set. With the learned summaries, inferential algorithms like NLE or SMCABC can be used to infer posterior distributions.
- Parameters:
model_fns – a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.
summary_net – a (neural) conditional density estimator to model the likelihood function of summary statistics, i.e., the modelled dimensionality is that of the summaries
summary_net – a SNASSSNet object
Examples
>>> from jax import numpy as jnp >>> from sbijax import NASSS >>> from sbijax.nn import make_nasss_net >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(5), 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal( ... theta["theta"], 1.0).sample(seed=seed, sample_shape=(2,) ... ).reshape(-1, 10) >>> fns = prior, s >>> neural_network = make_nasss_net([64, 64, 5], [64, 64, 1], [64, 64, 1]) >>> model = NASSS(fns, neural_network)
References
Yanzhi Chen et al. “Is Learning Summary Statistics Necessary for Likelihood-free Inference”. ICML, 2023
- fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)#
Fit the model to data.
- Parameters:
rng_key – a jax random key
data – data set obtained from calling simulate_data_and_possibly_append
optimizer – an optax optimizer object
n_iter – maximal number of training iterations per round
batch_size – batch size used for training the model
percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation
**kwargs – additional keyword arguments not used for NASS)
- Returns:
tuple of parameters and a tuple of the training information
Visualization#
- sbijax.plot_ess(inference_data)[source]#
Effective sample size plot.
- Parameters:
inference_data (DataTree) – an inference data object received from calling sample_posterior of an SBI algorithm
- Returns:
the same array of matplotlib axes with added plots
- sbijax.plot_loss_profile(losses, axes=None)[source]#
Visualize the training and validation loss profile.
- Return type:
Axes- Parameters:
losses (Array) – a jax.Array of training and validation losses
axes (Axes) – a matplotlib axes
- Returns:
the same array of matplotlib axes with added plots
- sbijax.plot_rank(inference_data)[source]#
Rank statistics plots.
- Parameters:
inference_data (DataTree) – an inference data object received from calling sample_posterior of an SBI algorithm
- Returns:
the same array of matplotlib axes with added plots
- sbijax.plot_rhat_and_ress(inference_data, axes=None)[source]#
Split-\(\hat{R}\) and relative effective sample size plot.
- Return type:
ndarray[Axes]- Parameters:
inference_data (DataTree) – an inference data object received from calling sample_posterior of an SBI algorithm
axes (ndarray[Axes]) – an array of matplotlib axes
- Returns:
the same array of matplotlib axes with added plots
- sbijax.plot_posterior(inference_data)[source]#
Posterior histogram plot.
- Parameters:
inference_data (DataTree) – an inference data object received from calling sample_posterior of an SBI algorithm
axes – an array of matplotlib axes
- Returns:
the same array of matplotlib axes with added plots
- sbijax.plot_trace(inference_data)[source]#
MCMC trace plot.
- Parameters:
inference_data (DataTree) – an inference data object received from calling sample_posterior of an SBI algorithm
axes – an array of matplotlib axes
**kwargs – additional parameters passed to Arviz
- Returns:
the same array of matplotlib axes with added plots