Source code for sbijax._src.abc.smc_abc

from collections import namedtuple

import chex
import jax
from blackjax.smc import resampling
from blackjax.smc.ess import ess
from jax import numpy as jnp
from jax import random as jr
from jax import scipy as jsp
from jax._src.flatten_util import ravel_pytree
from jax.tree_util import tree_map
from tensorflow_probability.substrates.jax import distributions as tfd
from tqdm import tqdm

from sbijax._src._sbi_base import SBI
from sbijax._src.util.data import _tree_stack, as_inference_data


# ruff: noqa: PLR0913, E501
[docs] class SMCABC(SBI): r"""Sequential Monte Carlo approximate Bayesian computation. Implements the algorithm from :cite:t:`beaumont2009adaptive`. Args: model_fns: a tuple of callables. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function. summary_fn: summary function distance_fn: distance function Examples: >>> from sbijax import SMCABC >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(0.0, 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed) >>> fns = prior, s >>> summary_fn = lambda x: x >>> distance_fn = lambda x, y: jax.vmap(lambda z: jnp.linalg.norm(z))(x - y) >>> model = SMCABC(fns, summary_fn, distance_fn) References: Beaumont, Mark A, et al. "Adaptive approximate Bayesian computation". Biometrika, 2009. """ def __init__(self, model_fns, summary_fn, distance_fn): super().__init__(model_fns) self.summary_fn = summary_fn self.distance_fn = distance_fn self.summarized_observed: chex.Array self.n_total_simulations = 0
[docs] def sample_posterior( self, rng_key, observable, n_rounds=10, n_particles=10_000, eps_step=0.825, ess_min=2_000, cov_scale=1.0, ): r"""Sample from the approximate posterior. Args: rng_key: a jax random n_rounds: max number of SMC rounds observable: the observation to condition on n_rounds: number of rounds of SMC n_particles: number of n_particles to draw for each parameter eps_step: decay of initial epsilon per simulation round ess_min: minimal effective sample size cov_scale: scaling of the transition kernel covariance Returns: an array of samples from the posterior distribution of dimension (n_samples \times p) """ observable = jnp.atleast_2d(observable) init_key, rng_key = jr.split(rng_key) particles, log_weights, epsilon = self._init_particles( init_key, observable, n_particles ) all_particles, all_n_simulations = [], [] for n in tqdm(range(n_rounds)): epsilon *= eps_step rng_key = jr.fold_in(rng_key, n) particle_key, rng_key = jr.split(rng_key) particles, log_weights = self._move( particle_key, observable, n_particles, particles, log_weights, epsilon, cov_scale, ) curr_ess = ess(log_weights) if curr_ess < ess_min: resample_key, rng_key = jr.split(rng_key) particles[list(particles.keys())[0]] particles, log_weights = self._resample( resample_key, particles, log_weights, particles[list(particles.keys())[0]].shape[0], ) all_particles.append(particles.copy()) all_n_simulations.append(self.n_total_simulations) thetas = jax.tree_util.tree_map(lambda x: x.reshape(1, *x.shape), particles) inference_data = as_inference_data(thetas, jnp.squeeze(observable)) smc_info = namedtuple("smc_info", "particles n_simulations") return inference_data, smc_info(all_particles, all_n_simulations)
def _chol_factor(self, particles, cov_scale): particles = jax.vmap(lambda x: ravel_pytree(x)[0])(particles) chol = jnp.linalg.cholesky(jnp.cov(particles.T) * cov_scale) return chol def _init_particles(self, rng_key, observable, n_particles): self.n_total_simulations += n_particles init_key, rng_key = jr.split(rng_key) particles = self.prior.sample(seed=init_key, sample_shape=(n_particles,)) simulator_key, rng_key = jr.split(rng_key) ys = self.simulator_fn(seed=simulator_key, theta=particles) summary_statistics = self.summary_fn(ys) distances = self.distance_fn( summary_statistics, self.summary_fn(observable) ) sort_idx = jnp.argsort(distances) particles = jax.tree_util.tree_map( lambda x: x[sort_idx][:n_particles], particles ) log_weights = -jnp.log(jnp.full(n_particles, n_particles)) initial_epsilon = distances[-1] return particles, log_weights, initial_epsilon def _sample_candidates( self, rng_key, particles, log_weights, n, cov_chol_factor ): n_sim = jnp.maximum(jnp.minimum(n, 1000), 100) self.n_total_simulations += n_sim sample_key, perturb_key, rng_key = jr.split(rng_key, 3) new_candidate_particles, _ = self._resample( sample_key, particles, log_weights, n_sim ) new_candidate_particles = self._perturb( perturb_key, new_candidate_particles, cov_chol_factor ) cand_lps = self.prior.log_prob(new_candidate_particles) is_finite = jnp.logical_not(jnp.isinf(cand_lps)) new_candidate_particles = tree_map( lambda x: x[is_finite], new_candidate_particles ) return new_candidate_particles def _simulate_and_distance( self, rng_key, observable, new_candidate_particles ): ys = self.simulator_fn( seed=rng_key, theta=new_candidate_particles, ) summary_statistics = self.summary_fn(ys) ds = self.distance_fn(summary_statistics, self.summary_fn(observable)) return ds # pylint: disable=too-many-arguments def _move( self, rng_key, observable, n_particles, particles, log_weights, epsilon, cov_scale, ): new_particles = None cov_chol_factor = self._chol_factor(particles, cov_scale) n = n_particles while n > 0: sample_key, simulate_key, rng_key = jr.split(rng_key, 3) new_candidate_particles = self._sample_candidates( sample_key, particles, log_weights, n, cov_chol_factor ) ds = self._simulate_and_distance( simulate_key, observable, new_candidate_particles, ) idxs = jnp.where(ds < epsilon)[0] new_candidate_particles = tree_map( lambda x: x[idxs], new_candidate_particles ) if new_particles is None: new_particles = new_candidate_particles else: new_particles = _tree_stack([new_particles, new_candidate_particles]) n -= len(idxs) new_particles = tree_map(lambda x: x[:n_particles], new_particles) new_log_weights = self._new_log_weights( new_particles, particles, log_weights, cov_chol_factor ) return new_particles, new_log_weights def _resample(self, rng_key, particles, log_weights, n_samples): idxs = resampling.multinomial(rng_key, jnp.exp(log_weights), n_samples) particles = tree_map(lambda x: x[idxs], particles) return particles, -jnp.log(jnp.full(n_samples, n_samples)) def _new_log_weights( self, new_particles, old_particles, old_log_weights, cov_chol_factor ): prior_log_density = self.prior.log_prob(new_particles) K = self._kernel(old_particles, cov_chol_factor) def _particle_weight(partcl): probs = old_log_weights + K.log_prob(partcl) weight = jsp.special.logsumexp(probs) return weight new_particles = jax.vmap(lambda x: ravel_pytree(x)[0])(new_particles) new_particles = new_particles[:, None, :] log_weighted_sum = jax.vmap(_particle_weight)(new_particles) new_log_weights = prior_log_density - log_weighted_sum new_log_weights -= jsp.special.logsumexp(new_log_weights) return new_log_weights def _kernel(self, mus, cov_chol_factor): mus = jax.vmap(lambda x: ravel_pytree(x)[0])(mus) return tfd.MultivariateNormalTriL(loc=mus, scale_tril=cov_chol_factor) def _perturb(self, rng_key, mus, cov_chol_factor): _, unravel_fn = ravel_pytree(self.prior.sample(seed=jr.PRNGKey(0))) samples = self._kernel(mus, cov_chol_factor).sample(seed=rng_key) samples = jax.vmap(unravel_fn)(samples) return samples