Source code for sbijax._src.snle

from sbijax._src.nle import NLE


# ruff: noqa: PLR0913, E501
[docs] class SNLE(NLE): """Surjective neural likelihood estimation. Implements the method introduced in :cite:t:`dirmeier2023simulation`. SNLE is particularly useful when dealing with high-dimensional data since it reduces its dimensionality using dimensionality reduction. Args: model_fns: a tuple of calalbles. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function. density_estimator: a (neural) conditional density estimator to model the likelihood function Examples: >>> from jax import numpy as jnp >>> from sbijax import SNLE >>> from sbijax.nn import make_maf >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(5), 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal( ... theta["theta"], 1.0).sample(seed=seed, sample_shape=(2,) ... ).reshape(-1, 10) >>> fns = prior, s >>> neural_network = make_maf(10, n_layer_dimensions=[10, 10, 5, 5, 5]) >>> model = SNLE(fns, neural_network) References: Dirmeier, Simon, et al. "Simulation-based inference using surjective sequential neural likelihood estimation." arXiv preprint arXiv:2308.01054, 2023. """