Source code for sbijax._src.inference.summary.nass

"""Neural approximate sufficient statistics.

Implements the NASS method of :cite:t:`chen2023learning` as a functional
``SummaryFns``. The network learns a summary of
the data by maximising a Jensen-Shannon mutual-information bound; the JSD loss
is reused from the existing implementation.
"""

import jax
from jax import numpy as jnp
from jax import random as jr

from sbijax._src.inference.summary._summary_net import make_summary_net


def _jsd_summary_loss(params, rng, apply_fn, **batch):
  y, theta = batch["y"], batch["theta"]
  m, _ = y.shape
  summr = apply_fn(params, method="summary", y=y)
  idx_pos = jnp.tile(jnp.arange(m), 10)
  idx_neg = jax.vmap(lambda x: jr.permutation(x, m))(jr.split(rng, 10)).reshape(
    -1
  )
  f_pos = apply_fn(params, method="critic", y=summr, theta=theta)
  f_neg = apply_fn(
    params, method="critic", y=summr[idx_pos], theta=theta[idx_neg]
  )
  a, b = -jax.nn.softplus(-f_pos), jax.nn.softplus(f_neg)
  mi = a.mean() - b.mean()
  return -mi


[docs] def nass(network): """Construct a neural approximate sufficient statistics summary network. Args: network: a NASS summary network with ``forward``, ``summary`` and ``critic`` methods Returns: a ``SummaryFns`` """ return make_summary_net(network, _jsd_summary_loss)