Source code for sbijax._src.mcmc.sampler

"""Posterior samplers bundling an MCMC kernel, a prior, and Gaussian init."""

from collections.abc import Callable
from typing import NamedTuple

import jax
from jax import numpy as jnp
from jax import random as jr

from sbijax._src.mcmc.util import run_blackjax


class Kernel(NamedTuple):
  """Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).

  Wraps the algorithm's initializer so :func:`make_sampler` can build the
  concrete kernel and initial chain states at sampling time. The single field
  ``init_fn`` has signature
  ``(rng_key, initial_positions, lp) -> (initial_states, kernel)``.
  """

  init_fn: Callable


def _gaussian_init(prior, rng_key, n_chains):
  """``N(0, 1)`` initial positions with the prior's pytree structure."""
  template = prior.sample(seed=jr.key(0))
  leaves, treedef = jax.tree_util.tree_flatten(template)
  keys = jr.split(rng_key, len(leaves))
  drawn = [
    jr.normal(k, (n_chains, *jnp.shape(leaf)))
    for k, leaf in zip(keys, leaves, strict=True)
  ]
  return jax.tree_util.tree_unflatten(treedef, drawn)


[docs] def make_sampler(kernel, *, prior, **kernel_kwargs): """Build a posterior sampler from an MCMC ``kernel`` and a ``prior``. Args: kernel: a ``Kernel`` handle (e.g. ``sbijax.mcmc.nuts``) prior: the prior; used for the target density and Gaussian chain init **kernel_kwargs: forwarded to the kernel Returns: a callable ``(rng_key, loglik_fn, *, n_chains, n_samples, n_warmup) -> (samples, MCMCSampleInfo)`` where the target is ``loglik_fn(theta) + prior.log_prob(theta)`` and chains start at ``N(0, I)``. """ def sampler( rng_key, loglik_fn, *, n_chains=4, n_samples=2_000, n_warmup=1_000 ): def logdensity(theta): return jnp.sum(loglik_fn(theta)) + jnp.sum(prior.log_prob(theta)) init_key, sample_key = jr.split(rng_key) positions = _gaussian_init(prior, init_key, n_chains) return run_blackjax( sample_key, kernel.init_fn, positions, logdensity, n_chains=n_chains, n_samples=n_samples, n_warmup=n_warmup, ) return sampler