sbijax.experimental#

sbijax.experimental contains experimental code that might get ported to the main code base or possibly deleted again.

cmpe (consistency-model posterior estimation) and aio are functional factories; aio delegates to the fmpe core, and make_truncated_proposal builds the truncated-prior proposal used with sbijax.run_sequential(). The score networks below are consumed by sbijax.npse(), which now lives in the main package.

cmpe(network,Β *[,Β t_min,Β t_max])

Construct a consistency model posterior objective.

aio(network)

Construct an all-in-one posterior estimator.

make_truncated_proposal(prior,Β network,Β *[,Β ...])

Build a truncated-prior proposal_fn for run_sequential().

sbijax.experimental.cmpe(network, *, t_min=0.001, t_max=50.0)[source]#

Construct a consistency model posterior objective.

The returned ObjectiveFns is trained via the shared train() driver. EMA params are threaded through TrainingState.params as a dict: {"params": live_params, "ema_params": ema_params}.

Parameters:
  • network – a consistency model with vector_field and sample methods

  • t_min – minimal time point for ODE integration

  • t_max – maximal time point for ODE integration

Returns:

an ObjectiveFns

sbijax.experimental.aio(network)[source]#

Construct an all-in-one posterior estimator.

Parameters:

network – a simformer-based score network with loss, sample and log_prob methods

Returns:

an ObjectiveFns

sbijax.experimental.make_truncated_proposal(prior, network, *, quantile=0.0005, n_calibration=100000, n_prior=1000000, max_iter=1000)[source]#

Build a truncated-prior proposal_fn for run_sequential().

Parameters:
  • prior – the prior distribution

  • network – the score network the estimator wraps (exposes log_prob)

  • quantile – lower-tail quantile of posterior log-densities taken as the truncation boundary

  • n_calibration – number of posterior draws used to calibrate the boundary and the bounding hypercube

  • n_prior – number of prior draws used to bound the hypercube

  • max_iter – maximum rejection rounds before giving up

Returns:

a proposal_fn(objective, params, observable, sampler) returning a callable (rng_key, n) -> theta that draws parameters from the truncated prior, in the pytree structure the prior and simulator use

make_score_model(n_dimension[,Β ...])

Create a score model for NPSE.

make_simformer_based_score_model(...[,Β ...])

Create a score network for AiO.

ScoreModel(n_dimension,Β transform,Β sde,Β ...)

Score model.

sbijax.experimental.nn.make_simformer_based_score_model(n_dimension, mask, n_heads=4, n_layers=4, head_size=None, embedding_dim_values=32, embedding_dim_ids=32, embedding_dim_conditioning=8, time_embedding_layers=(128, 128), dropout_rate=0.1, activation=<function gelu>, sde='vp', beta_min=0.1, beta_max=10.0, time_eps=0.001, time_max=1)[source]#

Create a score network for AiO.

The score model uses a transformer as a score estimator.

Parameters:
  • n_dimension (int) – dimensionality of modelled space

  • mask (Array) – a binary matrix of conditional dependencies

  • n_heads (int) – number of attention heads

  • n_layers (int) – number of attention layers

  • head_size (int | None) – size of an attention head

  • embedding_dim_values (int) – dimensionality of the embedding for the values

  • embedding_dim_ids (int) – dimensionality of the embedding for the ids of the variables

  • embedding_dim_conditioning (int) – dimensionality of the binary conditioning labels

  • time_embedding_layers (tuple[int, ...]) – a tuple if ints determining the output sizes of the data embedding network

  • dropout_rate (float) – a tuple if ints determining the output sizes of the data embedding network

  • activation (Callable[[...], Any]) – activation function to be used for

  • sde (str) – can be either of β€˜vp’ and β€˜ve’. Defines the type of SDE to be used as a forward process. See the original publication and references therein for details.

  • beta_min (float) – beta min. Again, see the paper please.

  • beta_max (float) – beta max. Again, see the paper please.

  • time_eps (float) – some small number to use as minimum time point for the forward process. Used for numerical stability.

  • time_max (float) – maximum integration time. 1 is good, but so is 5 or 10.

Returns:

returns a score model that can be used for posterior inference using AiO.

References

Gloeckler, Manuel, et al. β€œAll-in-one simulation-based inference.” International Conference on Machine Learning, 2024.

sbijax.experimental.nn.make_score_model(n_dimension, hidden_sizes=(128, 128), data_embedding_layers=(128, 128), param_embedding_layers=(128, 128), time_embedding_layers=(128, 128), activation=<jax._src.custom_derivatives.custom_jvp object>, sde='vp', beta_min=0.1, beta_max=10.0, time_eps=0.001, time_max=1)[source]#

Create a score model for NPSE.

The score model uses MLPs to embed the data, the parameters and the time points (after projecting them with a sinusoidal embedding). The score net itself is also an MLP.

Parameters:
  • n_dimension (int) – dimensionality of modelled space

  • hidden_sizes (tuple[int, ...]) – tuple of ints determining the layers of the score network

  • data_embedding_layers (tuple[int, ...]) – a tuple if ints determining the output sizes of the data embedding network

  • param_embedding_layers (tuple[int, ...]) – a tuple if ints determining the output sizes of the data embedding network

  • time_embedding_layers (tuple[int, ...]) – a tuple if ints determining the output sizes of the data embedding network

  • activation (Callable[[...], Any]) – a jax activation function

  • sde – can be either of β€˜vp’ and β€˜ve’. Defines the type of SDE to be used as a forward process. See the original publication and references therein for details.

  • beta_min – beta min. Again, see the paper please.

  • beta_max – beta max. Again, see the paper please.

  • time_eps – some small number to use as minimum time point for the forward process. Used for numerical stability.

  • time_max – maximum integration time. 1 is good, but so is 5 or 10.

Returns:

returns a score model that can be used for inference using NPSE.

class sbijax.experimental.nn.ScoreModel(n_dimension, transform, sde, beta_min, beta_max, time_eps, time_max, time_delta=0.01)[source]#

Score model.

Parameters:
  • n_dimension (int) – the dimensionality of the modelled space

  • transform (Callable[[...], Any]) – a haiku module. The transform is a callable that has to take as input arguments named theta, time, context and additional keyword arguments. Theta, time and context are two-dimensional arrays with the same batch dimensions.

__call__(method, **kwargs)[source]#

Apply the model.

Parameters:

method (str) – method to call

Keyword Arguments:

method (keyword arguments for the called)