sbijax#
The top-level module, sbijax, contains all implemented methods for neural
simulation-based inference and approximate Bayesian inference as well as
diagnostics and other utility.
Every method is a factory function that takes only the network and returns a
record of pure functions, following the low-level functional idiom of dm-haiku
and blackjax. Training and sampling are free driver functions (train(),
sample()); the optimizer is injected at train and, for likelihood/ratio
methods, the sampler (which carries the prior) at sample:
est = nle(make_maf(2))
params, info = train(key, est, data, optimizer=optax.adam(3e-4))
samples, info = sample(
key, est, params, y_observed, sampler=make_sampler(nuts, prior=prior)
)
See Design philosophy for the full design and Migration guide: 0.3 → 0.4 for moving from the class-based API.
|
Construct a neural posterior estimator. |
|
Construct a flow-matching posterior objective. |
|
Construct a neural posterior score estimator. |
|
Construct a neural likelihood objective. |
|
Construct a surjective neural likelihood objective. |
|
Construct a neural ratio objective. |
|
Construct a simulated annealing ABC sampler. |
|
Construct a sequential Monte Carlo ABC sampler. |
|
Construct a neural approximate sufficient statistics summary network. |
|
Construct a neural approximate slice sufficient statistics summary network. |
|
Adapt an objective to operate on learned summaries. |
|
Train any objective's |
|
Draw posterior samples from a trained objective. |
|
Run multi-round sequential inference. |
|
Draw a dataset of parameters and observations. |
|
Append two datasets along the sample axis. |
|
Compute simulation-based calibration ranks for a fitted objective. |
Data pipeline#
- sbijax.simulate(rng_key, prior, simulator, *, proposal=None, n)[source]#
Draw a dataset of parameters and observations.
Parameters are drawn from
proposalif given, otherwise fromprior, and observations are produced by runningsimulatoron them.- Parameters:
rng_key – a jax random key
prior – a distribution to draw parameters from when
proposalis Nonesimulator – a callable
(rng_key, theta) -> yproposal – an optional callable
(rng_key, n) -> thetaused to draw parameters instead of the prior (e.g. a fitted posterior in a sequential round)n – the number of parameter/observation pairs to draw
- Returns:
a dictionary with keys
yandtheta
- sbijax.stack(data, new_data)[source]#
Append two datasets along the sample axis.
- Parameters:
data – a dataset as returned by
simulate()new_data – a second dataset with the same structure
- Returns:
a dataset whose leaves are the concatenation of both inputs
Posterior estimation#
- sbijax.npe(network, *, num_atoms=10)[source]#
Construct a neural posterior estimator.
In round 0 use the returned
ObjectiveFnsdirectly for amortized maximum-likelihood training. For round > 0 (i.e. when training on data simulated from a fitted posterior rather than the prior), callobj.extra(prior)to obtain the atomic proposal-posterior objective of Greenberg et al. [2019].- Parameters:
network – a conditional density estimator with
log_probandsamplemethods (e.g. frommake_maf())num_atoms – the number of atoms in the contrastive proposal-posterior loss used by
extra(prior)
- Returns:
an
ObjectiveFns; itsextrafield is a callable(prior) -> ObjectiveFnsfor sequential rounds
- sbijax.fmpe(network)[source]#
Construct a flow-matching posterior objective.
- Parameters:
network – a continuous normalizing flow with
lossandsamplemethods- Returns:
an
ObjectiveFns
- sbijax.npse(network)[source]#
Construct a neural posterior score estimator.
- Parameters:
network – a score network with
loss,sampleandlog_probmethods (e.g.sbijax.experimental.nn.make_score_model())- Returns:
an
ObjectiveFns
Likelihood estimation#
Likelihood-ratio estimation#
Approximate Bayesian computation#
- sbijax.sabc(prior, simulator, *, summary_fn=<function <lambda>>, distance_fn=<function abs_distance>)[source]#
Construct a simulated annealing ABC sampler.
- Parameters:
prior – a
tfddistribution serving as the prior over parameterssimulator – a callable
(rng_key, theta) -> ysummary_fn – maps simulated data to summary statistics
distance_fn – distance between simulated and observed summaries
- Returns:
an
ABCSampler
- sbijax.smcabc(prior, simulator, summary_fn, distance_fn)[source]#
Construct a sequential Monte Carlo ABC sampler.
- Parameters:
prior – a
tfddistribution serving as the prior over parameterssimulator – a callable
(rng_key, theta) -> ysummary_fn – maps simulated data to summary statistics
distance_fn – distance between simulated and observed summaries
- Returns:
an
ABCSampler
Summary statistics#
- sbijax.nass(network)[source]#
Construct a neural approximate sufficient statistics summary network.
- Parameters:
network – a NASS summary network with
forward,summaryandcriticmethods- Returns:
a
SummaryFns
- sbijax.nasss(network)[source]#
Construct a neural approximate slice sufficient statistics summary network.
- Parameters:
network – a NASSS summary network with
forward,summary,secondary_summaryandcriticmethods- Returns:
a
SummaryFns
A summary network is chained into a downstream estimator with:
- sbijax.summarized_estimator(estimator, summary_net, summary_params)[source]#
Adapt an objective to operate on learned summaries.
Given a pre-fitted summary network, returns an
ObjectiveFnswhose training summarizes each batch before delegating to the wrapped estimator and whosesample_fnsummarizes the observation before sampling. Because it returns anObjectiveFnsrecord it stays conformant and is driven by the generictrain()/ sampling helpers.Fit the summary network first, then wrap the estimator:
sn = nass(make_nass_net(2, [64, 64])) sn_params, _ = train(key, sn, data) est = summarized_estimator(nle(make_maf(2)), sn, sn_params) params, info = train(key, est, data) # trains on summaries samples, _ = est.sample_fn(key, params, y_observed, sampler=sampler)
- Parameters:
estimator – the downstream
ObjectiveFnsconsuming the summariessummary_net – a fitted
SummaryFnssummary_params – the summary network’s fitted parameters
- Returns:
an
ObjectiveFns
Training and sampling#
Trainable objectives are trained and sampled with the two free drivers. The
sampler for likelihood/ratio methods is built with
sbijax.mcmc.make_sampler().
- sbijax.train(rng_key, objective, data, *, optimizer=None, info=None, n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, n_early_stopping_delta=0.001)[source]#
Train any objective’s
TrainFnswith early stopping.- Parameters:
rng_key – a jax random key
objective – an
ObjectiveFns/SummaryFns(anything withtrain)data – a
{"y", "theta"}dataset pytreeoptimizer – an optax optimizer, bound into the primitives here
info – previous round’s
Info(Nonefor round 0)n_iter – number of epochs
batch_size – minibatch size
percentage_data_as_validation_set – validation split fraction
n_early_stopping_patience – early-stopping patience
n_early_stopping_delta – minimum early-stopping improvement
- Returns:
a tuple
(params, Info)
- sbijax.sample(rng_key, objective, params, observable, *, sampler=None, **kwargs)[source]#
Draw posterior samples from a trained objective.
A one-line dispatch to
objective.sample_fnkept for symmetry withsbijax.train().- Parameters:
rng_key – a jax random key
objective – an
ObjectiveFnsparams – the trained parameters
observable – the observation to condition on
sampler – a sampler from
make_sampler()(required for MCMC methods, ignored by amortized methods)**kwargs – forwarded to
sample_fn
- Returns:
(samples, info)
Sequential inference#
- sbijax.run_sequential(rng_key, objective, prior, simulator, observable, *, n_rounds, n_simulations_per_round, sampler=None, proposal_fn=None, **train_kwargs)[source]#
Run multi-round sequential inference.
Round 0 simulates from the prior; later rounds simulate from a proposal built from the current posterior, append, and refit. NPE switches to its atomic objective (
objective.extra(prior)) in rounds > 0; proposal-invariant methods reuseobjective.- Parameters:
rng_key – a jax random key
objective – an
ObjectiveFnsprior – the prior distribution
simulator –
(rng_key, theta) -> yobservable – the observation to condition on
n_rounds – number of simulate/append/refit rounds
n_simulations_per_round – pairs drawn each round
sampler – a sampler (from
make_sampler) for MCMC-based objectivesproposal_fn – optional
(objective, params, observable, sampler) -> ((rng, n) -> theta); defaults to sampling the fitted posterior**train_kwargs – forwarded to
traineach round
- Returns:
(params, Info)from the final round
Diagnostics#
- sbijax.sbc(rng_key, objective, params, prior, simulator, *, sampler=None, n_simulations=100, n_posterior_samples=1000, **sample_kwargs)[source]#
Compute simulation-based calibration ranks for a fitted objective.
For each of
n_simulationsdrawstheta* ~ priorandy* ~ simulator(theta*), drawsn_posterior_samplesfrom the posterior giveny*and ranks each dimension oftheta*among them. Calibrated posteriors yield ranks uniform on[0, n_posterior_samples].- Parameters:
rng_key – a jax random key
objective – an
ObjectiveFnsreturned by a factory such asnpe()params – the fitted parameters
prior – the prior distribution
simulator – a callable
(rng_key, theta) -> ysampler – a sampler from
make_sampler(); required for MCMC methods, ignored by amortized methodsn_simulations – number of calibration draws
n_posterior_samples – posterior draws per calibration draw
**sample_kwargs – forwarded to
sample
- Returns:
an integer array of shape
(n_simulations, n_dims)of ranks