sbijax.mcmc#

sbijax.mcmc contains functionality to draw posterior samples using MCMC.

sample_with_imh(rng_key, lp, prior, *[, ...])

Draw samples using the independent Metropolis-Hastings sampler.

sample_with_mala(rng_key, lp, prior, *[, ...])

Sample from a distribution using the MALA sampler.

sample_with_nuts(rng_key, lp, prior, *[, ...])

Sample from a distribution using the No-U-Turn sampler.

sample_with_rmh(rng_key, lp, prior, *[, ...])

Sample from a distribution using Rosenbluth-Metropolis-Hastings sampler.

sample_with_slice(rng_key, lp, prior, *[, ...])

Sample from a distribution using a slice sampler.

sbijax.mcmc.sample_with_imh(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Draw samples using the independent Metropolis-Hastings sampler.

Parameters:
  • rng_key – a jax random key

  • lp – the logdensity you wish to sample from

  • prior – a function that returns a prior sample

  • n_chains – number of chains to sample

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Examples

>>> import functools as ft
>>> from jax import numpy as jnp, random as jr
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(jnp.zeros(2), 1.0))
... )
>>> def log_prob(theta, y):
...     lp_prior = prior.log_prob(theta)
...     lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y)
...     return jnp.sum(lp_data) + jnp.sum(lp_prior)
...
>>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0]))
>>> samples = sample_with_imh(jr.PRNGKey(0), prop_posterior_lp, prior)
Returns:

a JAX pytree with keys corresponding to the variables names and tensor values of dimension n_chains x n_samples x dim_variable

sbijax.mcmc.sample_with_mala(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Sample from a distribution using the MALA sampler.

Parameters:
  • rng_key – a jax random key

  • lp – the logdensity you wish to sample from

  • prior – a function that returns a prior sample

  • n_chains – number of chains to sample

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Examples

>>> import functools as ft
>>> from jax import numpy as jnp, random as jr
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(jnp.zeros(2), 1.0))
... )
>>> def log_prob(theta, y):
...     lp_prior = prior.log_prob(theta)
...     lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y)
...     return jnp.sum(lp_data) + jnp.sum(lp_prior)
...
>>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0]))
>>> samples = sample_with_mala(jr.PRNGKey(0), prop_posterior_lp, prior)
Returns:

a JAX pytree with keys corresponding to the variables names and tensor values of dimension n_chains x n_samples x dim_variable

sbijax.mcmc.sample_with_nuts(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Sample from a distribution using the No-U-Turn sampler.

Parameters:
  • rng_key – a jax random key

  • lp – the logdensity you wish to sample from

  • prior – a function that returns a prior sample

  • n_chains – number of chains to sample

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Examples

>>> import functools as ft
>>> from jax import numpy as jnp, random as jr
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(jnp.zeros(2), 1.0))
... )
>>> def log_prob(theta, y):
...     lp_prior = prior.log_prob(theta)
...     lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y)
...     return jnp.sum(lp_data) + jnp.sum(lp_prior)
...
>>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0]))
>>> samples = sample_with_nuts(jr.PRNGKey(0), prop_posterior_lp, prior)
Returns:

a JAX pytree with keys corresponding to the variables names and tensor values of dimension n_chains x n_samples x dim_variable

sbijax.mcmc.sample_with_rmh(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#

Sample from a distribution using Rosenbluth-Metropolis-Hastings sampler.

Parameters:
  • rng_key – a jax random key

  • lp – the logdensity you wish to sample from

  • prior – a function that returns a prior sample

  • n_chains – number of chains to sample

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

Examples

>>> import functools as ft
>>> from jax import numpy as jnp, random as jr
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(jnp.zeros(2), 1.0))
... )
>>> def log_prob(theta, y):
...     lp_prior = prior.log_prob(theta)
...     lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y)
...     return jnp.sum(lp_data) + jnp.sum(lp_prior)
...
>>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0]))
>>> samples = sample_with_rmh(jr.PRNGKey(0), prop_posterior_lp, prior)
Returns:

a JAX pytree with keys corresponding to the variables names and tensor values of dimension n_chains x n_samples x dim_variable

sbijax.mcmc.sample_with_slice(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, n_thin=2, n_doubling=5, step_size=1, **kwargs)[source]#

Sample from a distribution using a slice sampler.

Parameters:
  • rng_key – a jax random key

  • lp – the logdensity you wish to sample from

  • prior – a function that returns a prior sample

  • n_chains – number of chains to sample

  • n_samples – number of samples per chain

  • n_warmup – number of samples to discard

  • n_thin – integer specifying how many samples to discard between draws

  • n_doubling – maximum number of doubling steps

  • step_size – floating number specifying the size of each step

Examples

>>> import functools as ft
>>> from jax import numpy as jnp, random as jr
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = tfd.JointDistributionNamed(
...     dict(
...         mean=tfd.Normal(jnp.zeros(2), 1.0),
...         std=tfd.HalfNormal(1.0)
...     ),
...     batch_ndims=0,
... )
>>> def log_prob(theta, y):
...     lp_prior = prior.log_prob(theta)
...     lp_data = tfd.Normal(theta["mean"], theta["std"]).log_prob(y)
...     return jnp.sum(lp_data) + jnp.sum(lp_prior)
...
>>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0]))
>>> samples = sample_with_slice(jr.PRNGKey(0), prop_posterior_lp, prior)
Returns:

a JAX pytree with keys corresponding to the variables names and tensor values of dimension n_chains x n_samples x dim_variable