sbijax#


sbijax contains the implemented methods for simulation-based inference.

Methods#

SMCABC(model_fns, summary_fn, distance_fn)

Sequential Monte Carlo approximate Bayesian computation.

SNL(model_fns, density_estimator)

Sequential neural likelihood.

SNP(model_fns, density_estimator[, num_atoms])

Sequential neural posterior estimation.

SNR(model_fns, classifier[, num_classes, gamma])

Sequential (contrastive) neural ratio estimation.

SFMPE(model_fns, density_estimator)

Sequential flow matching posterior estimation.

SCMPE(model_fns, network[, t_max, t_min])

Sequential consistency model posterior estimation.

SNASS(model_fns, density_estimator, summary_net)

Sequential neural approximate summary statistics.

SNASSS(model_fns, density_estimator, summary_net)

Sequential neural approximate slice sufficient statistics.

SMCABC#

class sbijax.SMCABC(model_fns, summary_fn, distance_fn)[source]#

Sequential Monte Carlo approximate Bayesian computation.

Implements algorithm~4.8 from [1].

References

sample_posterior(rng_key, observable, n_rounds, n_particles, n_simulations_per_theta, eps_step, ess_min, cov_scale=1.0)[source]#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random

  • n_rounds – max number of SMC rounds

  • observable – the observation to condition on

  • n_rounds – number of rounds of SMC

  • n_particles – number of n_particles to draw for each parameter

  • n_simulations_per_theta – number of simulations for each paramrter sample

  • eps_step – decay of initial epsilon per simulation round

  • ess_min – minimal effective sample size

  • cov_scale – scaling of the transition kernel covariance

Returns:

an array of samples from the posterior distribution of dimension (n_samples times p)

SNL+SSNL#

class sbijax.SNL(model_fns, density_estimator)[source]#

Sequential neural likelihood.

Implements both SNL and SSNL estimation methods.

Parameters:
  • model_fns – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.

  • density_estimator – a (neural) conditional density estimator to model the likelihood function

References

fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#

Fit a SNL or SSNL model.

Parameters:
  • rng_key – a jax random key

  • data – data set obtained from calling simulate_data_and_possibly_append

  • optimizer – an optax optimizer object

  • n_iter – maximal number of training iterations per round

  • batch_size – batch size used for training the model

  • percentage_data_as_validation_set – percentage of the simulated data that is used for valitation and early stopping

  • n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation

Returns:

a tuple of parameters and a tuple of the training information

simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Simulate data from the prior or posterior.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • n_chains – number of MCMC chains

  • n_samples – number of sa les to draw in total

  • n_warmup – number of draws to discarded

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

returns a NamedTuple with two elements, y and theta

sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_chains – number of MCMC chains

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics

SNP#

class sbijax.SNP(model_fns, density_estimator, num_atoms=10)[source]#

Sequential neural posterior estimation.

Parameters:
  • model_fns – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.

  • density_estimator – a (neural) conditional density estimator to model the posterior distribution

  • num_atoms – number of atomic atoms

Examples

>>> import distrax
>>> from sbijax import SNP
>>> from sbijax.nn import make_affine_maf
>>>
>>> prior = distrax.Normal(0.0, 1.0)
>>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed)
>>> fns = (prior.sample, prior.log_prob), s
>>> flow = make_affine_maf(1)
>>>
>>> estim = SNP(fns, flow)

References

fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#

Fit an SNP model.

Parameters:
  • rng_key – a jax random key

  • data – data set obtained from calling simulate_data_and_possibly_append

  • optimizer – an optax optimizer object

  • n_iter – maximal number of training iterations per round

  • batch_size – batch size used for training the model

  • percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation

Returns:

a tuple of parameters and a tuple of the training information

sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)[source]#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_samples – number of samples to draw

Returns:

returns an array of samples from the posterior distribution of dimension (n_samples times p)

simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#

Simulate data and paarameters from the prior or posterior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

returns a NamedTuple of two axis, y and theta

