sbijax
#
sbijax
contains the implemented methods for simulation-based inference.
Methods#
|
Sequential Monte Carlo approximate Bayesian computation. |
|
Sequential neural likelihood. |
|
Sequential neural posterior estimation. |
|
Sequential (contrastive) neural ratio estimation. |
|
Sequential flow matching posterior estimation. |
|
Sequential consistency model posterior estimation. |
|
Sequential neural approximate summary statistics. |
|
Sequential neural approximate slice sufficient statistics. |
SMCABC#
- class sbijax.SMCABC(model_fns, summary_fn, distance_fn)[source]#
Sequential Monte Carlo approximate Bayesian computation.
Implements algorithm~4.8 from [1].
References
- sample_posterior(rng_key, observable, n_rounds, n_particles, n_simulations_per_theta, eps_step, ess_min, cov_scale=1.0)[source]#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random
n_rounds – max number of SMC rounds
observable – the observation to condition on
n_rounds – number of rounds of SMC
n_particles – number of n_particles to draw for each parameter
n_simulations_per_theta – number of simulations for each paramrter sample
eps_step – decay of initial epsilon per simulation round
ess_min – minimal effective sample size
cov_scale – scaling of the transition kernel covariance
- Returns:
an array of samples from the posterior distribution of dimension (n_samples times p)
SNL+SSNL#
- class sbijax.SNL(model_fns, density_estimator)[source]#
Sequential neural likelihood.
Implements both SNL and SSNL estimation methods.
- Parameters:
model_fns – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.
density_estimator – a (neural) conditional density estimator to model the likelihood function
References
[1] Papamakarios, George, et al. “Sequential neural likelihood: Fast likelihood-free inference with autoregressive flows.” International Conference on Artificial Intelligence and Statistics, 2019.
[2] Dirmeier, Simon, et al. “Simulation-based inference using surjective sequential neural likelihood estimation.” arXiv preprint arXiv:2308.01054, 2023.
- fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#
Fit a SNL or SSNL model.
- Parameters:
rng_key – a jax random key
data – data set obtained from calling simulate_data_and_possibly_append
optimizer – an optax optimizer object
n_iter – maximal number of training iterations per round
batch_size – batch size used for training the model
percentage_data_as_validation_set – percentage of the simulated data that is used for valitation and early stopping
n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation
- Returns:
a tuple of parameters and a tuple of the training information
- simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Simulate data from the prior or posterior.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters
observable – an observation
data – existing data set
n_simulations – number of newly simulated data
n_chains – number of MCMC chains
n_samples – number of sa les to draw in total
n_warmup – number of draws to discarded
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
returns a NamedTuple with two elements, y and theta
- sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_chains – number of MCMC chains
n_samples – number of samples per chain
n_warmup – number of samples to discard
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics
SNP#
- class sbijax.SNP(model_fns, density_estimator, num_atoms=10)[source]#
Sequential neural posterior estimation.
- Parameters:
model_fns – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.
density_estimator – a (neural) conditional density estimator to model the posterior distribution
num_atoms – number of atomic atoms
Examples
>>> import distrax >>> from sbijax import SNP >>> from sbijax.nn import make_affine_maf >>> >>> prior = distrax.Normal(0.0, 1.0) >>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed) >>> fns = (prior.sample, prior.log_prob), s >>> flow = make_affine_maf(1) >>> >>> estim = SNP(fns, flow)
References
[1] Greenberg, David, et al. “Automatic posterior transformation for likelihood-free inference.” International Conference on Machine Learning, 2019.
- fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#
Fit an SNP model.
- Parameters:
rng_key – a jax random key
data – data set obtained from calling simulate_data_and_possibly_append
optimizer – an optax optimizer object
n_iter – maximal number of training iterations per round
batch_size – batch size used for training the model
percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation
- Returns:
a tuple of parameters and a tuple of the training information
- sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)[source]#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_samples – number of samples to draw
- Returns:
returns an array of samples from the posterior distribution of dimension (n_samples times p)
- simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#
Simulate data and paarameters from the prior or posterior and append.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters
observable – an observation
data – existing data set
n_simulations – number of newly simulated data
kwargs – dictionary of ey value pairs passed to sample_posterior
- Returns:
returns a NamedTuple of two axis, y and theta
SNR#
- class sbijax.SNR(model_fns, classifier, num_classes=10, gamma=1.0)[source]#
Sequential (contrastive) neural ratio estimation.
- Parameters:
model_fns (tuple[tuple[Callable, Callable], Callable]) – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.
classifier (Transformed) – a neural network for classification
num_classes (int) – number of classes to classify against
gamma (float) – relative weight of classes
Examples
>>> import distrax >>> from sbijax import SNR >>> from sbijax.nn import make_resnet >>> >>> prior = distrax.Normal(0.0, 1.0) >>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed) >>> fns = (prior.sample, prior.log_prob), s >>> resnet = make_resnet() >>> >>> snr = SNR(fns, resnet)
References
[1] Miller, Benjamin K., et al. “Contrastive neural ratio estimation.” Advances in Neural Information Processing Systems, 2022.
- fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#
Fit an SNR model.
- Parameters:
rng_key (Array) – a jax random key
data (NamedTuple) – data set obtained from calling simulate_data_and_possibly_append
optimizer (GradientTransformation) – an optax optimizer object
n_iter (int) – maximal number of training iterations per round
batch_size (int) – batch size used for training the model
percentage_data_as_validation_set (float) – percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience (float) – number of iterations of no improvement of training the flow before stopping optimisation
- Returns:
a tuple of parameters and a tuple of the training information
- simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Simulate data from the prior or posterior.
Simulate new parameters and observables from the prior or posterior (when params and data given). If a data argument is provided, append the new samples to the data set and return the old+new data.
- Parameters:
rng_key (Array) – a jax random key
params (Mapping[str, Mapping[str, Array]] | None) – a dictionary of neural network parameters
observable (Array) – an observation
data (NamedTuple) – existing data set or None
n_simulations (int) – number of newly simulated data
n_chains (int) – number of MCMC chains
n_samples (int) – number of sa les to draw in total
n_warmup (int) – number of draws to discarded
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
returns a NamedTuple of two axis, y and theta
- sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_chains – number of MCMC chains
n_samples – number of samples per chain
n_warmup – number of samples to discard
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
returns an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics
SFMPE#
- class sbijax.SFMPE(model_fns, density_estimator)[source]#
Sequential flow matching posterior estimation.
Implements a sequential version of the FMPE algorithm introduced in [1]_. For all rounds \(r > 1\) parameter samples \(\theta \sim \hat{p}^r(\theta)\) are drawn from the approximate posterior instead of the prior when computing the flow matching loss. Note that the implementation does not strictly follow the paper.
- Parameters:
model_fns – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.
density_estimator – a continuous normalizing flow model
Examples
>>> import distrax >>> from sbijax import SFMPE >>> from sbijax.nn import make_ccnf >>> >>> prior = distrax.Normal(0.0, 1.0) >>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed) >>> fns = (prior.sample, prior.log_prob), s >>> flow = make_ccnf(1) >>> >>> estim = SFMPE(fns, flow)
References
[1] Wildberger, Jonas, et al. “Flow Matching for Scalable Simulation-Based Inference.” Advances in Neural Information Processing Systems, 2024.
- fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#
Fit the model.
- Parameters:
rng_key – a jax random key
data – data set obtained from calling simulate_data_and_possibly_append
optimizer – an optax optimizer object
n_iter – maximal number of training iterations per round
batch_size – batch size used for training the model
percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation
- Returns:
a tuple of parameters and a tuple of the training information
- sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)[source]#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_samples – number of samples to draw
- Returns:
returns an array of samples from the posterior distribution of dimension (n_samples times p)
- simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#
Simulate data and paarameters from the prior or posterior and append.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters
observable – an observation
data – existing data set
n_simulations – number of newly simulated data
kwargs – dictionary of ey value pairs passed to sample_posterior
- Returns:
returns a NamedTuple of two axis, y and theta
SCMPE#
- class sbijax.SCMPE(model_fns, network, t_max=50.0, t_min=0.001)[source]#
Sequential consistency model posterior estimation.
Implements a sequential version of the CMPE algorithm introduced in [1]_. For all rounds \(r > 1\) parameter samples \(\theta \sim \hat{p}^r(\theta)\) are drawn from the approximate posterior instead of the prior when computing consistency loss. Note that the implementation does not strictly follow the paper.
- Parameters:
model_fns – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.
network – a neural network
t_min – minimal time point for ODE integration
t_max – maximal time point for ODE integration
Examples
>>> import distrax >>> from sbijax import SCMPE >>> from sbijax.nn import make_consistency_model >>> >>> prior = distrax.Normal(0.0, 1.0) >>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed) >>> fns = (prior.sample, prior.log_prob), s >>> net = make_consistency_model(1) >>> >>> estim = SCMPE(fns, net)
References
[1] Schmitt, Marvin, et al. “Consistency Models for Scalable and Fast Simulation-Based Inference”. arXiv preprint arXiv:2312.05440, 2023.
- fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)#
Fit the model.
- Parameters:
rng_key – a jax random key
data – data set obtained from calling simulate_data_and_possibly_append
optimizer – an optax optimizer object
n_iter – maximal number of training iterations per round
batch_size – batch size used for training the model
percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation
- Returns:
a tuple of parameters and a tuple of the training information
- simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#
Simulate data and paarameters from the prior or posterior and append.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters
observable – an observation
data – existing data set
n_simulations – number of newly simulated data
kwargs – dictionary of ey value pairs passed to sample_posterior
- Returns:
returns a NamedTuple of two axis, y and theta
- sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_samples – number of samples to draw
- Returns:
returns an array of samples from the posterior distribution of dimension (n_samples times p)
SNASS#
- class sbijax.SNASS(model_fns, density_estimator, summary_net)[source]#
Sequential neural approximate summary statistics.
- Parameters:
model_fns – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.
density_estimator – a (neural) conditional density estimator to model the likelihood function of summary statistics, i.e., the modelled dimensionality is that of the summaries
snass_net – a SNASSNet object
References
[1] Chen, Yanzhi et al. “Neural Approximate Sufficient Statistics for Implicit Models”. ICLR, 2021
- fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#
Fit the model to data.
- Parameters:
rng_key – a jax random key
data – data set obtained from calling simulate_data_and_possibly_append
optimizer – an optax optimizer object
n_iter – maximal number of training iterations per round
batch_size – batch size used for training the model
percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
tuple of parameters and a tuple of the training information
- sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_chains – number of MCMC chains
n_samples – number of samples per chain
n_warmup – number of samples to discard
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics
- simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)#
Simulate data from the prior or posterior.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters
observable – an observation
data – existing data set
n_simulations – number of newly simulated data
n_chains – number of MCMC chains
n_samples – number of sa les to draw in total
n_warmup – number of draws to discarded
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
returns a NamedTuple with two elements, y and theta
SNASSS#
- class sbijax.SNASSS(model_fns, density_estimator, summary_net)[source]#
Sequential neural approximate slice sufficient statistics.
- Parameters:
model_fns – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.
density_estimator – a (neural) conditional density estimator to model the likelihood function of summary statistics, i.e., the modelled dimensionality is that of the summaries
summary_net – a SNASSSNet object
References
[1] Yanzhi Chen et al. “Is Learning Summary Statistics Necessary for Likelihood-free Inference”. ICML, 2023
- fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)#
Fit the model to data.
- Parameters:
rng_key – a jax random key
data – data set obtained from calling simulate_data_and_possibly_append
optimizer – an optax optimizer object
n_iter – maximal number of training iterations per round
batch_size – batch size used for training the model
percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
tuple of parameters and a tuple of the training information
- simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)#
Simulate data from the prior or posterior.
- Parameters:
rng_key – a random key
params – a dictionary of neural network parameters
observable – an observation
data – existing data set
n_simulations – number of newly simulated data
n_chains – number of MCMC chains
n_samples – number of sa les to draw in total
n_warmup – number of draws to discarded
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
returns a NamedTuple with two elements, y and theta
- sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)#
Sample from the approximate posterior.
- Parameters:
rng_key – a jax random key
params – a pytree of neural network parameters
observable – observation to condition on
n_chains – number of MCMC chains
n_samples – number of samples per chain
n_warmup – number of samples to discard
- Keyword Arguments:
sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)
n_thin (int) – number of thinning steps (only used if sampler=’slice’)
n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)
step_size (float) – step size of the initial interval (only used if sampler=’slice’)
- Returns:
an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics