sbijax.experimental#
sbijax.experimental contains experimental code that might get ported to the
main code base or possibly deleted again.
cmpe (consistency-model posterior estimation) and aio are functional
factories; aio delegates to the fmpe core, and
make_truncated_proposal builds the truncated-prior proposal used with
sbijax.run_sequential(). The score networks below are consumed by
sbijax.npse(), which now lives in the main package.
|
Construct a consistency model posterior objective. |
|
Construct an all-in-one posterior estimator. |
|
Build a truncated-prior |
- sbijax.experimental.cmpe(network, *, t_min=0.001, t_max=50.0)[source]#
Construct a consistency model posterior objective.
The returned
ObjectiveFnsis trained via the sharedtrain()driver. EMA params are threaded throughTrainingState.paramsas a dict:{"params": live_params, "ema_params": ema_params}.- Parameters:
network β a consistency model with
vector_fieldandsamplemethodst_min β minimal time point for ODE integration
t_max β maximal time point for ODE integration
- Returns:
an
ObjectiveFns
- sbijax.experimental.aio(network)[source]#
Construct an all-in-one posterior estimator.
- Parameters:
network β a simformer-based score network with
loss,sampleandlog_probmethods- Returns:
an
ObjectiveFns
- sbijax.experimental.make_truncated_proposal(prior, network, *, quantile=0.0005, n_calibration=100000, n_prior=1000000, max_iter=1000)[source]#
Build a truncated-prior
proposal_fnforrun_sequential().- Parameters:
prior β the prior distribution
network β the score network the estimator wraps (exposes
log_prob)quantile β lower-tail quantile of posterior log-densities taken as the truncation boundary
n_calibration β number of posterior draws used to calibrate the boundary and the bounding hypercube
n_prior β number of prior draws used to bound the hypercube
max_iter β maximum rejection rounds before giving up
- Returns:
a
proposal_fn(objective, params, observable, sampler)returning a callable(rng_key, n) -> thetathat draws parameters from the truncated prior, in the pytree structure the prior and simulator use
|
Create a score model for NPSE. |
|
Create a score network for AiO. |
|
Score model. |
- sbijax.experimental.nn.make_simformer_based_score_model(n_dimension, mask, n_heads=4, n_layers=4, head_size=None, embedding_dim_values=32, embedding_dim_ids=32, embedding_dim_conditioning=8, time_embedding_layers=(128, 128), dropout_rate=0.1, activation=<function gelu>, sde='vp', beta_min=0.1, beta_max=10.0, time_eps=0.001, time_max=1)[source]#
Create a score network for AiO.
The score model uses a transformer as a score estimator.
- Parameters:
n_dimension (int) β dimensionality of modelled space
mask (Array) β a binary matrix of conditional dependencies
n_heads (int) β number of attention heads
n_layers (int) β number of attention layers
head_size (int | None) β size of an attention head
embedding_dim_values (int) β dimensionality of the embedding for the values
embedding_dim_ids (int) β dimensionality of the embedding for the ids of the variables
embedding_dim_conditioning (int) β dimensionality of the binary conditioning labels
time_embedding_layers (tuple[int, ...]) β a tuple if ints determining the output sizes of the data embedding network
dropout_rate (float) β a tuple if ints determining the output sizes of the data embedding network
activation (Callable[[...], Any]) β activation function to be used for
sde (str) β can be either of βvpβ and βveβ. Defines the type of SDE to be used as a forward process. See the original publication and references therein for details.
beta_min (float) β beta min. Again, see the paper please.
beta_max (float) β beta max. Again, see the paper please.
time_eps (float) β some small number to use as minimum time point for the forward process. Used for numerical stability.
time_max (float) β maximum integration time. 1 is good, but so is 5 or 10.
- Returns:
returns a score model that can be used for posterior inference using AiO.
References
Gloeckler, Manuel, et al. βAll-in-one simulation-based inference.β International Conference on Machine Learning, 2024.
- sbijax.experimental.nn.make_score_model(n_dimension, hidden_sizes=(128, 128), data_embedding_layers=(128, 128), param_embedding_layers=(128, 128), time_embedding_layers=(128, 128), activation=<jax._src.custom_derivatives.custom_jvp object>, sde='vp', beta_min=0.1, beta_max=10.0, time_eps=0.001, time_max=1)[source]#
Create a score model for NPSE.
The score model uses MLPs to embed the data, the parameters and the time points (after projecting them with a sinusoidal embedding). The score net itself is also an MLP.
- Parameters:
n_dimension (int) β dimensionality of modelled space
hidden_sizes (tuple[int, ...]) β tuple of ints determining the layers of the score network
data_embedding_layers (tuple[int, ...]) β a tuple if ints determining the output sizes of the data embedding network
param_embedding_layers (tuple[int, ...]) β a tuple if ints determining the output sizes of the data embedding network
time_embedding_layers (tuple[int, ...]) β a tuple if ints determining the output sizes of the data embedding network
activation (Callable[[...], Any]) β a jax activation function
sde β can be either of βvpβ and βveβ. Defines the type of SDE to be used as a forward process. See the original publication and references therein for details.
beta_min β beta min. Again, see the paper please.
beta_max β beta max. Again, see the paper please.
time_eps β some small number to use as minimum time point for the forward process. Used for numerical stability.
time_max β maximum integration time. 1 is good, but so is 5 or 10.
- Returns:
returns a score model that can be used for inference using NPSE.
- class sbijax.experimental.nn.ScoreModel(n_dimension, transform, sde, beta_min, beta_max, time_eps, time_max, time_delta=0.01)[source]#
Score model.
- Parameters:
n_dimension (int) β the dimensionality of the modelled space
transform (Callable[[...], Any]) β a haiku module. The transform is a callable that has to take as input arguments named
theta,time,contextand additional keyword arguments. Theta, time and context are two-dimensional arrays with the same batch dimensions.