Source code for sbijax._src.train.sample

"""Free sampling driver, symmetric with ``fit`` (design B)."""


[docs] def sample(rng_key, objective, params, observable, *, sampler=None, **kwargs): """Draw posterior samples from a trained objective. A one-line dispatch to ``objective.sample_fn`` kept for symmetry with :func:`sbijax.train`. Args: rng_key: a jax random key objective: an ``ObjectiveFns`` params: the trained parameters observable: the observation to condition on sampler: a sampler from :func:`~sbijax.mcmc.make_sampler` (required for MCMC methods, ignored by amortized methods) **kwargs: forwarded to ``sample_fn`` Returns: ``(samples, info)`` """ return objective.sample_fn( rng_key, params, observable, sampler=sampler, **kwargs )