from functools import partial
import jax
import numpy as np
import optax
from absl import logging
from jax import numpy as jnp
from jax import random as jr
from jax import scipy as jsp
from jax._src.flatten_util import ravel_pytree
from tqdm import tqdm
from sbijax._src._ne_base import NE
from sbijax._src.util.data import as_inference_data
from sbijax._src.util.early_stopping import EarlyStopping
# ruff: noqa: PLR0913, E501
[docs]
class NPE(NE):
"""Neural posterior estimation.
Implements the method introduced in :cite:t:`greenberg2019automatic`.
In the literature, the method is usually referred to as APT or NPE-C, but
here we refer to it simply as NPE.
Args:
model_fns: a tuple of calalbles. The first element needs to be a
function that constructs a tfd.JointDistributionNamed, the second
element is a simulator function.
density_estimator: a (neural) conditional density estimator
to model the posterior distribution
num_atoms: number of atomic atoms
Examples:
>>> from sbijax import NPE
>>> from sbijax.nn import make_maf
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
... dict(theta=tfd.Normal(0.0, 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed)
>>> fns = prior, s
>>> neural_network = make_maf(1)
>>> model = NPE(fns, neural_network)
References:
Greenberg, David, et al. "Automatic posterior transformation for likelihood-free inference." International Conference on Machine Learning, 2019.
"""
def __init__(
self,
model_fns,
density_estimator,
num_atoms=10,
use_event_space_bijections=True,
):
"""Construct an SNP object.
Args:
model_fns: a tuple of tuples. The first element is a tuple that
consists of functions to sample and evaluate the
log-probability of a data point. The second element is a
simulator function.
density_estimator: a (neural) conditional density estimator
to model the posterior distribution
num_atoms: number of atomic atoms
use_event_space_bijections: if True uses a unconstraining bijection
to map the constrained parameters onto the real line and do
training there
"""
super().__init__(model_fns, density_estimator)
self.num_atoms = num_atoms
self.n_round = 0
prior = model_fns[0]()
# TODO(simon): check out event bijections
if (
hasattr(prior, "experimental_default_event_space_bijector")
and use_event_space_bijections
):
self._prior_bijectors = prior.experimental_default_event_space_bijector()
# ruff: noqa: D417
[docs]
def fit(
self,
rng_key,
data,
*,
optimizer=optax.adam(0.0003),
n_iter=1000,
batch_size=128,
percentage_data_as_validation_set=0.1,
n_early_stopping_patience=10,
**kwargs,
):
"""Fit an SNP model.
Args:
rng_key: a jax random key
data: data set obtained from calling
`simulate_data_and_possibly_append`
optimizer: an optax optimizer object
n_iter: maximal number of training iterations per round
batch_size: batch size used for training the model
percentage_data_as_validation_set: percentage of the simulated
data that is used for validation and early stopping
n_early_stopping_patience: number of iterations of no improvement
of training the flow before stopping optimisation
Returns:
a tuple of parameters and a tuple of the training information
"""
itr_key, rng_key = jr.split(rng_key)
train_iter, val_iter = self.as_iterators(
itr_key, data, batch_size, percentage_data_as_validation_set
)
params, losses = self._fit_model_single_round(
seed=rng_key,
train_iter=train_iter,
val_iter=val_iter,
optimizer=optimizer,
n_iter=n_iter,
n_early_stopping_patience=n_early_stopping_patience,
n_atoms=self.num_atoms,
)
return params, losses
# pylint: disable=undefined-loop-variable
def _fit_model_single_round(
self,
seed,
train_iter,
val_iter,
optimizer,
n_iter,
n_early_stopping_patience,
n_atoms,
):
init_key, seed = jr.split(seed)
params = self._init_params(init_key, **next(iter(train_iter)))
state = optimizer.init(params)
n_round = self.n_round
_, unravel_fn = ravel_pytree(self.prior.sample(seed=jr.PRNGKey(1)))
if n_round == 0:
def loss_fn(params, rng, **batch):
theta, y = batch["theta"], batch["y"]
log_det = 0
if hasattr(self, "_prior_bijectors"):
theta_map = jax.vmap(unravel_fn)(theta)
theta = self._prior_bijectors.inverse(theta_map)
log_det = self._prior_bijectors.inverse_log_det_jacobian(theta_map)
theta = jax.vmap(lambda x: ravel_pytree(x)[0])(theta)
lp = self.model.apply(
params,
None,
method="log_prob",
y=theta,
x=y,
)
lp = lp + log_det
return -jnp.mean(lp)
else:
# TODO(simon): do bijections here? probably
def loss_fn(params, rng, **batch):
lp = self._proposal_posterior_log_prob(
params,
rng,
n_atoms,
theta=batch["theta"],
y=batch["y"],
)
return -jnp.mean(lp)
@jax.jit
def step(params, rng, state, **batch):
loss, grads = jax.value_and_grad(loss_fn)(params, rng, **batch)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return loss, new_params, new_state
losses = np.zeros([n_iter, 2])
early_stop = EarlyStopping(1e-3, n_early_stopping_patience)
best_params, best_loss = None, np.inf
logging.info("training model")
for i in tqdm(range(n_iter)):
train_loss = 0.0
rng_key = jr.fold_in(seed, i)
for batch in train_iter:
train_key, rng_key = jr.split(rng_key)
batch_loss, params, state = step(params, train_key, state, **batch)
train_loss += batch_loss * (
batch["y"].shape[0] / train_iter.num_samples
)
val_key, rng_key = jr.split(rng_key)
validation_loss = self._validation_loss(
val_key, params, val_iter, n_atoms
)
losses[i] = jnp.array([train_loss, validation_loss])
_, early_stop = early_stop.update(validation_loss)
if early_stop.should_stop:
logging.info("early stopping criterion found")
break
if validation_loss < best_loss:
best_loss = validation_loss
best_params = params.copy()
self.n_round += 1
losses = jnp.vstack(losses)[: (i + 1), :]
return best_params, losses
def _init_params(self, rng_key, **init_data):
params = self.model.init(
rng_key, method="log_prob", y=init_data["theta"], x=init_data["y"]
)
return params
def _proposal_posterior_log_prob(self, params, rng, n_atoms, theta, y):
n = theta.shape[0]
n_atoms = np.maximum(2, np.minimum(n_atoms, n))
repeated_y = jnp.repeat(y, n_atoms, axis=0)
probs = jnp.ones((n, n)) * (1 - jnp.eye(n)) / (n - 1)
choice = partial(
jr.choice, a=jnp.arange(n), replace=False, shape=(n_atoms - 1,)
)
sample_keys = jr.split(rng, probs.shape[0])
choices = jax.vmap(lambda key, prob: choice(key, p=prob))(
sample_keys, probs
)
contrasting_theta = theta[choices]
atomic_theta = jnp.concatenate(
(theta[:, None, :], contrasting_theta), axis=1
)
atomic_theta = atomic_theta.reshape(n * n_atoms, -1)
log_prob_posterior = self.model.apply(
params, None, method="log_prob", y=atomic_theta, x=repeated_y
)
log_prob_posterior = log_prob_posterior.reshape(n, n_atoms)
log_prob_prior = self.prior.log_prob(atomic_theta)
log_prob_prior = log_prob_prior.reshape(n, n_atoms)
unnormalized_log_prob = log_prob_posterior - log_prob_prior
log_prob_proposal_posterior = unnormalized_log_prob[
:, 0
] - jsp.special.logsumexp(unnormalized_log_prob, axis=-1)
return log_prob_proposal_posterior
def _validation_loss(self, rng_key, params, val_iter, n_atoms):
if self.n_round == 0:
_, unravel_fn = ravel_pytree(self.prior.sample(seed=jr.PRNGKey(1)))
def loss_fn(rng, **batch):
theta, y = batch["theta"], batch["y"]
log_det = 0
if hasattr(self, "_prior_bijectors"):
theta_map = jax.vmap(unravel_fn)(theta)
theta = self._prior_bijectors.inverse(theta_map)
log_det = self._prior_bijectors.inverse_log_det_jacobian(theta_map)
theta = jax.vmap(lambda x: ravel_pytree(x)[0])(theta)
lp = self.model.apply(
params,
None,
method="log_prob",
y=theta,
x=y,
)
lp = lp + log_det
return -jnp.mean(lp)
else:
def loss_fn(rng, **batch):
lp = self._proposal_posterior_log_prob(
params, rng, n_atoms, batch["theta"], batch["y"]
)
return -jnp.mean(lp)
def body_fn(batch, rng_key):
loss = jax.jit(loss_fn)(rng_key, **batch)
return loss * (batch["y"].shape[0] / val_iter.num_samples)
loss = 0.0
for batch in val_iter:
val_key, rng_key = jr.split(rng_key)
loss += body_fn(batch, val_key)
return loss
[docs]
def sample_posterior(
self,
rng_key,
params,
observable,
*,
n_samples=4_000,
check_proposal_probs=True,
**kwargs,
):
r"""Sample from the approximate posterior.
Args:
rng_key: a jax random key
params: a pytree of neural network parameters
observable: observation to condition on
n_samples: number of samples to draw
check_proposal_probs: check if the proposal draws have finite density
and only accept a proposal if it is. This is convenient to turn
off if the density estimator has to learn a density of a
constrained variable, and the RejectionMC takes a long time to
draw valid samples.
Returns:
returns an array of samples from the posterior distribution of
dimension (n_samples \times p)
"""
observable = jnp.atleast_2d(observable)
thetas = None
n_curr = n_samples
n_total_simulations_round = 0
_, unravel_fn = ravel_pytree(self.prior.sample(seed=jr.PRNGKey(1)))
while n_curr > 0:
n_sim = jnp.minimum(200, jnp.maximum(200, n_curr))
n_total_simulations_round += n_sim
sample_key, rng_key = jr.split(rng_key)
proposal = self.model.apply(
params,
sample_key,
method="sample",
sample_shape=(n_sim,),
x=jnp.tile(observable, [n_sim, 1]),
)
if hasattr(self, "_prior_bijectors"):
proposal = jax.vmap(unravel_fn)(proposal)
proposal = self._prior_bijectors.forward(proposal)
proposal_probs = self.prior.log_prob(proposal)
proposal = jax.vmap(lambda x: ravel_pytree(x)[0])(proposal)
else:
proposal_probs = self.prior.log_prob(jax.vmap(unravel_fn)(proposal))
if check_proposal_probs:
proposal = proposal[jnp.isfinite(proposal_probs)]
if thetas is None:
thetas = proposal
else:
thetas = jnp.vstack([thetas, proposal])
n_curr -= proposal.shape[0]
ess = float(thetas.shape[0] / n_total_simulations_round)
def reshape(p):
if p.ndim == 1:
p = p.reshape(p.shape[0], 1)
p = p.reshape(1, *p.shape)
return p
thetas = jax.tree_util.tree_map(
reshape, jax.vmap(unravel_fn)(thetas[:n_samples])
)
inference_data = as_inference_data(thetas, jnp.squeeze(observable))
return inference_data, ess
def _simulate_parameters_with_model(
self,
rng_key,
params,
observable,
*,
n_samples=4_000,
check_proposal_probs=True,
**kwargs,
):
return self.sample_posterior(
rng_key=rng_key,
params=params,
observable=observable,
n_samples=n_samples,
check_proposal_probs=check_proposal_probs,
**kwargs,
)