sbijax.mcmc#

sbijax.mcmc builds the posterior samplers and exposes the low-level MCMC routines they are built on.

make_sampler() bundles a Kernel – a handle identifying a BlackJAX MCMC algorithm – with the prior and N(0, I) chain initialisation into a sampler that is passed to sbijax.sample(). The available algorithms are nuts, mala, rmh and imh:

from sbijax.mcmc import make_sampler, nuts

sampler = make_sampler(nuts, prior=prior)
samples, info = sample(key, estimator, params, y_obs, sampler=sampler)

make_sampler(kernel, *, prior, **kernel_kwargs)

Build a posterior sampler from an MCMC kernel and a prior.

imh

Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).

mala

Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).

nuts

Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).

rmh

Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).

sample_with_imh(rng_key, lp, prior, *[, ...])

Draw samples using the independent Metropolis-Hastings sampler.

sample_with_mala(rng_key, lp, prior, *[, ...])

Sample from a distribution using the MALA sampler.

sample_with_nuts(rng_key, lp, prior, *[, ...])

Sample from a distribution using the No-U-Turn sampler.

sample_with_rmh(rng_key, lp, prior, *[, ...])

Sample from a distribution using Rosenbluth-Metropolis-Hastings sampler.

sample_with_slice(rng_key, lp, prior, *[, ...])

Sample from a distribution using a slice sampler.

sbijax.mcmc.make_sampler(kernel, *, prior, **kernel_kwargs)[source]#

Build a posterior sampler from an MCMC kernel and a prior.

Parameters:
  • kernel – a Kernel handle (e.g. sbijax.mcmc.nuts)

  • prior – the prior; used for the target density and Gaussian chain init

  • **kernel_kwargs – forwarded to the kernel

Returns:

a callable (rng_key, loglik_fn, *, n_chains, n_samples, n_warmup) -> (samples, MCMCSampleInfo) where the target is loglik_fn(theta) + prior.log_prob(theta) and chains start at N(0, I).

sbijax.mcmc.sample_with_imh(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Draw samples using the independent Metropolis-Hastings sampler.

Parameters:
  • rng_key – a jax random key

  • lp – the logdensity you wish to sample from

  • prior – a function that returns a prior sample

  • n_chains – number of chains to sample

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Examples

>>> import functools as ft
>>> from jax import numpy as jnp, random as jr
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(jnp.zeros(2), 1.0))
... )
>>> def log_prob(theta, y):
...     lp_prior = prior.log_prob(theta)
...     lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y)
...     return jnp.sum(lp_data) + jnp.sum(lp_prior)
...
>>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0]))
>>> samples = sample_with_imh(jr.key(0), prop_posterior_lp, prior)
Returns:

a named pytree with leaves of shape n_chains x (n_samples - n_warmup) x dim and an MCMCSampleInfo with the mean post-warmup acceptance rate

Return type:

a tuple (samples, info)

sbijax.mcmc.sample_with_mala(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Sample from a distribution using the MALA sampler.

Parameters:
  • rng_key – a jax random key

  • lp – the logdensity you wish to sample from

  • prior – a function that returns a prior sample

  • n_chains – number of chains to sample

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Examples

>>> import functools as ft
>>> from jax import numpy as jnp, random as jr
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(jnp.zeros(2), 1.0))
... )
>>> def log_prob(theta, y):
...     lp_prior = prior.log_prob(theta)
...     lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y)
...     return jnp.sum(lp_data) + jnp.sum(lp_prior)
...
>>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0]))
>>> samples = sample_with_mala(jr.key(0), prop_posterior_lp, prior)
Returns:

a named pytree with leaves of shape n_chains x (n_samples - n_warmup) x dim and an MCMCSampleInfo with the mean post-warmup acceptance rate

Return type:

a tuple (samples, info)

