Source code for sbijax._src.nass

from functools import partial

import jax
import numpy as np
import optax
from absl import logging
from jax import numpy as jnp
from jax import random as jr
from tqdm import tqdm

from sbijax._src._ne_base import NE
from sbijax._src.util.dataloader import as_numpy_iterator_from_slices
from sbijax._src.util.early_stopping import EarlyStopping


def _jsd_summary_loss(params, rng, apply_fn, **batch):
  y, theta = batch["y"], batch["theta"]
  m, _ = y.shape
  summr = apply_fn(params, method="summary", y=y)
  idx_pos = jnp.tile(jnp.arange(m), 10)
  idx_neg = jax.vmap(lambda x: jr.permutation(x, m))(jr.split(rng, 10)).reshape(
    -1
  )
  f_pos = apply_fn(params, method="critic", y=summr, theta=theta)
  f_neg = apply_fn(
    params, method="critic", y=summr[idx_pos], theta=theta[idx_neg]
  )
  a, b = -jax.nn.softplus(-f_pos), jax.nn.softplus(f_neg)
  mi = a.mean() - b.mean()
  return -mi


# ruff: noqa: PLR0913, E501
[docs] class NASS(NE): """Neural approximate summary statistics. Implements the NASS algorithm introduced in :cite:t:`chen2023learning`. NASS can be used to automatically summary statistics of a data set. With the learned summaries, inferential algorithms like NLE or SMCABC can be used to infer posterior distributions. Args: model_fns: a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function. summary_net: a SNASSNet object Examples: >>> from sbijax import NASS >>> from sbijax.nn import make_nass_net >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(5), 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal( ... theta["theta"], 1.0).sample(seed=seed, sample_shape=(2,) ... ).reshape(-1, 10) >>> fns = prior, s >>> neural_network = make_nass_net([64, 64, 5], [64, 64, 1]) >>> model = NASS(fns, neural_network) References: Chen, Yanzhi et al. "Neural Approximate Sufficient Statistics for Implicit Models". ICLR, 2021 """ def __init__(self, model_fns, summary_net): super().__init__(model_fns, summary_net) # pylint: disable=arguments-differ,too-many-locals
[docs] def fit( self, rng_key, data, optimizer=optax.adam(0.0003), n_iter=1000, batch_size=128, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs, ): """Fit the model to data. Args: rng_key: a jax random key data: data set obtained from calling `simulate_data_and_possibly_append` optimizer: an optax optimizer object n_iter: maximal number of training iterations per round batch_size: batch size used for training the model percentage_data_as_validation_set: percentage of the simulated data that is used for validation and early stopping n_early_stopping_patience: number of iterations of no improvement of training the flow before stopping optimisation **kwargs: additional keyword arguments not used for NASS) Returns: tuple of parameters and a tuple of the training information """ itr_key, rng_key = jr.split(rng_key) train_iter, val_iter = self.as_iterators( itr_key, data, batch_size, percentage_data_as_validation_set ) snet_params, snet_losses = self._fit_summary_net( rng_key=rng_key, train_iter=train_iter, val_iter=val_iter, optimizer=optimizer, n_iter=n_iter, n_early_stopping_patience=n_early_stopping_patience, ) return snet_params, snet_losses
# TODO(Simon): this is not very nicely solved def summarize(self, params, data, batch_size=512): if params is None or len(params) == 0: return data y = {"y": data} if isinstance(data, jnp.ndarray) else data itr = as_numpy_iterator_from_slices(y, batch_size) @jax.jit def _summarize(batch): return self.model.apply(params, method="summary", y=batch["y"]) summaries = jnp.concatenate([_summarize(batch) for batch in itr], axis=0) if isinstance(data, dict): ret_summaries = data.copy() ret_summaries["y"] = summaries else: ret_summaries = summaries return ret_summaries # pylint: disable=undefined-loop-variable def _fit_summary_net( self, rng_key, train_iter, val_iter, optimizer, n_iter, n_early_stopping_patience, ): init_key, rng_key = jr.split(rng_key) params = self._init_summary_net_params(init_key, **next(iter(train_iter))) state = optimizer.init(params) loss_fn = jax.jit(partial(_jsd_summary_loss, apply_fn=self.model.apply)) @jax.jit def step(rng, params, state, **batch): loss, grads = jax.value_and_grad(loss_fn)(params, rng, **batch) updates, new_state = optimizer.update(grads, state, params) new_params = optax.apply_updates(params, updates) return loss, new_params, new_state losses = np.zeros([n_iter, 2]) early_stop = EarlyStopping(1e-3, n_early_stopping_patience) best_params, best_loss = None, np.inf logging.info("training summary net") for i in tqdm(range(n_iter)): train_loss = 0.0 epoch_key, rng_key = jr.split(rng_key) for j, batch in enumerate(train_iter): batch_loss, params, state = step( jr.fold_in(epoch_key, j), params, state, **batch ) train_loss += batch_loss * ( batch["y"].shape[0] / train_iter.num_samples ) val_key, rng_key = jr.split(rng_key) validation_loss = self._summary_validation_loss(params, val_key, val_iter) losses[i] = jnp.array([train_loss, validation_loss]) _, early_stop = early_stop.update(validation_loss) if early_stop.should_stop: logging.info("early stopping criterion found") break if validation_loss < best_loss: best_loss = validation_loss best_params = params.copy() losses = jnp.vstack(losses)[: (i + 1), :] return best_params, losses def _init_summary_net_params(self, rng_key, **init_data): params = self.model.init(rng_key, method="forward", **init_data) return params def _summary_validation_loss(self, params, rng_key, val_iter): loss_fn = jax.jit(partial(_jsd_summary_loss, apply_fn=self.model.apply)) def body_fn(batch_key, **batch): loss = loss_fn(params, batch_key, **batch) return loss * (batch["y"].shape[0] / val_iter.num_samples) losses = 0.0 for batch in val_iter: batch_key, rng_key = jr.split(rng_key) losses += body_fn(batch_key, **batch) return losses def simulate_data( self, rng_key, *, n_simulations=1000, **kwargs, ): return super().simulate_data(rng_key, n_simulations=n_simulations, **kwargs) def _simulate_parameters_with_model( self, rng_key, params, observable, *args, **kwargs ): raise NotImplementedError() def sample_posterior(self, rng_key, params, observable, *args, **kwargs): raise NotImplementedError()