Source code for sbijax._src.nle

from functools import partial

import arviz
import chex
import jax
import numpy as np
import optax
import xarray
from absl import logging
from jax import numpy as jnp
from jax import random as jr
from jax._src.flatten_util import ravel_pytree
from tqdm import tqdm

from sbijax._src import mcmc
from sbijax._src._ne_base import NE
from sbijax._src.mcmc.util import mcmc_diagnostics
from sbijax._src.util.data import as_inference_data
from sbijax._src.util.early_stopping import EarlyStopping


# ruff: noqa: PLR0913, E501
[docs] class NLE(NE): """Neural likelihood estimation. Implements the method introduced in :cite:t:`papama2019neural`. Args: model_fns: a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function. density_estimator: a (neural) conditional density estimator to model the likelihood function Examples: >>> from sbijax import NLE >>> from sbijax.nn import make_mdn >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(0.0, 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed) >>> fns = prior, s >>> neural_network = make_mdn(1, 5) >>> model = NLE(fns, neural_network) References: Papamakarios, George, et al. "Sequential neural likelihood: Fast likelihood-free inference with autoregressive flows." International Conference on Artificial Intelligence and Statistics, 2019. """ # pylint: disable=arguments-differ,too-many-locals
[docs] def fit( self, rng_key, data, optimizer=optax.adam(0.0003), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, **kwargs, ): """Fit the model. Args: rng_key: a jax random key data: data set obtained from calling `simulate_data_and_possibly_append` optimizer: an optax optimizer object n_iter: maximal number of training iterations per round batch_size: batch size used for training the model percentage_data_as_validation_set: percentage of the simulated data that is used for valitation and early stopping n_early_stopping_patience: number of iterations of no improvement of training the flow before stopping optimisation **kwargs: additional keyword arguments (not used for NLE) Returns: a tuple of parameters and a tuple of the training information """ itr_key, rng_key = jr.split(rng_key) train_iter, val_iter = self.as_iterators( itr_key, data, batch_size, percentage_data_as_validation_set ) params, losses = self._fit_model_single_round( seed=rng_key, train_iter=train_iter, val_iter=val_iter, optimizer=optimizer, n_iter=n_iter, n_early_stopping_patience=n_early_stopping_patience, ) return params, losses
# pylint: disable=arguments-differ,undefined-loop-variable def _fit_model_single_round( self, seed, train_iter, val_iter, optimizer, n_iter, n_early_stopping_patience, ): init_key, seed = jr.split(seed) params = self._init_params(init_key, **next(iter(train_iter))) state = optimizer.init(params) @jax.jit def step(params, state, **batch): def loss_fn(params): lp = self.model.apply( params, rng=None, method="log_prob", y=batch["y"], x=batch["theta"], ) return -jnp.mean(lp) loss, grads = jax.value_and_grad(loss_fn)(params) updates, new_state = optimizer.update(grads, state, params) new_params = optax.apply_updates(params, updates) return loss, new_params, new_state losses = np.zeros([n_iter, 2]) early_stop = EarlyStopping(1e-3, n_early_stopping_patience) best_params, best_loss = None, np.inf logging.info("training model") for i in tqdm(range(n_iter)): train_loss = 0.0 for batch in train_iter: batch_loss, params, state = step(params, state, **batch) train_loss += batch_loss * ( batch["y"].shape[0] / train_iter.num_samples ) validation_loss = self._validation_loss(params, val_iter) losses[i] = jnp.array([train_loss, validation_loss]) _, early_stop = early_stop.update(validation_loss) if early_stop.should_stop: logging.info("early stopping criterion found") break if validation_loss < best_loss: best_loss = validation_loss best_params = params.copy() losses = jnp.vstack(losses)[: (i + 1), :] return best_params, losses def _validation_loss(self, params, val_iter): @jax.jit def loss_fn(**batch): lp = self.model.apply( params, rng=None, method="log_prob", y=batch["y"], x=batch["theta"], ) return -jnp.mean(lp) def body_fn(batch): loss = loss_fn(**batch) return loss * (batch["y"].shape[0] / val_iter.num_samples) losses = 0.0 for batch in val_iter: losses += body_fn(batch) return losses def _init_params(self, rng_key, **init_data): params = self.model.init( rng_key, method="log_prob", y=init_data["y"], x=init_data["theta"] ) return params # ruff: noqa: D417
[docs] def simulate_data_and_possibly_append( self, rng_key, params=None, observable=None, data=None, n_simulations=1_000, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs, ): """Simulate data from the prior or posterior. Args: rng_key: a random key params: a dictionary of neural network parameters observable: an observation data: existing data set n_simulations: number of newly simulated data n_chains: number of MCMC chains n_samples: number of sa les to draw in total n_warmup: number of draws to discarded Keyword Args: sampler (str): either 'nuts', 'slice' or None (defaults to nuts) n_thin (int): number of thinning steps (only used if sampler='slice') n_doubling (int): number of doubling steps of the interval (only used if sampler='slice') step_size (float): step size of the initial interval (only used if sampler='slice') Returns: returns a NamedTuple with two elements, y and theta """ return super().simulate_data_and_possibly_append( rng_key=rng_key, params=params, observable=observable, data=data, n_simulations=n_simulations, n_chains=n_chains, n_samples=n_samples, n_warmup=n_warmup, **kwargs, )
[docs] def simulate_data( self, rng_key, params=None, observable=None, data=None, n_simulations=1_000, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs, ): return super().simulate_data( rng_key=rng_key, params=params, observable=observable, data=data, n_simulations=n_simulations, n_chains=n_chains, n_samples=n_samples, n_warmup=n_warmup, **kwargs, )
[docs] def sample_posterior( self, rng_key, params, observable, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs, ): r"""Sample from the approximate posterior. Args: rng_key: a jax random key params: a pytree of neural network parameters observable: observation to condition on n_chains: number of MCMC chains n_samples: number of samples per chain n_warmup: number of samples to discard Keyword Args: sampler (str): either 'nuts', 'slice' or None (defaults to nuts) n_thin (int): number of thinning steps (only used if sampler='slice') n_doubling (int): number of doubling steps of the interval (only used if sampler='slice') step_size (float): step size of the initial interval (only used if sampler='slice') Returns: an array of samples from the posterior distribution of dimension (n_samples \times p) and posterior diagnostics """ observable = jnp.atleast_2d(observable) return self._sample_posterior( rng_key, params, observable, n_chains=n_chains, n_samples=n_samples, n_warmup=n_warmup, **kwargs, )
def _sample_posterior( self, rng_key, params, observable, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs, ): part = partial( self.model.apply, params=params, rng=None, method="log_prob", y=observable, ) def _log_likelihood_fn(theta): theta, _ = ravel_pytree(theta) theta = jnp.tile(theta, [observable.shape[0], 1]) return part(x=theta) def _prop_posterior_density(theta): lp_prior = self.prior.log_prob(theta) lp = _log_likelihood_fn(theta) return jnp.sum(lp) + jnp.sum(lp_prior) sampler = kwargs.pop("sampler", "nuts") sampling_fn = getattr(mcmc, "sample_with_" + sampler) samples = sampling_fn( rng_key=rng_key, lp=_prop_posterior_density, prior=self.prior, n_chains=n_chains, n_samples=n_samples, n_warmup=n_warmup, **kwargs, ) for v in samples.values(): chex.assert_shape(v, [n_chains, n_samples - n_warmup, None]) inference_data = as_inference_data(samples, jnp.squeeze(observable)) diagnostics = mcmc_diagnostics(inference_data) return inference_data, diagnostics def _simulate_parameters_with_model( self, rng_key, params, observable, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs, ): return self.sample_posterior( rng_key=rng_key, params=params, observable=observable, n_chains=n_chains, n_samples=n_samples, n_warmup=n_warmup, **kwargs, ) @staticmethod def plot(inference_data: xarray.DataTree): arviz.plot_trace(inference_data)