sbijax.mcmc.sample_with_nuts(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Sample from a distribution using the No-U-Turn sampler.

Parameters:
  • rng_key – a jax random key

  • lp – the logdensity you wish to sample from

  • prior – a function that returns a prior sample

  • n_chains – number of chains to sample

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Examples

>>> import functools as ft
>>> from jax import numpy as jnp, random as jr
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(jnp.zeros(2), 1.0))
... )
>>> def log_prob(theta, y):
...     lp_prior = prior.log_prob(theta)
...     lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y)
...     return jnp.sum(lp_data) + jnp.sum(lp_prior)
...
>>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0]))
>>> samples = sample_with_nuts(jr.key(0), prop_posterior_lp, prior)
Returns:

a named pytree with leaves of shape n_chains x (n_samples - n_warmup) x dim and an MCMCSampleInfo with the mean post-warmup acceptance rate

Return type:

a tuple (samples, info)

sbijax.mcmc.sample_with_rmh(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Sample from a distribution using Rosenbluth-Metropolis-Hastings sampler.

Parameters:
  • rng_key – a jax random key

  • lp – the logdensity you wish to sample from

  • prior – a function that returns a prior sample

  • n_chains – number of chains to sample

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Examples

>>> import functools as ft
>>> from jax import numpy as jnp, random as jr
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(jnp.zeros(2), 1.0))
... )
>>> def log_prob(theta, y):
...     lp_prior = prior.log_prob(theta)
...     lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y)
...     return jnp.sum(lp_data) + jnp.sum(lp_prior)
...
>>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0]))
>>> samples = sample_with_rmh(jr.key(0), prop_posterior_lp, prior)
Returns:

a named pytree with leaves of shape n_chains x (n_samples - n_warmup) x dim and an MCMCSampleInfo with the mean post-warmup acceptance rate

Return type:

a tuple (samples, info)

sbijax.mcmc.sample_with_slice(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, n_thin=2, n_doubling=5, step_size=1, **kwargs)[source]#

Sample from a distribution using a slice sampler.

Parameters:
  • rng_key – a jax random key

  • lp – the logdensity you wish to sample from

  • prior – a function that returns a prior sample

  • n_chains – number of chains to sample

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

  • n_thin – integer specifying how many samples to discard between draws

  • n_doubling – maximum number of doubling steps

  • step_size – floating number specifying the size of each step

Examples

>>> import functools as ft
>>> from jax import numpy as jnp, random as jr
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = tfd.JointDistributionNamed(
...     dict(
...         mean=tfd.Normal(jnp.zeros(2), 1.0),
...         std=tfd.HalfNormal(1.0)
...     ),
...     batch_ndims=0,
... )
>>> def log_prob(theta, y):
...     lp_prior = prior.log_prob(theta)
...     lp_data = tfd.Normal(theta["mean"], theta["std"]).log_prob(y)
...     return jnp.sum(lp_data) + jnp.sum(lp_prior)
...
>>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0]))
>>> samples = sample_with_slice(jr.PRNGKey(0), prop_posterior_lp, prior)
Returns:

a named pytree with leaves of shape n_chains x (n_samples - n_warmup) x dim and an MCMCSampleInfo (acceptance rate is nan for slice sampling)

Return type:

a tuple (samples, info)

sbijax.mcmc.imh#

Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).

Wraps the algorithm’s initializer so make_sampler() can build the concrete kernel and initial chain states at sampling time. The single field init_fn has signature (rng_key, initial_positions, lp) -> (initial_states, kernel).

sbijax.mcmc.mala#

Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).

Wraps the algorithm’s initializer so make_sampler() can build the concrete kernel and initial chain states at sampling time. The single field init_fn has signature (rng_key, initial_positions, lp) -> (initial_states, kernel).

sbijax.mcmc.nuts#

Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).

Wraps the algorithm’s initializer so make_sampler() can build the concrete kernel and initial chain states at sampling time. The single field init_fn has signature (rng_key, initial_positions, lp) -> (initial_states, kernel).

sbijax.mcmc.rmh#

Identifies a BlackJAX MCMC algorithm (NUTS, MALA, RMH, IMH).

Wraps the algorithm’s initializer so make_sampler() can build the concrete kernel and initial chain states at sampling time. The single field init_fn has signature (rng_key, initial_positions, lp) -> (initial_states, kernel).