Source code for sbijax._src.diagnostics.sbc
"""Simulation-based calibration (SBC).
SBC checks posterior calibration: if a posterior approximation is calibrated,
then for parameters drawn from the prior the rank of each ground-truth
parameter among posterior draws is uniformly distributed. Systematic departures
from uniformity reveal miscalibration (over- or under-dispersed posteriors,
bias). This is the calibration tier of the correctness harness (DR-009).
"""
# ruff: noqa: PLR0913
import jax
from jax import numpy as jnp
from jax import random as jr
from jax._src.flatten_util import ravel_pytree
from sbijax._src.train.sample import sample
from sbijax._src.util.data import flatten_chains
[docs]
def sbc(
rng_key,
objective,
params,
prior,
simulator,
*,
sampler=None,
n_simulations=100,
n_posterior_samples=1_000,
**sample_kwargs,
):
"""Compute simulation-based calibration ranks for a fitted objective.
For each of ``n_simulations`` draws ``theta* ~ prior`` and
``y* ~ simulator(theta*)``, draws ``n_posterior_samples`` from the posterior
given ``y*`` and ranks each dimension of ``theta*`` among them. Calibrated
posteriors yield ranks uniform on ``[0, n_posterior_samples]``.
Args:
rng_key: a jax random key
objective: an ``ObjectiveFns`` returned by a factory such as
:func:`~sbijax.npe`
params: the fitted parameters
prior: the prior distribution
simulator: a callable ``(rng_key, theta) -> y``
sampler: a sampler from
:func:`~sbijax.mcmc.make_sampler`; required for
MCMC methods, ignored by amortized methods
n_simulations: number of calibration draws
n_posterior_samples: posterior draws per calibration draw
**sample_kwargs: forwarded to ``sample``
Returns:
an integer array of shape ``(n_simulations, n_dims)`` of ranks
"""
def rank_one(key):
theta_key, y_key, post_key = jr.split(key, 3)
theta_true = prior.sample(seed=theta_key, sample_shape=(1,))
y = simulator(y_key, theta_true)
samples, _ = sample(
post_key,
objective,
params,
y[0],
sampler=sampler,
n_samples=n_posterior_samples,
**sample_kwargs,
)
posterior = jax.vmap(lambda x: ravel_pytree(x)[0])(flatten_chains(samples))
theta_flat, _ = ravel_pytree(jax.tree.map(lambda a: a[0], theta_true))
return jnp.sum(posterior < theta_flat, axis=0)
ranks = [rank_one(k) for k in jr.split(rng_key, n_simulations)]
return jnp.stack(ranks)