sbijax.experimental#
sbijax.experimental contains experimental code that might get ported to the
main code base or possibly deleted again.
|
All-in-one simulation-based inference. |
|
Neural posterior score estimation. |
- class sbijax.experimental.AiO(model_fns, density_estimator)[source]#
All-in-one simulation-based inference.
Implements all-on-one posterior estimation as introduced Gloeckler et al. [2024]. In comparison to the original paper, this implementation (so far) only infers the posterior distribution of all latent variables, so no marginals or other conditional distributions. As a consequence, when training the model, we use the same mask for all latent/conditioning variables, and donβt sample it every step. Hence, this implementation is basically the same as NPSE only that we use a transformer as score network and a mask to encode the conditional dependencies.
- Parameters:
model_fns β a tuple of callables. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.
score_estimator β a score estimator
Examples
>>> from sbijax.experimental import AiO >>> from sbijax.experimental.nn import make_simformer_based_score_model >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(jnp.zeros(2), 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed) >>> fns = prior, s >>> neural_network = make_simformer_based_score_model(2, jnp.eye(4)) >>> model = AiO(fns, neural_network)
References
Gloeckler, Manuel, et al. βAll-in-one simulation-based inference.β International Conference on Machine Learning, 2024.
- fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, n_early_stopping_delta=0.001, **kwargs)#
Fit the model.
- Parameters:
rng_key (PRNGKey) β a jax random key
data (Any) β data set obtained from calling simulate_data_and_possibly_append
optimizer (GradientTransformation) β an optax optimizer object
n_iter (int) β maximal number of training iterations per round
batch_size (int) β batch size used for training the model
percentage_data_as_validation_set (float) β percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience (int) β number of iterations of no improvement of training the flow before stopping optimisation
**kwargs β optional keyword arguments
n_early_stopping_delta (float)
- Returns:
a tuple of parameters and a tuple of the training information
- sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)#
Sample from the approximate posterior.
- Parameters:
rng_key β a jax random key
params β a pytree of neural network parameters
observable β observation to condition on
n_samples β number of samples to draw
- Returns:
returns an array of samples from the posterior distribution of dimension (n_samples times p)
- simulate_data(rng_key, *, params=None, observable=None, n_simulations=1000, **kwargs)#
Simulate data from the posterior or prior and append.
- Parameters:
rng_key β a random key
params β a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using βobservable;
observable β an observation. Needs to be gfiven if posterior draws are desired
n_simulations β number of newly simulated data
kwargs β dictionary of ey value pairs passed to sample_posterior
- Returns:
a NamedTuple of two axis, y and theta
- simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#
Simulate data and paarameters from the prior or posterior and append.
- Parameters:
rng_key β a random key
params β a dictionary of neural network parameters
observable β an observation
data β existing data set
n_simulations β number of newly simulated data
kwargs β dictionary of ey value pairs passed to sample_posterior
- Returns:
returns a NamedTuple of two axis, y and theta
- class sbijax.experimental.NPSE(model_fns, score_estimator)[source]#
Neural posterior score estimation.
Implements (truncated sequential) neural posterior score estimation as introduced in Sharrock et al. [2024].
- Parameters:
model_fns β a tuple of callables. The first element needs to be a function that constructs a tfd.JointDistributionNamed, the second element is a simulator function.
score_estimator β a score_estimator estimator
Examples
>>> from sbijax.experimental import NPSE >>> from sbijax.experimental.nn import make_score_model >>> from tensorflow_probability.substrates.jax import distributions as tfd ... >>> prior = lambda: tfd.JointDistributionNamed( ... dict(theta=tfd.Normal(0.0, 1.0)) ... ) >>> s = lambda seed, theta: tfd.Normal(theta["theta"], 1.0).sample(seed=seed) >>> fns = prior, s >>> neural_network = make_score_model(1) >>> model = NPSE(fns, neural_network)
References
Sharrock, Louis, et al. βSequential neural score estimation: likelihood-free inference with conditional score based diffusion models.β International Conference on Machine Learning, 2025.
- fit(rng_key, data, *, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, n_early_stopping_delta=0.001, **kwargs)#
Fit the model.
- Parameters:
rng_key (PRNGKey) β a jax random key
data (Any) β data set obtained from calling simulate_data_and_possibly_append
optimizer (GradientTransformation) β an optax optimizer object
n_iter (int) β maximal number of training iterations per round
batch_size (int) β batch size used for training the model
percentage_data_as_validation_set (float) β percentage of the simulated data that is used for validation and early stopping
n_early_stopping_patience (int) β number of iterations of no improvement of training the flow before stopping optimisation
**kwargs β optional keyword arguments
n_early_stopping_delta (float)
- Returns:
a tuple of parameters and a tuple of the training information
- sample_posterior(rng_key, params, observable, *, n_samples=4000, **kwargs)#
Sample from the approximate posterior.
- Parameters:
rng_key β a jax random key
params β a pytree of neural network parameters
observable β observation to condition on
n_samples β number of samples to draw
- Returns:
returns an array of samples from the posterior distribution of dimension (n_samples times p)
- simulate_data(rng_key, *, params=None, observable=None, n_simulations=1000, **kwargs)#
Simulate data from the posterior or prior and append.
- Parameters:
rng_key β a random key
params β a dictionary of neural network parameters. If None, will draw from prior. If parameters given, will draw from amortized posterior using βobservable;
observable β an observation. Needs to be gfiven if posterior draws are desired
n_simulations β number of newly simulated data
kwargs β dictionary of ey value pairs passed to sample_posterior
- Returns:
a NamedTuple of two axis, y and theta
- simulate_data_and_possibly_append(rng_key, params, observable, data=None, n_simulations=1000, **kwargs)#
Simulate data and paarameters from the prior or posterior and append.
- Parameters:
rng_key β a random key
params β a dictionary of neural network parameters
observable β an observation
data β existing data set
n_simulations β number of newly simulated data
kwargs β dictionary of ey value pairs passed to sample_posterior
- Returns:
returns a NamedTuple of two axis, y and theta
|
Create a score model for NPSE. |
|
Create a score network for AiO. |
- sbijax.experimental.nn.make_simformer_based_score_model(n_dimension, mask, n_heads=4, n_layers=4, head_size=None, embedding_dim_values=32, embedding_dim_ids=32, embedding_dim_conditioning=8, time_embedding_layers=(128, 128), dropout_rate=0.1, activation=<function gelu>, sde='vp', beta_min=0.1, beta_max=10.0, time_eps=0.001, time_max=1)[source]#
Create a score network for AiO.
The score model uses a transformer as a score estimator.
- Parameters:
n_dimension (int) β dimensionality of modelled space
mask (Array) β a binary matrix of conditional dependencies
n_heads (int) β number of attention heads
n_layers (int) β number of attention layers
head_size (int | None) β size of an attention head
embedding_dim_values (int) β dimensionality of the embedding for the values
embedding_dim_ids (int) β dimensionality of the embedding for the ids of the variables
embedding_dim_conditioning (int) β dimensionality of the binary conditioning labels
time_embedding_layers (tuple[int, ...]) β a tuple if ints determining the output sizes of the data embedding network
dropout_rate (float) β a tuple if ints determining the output sizes of the data embedding network
activation (Callable) β activation function to be used for
sde (str) β can be either of βvpβ and βveβ. Defines the type of SDE to be used as a forward process. See the original publication and references therein for details.
beta_min (float) β beta min. Again, see the paper please.
beta_max (float) β beta max. Again, see the paper please.
time_eps (float) β some small number to use as minimum time point for the forward process. Used for numerical stability.
time_max (float) β maximum integration time. 1 is good, but so is 5 or 10.
- Returns:
returns a score model that can be used for posterior inference using AiO.
References
Gloeckler, Manuel, et al. βAll-in-one simulation-based inference.β International Conference on Machine Learning, 2024.
- sbijax.experimental.nn.make_score_model(n_dimension, hidden_sizes=(128, 128), data_embedding_layers=(128, 128), param_embedding_layers=(128, 128), time_embedding_layers=(128, 128), activation=<jax._src.custom_derivatives.custom_jvp object>, sde='vp', beta_min=0.1, beta_max=10.0, time_eps=0.001, time_max=1)[source]#
Create a score model for NPSE.
The score model uses MLPs to embed the data, the parameters and the time points (after projecting them with a sinusoidal embedding). The score net itself is also an MLP.
- Parameters:
n_dimension (int) β dimensionality of modelled space
hidden_sizes (tuple[int, ...]) β tuple of ints determining the layers of the score network
data_embedding_layers (tuple[int, ...]) β a tuple if ints determining the output sizes of the data embedding network
param_embedding_layers (tuple[int, ...]) β a tuple if ints determining the output sizes of the data embedding network
time_embedding_layers (tuple[int, ...]) β a tuple if ints determining the output sizes of the data embedding network
activation (Callable) β a jax activation function
sde β can be either of βvpβ and βveβ. Defines the type of SDE to be used as a forward process. See the original publication and references therein for details.
beta_min β beta min. Again, see the paper please.
beta_max β beta max. Again, see the paper please.
time_eps β some small number to use as minimum time point for the forward process. Used for numerical stability.
time_max β maximum integration time. 1 is good, but so is 5 or 10.
- Returns:
returns a score model that can be used for inference using NPSE.