sbijax.experimental#

sbijax.experimental contains experimental code that might get ported to the main code base or possibly deleted again.

AiO(model_fns,Β density_estimator)

All-in-one simulation-based inference.

NPSE(model_fns,Β score_estimator)

Neural posterior score estimation.

class sbijax.experimental.AiO(model_fns, density_estimator)[source]#

All-in-one simulation-based inference.

Implements all-on-one posterior estimation as introduced Gloeckler et al. [2024]. In comparison to the original paper, this implementation (so far) only infers the posterior distribution of all latent variables, so no marginals or other conditional distributions. As a consequence, when training the model, we use the same mask for all latent/conditioning variables, and don’t sample it every step. Hence, this implementation is basically the same as NPSE only that we use a transformer as score network and a mask to encode the conditional dependencies.

Parameters:
  • model_fns – a tuple of callables. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.

  • score_estimator – a score estimator

Examples

>>> from sbijax.experimental import AiO
>>> from sbijax.experimental.nn import make_simformer_based_score_model
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(jnp.zeros(2), 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed)
>>> fns = prior, s
>>> neural_network = make_simformer_based_score_model(2, jnp.eye(4))
>>> model = AiO(fns, neural_network)

References

Gloeckler, Manuel, et al. β€œAll-in-one simulation-based inference.” International Conference on Machine Learning, 2024.

fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, n_early_stopping_delta=0.001, **kwargs)#

Fit the model.

Parameters:
  • rng_key (PRNGKey) – a jax random key

  • data (Any) – data set obtained from calling simulate_data_and_possibly_append

  • optimizer (GradientTransformation) – an optax optimizer object

  • n_iter (int) – maximal number of training iterations per round

  • batch_size (int) – batch size used for training the model

  • percentage_data_as_validation_set (float) – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience (int) – number of iterations of no improvement of training the flow before stopping optimisation

  • **kwargs – optional keyword arguments

  • n_early_stopping_delta (float)

Returns:

a tuple of parameters and a tuple of the training information

sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_samples – number of samples to draw

Returns:

returns an array of samples from the posterior distribution of dimension (n_samples times p)

simulate_data(rng_key, *, params=None, observable=None, n_simulations=1000, **kwargs)#

Simulate data from the posterior or prior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using β€˜observable;

  • observable – an observation. Needs to be gfiven if posterior draws are desired

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

a NamedTuple of two axis, y and theta

simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#

Simulate data and paarameters from the prior or posterior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

returns a NamedTuple of two axis, y and theta

class sbijax.experimental.NPSE(model_fns, score_estimator)[source]#

Neural posterior score estimation.

Implements (truncated sequential) neural posterior score estimation as introduced in Sharrock et al. [2024].

Parameters:
  • model_fns – a tuple of callables. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.

  • score_estimator – a score_estimator estimator

Examples

>>> from sbijax.experimental import NPSE
>>> from sbijax.experimental.nn import make_score_model
>>> from tensorflow_probability.substrates.jax import distributions as tfd
...
>>> prior = lambda: tfd.JointDistributionNamed(
...    dict(theta=tfd.Normal(0.0, 1.0))
... )
>>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed)
>>> fns = prior, s
>>> neural_network = make_score_model(1)
>>> model = NPSE(fns, neural_network)

References

Sharrock, Louis, et al. β€œSequential neural score estimation: likelihood-free inference with conditional score based diffusion models.” International Conference on Machine Learning, 2025.

fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, n_early_stopping_delta=0.001, **kwargs)#

Fit the model.

Parameters:
  • rng_key (PRNGKey) – a jax random key

  • data (Any) – data set obtained from calling simulate_data_and_possibly_append

  • optimizer (GradientTransformation) – an optax optimizer object

  • n_iter (int) – maximal number of training iterations per round

  • batch_size (int) – batch size used for training the model

  • percentage_data_as_validation_set (float) – percentage of the simulated data that is used for validation and early stopping

  • n_early_stopping_patience (int) – number of iterations of no improvement of training the flow before stopping optimisation

  • **kwargs – optional keyword arguments

  • n_early_stopping_delta (float)

Returns:

a tuple of parameters and a tuple of the training information

sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)#

Sample from the approximate posterior.

Parameters:
  • rng_key – a jax random key

  • params – a pytree of neural network parameters

  • observable – observation to condition on

  • n_samples – number of samples to draw

