Source code for sbijax._src.inference.posterior.npe

"""Neural posterior estimation.

Implements the NPE objective of :cite:t:`greenberg2019automatic` as a
functional estimator. The network models the posterior directly in the
parameter space; no event-space bijectors are applied.

In round 0 the network is trained by maximum likelihood (amortized).
In later rounds (driven by :func:`sbijax.run_sequential`), call
``obj.extra(prior)`` to obtain the atomic proposal-posterior objective
(:cite:t:`greenberg2019automatic`) which corrects for the proposal no
longer being the prior.
"""

# ruff: noqa: PLR0913
from functools import partial

import jax
import jax.scipy as jsp
import optax
from jax import numpy as jnp
from jax import random as jr
from jax._src.flatten_util import ravel_pytree

from sbijax._src.inference._sample_info import DirectSampleInfo
from sbijax._src.train._types import ObjectiveFns, TrainFns, TrainingState


def _maximum_likelihood_loss(params, _rng, network, **batch):
  """Round-0 loss: maximum likelihood against draws from the prior.

  Args:
      params: the network parameter pytree
      _rng: unused rng key (kept for a uniform ``(params, rng, **batch)``
          signature)
      network: a conditional density estimator with a ``log_prob`` method
      **batch: must contain ``theta`` (flat ``(n, d)`` array) and ``y``
          (observations ``(n, obs_dim)``)

  Returns:
      a scalar loss
  """
  lp = network.apply(
    params, None, method="log_prob", y=batch["theta"], x=batch["y"]
  )
  return -jnp.mean(lp)


def _atomic_loss(params, rng, network, prior, num_atoms, **batch):
  """Round->0 atomic proposal-posterior loss of NPE-C / APT.

  Corrects for the proposal no longer being the prior by contrasting the
  true parameter against ``num_atoms - 1`` others drawn from the batch,
  reweighted by the prior.  Needs only the network, the prior and
  ``num_atoms`` — no proposal density (:cite:t:`greenberg2019automatic`).

  Args:
      params: the network parameter pytree
      rng: a jax random key
      network: a conditional density estimator with a ``log_prob`` method
      prior: a tfd distribution; only ``log_prob`` is called
      num_atoms: number of atoms in the contrastive loss
      **batch: must contain ``theta`` (flat ``(n, d)`` array) and ``y``
          (observations ``(n, obs_dim)``)

  Returns:
      a scalar loss
  """
  theta, y = batch["theta"], batch["y"]
  n = theta.shape[0]
  m = min(num_atoms, n)

  # Unravel flat theta to the prior's pytree structure for log_prob.
  _, unravel_fn = ravel_pytree(prior.sample(seed=jr.key(0)))

  # For each row draw m-1 contrasting rows without replacement (exclude self).
  probs = jnp.ones((n, n)) * (1.0 - jnp.eye(n)) / (n - 1.0)
  choices = jax.vmap(
    lambda key, p: jr.choice(key, n, (m - 1,), replace=False, p=p)
  )(jr.split(rng, n), probs)
  idx = jnp.concatenate([jnp.arange(n)[:, None], choices], axis=1)

  lp_net = network.apply(
    params,
    None,
    method="log_prob",
    y=theta[idx].reshape(n * m, -1),
    x=jnp.repeat(y, m, axis=0),
  ).reshape(n, m)

  lp_prior = prior.log_prob(
    jax.vmap(unravel_fn)(theta[idx].reshape(n * m, -1))
  ).reshape(n, m)

  # Importance-reweighted contrast; the true theta sits at atom index 0.
  unnormalized = lp_net - lp_prior
  log_prob = unnormalized[:, 0] - jsp.special.logsumexp(unnormalized, axis=-1)
  return -jnp.mean(log_prob)


[docs] def npe(network, *, num_atoms=10): """Construct a neural posterior estimator. In round 0 use the returned ``ObjectiveFns`` directly for amortized maximum-likelihood training. For round > 0 (i.e. when training on data simulated from a fitted posterior rather than the prior), call ``obj.extra(prior)`` to obtain the atomic proposal-posterior objective of :cite:t:`greenberg2019automatic`. Args: network: a conditional density estimator with ``log_prob`` and ``sample`` methods (e.g. from :func:`~sbijax.nn.make_maf`) num_atoms: the number of atoms in the contrastive proposal-posterior loss used by ``extra(prior)`` Returns: an ``ObjectiveFns``; its ``extra`` field is a callable ``(prior) -> ObjectiveFns`` for sequential rounds """ def _objective(loss_fn, extra): def init_fn(optimizer, rng_key, batch): """Initialise network params and optimizer state from a sample batch. Args: optimizer: an optax optimizer rng_key: a jax random key batch: a ``{"theta", "y"}`` batch dict Returns: a ``TrainingState`` """ params = network.init( rng_key, method="log_prob", y=batch["theta"], x=batch["y"] ) return TrainingState(params=params, opt_state=optimizer.init(params)) def step_fn(optimizer, rng_key, state, batch): """Apply one gradient update. Args: optimizer: an optax optimizer rng_key: a jax random key state: the current ``TrainingState`` batch: a ``{"theta", "y"}`` batch dict Returns: a tuple ``({"loss": scalar}, new_state)`` """ loss, grads = jax.value_and_grad(loss_fn)(state.params, rng_key, **batch) updates, opt_state = optimizer.update( grads, state.opt_state, state.params ) return {"loss": loss}, TrainingState( optax.apply_updates(state.params, updates), opt_state ) def eval_fn(rng_key, state, batch): """Evaluate the loss without updating parameters. Args: rng_key: a jax random key state: the current ``TrainingState`` batch: a ``{"theta", "y"}`` batch dict Returns: ``{"loss": scalar}`` """ return {"loss": loss_fn(state.params, rng_key, **batch)} def sample_fn(rng_key, params, observable, *, n_samples=4_000, **kwargs): """Draw posterior samples from the trained flow. Args: rng_key: a jax random key params: the trained network parameters observable: a 1-D (or 2-D with one row) observation array sampler: unused; present for API symmetry n_samples: number of posterior draws to return **kwargs: ignored Returns: a tuple ``(samples, DirectSampleInfo)`` where ``samples`` is a named pytree with each leaf shaped ``(1, n_samples, dim)`` """ observable = jnp.atleast_2d(observable) thetas = network.apply( params, rng_key, method="sample", sample_shape=(n_samples,), x=jnp.tile(observable, [n_samples, 1]), ) def reshape(p): if p.ndim == 1: p = p.reshape(p.shape[0], 1) return p.reshape(1, *p.shape) thetas = jax.tree_util.tree_map(reshape, {"theta": thetas}) return thetas, DirectSampleInfo(n_samples=n_samples) return ObjectiveFns(TrainFns(init_fn, step_fn, eval_fn), sample_fn, extra) ml = partial(_maximum_likelihood_loss, network=network) atomic = lambda prior: _objective( # noqa: E731 partial(_atomic_loss, network=network, prior=prior, num_atoms=num_atoms), extra=None, ) return _objective(ml, extra=atomic)