sbijax.nn#

sbijax.nn contains utility functions and classes to construct neural networks and normalizing flows.

make_mdn(n_dimension, n_components[, ...])

Create a mixture density network.

make_maf(n_dimension[, n_layers, ...])

Create an affine (surjective) masked autoregressive flow.

make_spf(n_dimension, range_min, range_max)

Create a rational-quadratic (surjective) spline coupling flow.

make_cnf(n_dimension[, n_layers, ...])

Create a conditional continuous normalizing flow.

make_mlp([n_layers, hidden_size, ...])

Create a MLP-based classifier network.

make_resnet([n_layers, hidden_size, ...])

Create a ResNet-based classifier network.

make_cm(n_dimension[, n_layers, ...])

Create a consistency model.

make_nass_net(embedding_dim, hidden_sizes[, ...])

Create a critic network for SNASS.

make_nasss_net(embedding_dim, ...[, activation])

Create a critic network for SNASSS.

Density estimators#

sbijax.nn.make_mdn(n_dimension, n_components, hidden_sizes=(64, 64), activation=<jax._src.custom_derivatives.custom_jvp object>)[source]#

Create a mixture density network.

The MDN uses n_components mixture components each modelling the distribution of a `n_dimension`al data point.

Parameters:
  • n_dimension (int) – dimensionality of data

  • n_components (int) – number of mixture components

  • hidden_sizes (Iterable[int]) – sizes of hidden layers for each normalizing flow. E.g., when the hidden sizes are a tuple (64, 64), then each maf layer uses a MADE with two layers of size 64 each

  • activation (Callable) – a jax activation function

Returns:

a mixture density network

sbijax.nn.make_maf(n_dimension, n_layers=5, n_layer_dimensions=None, hidden_sizes=(64, 64), activation=<PjitFunction of <function tanh>>)[source]#

Create an affine (surjective) masked autoregressive flow.

The MAFs use n_layers layers and are parameterized using MADE networks with hidden_sizes neurons per layer. For each dimensionality reducing layer, a conditional Gaussian density is used that uses the same number of layer and nodes per layers as hidden_sizes. The argument n_layer_dimensions determines which layer is dimensionality-preserving or -reducing. For example, for n_layer_dimensions=(5, 5, 3, 3) and n_dimension=5, the third layer would reduce the dimensionality by two and use a surjection layer. THe other layers are dimensionality-preserving.

Return type:

Transformed

Parameters:
  • n_dimension (int) – a list of integers that determine the dimensionality of each flow layer

  • n_layers (int | None) – number of layers

  • n_layer_dimensions (Iterable[int] | None) – list of integers that determine if a layer is dimensionality-preserving or -reducing

  • hidden_sizes (Iterable[int]) – sizes of hidden layers for each normalizing flow

  • activation (Callable) – a jax activation function

Examples

>>> neural_network = make_maf(10, n_layer_dimensions=(10, 10, 5, 5, 5))
Returns:

a (surjective) normalizing flow model

Parameters:
  • n_dimension (int)

  • n_layers (int | None)

  • n_layer_dimensions (Iterable[int] | None)

  • hidden_sizes (Iterable[int])

  • activation (Callable)

Return type:

Transformed

sbijax.nn.make_spf(n_dimension, range_min, range_max, n_layers=5, n_layer_dimensions=None, hidden_sizes=(64, 64), n_params=10, activation=<PjitFunction of <function tanh>>)[source]#

Create a rational-quadratic (surjective) spline coupling flow.

The MAFs use n_layers layers and are parameterized using MADE networks with hidden_sizes neurons per layer. For each dimensionality reducing layer, a conditional Gaussian density is used that uses the same number of layer and nodes per layers as hidden_sizes. The argument n_layer_dimensions determines which layer is dimensionality-preserving or -reducing. For example, for n_layer_dimensions=(5, 5, 3, 3) and n_dimension=5, the third layer would reduce the dimensionality by two and use a surjection layer. THe other layers are dimensionality-preserving.

Return type:

Transformed

Parameters:
  • n_dimension (int) – a list of integers that determine the dimensionality of each flow layer

  • range_min (float) – minimum range on which the spline is defined

  • range_max (float) – maximum range on which the spline is defined

  • n_layers (int | None) – number of layers

  • n_layer_dimensions (Iterable[int] | None) – list of integers that determine if a layer is dimensionality-preserving or -reducing

  • hidden_sizes (Iterable[int]) – sizes of hidden layers for each normalizing flow

  • n_params (int) – number of parameters of each spline

  • activation (Callable) – a jax activation function

Examples