SNR#

class sbijax.SNR(model_fns, classifier, num_classes=10, gamma=1.0)[source]#

Sequential (contrastive) neural ratio estimation.

Parameters:
  • model_fns (tuple[tuple[Callable, Callable], Callable]) – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.

  • classifier (Transformed) – a neural network for classification

  • num_classes (int) – number of classes to classify against

  • gamma (float) – relative weight of classes

Examples

>>> import distrax
>>> from sbijax import SNR
>>> from sbijax.nn import make_resnet
>>>
>>> prior = distrax.Normal(0.0, 1.0)
>>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed)
>>> fns = (prior.sample, prior.log_prob), s
>>> resnet = make_resnet()
>>>
>>> snr = SNR(fns, resnet)

References

fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#

Fit an SNR model.

Parameters:
  • rng_key (Array) – a jax random key

  • data (NamedTuple) – data set obtained from calling simulate_data_and_possibly_append

  • optimizer (GradientTransformation) – an optax optimizer object

  • n_iter (int) – maximal number of training iterations per round

  • batch_size (int) – batch size used for training the model

  • percentage_data_as_validation_set (float) – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience (float) – number of iterations of no improvement of training the flow before stopping optimisation

Returns:

a tuple of parameters and a tuple of the training information

simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Simulate data from the prior or posterior.

Simulate new parameters and observables from the prior or posterior (when params and data given). If a data argument is provided, append the new samples to the data set and return the old+new data.

Parameters:
  • rng_key (Array) – a jax random key

  • params (Mapping[str, Mapping[str, Array]] | None) – a dictionary of neural network parameters

  • observable (Array) – an observation

  • data (NamedTuple) – existing data set or None

  • n_simulations (int) – number of newly simulated data

  • n_chains (int) – number of MCMC chains

  • n_samples (int) – number of sa les to draw in total

  • n_warmup (int) – number of draws to discarded

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

returns a NamedTuple of two axis, y and theta

sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_chains – number of MCMC chains

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

returns an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics

SFMPE#

class sbijax.SFMPE(model_fns, density_estimator)[source]#

Sequential flow matching posterior estimation.

Implements a sequential version of the FMPE algorithm introduced in [1]_. For all rounds \(r > 1\) parameter samples \(\theta \sim \hat{p}^r(\theta)\) are drawn from the approximate posterior instead of the prior when computing the flow matching loss. Note that the implementation does not strictly follow the paper.

Parameters:
  • model_fns – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.

  • density_estimator – a continuous normalizing flow model

Examples

>>> import distrax
>>> from sbijax import SFMPE
>>> from sbijax.nn import make_ccnf
>>>
>>> prior = distrax.Normal(0.0, 1.0)
>>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed)
>>> fns = (prior.sample, prior.log_prob), s
>>> flow = make_ccnf(1)
>>>
>>> estim = SFMPE(fns, flow)

References

fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#

Fit the model.

Parameters:
  • rng_key – a jax random key

  • data – data set obtained from calling simulate_data_and_possibly_append

  • optimizer – an optax optimizer object

  • n_iter – maximal number of training iterations per round

  • batch_size – batch size used for training the model

  • percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation

Returns:

a tuple of parameters and a tuple of the training information

sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)[source]#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_samples – number of samples to draw

Returns:

returns an array of samples from the posterior distribution of dimension (n_samples times p)

simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#

Simulate data and paarameters from the prior or posterior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

returns a NamedTuple of two axis, y and theta

SCMPE#

class sbijax.SCMPE(model_fns, network, t_max=50.0, t_min=0.001)[source]#

Sequential consistency model posterior estimation.

Implements a sequential version of the CMPE algorithm introduced in [1]_. For all rounds \(r > 1\) parameter samples \(\theta \sim \hat{p}^r(\theta)\) are drawn from the approximate posterior instead of the prior when computing consistency loss. Note that the implementation does not strictly follow the paper.

