sbijax.nn
#
sbijax.nn
contains utility functions and classes to construct neural
networks and normalizing flows.
|
Create an affine masked autoregressive flow. |
|
Create a surjective affine masked autoregressive flow. |
|
Create a resnet. |
|
Create a conditional continuous normalizing flow. |
|
Create a critic network for SNASS. |
|
Create a critic network for SNASSS. |
- sbijax.nn.make_affine_maf(n_dimension, n_layers=5, hidden_sizes=(64, 64), activation=<PjitFunction of <function jax.numpy.tanh>>)[source]#
Create an affine masked autoregressive flow.
The MAFs use n_layers layers and are parameterized using MADE networks with hidden_sizes neurons per layer.
- Parameters:
n_dimension (int) β dimensionality of data
n_layers (int) β number of normalizing flow layers
hidden_sizes (Iterable[int]) β sizes of hidden layers for each normalizing flow. E.g., when the hidden sizes are a tuple (64, 64), then each maf layer uses a MADE with two layers of size 64 each
activation (Callable) β a jax activation function
- Returns:
a normalizing flow model
- sbijax.nn.make_surjective_affine_maf(n_dimension, n_layer_dimensions, hidden_sizes=(64, 64), activation=<PjitFunction of <function jax.numpy.tanh>>)[source]#
Create a surjective affine masked autoregressive flow.
The MAFs use n_layers layers and are parameterized using MADE networks with hidden_sizes neurons per layer. For each dimensionality reducing layer, a conditional Gaussian density is used that uses the same number of layer and nodes per layers as hidden_sizes. The argument n_layer_dimensions determines which layer is dimensionality-preserving or -reducing. For example, for n_layer_dimensions=(5, 5, 3, 3) and n_dimension=5, the third layer would reduce the dimensionality by two and use a surjection layer. THe other layers are dimensionality-preserving.
- Parameters:
n_dimension (int) β a list of integers that determine the dimensionality of each flow layer
n_layer_dimensions (Iterable[int]) β list of integers that determine if a layer is dimensionality-preserving or -reducing
hidden_sizes (Iterable[int]) β sizes of hidden layers for each normalizing flow
activation (Callable) β a jax activation function
Examples
>>> make_surjective_affine_maf(10, (10, 10, 5, 5, 5))
- Returns:
a surjective normalizing flow model
- Parameters:
n_dimension (int)
n_layer_dimensions (Iterable[int])
hidden_sizes (Iterable[int])
activation (Callable)
- sbijax.nn.make_resnet(n_layers=2, hidden_size=64, activation=<PjitFunction of <function jax.numpy.tanh>>, dropout_rate=0.2, do_batch_norm=False, batch_norm_decay=0.2)[source]#
Create a resnet.
- Parameters:
n_layers (int) β number of normalizing flow layers
hidden_size (int) β sizes of hidden layers for each normalizing flow
activation (Callable) β a jax activation function
dropout_rate (float) β dropout rate to use in resnet blocks
do_batch_norm (bool) β use batch normalization or not
batch_norm_decay (float) β decay rate of EMA in batch norm layer
- Returns:
a neural network model
- sbijax.nn.make_ccnf(n_dimension, n_layers=2, hidden_size=64, activation=<PjitFunction of <function jax.numpy.tanh>>, dropout_rate=0.2, do_batch_norm=False, batch_norm_decay=0.2)[source]#
Create a conditional continuous normalizing flow.
The CCNF uses a residual network as transformer which is created automatically.
- Parameters:
n_dimension (int) β dimensionality of modelled space
n_layers (int) β number of resnet blocks
hidden_size (int) β sizes of hidden layers for each resnet block
activation (Callable) β a jax activation function
dropout_rate (float) β dropout rate to use in resnet blocks
do_batch_norm (bool) β use batch normalization or not
batch_norm_decay (float) β decay rate of EMA in batch norm layer
- Returns:
returns a conditional continuous normalizing flow
- sbijax.nn.make_snass_net(summary_net_dimensions, critic_net_dimensions, activation=<jax._src.custom_derivatives.custom_jvp object>)[source]#
Create a critic network for SNASS.
- Parameters:
summary_net_dimensions (Iterable[int]) β a list of integers representing the dimensionalities of the summary network. The _last_ dimension determines the dimensionality of the summary statistic
critic_net_dimensions (Iterable[int]) β a list of integers representing the dimensionality of the critic network. The _last_ dimension needs to be 1.
activation (Callable[[Array], Array]) β a jax activation function
- Returns:
a network that can be used within a SNASS posterior estimator
- sbijax.nn.make_snasss_net(summary_net_dimensions, sec_summary_net_dimensions, critic_net_dimensions, activation=<jax._src.custom_derivatives.custom_jvp object>)[source]#
Create a critic network for SNASSS.
- Parameters:
summary_net_dimensions (Iterable[int]) β a list of integers representing the dimensionalities of the summary network. The _last_ dimension determines the dimensionality of the summary statistic
sec_summary_net_dimensions (Iterable[int]) β list of integers representing the dimensionalities of the summary network. The _last_ dimension determines the dimensionality of the second summary statistic and it should be smaller than the last dimension of the first summary net.
critic_net_dimensions (Iterable[int]) β a list of integers representing the dimensionality of the critic network. The _last_ dimension needs to be 1.
activation (Callable[[Array], Array]) β a jax activation function
- Returns:
a network that can be used within a SNASSS posterior estimator