Source code for sbijax._src.experimental._truncated

"""Truncated-prior proposal for sequential score-based posterior estimation.

Ports the truncated-prior sampling of NPSE/AiO (:cite:t:`sharrock2024sequential`,
:cite:t:`gloeckler2024allinone`) to a standalone proposal compatible with
:func:`~sbijax.run_sequential`'s ``proposal_fn`` hook. A trained posterior
defines a truncation boundary (a low quantile of posterior log-densities) and a
bounding hypercube; the proposal draws uniformly in the hypercube and keeps
draws whose posterior log-density clears the boundary. This resolves the
truncated-proposal open question in the backlog: truncation is a driver option,
not a separate driver.
"""

# ruff: noqa: PLR0913
import jax
from jax import numpy as jnp
from jax import random as jr
from jax._src.flatten_util import ravel_pytree

from sbijax._src.train.sample import sample
from sbijax._src.util.data import flatten_chains


[docs] def make_truncated_proposal( prior, network, *, quantile=5e-4, n_calibration=100_000, n_prior=1_000_000, max_iter=1_000, ): """Build a truncated-prior ``proposal_fn`` for :func:`~sbijax.run_sequential`. Args: prior: the prior distribution network: the score network the estimator wraps (exposes ``log_prob``) quantile: lower-tail quantile of posterior log-densities taken as the truncation boundary n_calibration: number of posterior draws used to calibrate the boundary and the bounding hypercube n_prior: number of prior draws used to bound the hypercube max_iter: maximum rejection rounds before giving up Returns: a ``proposal_fn(objective, params, observable, sampler)`` returning a callable ``(rng_key, n) -> theta`` that draws parameters from the truncated prior, in the pytree structure the prior and simulator use """ _, unravel_fn = ravel_pytree(prior.sample(seed=jr.key(0))) def proposal_fn(objective, params, observable, sampler=None): def log_prob(rng_key, theta_flat): return network.apply( params, rng=rng_key, method="log_prob", inputs=theta_flat, context=jnp.tile(observable, [theta_flat.shape[0], 1]), is_training=False, ) def proposal(rng_key, n): calib_key, bound_key, rng_key = jr.split(rng_key, 3) samples, _ = sample( calib_key, objective, params, observable, sampler=sampler, n_samples=n_calibration, ) flat_posterior = jax.vmap(lambda x: ravel_pytree(x)[0])( flatten_chains(samples) ) lp_key, rng_key = jr.split(rng_key) boundary = jnp.quantile(log_prob(lp_key, flat_posterior), quantile) flat_prior = jax.vmap(lambda x: ravel_pytree(x)[0])( prior.sample(seed=bound_key, sample_shape=(n_prior,)) ) lo = jnp.maximum(flat_posterior.min(0), flat_prior.min(0)) hi = jnp.minimum(flat_posterior.max(0), flat_prior.max(0)) accepted, n_curr, it = [], 0, 0 while n_curr < n and it < max_iter: u_key, lp_key, rng_key = jr.split(rng_key, 3) cand = jr.uniform(u_key, (n, lo.shape[0]), minval=lo, maxval=hi) keep = cand[log_prob(lp_key, cand) > boundary] accepted.append(keep) n_curr += keep.shape[0] it += 1 if it == max_iter: raise ValueError("truncated proposal did not converge") thetas = jnp.concatenate(accepted, axis=0)[:n] return jax.vmap(unravel_fn)(thetas) return proposal return proposal_fn