>>> neural_network = make_spf(10, -1.0, 1.0, n_layer_dimensions=(10, 10, 5, 5, 5))
Returns:

a (surjective) normalizing flow model

Parameters:
  • n_dimension (int)

  • range_min (float)

  • range_max (float)

  • n_layers (int | None)

  • n_layer_dimensions (Iterable[int] | None)

  • hidden_sizes (Iterable[int])

  • n_params (int)

  • activation (Callable)

Return type:

Transformed

sbijax.nn.make_cnf(n_dimension, n_layers=2, hidden_size=64, activation=<jax._src.custom_derivatives.custom_jvp object>, dropout_rate=0.1, do_batch_norm=False, batch_norm_decay=0.2, sigma_min=0.001)[source]#

Create a conditional continuous normalizing flow.

The CCNF uses a residual network as transformer which is created automatically.

Parameters:
  • n_dimension (int) – dimensionality of modelled space

  • n_layers (int) – number of resnet blocks

  • hidden_size (int) – sizes of hidden layers for each resnet block

  • activation (Callable) – a jax activation function

  • dropout_rate (float) – dropout rate to use in resnet blocks

  • do_batch_norm (bool) – use batch normalization or not

  • batch_norm_decay (float) – decay rate of EMA in batch norm layer

  • sigma_min (float) – minimal scaling for the vector field

Returns:

returns a conditional continuous normalizing flow

Classifier networks#

sbijax.nn.make_mlp(n_layers=2, hidden_size=64, activation=<function gelu>, w_init=<haiku._src.initializers.TruncatedNormal object>, b_init=<function zeros>)[source]#

Create a MLP-based classifier network.

Parameters:
  • n_layers (int) – the number of hidden layers to be used

  • hidden_size (int) – the size of each layer

  • activation – a JAX activation function

  • w_init – a haiku initializer

  • b_init – a haiku initializer

Returns:

a transformable haiku neural network module

sbijax.nn.make_resnet(n_layers=2, hidden_size=64, activation=<PjitFunction of <function tanh>>, dropout_rate=0.2, do_batch_norm=False, batch_norm_decay=0.2)[source]#

Create a ResNet-based classifier network.

Parameters:
  • n_layers (int) – number of normalizing flow layers

  • hidden_size (int) – sizes of hidden layers for each normalizing flow

  • activation (Callable) – a jax activation function

  • dropout_rate (float) – dropout rate to use in resnet blocks

  • do_batch_norm (bool) – use batch normalization or not

  • batch_norm_decay (float) – decay rate of EMA in batch norm layer

Returns:

a neural network model

Consistency models#

sbijax.nn.make_cm(n_dimension, n_layers=2, hidden_size=64, activation=<PjitFunction of <function tanh>>, dropout_rate=0.2, do_batch_norm=False, batch_norm_decay=0.2, t_min=0.001, t_max=50.0, sigma_data=1.0)[source]#

Create a consistency model.

The consistency model uses a residual network as score network.

Parameters:
  • n_dimension (int) – dimensionality of modelled space

  • n_layers (int) – number of resnet blocks

  • hidden_size (int) – sizes of hidden layers for each resnet block

  • activation (Callable) – a jax activation function

  • dropout_rate (float) – dropout rate to use in resnet blocks

  • do_batch_norm (bool) – use batch normalization or not

  • batch_norm_decay (float) – decay rate of EMA in batch norm layer

  • t_min (float) – minimal time point for ODE integration

  • t_max (float) – maximal time point for ODE integration

  • sigma_data (float) – the standard deviation of the data :)

Returns:

a consistency model

Summary statistics networks#

sbijax.nn.make_nass_net(embedding_dim, hidden_sizes, activation=<jax._src.custom_derivatives.custom_jvp object>)[source]#

Create a critic network for SNASS.

Parameters:
  • embedding_dim (int) – dimensionality of the summary statistic

  • hidden_sizes (Iterable[int]) – list of integers specifying the hidden dimensions of the networks

  • activation (Callable[[Array], Array]) – a jax activation function

Returns:

a network that can be used within a NASS posterior estimator

sbijax.nn.make_nasss_net(embedding_dim, sec_embedding_dim, hidden_sizes, activation=<jax._src.custom_derivatives.custom_jvp object>)[source]#

Create a critic network for SNASSS.

Parameters:
  • embedding_dim (int) – dimensionality of the summary statistic

  • sec_embedding_dim (int) – dimensionality of the secondary summary statistic

  • hidden_sizes (Iterable[int]) – list of integers specifying the hidden dimensions of the networks

  • activation (Callable[[Array], Array]) – a jax activation function

Returns:

a network that can be used within a SNASSS posterior estimator