import jax
from jax import numpy as jnp
from jax import random as jr
from jax import scipy as jsp
from tensorflow_probability.substrates.jax import distributions as tfd
# ruff: noqa: PLR0913, E501
[docs]
def slcp():
"""Simple likelihood complex posterior model.
Constructs prior, simulator, and likelihood functions.
Returns:
returns a tuple of three objects. The first is a
tfd.JointDistributionNamed serving as a prior distribution. The second
is a simulator function that can be used to generate data. The third
is the likelihood function.
References:
Papamakarios, George, Sequential Neural Likelihood: Fast Likelihood-free Inference with Autoregressive Flows, 2019
"""
def prior_fn():
prior = tfd.JointDistributionNamed(
dict(
theta=tfd.Independent(
tfd.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0)),
reinterpreted_batch_ndims=1,
)
)
)
return prior
def simulator(seed, theta):
theta = theta["theta"].reshape(-1, 5)
us_key, noise_key = jr.split(seed)
def _unpack_params(ps):
m0 = ps[..., [0]]
m1 = ps[..., [1]]
s0 = ps[..., [2]] ** 2
s1 = ps[..., [3]] ** 2
r = jnp.tanh(ps[..., [4]])
return m0, m1, s0, s1, r
m0, m1, s0, s1, r = _unpack_params(theta)
us = tfd.Normal(jnp.array(0.0), jnp.array(1.0)).sample(
seed=us_key,
sample_shape=(theta.shape[0], 4, 2),
)
xs = jnp.empty_like(us)
xs = xs.at[..., 0].set(s0 * us[..., 0] + m0)
y = xs.at[..., 1].set(
s1 * (r * us[..., 0] + jnp.sqrt(10 - r**2) * us[..., 1]) + m1
)
y = y.reshape((*theta.shape[:1], 8))
return y
def likelihood(y, theta):
def fn(y, theta):
mu = jnp.tile(theta[:2], 4)
s1, s2 = theta[2] ** 2, theta[3] ** 2
corr = s1 * s2 * jnp.tanh(theta[4])
cov = jnp.array([[s1**2, corr], [corr, s2**2]])
cov = jsp.linalg.block_diag(*[cov for _ in range(4)])
lik_fn = tfd.MultivariateNormalFullCovariance(mu, cov)
log_lik = lik_fn.log_prob(y)
return log_lik
theta = theta["theta"].reshape(-1, 5)
y = y.reshape(-1, 8)
log_lik = jax.vmap(fn)(y, theta)
return log_lik
return prior_fn(), simulator, likelihood