Source code for sbijax._src.nre

# Parts of this codebase have been adopted from https://github.com/bkmi/cnre
from collections.abc import Callable
from functools import partial
from typing import NamedTuple

import chex
import jax
import numpy as np
import optax
from absl import logging
from haiku import Params, Transformed
from jax import Array
from jax import numpy as jnp
from jax import random as jr
from jax import scipy as jsp
from jax._src.flatten_util import ravel_pytree
from tqdm import tqdm

from sbijax._src import mcmc
from sbijax._src._ne_base import NE
from sbijax._src.mcmc.util import mcmc_diagnostics
from sbijax._src.util.data import as_inference_data
from sbijax._src.util.early_stopping import EarlyStopping


def _get_prior_probs_marginal_and_joint(K, gamma):
  p_marginal = 1 / (1 + gamma * K)
  p_joint = gamma / (1 + gamma * K)
  return p_marginal, p_joint


# pylint: disable=too-many-arguments
def _as_logits(params, rng_key, model, K, theta, y):
  n = theta.shape[0]
  y = jnp.repeat(y, K + 1, axis=0)
  ps = jnp.ones((n, n)) * (1.0 - jnp.eye(n)) / (n - 1.0)

  choices = jax.vmap(
    lambda key, p: jr.choice(key, n, (K,), replace=False, p=p)
  )(jr.split(rng_key, n), ps)

  contrasting_theta = theta[choices]
  atomic_theta = jnp.concatenate(
    [theta[:, None, :], contrasting_theta], axis=1
  ).reshape(n * (K + 1), -1)

  inputs = jnp.concatenate([y, atomic_theta], axis=-1)
  return model.apply(params, inputs, is_training=False)


def _marginal_joint_loss(gamma, num_classes, log_marg, log_joint):
  loggamma = jnp.log(gamma)
  logK = jnp.full((log_marg.shape[0], 1), jnp.log(num_classes))

  denominator_marginal = jnp.concatenate(
    [loggamma + log_marg, logK],
    axis=-1,
  )
  denomintator_joint = jnp.concatenate(
    [loggamma + log_joint, logK],
    axis=-1,
  )

  log_prob_marginal = logK - jsp.special.logsumexp(
    denominator_marginal, axis=-1
  )
  log_prob_joint = (
    loggamma
    + log_joint[:, 0]
    - jsp.special.logsumexp(denomintator_joint, axis=-1)
  )

  p_marg, p_joint = _get_prior_probs_marginal_and_joint(num_classes, gamma)
  loss = p_marg * log_prob_marginal + p_joint * num_classes * log_prob_joint
  return loss


def _loss(params, rng_key, model, gamma, num_classes, **batch):
  n, _ = batch["y"].shape

  rng_key1, rng_key2, rng_key = jr.split(rng_key, 3)
  log_marg = _as_logits(params, rng_key1, model, num_classes, **batch)
  log_joint = _as_logits(params, rng_key2, model, num_classes, **batch)

  log_marg = log_marg.reshape(n, num_classes + 1)[:, 1:]
  log_joint = log_joint.reshape(n, num_classes + 1)[:, :-1]

  loss = _marginal_joint_loss(gamma, num_classes, log_marg, log_joint)
  return -jnp.mean(loss)


