Source code for sbijax._src.experimental.aio

import jax
from jax import numpy as jnp
from jax import random as jr
from jax._src.flatten_util import ravel_pytree

from sbijax import FMPE, as_inference_data, inference_data_as_dictionary


# ruff: noqa: PLR0913, E501
[docs] class AiO(FMPE): """All-in-one simulation-based inference. Implements all-on-one posterior estimation as introduced :cite:t:`gloeckler2024allinone`. In comparison to the original paper, this implementation (so far) only infers the posterior distribution of all latent variables, so no marginals or other conditional distributions. As a consequence, when training the model, we use the same mask for all latent/conditioning variables, and don't sample it every step. Hence, this implementation is basically the same as NPSE only that we use a transformer as score network and a mask to encode the conditional dependencies. Args: model_fns: a tuple of callables. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function. score_estimator: a score estimator Examples: >>> from sbijax.experimental import AiO >>> from sbijax.experimental.nn import make_simformer_based_score_model >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(2), 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed) >>> fns = prior, s >>> neural_network = make_simformer_based_score_model(2, jnp.eye(4)) >>> model = AiO(fns, neural_network) References: Gloeckler, Manuel, et al. "All-in-one simulation-based inference." International Conference on Machine Learning, 2024. """ def _simulate_parameters_with_model( self, rng_key, params, observable, *, n_samples=4_000, **kwargs ): prior_key, rng_key = jr.split(rng_key) prior_fn = self.get_truncated_prior( prior_key, params, observable, n_samples=int(1e5) ) return prior_fn(rng_key, n_samples) def __init__(self, model_fns, density_estimator): super().__init__(model_fns, density_estimator) def _init_params(self, rng_key, **init_data): params = self.model.init( rng_key, method="loss", inputs=init_data["theta"], context=init_data["y"], is_training=False, ) return params def get_truncated_prior(self, rng_key, params, observable, n_samples): samp = self.prior_sampler_fn(seed=jr.PRNGKey(0), sample_shape=()) _, unravel_fn = ravel_pytree(samp) sample_key, rng_key = jr.split(rng_key) inf_data, _ = self.sample_posterior( sample_key, params, observable, n_samples=n_samples ) posterior_samples = inference_data_as_dictionary(inf_data.posterior) lp_key, rng_key = jr.split(rng_key) flat_posterior_samples = jax.vmap(lambda x: ravel_pytree(x)[0])( posterior_samples ) log_probs = self.model.apply( params, rng=lp_key, method="log_prob", inputs=flat_posterior_samples, context=jnp.tile(observable, [flat_posterior_samples.shape[0], 1]), is_training=False, ) trunc_boundary = jnp.quantile(log_probs, 5e-4) min_posterior, max_posterior = ( jax.tree.map(lambda x: x.min(axis=0), posterior_samples), jax.tree.map(lambda x: x.max(axis=0), posterior_samples), ) sample_key, rng_key = jr.split(rng_key) prior_samples = self.prior_sampler_fn( seed=sample_key, sample_shape=(int(1e6),) ) min_prior, max_prior = ( jax.tree.map(lambda x: x.min(axis=0), prior_samples), jax.tree.map(lambda x: x.max(axis=0), prior_samples), ) hypercube_min = jax.tree.map( lambda po, pr: jnp.concatenate( [po[None, ...], pr[None, ...]], axis=0 ).max(axis=0), min_posterior, min_prior, ) hypercube_max = jax.tree.map( lambda po, pr: jnp.concatenate( [po[None, ...], pr[None, ...]], axis=0 ).min(axis=0), max_posterior, max_prior, ) def hypercube_uniform_prior(rng_key, n_samples): return jr.uniform( rng_key, ( n_samples, flat_posterior_samples.shape[-1], ), minval=jnp.concatenate(jax.tree.leaves(hypercube_min)), maxval=jnp.concatenate(jax.tree.leaves(hypercube_max)), ) def truncated_prior_fn(rng_key, n_samples, n_iter=1_000): cnt = n_curr = 0 samples_out = [] while n_curr < n_samples and cnt < n_iter: sample_key, lp_key, rng_key = jr.split(rng_key, 3) samples = hypercube_uniform_prior(sample_key, n_samples) log_probs = self.model.apply( params, rng=lp_key, method="log_prob", inputs=samples, context=jnp.tile(observable, [samples.shape[0], 1]), is_training=False, ) accepted_samples = samples[log_probs > trunc_boundary] samples_out.append(accepted_samples) n_curr += len(accepted_samples) cnt += 1 if cnt == n_iter: raise ValueError("truncated sampling did not converge") thetas = jnp.concatenate(samples_out, axis=0)[:n_samples] def reshape(p): if p.ndim == 1: p = p.reshape(p.shape[0], 1) p = p.reshape(1, *p.shape) return p ess = n_curr / (cnt * n_samples) thetas = jax.tree_util.tree_map( reshape, jax.vmap(unravel_fn)(thetas[:n_samples]) ) inference_data = as_inference_data(thetas, jnp.squeeze(observable)) return inference_data, ess return truncated_prior_fn