import jax
import numpy as np
import optax
from absl import logging
from jax import numpy as jnp
from jax import random as jr
from jax._src.flatten_util import ravel_pytree
from tqdm import tqdm
from sbijax._src._ne_base import NE
from sbijax._src.util.data import as_inference_data
from sbijax._src.util.early_stopping import EarlyStopping
from sbijax._src.util.types import PyTree
# ruff: noqa: PLR0913, E501
[docs]
class FMPE(NE):
r"""Flow matching posterior estimation.
Implements the FMPE algorithm introduced in :cite:t:`wilderberger2023flow`.
Args:
model_fns: a tuple of callables. The first element needs to be a
function that constructs a tfd.JointDistributionNamed, the second
element is a simulator function.
density_estimator: a continuous normalizing flow model
Examples:
>>> from sbijax import FMPE
>>> from sbijax.nn import make_cnf
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
... dict(theta=tfd.Normal(0.0, 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed)
>>> fns = prior, s
>>> neural_network = make_cnf(1)
>>> model = FMPE(fns, neural_network)
References:
Wildberger, Jonas, et al. "Flow Matching for Scalable Simulation-Based Inference." Advances in Neural Information Processing Systems, 2024.
"""
def __init__(self, model_fns, density_estimator):
super().__init__(model_fns, density_estimator)
[docs]
def fit(
self,
rng_key: jr.PRNGKey,
data: PyTree,
*,
optimizer: optax.GradientTransformation = optax.adam(0.0003),
n_iter: int = 1000,
batch_size: int = 100,
percentage_data_as_validation_set: float = 0.1,
n_early_stopping_patience: int = 10,
n_early_stopping_delta: float = 0.001,
**kwargs,
):
"""Fit the model.
Args:
rng_key: a jax random key
data: data set obtained from calling
`simulate_data_and_possibly_append`
optimizer: an optax optimizer object
n_iter: maximal number of training iterations per round
batch_size: batch size used for training the model
percentage_data_as_validation_set: percentage of the simulated
data that is used for validation and early stopping
n_early_stopping_patience: number of iterations of no improvement
of training the flow before stopping optimisation
**kwargs: optional keyword arguments
Returns:
a tuple of parameters and a tuple of the training information
"""
itr_key, rng_key = jr.split(rng_key)
train_iter, val_iter = self.as_iterators(
itr_key, data, batch_size, percentage_data_as_validation_set
)
params, losses = self._fit_model_single_round(
seed=rng_key,
train_iter=train_iter,
val_iter=val_iter,
optimizer=optimizer,
n_iter=n_iter,
n_early_stopping_patience=n_early_stopping_patience,
n_early_stopping_delta=n_early_stopping_delta,
)
return params, losses
def _fit_model_single_round(
self,
seed,
train_iter,
val_iter,
optimizer,
n_iter,
n_early_stopping_patience,
n_early_stopping_delta,
):
init_key, seed = jr.split(seed)
params = self._init_params(init_key, **next(iter(train_iter)))
state = optimizer.init(params)
@jax.jit
def step(params, rng, state, **batch):
def loss_fn(params, rng, **batch):
lp = self.model.apply(
params,
rng=rng,
method="loss",
inputs=batch["theta"],
context=batch["y"],
is_training=True,
)
return jnp.mean(lp)
loss, grads = jax.value_and_grad(loss_fn)(params, rng, **batch)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return loss, new_params, new_state
losses = np.zeros([n_iter, 2])
early_stop = EarlyStopping(
n_early_stopping_delta, n_early_stopping_patience
)
best_params, best_loss = None, np.inf
logging.info("training model")
for i in tqdm(range(n_iter)):
train_loss = 0.0
rng_key = jr.fold_in(seed, i)
for batch in train_iter:
train_key, rng_key = jr.split(rng_key)
batch_loss, params, state = step(params, train_key, state, **batch)
train_loss += batch_loss * (
batch["y"].shape[0] / train_iter.num_samples
)
val_key, rng_key = jr.split(rng_key)
validation_loss = self._validation_loss(val_key, params, val_iter)
losses[i] = jnp.array([train_loss, validation_loss])
_, early_stop = early_stop.update(validation_loss)
if early_stop.should_stop:
logging.info("early stopping criterion found")
break
if validation_loss < best_loss:
best_loss = validation_loss
best_params = params.copy()
losses = jnp.vstack(losses)[: (i + 1), :]
return best_params, losses
def _init_params(self, rng_key, **init_data):
params = self.model.init(
rng_key,
method="loss",
inputs=init_data["theta"],
context=init_data["y"],
is_training=False,
)
return params
def _validation_loss(self, rng_key, params, val_iter):
def loss_fn(params, rng, **batch):
lp = self.model.apply(
params,
rng=rng,
method="loss",
inputs=batch["theta"],
context=batch["y"],
is_training=False,
)
return jnp.mean(lp)
def body_fn(batch_key, **batch):
loss = loss_fn(params, batch_key, **batch)
return loss * (batch["y"].shape[0] / val_iter.num_samples)
loss = 0.0
for batch in val_iter:
val_key, rng_key = jr.split(rng_key)
loss += body_fn(val_key, **batch)
return loss
# ruff: noqa: D417
[docs]
def sample_posterior(
self, rng_key, params, observable, *, n_samples=4_000, **kwargs
):
r"""Sample from the approximate posterior.
Args:
rng_key: a jax random key
params: a pytree of neural network parameters
observable: observation to condition on
n_samples: number of samples to draw
Returns:
returns an array of samples from the posterior distribution of
dimension (n_samples \times p)
"""
observable = jnp.atleast_2d(observable)
thetas = None
n_curr = n_samples
n_total_simulations_round = 0
_, unravel_fn = ravel_pytree(self.prior.sample(seed=jr.PRNGKey(1)))
while n_curr > 0:
n_sim = jnp.minimum(1024, jnp.maximum(1024, n_curr))
n_total_simulations_round += n_sim
sample_key, rng_key = jr.split(rng_key)
proposal = self.model.apply(
params,
sample_key,
method="sample",
context=jnp.tile(observable, [n_sim, 1]),
is_training=False,
)
proposal_probs = self.prior.log_prob(jax.vmap(unravel_fn)(proposal))
proposal_accepted = proposal[jnp.isfinite(proposal_probs)]
if thetas is None:
thetas = proposal_accepted
else:
thetas = jnp.vstack([thetas, proposal_accepted])
n_curr -= proposal_accepted.shape[0]
ess = float(thetas.shape[0] / n_total_simulations_round)
def reshape(p):
if p.ndim == 1:
p = p.reshape(p.shape[0], 1)
p = p.reshape(1, *p.shape)
return p
thetas = jax.tree_util.tree_map(
reshape, jax.vmap(unravel_fn)(thetas[:n_samples])
)
inference_data = as_inference_data(thetas, jnp.squeeze(observable))
return inference_data, ess
def _simulate_parameters_with_model(
self, rng_key, params, observable, *, n_samples=4_000, **kwargs
):
return self.sample_posterior(
rng_key, params, observable, n_samples=n_samples, **kwargs
)