sbijax.mcmc#
sbijax.mcmc builds the posterior samplers
and exposes the low-level MCMC routines they are built on.
make_sampler() bundles a Kernel – a handle identifying a
BlackJAX MCMC algorithm – with the prior and N(0, I) chain initialisation
into a sampler that is passed to sbijax.sample(). The available algorithms
are nuts, mala, rmh and imh:
from sbijax.mcmc import make_sampler, nuts
sampler = make_sampler(nuts, prior=prior)
samples, info = sample(key, estimator, params, y_obs, sampler=sampler)
|
Build a posterior sampler from an MCMC |
Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH). |
|
Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH). |
|
Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH). |
|
Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH). |
|
|
Draw samples using the independent Metropolis-Hastings sampler. |
|
Sample from a distribution using the MALA sampler. |
|
Sample from a distribution using the No-U-Turn sampler. |
|
Sample from a distribution using Rosenbluth-Metropolis-Hastings sampler. |
|
Sample from a distribution using a slice sampler. |
- sbijax.mcmc.make_sampler(kernel, *, prior, **kernel_kwargs)[source]#
Build a posterior sampler from an MCMC
kerneland aprior.- Parameters:
kernel – a
Kernelhandle (e.g.sbijax.mcmc.nuts)prior – the prior; used for the target density and Gaussian chain init
**kernel_kwargs – forwarded to the kernel
- Returns:
a callable
(rng_key, loglik_fn, *, n_chains, n_samples, n_warmup) -> (samples, MCMCSampleInfo)where the target isloglik_fn(theta) + prior.log_prob(theta)and chains start atN(0, I).
- sbijax.mcmc.sample_with_imh(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Draw samples using the independent Metropolis-Hastings sampler.
- Parameters:
rng_key – a jax random key
lp – the logdensity you wish to sample from
prior – a function that returns a prior sample
n_chains – number of chains to sample
n_samples – number of samples per chain
n_warmup – number of samples to discard
Examples
>>> import functools as ft >>> from jax import numpy as jnp, random as jr >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(2), 1.0)) ... ) >>> def log_prob(theta, y): ... lp_prior = prior.log_prob(theta) ... lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y) ... return jnp.sum(lp_data) + jnp.sum(lp_prior) ... >>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0])) >>> samples = sample_with_imh(jr.key(0), prop_posterior_lp, prior)
- Returns:
a named pytree with leaves of shape
n_chains x (n_samples - n_warmup) x dimand anMCMCSampleInfowith the mean post-warmup acceptance rate- Return type:
a tuple
(samples, info)
- sbijax.mcmc.sample_with_mala(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Sample from a distribution using the MALA sampler.
- Parameters:
rng_key – a jax random key
lp – the logdensity you wish to sample from
prior – a function that returns a prior sample
n_chains – number of chains to sample
n_samples – number of samples per chain
n_warmup – number of samples to discard
Examples
>>> import functools as ft >>> from jax import numpy as jnp, random as jr >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(2), 1.0)) ... ) >>> def log_prob(theta, y): ... lp_prior = prior.log_prob(theta) ... lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y) ... return jnp.sum(lp_data) + jnp.sum(lp_prior) ... >>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0])) >>> samples = sample_with_mala(jr.key(0), prop_posterior_lp, prior)
- Returns:
a named pytree with leaves of shape
n_chains x (n_samples - n_warmup) x dimand anMCMCSampleInfowith the mean post-warmup acceptance rate- Return type:
a tuple
(samples, info)
- sbijax.mcmc.sample_with_nuts(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Sample from a distribution using the No-U-Turn sampler.
- Parameters:
rng_key – a jax random key
lp – the logdensity you wish to sample from
prior – a function that returns a prior sample
n_chains – number of chains to sample
n_samples – number of samples per chain
n_warmup – number of samples to discard
Examples
>>> import functools as ft >>> from jax import numpy as jnp, random as jr >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(2), 1.0)) ... ) >>> def log_prob(theta, y): ... lp_prior = prior.log_prob(theta) ... lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y) ... return jnp.sum(lp_data) + jnp.sum(lp_prior) ... >>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0])) >>> samples = sample_with_nuts(jr.key(0), prop_posterior_lp, prior)
- Returns:
a named pytree with leaves of shape
n_chains x (n_samples - n_warmup) x dimand anMCMCSampleInfowith the mean post-warmup acceptance rate- Return type:
a tuple
(samples, info)
- sbijax.mcmc.sample_with_rmh(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Sample from a distribution using Rosenbluth-Metropolis-Hastings sampler.
- Parameters:
rng_key – a jax random key
lp – the logdensity you wish to sample from
prior – a function that returns a prior sample
n_chains – number of chains to sample
n_samples – number of samples per chain
n_warmup – number of samples to discard
Examples
>>> import functools as ft >>> from jax import numpy as jnp, random as jr >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(2), 1.0)) ... ) >>> def log_prob(theta, y): ... lp_prior = prior.log_prob(theta) ... lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y) ... return jnp.sum(lp_data) + jnp.sum(lp_prior) ... >>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0])) >>> samples = sample_with_rmh(jr.key(0), prop_posterior_lp, prior)
- Returns:
a named pytree with leaves of shape
n_chains x (n_samples - n_warmup) x dimand anMCMCSampleInfowith the mean post-warmup acceptance rate- Return type:
a tuple
(samples, info)
- sbijax.mcmc.sample_with_slice(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, n_thin=2, n_doubling=5, step_size=1, **kwargs)[source]#
Sample from a distribution using a slice sampler.
- Parameters:
rng_key – a jax random key
lp – the logdensity you wish to sample from
prior – a function that returns a prior sample
n_chains – number of chains to sample
n_samples – number of samples per chain
n_warmup – number of samples to discard
n_thin – integer specifying how many samples to discard between draws
n_doubling – maximum number of doubling steps
step_size – floating number specifying the size of each step
Examples
>>> import functools as ft >>> from jax import numpy as jnp, random as jr >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = tfd.JointDistributionNamed( ... dict( ... mean=tfd.Normal(jnp.zeros(2), 1.0), ... std=tfd.HalfNormal(1.0) ... ), ... batch_ndims=0, ... ) >>> def log_prob(theta, y): ... lp_prior = prior.log_prob(theta) ... lp_data = tfd.Normal(theta["mean"], theta["std"]).log_prob(y) ... return jnp.sum(lp_data) + jnp.sum(lp_prior) ... >>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0])) >>> samples = sample_with_slice(jr.PRNGKey(0), prop_posterior_lp, prior)
- Returns:
a named pytree with leaves of shape
n_chains x (n_samples - n_warmup) x dimand anMCMCSampleInfo(acceptance rate isnanfor slice sampling)- Return type:
a tuple
(samples, info)
- sbijax.mcmc.imh#
Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).
Wraps the algorithm’s initializer so
make_sampler()can build the concrete kernel and initial chain states at sampling time. The single fieldinit_fnhas signature(rng_key, initial_positions, lp) -> (initial_states, kernel).
- sbijax.mcmc.mala#
Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).
Wraps the algorithm’s initializer so
make_sampler()can build the concrete kernel and initial chain states at sampling time. The single fieldinit_fnhas signature(rng_key, initial_positions, lp) -> (initial_states, kernel).
- sbijax.mcmc.nuts#
Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).
Wraps the algorithm’s initializer so
make_sampler()can build the concrete kernel and initial chain states at sampling time. The single fieldinit_fnhas signature(rng_key, initial_positions, lp) -> (initial_states, kernel).
- sbijax.mcmc.rmh#
Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).
Wraps the algorithm’s initializer so
make_sampler()can build the concrete kernel and initial chain states at sampling time. The single fieldinit_fnhas signature(rng_key, initial_positions, lp) -> (initial_states, kernel).