Returns:

returns an array of samples from the posterior distribution of dimension (n_samples times p)

simulate_data(rng_key, *, params=None, observable=None, n_simulations=1000, **kwargs)#

Simulate data from the posterior or prior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using β€˜observable;

  • observable – an observation. Needs to be gfiven if posterior draws are desired

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

a NamedTuple of two axis, y and theta

simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#

Simulate data and paarameters from the prior or posterior and append.

Parameters:
  • rng_key – a random key

  • params – a dictionary of neural network parameters

  • observable – an observation

  • data – existing data set

  • n_simulations – number of newly simulated data

  • kwargs – dictionary of ey value pairs passed to sample_posterior

Returns:

returns a NamedTuple of two axis, y and theta

make_score_model(n_dimension[,Β ...])

Create a score model for NPSE.

make_simformer_based_score_model(...[,Β ...])

Create a score network for AiO.

sbijax.experimental.nn.make_simformer_based_score_model(n_dimension, mask, n_heads=4, n_layers=4, head_size=None, embedding_dim_values=32, embedding_dim_ids=32, embedding_dim_conditioning=8, time_embedding_layers=(128, 128), dropout_rate=0.1, activation=<function gelu>, sde='vp', beta_min=0.1, beta_max=10.0, time_eps=0.001, time_max=1)[source]#

Create a score network for AiO.

The score model uses a transformer as a score estimator.

Parameters:
  • n_dimension (int) – dimensionality of modelled space

  • mask (Array) – a binary matrix of conditional dependencies

  • n_heads (int) – number of attention heads

  • n_layers (int) – number of attention layers

  • head_size (int | None) – size of an attention head

  • embedding_dim_values (int) – dimensionality of the embedding for the values

  • embedding_dim_ids (int) – dimensionality of the embedding for the ids of the variables

  • embedding_dim_conditioning (int) – dimensionality of the binary conditioning labels

  • time_embedding_layers (tuple[int, ...]) – a tuple if ints determining the output sizes of the data embedding network

  • dropout_rate (float) – a tuple if ints determining the output sizes of the data embedding network

  • activation (Callable) – activation function to be used for

  • sde (str) – can be either of β€˜vp’ and β€˜ve’. Defines the type of SDE to be used as a forward process. See the original publication and references therein for details.

  • beta_min (float) – beta min. Again, see the paper please.

  • beta_max (float) – beta max. Again, see the paper please.

  • time_eps (float) – some small number to use as minimum time point for the forward process. Used for numerical stability.

  • time_max (float) – maximum integration time. 1 is good, but so is 5 or 10.

Returns:

returns a score model that can be used for posterior inference using AiO.

References

Gloeckler, Manuel, et al. β€œAll-in-one simulation-based inference.” International Conference on Machine Learning, 2024.

sbijax.experimental.nn.make_score_model(n_dimension, hidden_sizes=(128, 128), data_embedding_layers=(128, 128), param_embedding_layers=(128, 128), time_embedding_layers=(128, 128), activation=<jax._src.custom_derivatives.custom_jvp object>, sde='vp', beta_min=0.1, beta_max=10.0, time_eps=0.001, time_max=1)[source]#

Create a score model for NPSE.

The score model uses MLPs to embed the data, the parameters and the time points (after projecting them with a sinusoidal embedding). The score net itself is also an MLP.

Parameters:
  • n_dimension (int) – dimensionality of modelled space

  • hidden_sizes (tuple[int, ...]) – tuple of ints determining the layers of the score network

  • data_embedding_layers (tuple[int, ...]) – a tuple if ints determining the output sizes of the data embedding network

  • param_embedding_layers (tuple[int, ...]) – a tuple if ints determining the output sizes of the data embedding network

  • time_embedding_layers (tuple[int, ...]) – a tuple if ints determining the output sizes of the data embedding network

  • activation (Callable) – a jax activation function

  • sde – can be either of β€˜vp’ and β€˜ve’. Defines the type of SDE to be used as a forward process. See the original publication and references therein for details.

  • beta_min – beta min. Again, see the paper please.

  • beta_max – beta max. Again, see the paper please.

  • time_eps – some small number to use as minimum time point for the forward process. Used for numerical stability.

  • time_max – maximum integration time. 1 is good, but so is 5 or 10.

Returns:

returns a score model that can be used for inference using NPSE.