sbijax.mcmc#
sbijax.mcmc contains functionality to draw posterior samples using MCMC.
|
Draw samples using the independent Metropolis-Hastings sampler. |
|
Sample from a distribution using the MALA sampler. |
|
Sample from a distribution using the No-U-Turn sampler. |
|
Sample from a distribution using Rosenbluth-Metropolis-Hastings sampler. |
|
Sample from a distribution using a slice sampler. |
- sbijax.mcmc.sample_with_imh(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Draw samples using the independent Metropolis-Hastings sampler.
- Parameters:
rng_key – a jax random key
lp – the logdensity you wish to sample from
prior – a function that returns a prior sample
n_chains – number of chains to sample
n_samples – number of samples per chain
n_warmup – number of samples to discard
Examples
>>> import functools as ft >>> from jax import numpy as jnp, random as jr >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(2), 1.0)) ... ) >>> def log_prob(theta, y): ... lp_prior = prior.log_prob(theta) ... lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y) ... return jnp.sum(lp_data) + jnp.sum(lp_prior) ... >>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0])) >>> samples = sample_with_imh(jr.PRNGKey(0), prop_posterior_lp, prior)
- Returns:
a JAX pytree with keys corresponding to the variables names and tensor values of dimension n_chains x n_samples x dim_variable
- sbijax.mcmc.sample_with_mala(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Sample from a distribution using the MALA sampler.
- Parameters:
rng_key – a jax random key
lp – the logdensity you wish to sample from
prior – a function that returns a prior sample
n_chains – number of chains to sample
n_samples – number of samples per chain
n_warmup – number of samples to discard
Examples
>>> import functools as ft >>> from jax import numpy as jnp, random as jr >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(2), 1.0)) ... ) >>> def log_prob(theta, y): ... lp_prior = prior.log_prob(theta) ... lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y) ... return jnp.sum(lp_data) + jnp.sum(lp_prior) ... >>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0])) >>> samples = sample_with_mala(jr.PRNGKey(0), prop_posterior_lp, prior)
- Returns:
a JAX pytree with keys corresponding to the variables names and tensor values of dimension n_chains x n_samples x dim_variable
- sbijax.mcmc.sample_with_nuts(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Sample from a distribution using the No-U-Turn sampler.
- Parameters:
rng_key – a jax random key
lp – the logdensity you wish to sample from
prior – a function that returns a prior sample
n_chains – number of chains to sample
n_samples – number of samples per chain
n_warmup – number of samples to discard
Examples
>>> import functools as ft >>> from jax import numpy as jnp, random as jr >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(2), 1.0)) ... ) >>> def log_prob(theta, y): ... lp_prior = prior.log_prob(theta) ... lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y) ... return jnp.sum(lp_data) + jnp.sum(lp_prior) ... >>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0])) >>> samples = sample_with_nuts(jr.PRNGKey(0), prop_posterior_lp, prior)
- Returns:
a JAX pytree with keys corresponding to the variables names and tensor values of dimension n_chains x n_samples x dim_variable
- sbijax.mcmc.sample_with_rmh(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, **kwargs)[source]#
Sample from a distribution using Rosenbluth-Metropolis-Hastings sampler.
- Parameters:
rng_key – a jax random key
lp – the logdensity you wish to sample from
prior – a function that returns a prior sample
n_chains – number of chains to sample
n_samples – number of samples per chain
n_warmup – number of samples to discard
Examples
>>> import functools as ft >>> from jax import numpy as jnp, random as jr >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(2), 1.0)) ... ) >>> def log_prob(theta, y): ... lp_prior = prior.log_prob(theta) ... lp_data = tfd.Normal(theta["theta"], 1.0).log_prob(y) ... return jnp.sum(lp_data) + jnp.sum(lp_prior) ... >>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0])) >>> samples = sample_with_rmh(jr.PRNGKey(0), prop_posterior_lp, prior)
- Returns:
a JAX pytree with keys corresponding to the variables names and tensor values of dimension n_chains x n_samples x dim_variable
- sbijax.mcmc.sample_with_slice(rng_key, lp, prior, *, n_chains=4, n_samples=2000, n_warmup=1000, n_thin=2, n_doubling=5, step_size=1, **kwargs)[source]#
Sample from a distribution using a slice sampler.
- Parameters:
rng_key – a jax random key
lp – the logdensity you wish to sample from
prior – a function that returns a prior sample
n_chains – number of chains to sample
n_samples – number of samples per chain
n_warmup – number of samples to discard
n_thin – integer specifying how many samples to discard between draws
n_doubling – maximum number of doubling steps
step_size – floating number specifying the size of each step
Examples
>>> import functools as ft >>> from jax import numpy as jnp, random as jr >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = tfd.JointDistributionNamed( ... dict( ... mean=tfd.Normal(jnp.zeros(2), 1.0), ... std=tfd.HalfNormal(1.0) ... ), ... batch_ndims=0, ... ) >>> def log_prob(theta, y): ... lp_prior = prior.log_prob(theta) ... lp_data = tfd.Normal(theta["mean"], theta["std"]).log_prob(y) ... return jnp.sum(lp_data) + jnp.sum(lp_prior) ... >>> prop_posterior_lp = ft.partial(log_prob, y=jnp.array([-1.0, 1.0])) >>> samples = sample_with_slice(jr.PRNGKey(0), prop_posterior_lp, prior)
- Returns:
a JAX pytree with keys corresponding to the variables names and tensor values of dimension n_chains x n_samples x dim_variable