Source code for sbijax._src.simulators.mixture_model_distractors

from jax import numpy as jnp
from jax import random as jr
from tensorflow_probability.substrates.jax import distributions as tfd


# ruff: noqa: PLR0913, E501
[docs] def mixture_model_with_distractors(): """Mixture model with distractors. Constructs prior, simulator, and likelihood functions. Returns: returns a tuple of three objects. The first is a tfd.JointDistributionNamed serving as a prior distribution. The second is a simulator function that can be used to generate data. The third is the likelihood function. References: Albert, Carlo, et al., Simulated Annealing ABC with multiple summary statistics, 2025 """ alpha = 0.3 sigma = 0.3 def prior_fn(): return tfd.JointDistributionNamed( dict(theta=tfd.Uniform(jnp.array([-10.0]), jnp.array([10.0]))) ) def simulator(seed, theta): parameters = theta["theta"].reshape(-1, 1) neg_parameters = -theta["theta"].reshape(-1, 1) idxs_rng_key, seed = jr.split(seed) idxs = ( tfd.Categorical(probs=jnp.array([alpha, 1.0 - alpha])) .sample( seed=idxs_rng_key, sample_shape=( parameters.shape[0], 2, ), ) .reshape(-1, 2) ) means = jnp.concatenate((parameters, neg_parameters), axis=1) means = jnp.take_along_axis(means, idxs, axis=1) means = means.squeeze() scales = jnp.array([1.0, sigma])[idxs.squeeze()] distr = tfd.Normal(loc=means, scale=scales) y_rng_key, distrators_rng_key, seed = jr.split(seed, 3) y = distr.sample(seed=y_rng_key).reshape(-1, 2) distractor = ( tfd.Normal(0.0, 1.0) .sample(seed=distrators_rng_key, sample_shape=(y.shape[0], 8)) .reshape(-1, 8) ) y = jnp.concatenate((y, distractor), axis=1) return y def likelihood(y, theta): y = y.reshape(-1, 10)[:, :2] theta = theta["theta"].reshape(-1, 1) theta = jnp.broadcast_to(theta, y.shape) lp1 = tfd.Normal(loc=theta, scale=1.0).log_prob(y) lp2 = tfd.Normal(loc=-theta, scale=sigma).log_prob(y) lp = jnp.logaddexp(jnp.log(alpha) + lp1, jnp.log(1 - alpha) + lp2) lp = lp.sum(axis=1) return lp return prior_fn(), simulator, likelihood