# ruff: noqa: PLR0913, E501
[docs] class NRE(NE): r"""Neural ratio estimation. Implements the method by :cite:t:`miller2022contrast`. The original publication calls the method as CNRE or NRE-C, but here, we refer to it as NRE. Args: model_fns: a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function. classifier: a neural network for classification num_classes: number of classes to classify against gamma: relative weight of classes Examples: >>> from sbijax import NRE >>> from sbijax.nn import make_resnet >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(0.0, 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed) >>> fns = prior, s >>> neural_network = make_resnet() >>> model = NRE(fns, neural_network) References: Miller, Benjamin K., et al. "Contrastive neural ratio estimation." Advances in Neural Information Processing Systems, 2022. """ def __init__( self, model_fns: tuple[Callable, Callable], classifier: Transformed, num_classes: int = 10, gamma: float = 1.0, ): super().__init__(model_fns, classifier) self.gamma = gamma self.num_classes = num_classes # pylint: disable=arguments-differ,too-many-locals
[docs] def fit( self, rng_key: Array, data: NamedTuple, *, optimizer: optax.GradientTransformation = optax.adam(0.003), n_iter: int = 1000, batch_size: int = 100, percentage_data_as_validation_set: float = 0.1, n_early_stopping_patience: int = 25, n_early_stopping_delta=0.001, **kwargs, ): """Fit the model. Args: rng_key: a jax random key data: data set obtained from calling `simulate_data_and_possibly_append` optimizer: an optax optimizer object n_iter: maximal number of training iterations per round batch_size: batch size used for training the model percentage_data_as_validation_set: percentage of the simulated data that is used for validation and early stopping n_early_stopping_patience: number of iterations of no improvement of training the flow before stopping optimisation n_early_stopping_delta: minimal value for improvement for early stopping Returns: a tuple of parameters and a tuple of the training information """ itr_key, rng_key = jr.split(rng_key) train_iter, val_iter = self.as_iterators( itr_key, data, batch_size, percentage_data_as_validation_set ) params, losses = self._fit_model_single_round( rng_key=rng_key, train_iter=train_iter, val_iter=val_iter, optimizer=optimizer, n_iter=n_iter, n_early_stopping_patience=n_early_stopping_patience, n_early_stopping_delta=n_early_stopping_delta, ) return params, losses
# pylint: disable=undefined-loop-variable def _fit_model_single_round( self, rng_key, train_iter, val_iter, optimizer, n_iter, n_early_stopping_patience, n_early_stopping_delta, ): init_key, rng_key = jr.split(rng_key) params = self._init_params(init_key, **next(iter(train_iter))) state = optimizer.init(params) loss_fn = partial(_loss, gamma=self.gamma, num_classes=self.num_classes) @jax.jit def step(params, rng, state, **batch): loss, grads = jax.value_and_grad(loss_fn)( params, rng, self.model, **batch ) updates, new_state = optimizer.update(grads, state, params) new_params = optax.apply_updates(params, updates) return loss, new_params, new_state losses = np.zeros([n_iter, 2]) early_stop = EarlyStopping( n_early_stopping_delta, n_early_stopping_patience ) best_params, best_loss = None, np.inf logging.info("training model") for i in tqdm(range(n_iter)): train_loss = 0.0 rng_key = jr.fold_in(rng_key, i) for batch in train_iter: train_key, rng_key = jr.split(rng_key) batch_loss, params, state = step(params, train_key, state, **batch) train_loss += batch_loss * ( batch["y"].shape[0] / train_iter.num_samples ) val_key, rng_key = jr.split(rng_key) validation_loss = self._validation_loss(val_key, params, val_iter) losses[i] = jnp.array([train_loss, validation_loss]) _, early_stop = early_stop.update(validation_loss) if early_stop.should_stop: logging.info("early stopping criterion found") break if validation_loss < best_loss: best_loss = validation_loss best_params = params.copy() losses = jnp.vstack(losses)[: (i + 1), :] return best_params, losses def _init_params(self, rng_key, **init_data): params = self.model.init( rng_key, jnp.concatenate([init_data["y"], init_data["theta"]], axis=-1), ) return params def _validation_loss(self, rng_key, params, val_iter): loss_fn = partial(_loss, gamma=self.gamma, num_classes=self.num_classes) @jax.jit def body_fn(rng_key, **batch): loss = loss_fn(params, rng_key, self.model, **batch) return loss * (batch["y"].shape[0] / val_iter.num_samples) loss = 0.0 for batch in val_iter: val_key, rng_key = jr.split(rng_key) loss += body_fn(val_key, **batch) return loss
[docs] def simulate_data_and_possibly_append( self, rng_key: Array, params: Params | None = None, observable: Array = None, data: tuple | None = None, n_simulations: int = 1_000, n_chains: int = 4, n_samples: int = 2_000, n_warmup: int = 1_000, **kwargs, ): """Simulate data from the prior or posterior. Simulate new parameters and observables from the prior or posterior (when params and data given). If a data argument is provided, append the new samples to the data set and return the old+new data. Args: rng_key: a jax random key params: a dictionary of neural network parameters observable: an observation data: existing data set or None n_simulations: number of newly simulated data n_chains: number of MCMC chains n_samples: number of sa les to draw in total n_warmup: number of draws to discarded Keyword Args: sampler (str): either 'nuts', 'slice' or None (defaults to nuts) n_thin (int): number of thinning steps (only used if sampler='slice') n_doubling (int): number of doubling steps of the interval (only used if sampler='slice') step_size (float): step size of the initial interval (only used if sampler='slice') Returns: returns a NamedTuple of two axis, y and theta """ return super().simulate_data_and_possibly_append( rng_key=rng_key, params=params, observable=observable, data=data, n_simulations=n_simulations, n_chains=n_chains, n_samples=n_samples, n_warmup=n_warmup, **kwargs, )
# ruff: noqa: D417
[docs] def sample_posterior( self, rng_key, params, observable, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs, ): r"""Sample from the approximate posterior. Args: rng_key: a jax random key params: a pytree of neural network parameters observable: observation to condition on n_chains: number of MCMC chains n_samples: number of samples per chain n_warmup: number of samples to discard Keyword Args: sampler (str): either 'nuts', 'slice' or None (defaults to nuts) n_thin (int): number of thinning steps (only used if sampler='slice') n_doubling (int): number of doubling steps of the interval (only used if sampler='slice') step_size (float): step size of the initial interval (only used if sampler='slice') Returns: returns an array of samples from the posterior distribution of dimension (n_samples \times p) and posterior diagnostics """ observable = jnp.atleast_2d(observable) return self._sample_posterior( rng_key, params, observable, n_chains=n_chains, n_samples=n_samples, n_warmup=n_warmup, **kwargs, )
def _sample_posterior( self, rng_key, params, observable, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs, ): part = partial(self.model.apply, params, is_training=False) def _joint_logdensity_fn(theta): lp_prior = self.prior.log_prob(theta) theta, _ = ravel_pytree(theta) theta = theta.reshape(observable.shape[0], -1) lp = part(jnp.concatenate([observable, theta], axis=-1)) return jnp.sum(lp_prior) + jnp.sum(lp) sampler = kwargs.pop("sampler", "nuts") sampling_fn = getattr(mcmc, "sample_with_" + sampler) samples = sampling_fn( rng_key=rng_key, lp=_joint_logdensity_fn, prior=self.prior, n_chains=n_chains, n_samples=n_samples, n_warmup=n_warmup, **kwargs, ) for v in samples.values(): chex.assert_shape(v, [n_chains, n_samples - n_warmup, None]) inference_data = as_inference_data(samples, jnp.squeeze(observable)) diagnostics = mcmc_diagnostics(inference_data) return inference_data, diagnostics def _simulate_parameters_with_model( self, rng_key, params, observable, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs, ): return self.sample_posterior( rng_key=rng_key, params=params, observable=observable, n_samples=n_samples, n_warmup=n_warmup, n_chains=n_chains, **kwargs, )