Parameters:
  • model_fns – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.

  • network – a neural network

  • t_min – minimal time point for ODE integration

  • t_max – maximal time point for ODE integration

Examples

>>> import distrax
>>> from sbijax import SCMPE
>>> from sbijax.nn import make_consistency_model
>>>
>>> prior = distrax.Normal(0.0, 1.0)
>>> s = lambda seed, theta: distrax.Normal(theta, 1.0).sample(seed=seed)
>>> fns = (prior.sample, prior.log_prob), s
>>> net = make_consistency_model(1)
>>>
>>> estim = SCMPE(fns, net)

References

fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)#

Fit the model.

Parameters:
  • rng_key – a jax random key

  • data – data set obtained from calling simulate_data_and_possibly_append

  • optimizer – an optax optimizer object

  • n_iter – maximal number of training iterations per round

  • batch_size – batch size used for training the model

  • percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation

Returns:

a tuple of parameters and a tuple of the training information

simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#

Simulate data and paarameters from the prior or posterior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

returns a NamedTuple of two axis, y and theta

sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_samples – number of samples to draw

Returns:

returns an array of samples from the posterior distribution of dimension (n_samples times p)

SNASS#

class sbijax.SNASS(model_fns, density_estimator, summary_net)[source]#

Sequential neural approximate summary statistics.

Parameters:
  • model_fns – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.

  • density_estimator – a (neural) conditional density estimator to model the likelihood function of summary statistics, i.e., the modelled dimensionality is that of the summaries

  • snass_net – a SNASSNet object

References

fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)[source]#

Fit the model to data.

Parameters:
  • rng_key – a jax random key

  • data – data set obtained from calling simulate_data_and_possibly_append

  • optimizer – an optax optimizer object

  • n_iter – maximal number of training iterations per round

  • batch_size – batch size used for training the model

  • percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

tuple of parameters and a tuple of the training information

sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_chains – number of MCMC chains

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics

simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)#

Simulate data from the prior or posterior.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • n_chains – number of MCMC chains

  • n_samples – number of sa les to draw in total

  • n_warmup – number of draws to discarded

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

returns a NamedTuple with two elements, y and theta

SNASSS#

class sbijax.SNASSS(model_fns, density_estimator, summary_net)[source]#

Sequential neural approximate slice sufficient statistics.

Parameters:
  • model_fns – a tuple of tuples. The first element is a tuple that consists of functions to sample and evaluate the log-probability of a data point. The second element is a simulator function.

  • density_estimator – a (neural) conditional density estimator to model the likelihood function of summary statistics, i.e., the modelled dimensionality is that of the summaries

  • summary_net – a SNASSSNet object

References

fit(rng_key, data, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs)#

Fit the model to data.

Parameters:
  • rng_key – a jax random key

  • data – data set obtained from calling simulate_data_and_possibly_append

  • optimizer – an optax optimizer object

  • n_iter – maximal number of training iterations per round

  • batch_size – batch size used for training the model

  • percentage_data_as_validation_set – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience – number of iterations of no improvement of training the flow before stopping optimisation

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

tuple of parameters and a tuple of the training information

simulate_data_and_possibly_append(rng_key, params=None, observable=None, data=None, n_simulations=1000, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)#

Simulate data from the prior or posterior.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • n_chains – number of MCMC chains

  • n_samples – number of sa les to draw in total

  • n_warmup – number of draws to discarded

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

returns a NamedTuple with two elements, y and theta

sample_posterior(rng_key, params, observable, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_chains – number of MCMC chains

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Keyword Arguments:
  • sampler (str) – either ‘nuts’, ‘slice’ or None (defaults to nuts)

  • n_thin (int) – number of thinning steps (only used if sampler=’slice’)

  • n_doubling (int) – number of doubling steps of the interval (only used if sampler=’slice’)

  • step_size (float) – step size of the initial interval (only used if sampler=’slice’)

Returns:

an array of samples from the posterior distribution of dimension (n_samples times p) and posterior diagnostics