Source code for sbijax._src.simulate.simulate

"""Draw parameters and observations for simulation-based inference."""

import jax
from jax import numpy as jnp
from jax import random as jr


[docs] def simulate(rng_key, prior, simulator, *, proposal=None, n): """Draw a dataset of parameters and observations. Parameters are drawn from ``proposal`` if given, otherwise from ``prior``, and observations are produced by running ``simulator`` on them. Args: rng_key: a jax random key prior: a distribution to draw parameters from when ``proposal`` is None simulator: a callable ``(rng_key, theta) -> y`` proposal: an optional callable ``(rng_key, n) -> theta`` used to draw parameters instead of the prior (e.g. a fitted posterior in a sequential round) n: the number of parameter/observation pairs to draw Returns: a dictionary with keys ``y`` and ``theta`` """ theta_key, sim_key = jr.split(rng_key) if proposal is None: theta = prior.sample(seed=theta_key, sample_shape=(n,)) else: theta = proposal(theta_key, n) y = simulator(sim_key, theta) return {"y": y, "theta": theta}
[docs] def stack(data, new_data): """Append two datasets along the sample axis. Args: data: a dataset as returned by :func:`simulate` new_data: a second dataset with the same structure Returns: a dataset whose leaves are the concatenation of both inputs """ return jax.tree_util.tree_map( lambda a, b: jnp.concatenate([a, b], axis=0), data, new_data )