Source code for sbijax._src.inference.summary.nasss
"""Neural approximate slice sufficient statistics.
Implements the NASSS method of :cite:t:`chen2021neural` as a functional
``SummaryFns``. It differs from NASS only in the
loss (a slice-based JSD bound with a secondary summary); the training and
summarization logic are shared.
"""
import jax
from jax import numpy as jnp
from jax import random as jr
from sbijax._src.inference.summary._summary_net import make_summary_net
def _sample_unit_sphere(rng_key, n, dim):
u = jr.normal(rng_key, (n, dim))
norm = jnp.linalg.norm(u, ord=2, axis=-1, keepdims=True)
return u / norm
def _jsd_summary_loss(params, rng_key, apply_fn, **batch):
y, theta = batch["y"], batch["theta"]
n, p = theta.shape
phi_key, rng_key = jr.split(rng_key)
summr = apply_fn(params, method="summary", y=y)
summr = jnp.tile(summr, [10, 1])
theta = jnp.tile(theta, [10, 1])
phi = _sample_unit_sphere(phi_key, 10, p)
phi = jnp.repeat(phi, n, axis=0)
second_summr = apply_fn(
params, method="secondary_summary", y=summr, theta=phi
)
theta_prime = jnp.sum(theta * phi, axis=1).reshape(-1, 1)
idx_pos = jnp.tile(jnp.arange(n), 10)
perm_key, rng_key = jr.split(rng_key)
idx_neg = jax.vmap(lambda x: jr.permutation(x, n))(
jr.split(perm_key, 10)
).reshape(-1)
f_pos = apply_fn(params, method="critic", y=second_summr, theta=theta_prime)
f_neg = apply_fn(
params,
method="critic",
y=second_summr[idx_pos],
theta=theta_prime[idx_neg],
)
a, b = -jax.nn.softplus(-f_pos), jax.nn.softplus(f_neg)
mi = a.mean() - b.mean()
return -mi
[docs]
def nasss(network):
"""Construct a neural approximate slice sufficient statistics summary network.
Args:
network: a NASSS summary network with ``forward``, ``summary``,
``secondary_summary`` and ``critic`` methods
Returns:
a ``SummaryFns``
"""
return make_summary_net(network, _jsd_summary_loss)