sbijax.nn#


sbijax.nn contains utility functions and classes to construct neural networks and normalizing flows.

make_affine_maf(n_dimension[,Β n_layers,Β ...])

Create an affine masked autoregressive flow.

make_surjective_affine_maf(n_dimension,Β ...)

Create a surjective affine masked autoregressive flow.

make_resnet([n_layers,Β hidden_size,Β ...])

Create a resnet.

make_ccnf(n_dimension[,Β n_layers,Β ...])

Create a conditional continuous normalizing flow.

make_snass_net(summary_net_dimensions,Β ...)

Create a critic network for SNASS.

make_snasss_net(summary_net_dimensions,Β ...)

Create a critic network for SNASSS.

sbijax.nn.make_affine_maf(n_dimension, n_layers=5, hidden_sizes=(64, 64), activation=<PjitFunction of <function jax.numpy.tanh>>)[source]#

Create an affine masked autoregressive flow.

The MAFs use n_layers layers and are parameterized using MADE networks with hidden_sizes neurons per layer.

Parameters:
  • n_dimension (int) – dimensionality of data

  • n_layers (int) – number of normalizing flow layers

  • hidden_sizes (Iterable[int]) – sizes of hidden layers for each normalizing flow. E.g., when the hidden sizes are a tuple (64, 64), then each maf layer uses a MADE with two layers of size 64 each

  • activation (Callable) – a jax activation function

Returns:

a normalizing flow model

sbijax.nn.make_surjective_affine_maf(n_dimension, n_layer_dimensions, hidden_sizes=(64, 64), activation=<PjitFunction of <function jax.numpy.tanh>>)[source]#

Create a surjective affine masked autoregressive flow.

The MAFs use n_layers layers and are parameterized using MADE networks with hidden_sizes neurons per layer. For each dimensionality reducing layer, a conditional Gaussian density is used that uses the same number of layer and nodes per layers as hidden_sizes. The argument n_layer_dimensions determines which layer is dimensionality-preserving or -reducing. For example, for n_layer_dimensions=(5, 5, 3, 3) and n_dimension=5, the third layer would reduce the dimensionality by two and use a surjection layer. THe other layers are dimensionality-preserving.

Parameters:
  • n_dimension (int) – a list of integers that determine the dimensionality of each flow layer

  • n_layer_dimensions (Iterable[int]) – list of integers that determine if a layer is dimensionality-preserving or -reducing

  • hidden_sizes (Iterable[int]) – sizes of hidden layers for each normalizing flow

  • activation (Callable) – a jax activation function

Examples

>>> make_surjective_affine_maf(10, (10, 10, 5, 5, 5))
Returns:

a surjective normalizing flow model

Parameters:
  • n_dimension (int)

  • n_layer_dimensions (Iterable[int])

  • hidden_sizes (Iterable[int])

  • activation (Callable)

sbijax.nn.make_resnet(n_layers=2, hidden_size=64, activation=<PjitFunction of <function jax.numpy.tanh>>, dropout_rate=0.2, do_batch_norm=False, batch_norm_decay=0.2)[source]#

Create a resnet.

Parameters:
  • n_layers (int) – number of normalizing flow layers

  • hidden_size (int) – sizes of hidden layers for each normalizing flow

  • activation (Callable) – a jax activation function

  • dropout_rate (float) – dropout rate to use in resnet blocks

  • do_batch_norm (bool) – use batch normalization or not

  • batch_norm_decay (float) – decay rate of EMA in batch norm layer

Returns:

a neural network model

sbijax.nn.make_ccnf(n_dimension, n_layers=2, hidden_size=64, activation=<PjitFunction of <function jax.numpy.tanh>>, dropout_rate=0.2, do_batch_norm=False, batch_norm_decay=0.2)[source]#

Create a conditional continuous normalizing flow.

The CCNF uses a residual network as transformer which is created automatically.

Parameters:
  • n_dimension (int) – dimensionality of modelled space

  • n_layers (int) – number of resnet blocks

  • hidden_size (int) – sizes of hidden layers for each resnet block

  • activation (Callable) – a jax activation function

  • dropout_rate (float) – dropout rate to use in resnet blocks

  • do_batch_norm (bool) – use batch normalization or not

  • batch_norm_decay (float) – decay rate of EMA in batch norm layer

Returns:

returns a conditional continuous normalizing flow

sbijax.nn.make_snass_net(summary_net_dimensions, critic_net_dimensions, activation=<jax._src.custom_derivatives.custom_jvp object>)[source]#

Create a critic network for SNASS.

Parameters:
  • summary_net_dimensions (Iterable[int]) – a list of integers representing the dimensionalities of the summary network. The _last_ dimension determines the dimensionality of the summary statistic

  • critic_net_dimensions (Iterable[int]) – a list of integers representing the dimensionality of the critic network. The _last_ dimension needs to be 1.

  • activation (Callable[[Array], Array]) – a jax activation function

Returns:

a network that can be used within a SNASS posterior estimator

sbijax.nn.make_snasss_net(summary_net_dimensions, sec_summary_net_dimensions, critic_net_dimensions, activation=<jax._src.custom_derivatives.custom_jvp object>)[source]#

Create a critic network for SNASSS.

Parameters:
  • summary_net_dimensions (Iterable[int]) – a list of integers representing the dimensionalities of the summary network. The _last_ dimension determines the dimensionality of the summary statistic

  • sec_summary_net_dimensions (Iterable[int]) – list of integers representing the dimensionalities of the summary network. The _last_ dimension determines the dimensionality of the second summary statistic and it should be smaller than the last dimension of the first summary net.

  • critic_net_dimensions (Iterable[int]) – a list of integers representing the dimensionality of the critic network. The _last_ dimension needs to be 1.

  • activation (Callable[[Array], Array]) – a jax activation function

Returns:

a network that can be used within a SNASSS posterior estimator