Source code for sbijax._src.simulators.hyperboloid

from functools import partial

import jax
from jax import numpy as jnp
from jax import random as jr
from tensorflow_probability.python.internal.backend.jax.gen.linear_operator_lower_triangular import (
  LinearOperatorLowerTriangular,
)
from tensorflow_probability.substrates.jax import distributions as tfd

m11 = jnp.array([-0.5, 0.0])
m12 = jnp.array([0.5, 0.0])
m21 = jnp.array([0.0, -0.5])
m22 = jnp.array([0.0, 0.5])
scale = jnp.array(0.1)
nu = 3.0


def _eudclidean(theta, m1, m2):
  diff = jnp.linalg.norm(theta - m1, ord=2) - jnp.linalg.norm(theta - m2, ord=2)
  return jnp.repeat(jnp.abs(diff), 10)


dists_1_fn = jax.vmap(partial(_eudclidean, m1=m11, m2=m12))
dists_2_fn = jax.vmap(partial(_eudclidean, m1=m21, m2=m22))


# ruff: noqa: PLR0913, E501
[docs] def hyperboloid(): """Hyperboloid model. Constructs prior, simulator, and likelihood functions. Returns: returns a tuple of three objects. The first is a tfd.JointDistributionNamed serving as a prior distribution. The second is a simulator function that can be used to generate data. The third is the likelihood function. References: Forbes, Florence, et al., Summary statistics and discrepancy measures for approximate Bayesian computation via surrogate posteriors, 2022 """ def prior_fn(): return tfd.JointDistributionNamed( dict( theta=tfd.Independent( tfd.Uniform(jnp.full(2, -2.0), jnp.full(2, 2.0)), 1 ) ) ) def simulator(seed, theta): mix_key, data_key = jr.split(seed) theta = theta["theta"].reshape(-1, 2) d1 = dists_1_fn(theta).reshape(-1, 1, 10) d2 = dists_2_fn(theta).reshape(-1, 1, 10) theta = jnp.concatenate([d1, d2], axis=1) idxs = jr.categorical(mix_key, logits=jnp.ones(2), shape=(theta.shape[0],)) idxs = idxs.reshape(-1, 1, 1) locs = jnp.take_along_axis(theta, idxs, 1).squeeze() scales = scale * jnp.eye(10) distr = tfd.MultivariateStudentTLinearOperator( df=nu, loc=locs, scale=LinearOperatorLowerTriangular(scales), ) y = distr.sample(seed=data_key) return y.reshape(-1, 10) def likelihood(y, theta): theta = theta["theta"].reshape(-1, 2) d1 = dists_1_fn(theta).reshape(-1, 10) d2 = dists_2_fn(theta).reshape(-1, 10) scales = scale * jnp.eye(10) lp1 = tfd.MultivariateStudentTLinearOperator( df=nu, loc=d1.squeeze(), scale=LinearOperatorLowerTriangular(scales), ).log_prob(y) lp2 = tfd.MultivariateStudentTLinearOperator( df=nu, loc=d2.squeeze(), scale=LinearOperatorLowerTriangular(scales), ).log_prob(y) lp = jnp.logaddexp(jnp.log(0.5) + lp1, jnp.log(0.5) + lp2) return lp return prior_fn(), simulator, likelihood