"""Neural ratio estimation.
Implements the contrastive NRE method of :cite:t:`miller2022contrast` as a
functional objective. The network is a classifier whose logits define a
likelihood-to-evidence ratio; the posterior is obtained at sample time by
handing the ratio log-density to an injected MCMC sampler that adds the prior.
"""
# ruff: noqa: PLR0913
from functools import partial
import jax
import optax
from jax import numpy as jnp
from jax import random as jr
from jax import scipy as jsp
from jax._src.flatten_util import ravel_pytree
from sbijax._src.train._types import ObjectiveFns, TrainFns, TrainingState
def _get_prior_probs_marginal_and_joint(k, gamma):
p_marginal = 1 / (1 + gamma * k)
p_joint = gamma / (1 + gamma * k)
return p_marginal, p_joint
def _as_logits(params, rng_key, model, k, theta, y):
n = theta.shape[0]
y = jnp.repeat(y, k + 1, axis=0)
ps = jnp.ones((n, n)) * (1.0 - jnp.eye(n)) / (n - 1.0)
choices = jax.vmap(
lambda key, p: jr.choice(key, n, (k,), replace=False, p=p)
)(jr.split(rng_key, n), ps)
contrasting_theta = theta[choices]
atomic_theta = jnp.concatenate(
[theta[:, None, :], contrasting_theta], axis=1
).reshape(n * (k + 1), -1)
inputs = jnp.concatenate([y, atomic_theta], axis=-1)
return model.apply(params, inputs, is_training=False)
def _marginal_joint_loss(gamma, num_classes, log_marg, log_joint):
loggamma = jnp.log(gamma)
log_k = jnp.full((log_marg.shape[0], 1), jnp.log(num_classes))
denominator_marginal = jnp.concatenate([loggamma + log_marg, log_k], axis=-1)
denominator_joint = jnp.concatenate([loggamma + log_joint, log_k], axis=-1)
log_prob_marginal = log_k - jsp.special.logsumexp(
denominator_marginal, axis=-1
)
log_prob_joint = (
loggamma
+ log_joint[:, 0]
- jsp.special.logsumexp(denominator_joint, axis=-1)
)
p_marg, p_joint = _get_prior_probs_marginal_and_joint(num_classes, gamma)
return p_marg * log_prob_marginal + p_joint * num_classes * log_prob_joint
def _classifier_loss(params, rng_key, model, gamma, num_classes, **batch):
n, _ = batch["y"].shape
rng_key1, rng_key2, _ = jr.split(rng_key, 3)
log_marg = _as_logits(params, rng_key1, model, num_classes, **batch)
log_joint = _as_logits(params, rng_key2, model, num_classes, **batch)
log_marg = log_marg.reshape(n, num_classes + 1)[:, 1:]
log_joint = log_joint.reshape(n, num_classes + 1)[:, :-1]
loss = _marginal_joint_loss(gamma, num_classes, log_marg, log_joint)
return -jnp.mean(loss)
[docs]
def nre(network, *, num_classes=10, gamma=1.0):
"""Construct a neural ratio objective.
Args:
network: a classifier network mapping ``concat(y, theta)`` to logits
num_classes: number of contrastive classes
gamma: relative weight of the contrastive classes
Returns:
an ``ObjectiveFns``
"""
def _loss(params, rng, batch):
return _classifier_loss(
params, rng, network, gamma=gamma, num_classes=num_classes, **batch
)
def init_fn(optimizer, rng_key, batch):
params = network.init(
rng_key,
jnp.concatenate([batch["y"], batch["theta"]], axis=-1),
)
return TrainingState(params=params, opt_state=optimizer.init(params))
def step_fn(optimizer, rng_key, state, batch):
loss, grads = jax.value_and_grad(_loss)(state.params, rng_key, batch)
updates, opt_state = optimizer.update(grads, state.opt_state, state.params)
return {"loss": loss}, TrainingState(
optax.apply_updates(state.params, updates), opt_state
)
def eval_fn(rng_key, state, batch):
return {"loss": _loss(state.params, rng_key, batch)}
def sample_fn(
rng_key,
params,
observable,
*,
sampler=None,
n_chains=4,
n_samples=2_000,
n_warmup=1_000,
**kwargs,
):
if sampler is None:
raise ValueError(
"nre sampling requires a sampler, e.g. make_sampler(nuts, prior=prior)"
)
observable = jnp.atleast_2d(observable)
classifier = partial(network.apply, params, is_training=False)
def loglik_fn(theta):
theta_flat, _ = ravel_pytree(theta)
theta_flat = theta_flat.reshape(observable.shape[0], -1)
return classifier(jnp.concatenate([observable, theta_flat], axis=-1))
return sampler(
rng_key,
loglik_fn,
n_chains=n_chains,
n_samples=n_samples,
n_warmup=n_warmup,
)
return ObjectiveFns(TrainFns(init_fn, step_fn, eval_fn), sample_fn)