from functools import partial
import jax
import numpy as np
import optax
from absl import logging
from jax import numpy as jnp
from jax import random as jr
from sbijax._src.nass import NASS
from sbijax._src.util.early_stopping import EarlyStopping
def _sample_unit_sphere(rng_key, n, dim):
u = jr.normal(rng_key, (n, dim))
norm = jnp.linalg.norm(u, ord=2, axis=-1, keepdims=True)
return u / norm
# pylint: disable=too-many-locals
def _jsd_summary_loss(params, rng_key, apply_fn, **batch):
y, theta = batch["y"], batch["theta"]
n, p = theta.shape
phi_key, rng_key = jr.split(rng_key)
summr = apply_fn(params, method="summary", y=y)
summr = jnp.tile(summr, [10, 1])
theta = jnp.tile(theta, [10, 1])
phi = _sample_unit_sphere(phi_key, 10, p)
phi = jnp.repeat(phi, n, axis=0)
second_summr = apply_fn(
params, method="secondary_summary", y=summr, theta=phi
)
theta_prime = jnp.sum(theta * phi, axis=1).reshape(-1, 1)
idx_pos = jnp.tile(jnp.arange(n), 10)
perm_key, rng_key = jr.split(rng_key)
idx_neg = jax.vmap(lambda x: jr.permutation(x, n))(
jr.split(perm_key, 10)
).reshape(-1)
f_pos = apply_fn(params, method="critic", y=second_summr, theta=theta_prime)
f_neg = apply_fn(
params,
method="critic",
y=second_summr[idx_pos],
theta=theta_prime[idx_neg],
)
a, b = -jax.nn.softplus(-f_pos), jax.nn.softplus(f_neg)
mi = a.mean() - b.mean()
return -mi
# ruff: noqa: PLR0913, E501
[docs]
class NASSS(NASS):
"""Neural approximate slice sufficient statistics.
Implements the NASSS algorithm introduced in :cite:t:`chen2021neural`.
NASS can be used to automatically summary statistics of a data set.
With the learned summaries, inferential algorithms like NLE or SMCABC
can be used to infer posterior distributions.
Args:
model_fns: a tuple of calalbles. The first element needs to be a
function that constructs a tfd.JointDistributionNamed, the second
element is a simulator function.
summary_net: a (neural) conditional density estimator
to model the likelihood function of summary statistics, i.e.,
the modelled dimensionality is that of the summaries
summary_net: a SNASSSNet object
Examples:
>>> from jax import numpy as jnp
>>> from sbijax import NASSS
>>> from sbijax.nn import make_nasss_net
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
... dict(theta=tfd.Normal(jnp.zeros(5), 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(
... theta["theta"], 1.0).sample(seed=seed, sample_shape=(2,)
... ).reshape(-1, 10)
>>> fns = prior, s
>>> neural_network = make_nasss_net([64, 64, 5], [64, 64, 1], [64, 64, 1])
>>> model = NASSS(fns, neural_network)
References:
Yanzhi Chen et al. "Is Learning Summary Statistics Necessary for Likelihood-free Inference". ICML, 2023
"""
# pylint: disable=useless-parent-delegation
def __init__(self, model_fns, summary_net):
super().__init__(model_fns, summary_net)
# pylint: disable=undefined-loop-variable
def _fit_summary_net(
self,
rng_key,
train_iter,
val_iter,
optimizer,
n_iter,
n_early_stopping_patience,
):
init_key, rng_key = jr.split(rng_key)
params = self._init_summary_net_params(init_key, **next(iter(train_iter)))
state = optimizer.init(params)
loss_fn = jax.jit(partial(_jsd_summary_loss, apply_fn=self.model.apply))
@jax.jit
def step(rng, params, state, **batch):
loss, grads = jax.value_and_grad(loss_fn)(params, rng, **batch)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return loss, new_params, new_state
losses = np.zeros([n_iter, 2])
early_stop = EarlyStopping(1e-3, n_early_stopping_patience)
best_params, best_loss = None, np.inf
logging.info("training summary net")
for i in range(n_iter):
train_loss = 0.0
epoch_key, rng_key = jr.split(rng_key)
for j, batch in enumerate(train_iter):
batch_loss, params, state = step(
jr.fold_in(epoch_key, j), params, state, **batch
)
train_loss += batch_loss * (
batch["y"].shape[0] / train_iter.num_samples
)
val_key, rng_key = jr.split(rng_key)
validation_loss = self._summary_validation_loss(params, val_key, val_iter)
losses[i] = jnp.array([train_loss, validation_loss])
_, early_stop = early_stop.update(validation_loss)
if early_stop.should_stop:
logging.info("early stopping criterion found")
break
if validation_loss < best_loss:
best_loss = validation_loss
best_params = params.copy()
losses = jnp.vstack(losses)[: (i + 1), :]
return best_params, losses
def _summary_validation_loss(self, params, rng_key, val_iter):
loss_fn = jax.jit(partial(_jsd_summary_loss, apply_fn=self.model.apply))
def body_fn(batch_key, **batch):
loss = loss_fn(params, batch_key, **batch)
return loss * (batch["y"].shape[0] / val_iter.num_samples)
losses = 0.0
for batch in val_iter:
batch_key, rng_key = jr.split(rng_key)
losses += body_fn(batch_key, **batch)
return losses