sbijax.nn#
sbijax.nn contains utility functions and classes to construct neural
networks and normalizing flows.
|
Create a mixture density network. |
|
Create an affine (surjective) masked autoregressive flow. |
|
Create a rational-quadratic (surjective) spline coupling flow. |
|
Create a conditional continuous normalizing flow. |
|
Create a MLP-based classifier network. |
|
Create a ResNet-based classifier network. |
|
Create a consistency model. |
|
Create a critic network for SNASS. |
|
Create a critic network for SNASSS. |
Density estimators#
- sbijax.nn.make_mdn(n_dimension, n_components, hidden_sizes=(64, 64), activation=<jax._src.custom_derivatives.custom_jvp object>)[source]#
Create a mixture density network.
The MDN uses n_components mixture components each modelling the distribution of a `n_dimension`al data point.
- Parameters:
n_dimension (int) – dimensionality of data
n_components (int) – number of mixture components
hidden_sizes (Iterable[int]) – sizes of hidden layers for each normalizing flow. E.g., when the hidden sizes are a tuple (64, 64), then each maf layer uses a MADE with two layers of size 64 each
activation (Callable) – a jax activation function
- Returns:
a mixture density network
- sbijax.nn.make_maf(n_dimension, n_layers=5, n_layer_dimensions=None, hidden_sizes=(64, 64), activation=<PjitFunction of <function tanh>>)[source]#
Create an affine (surjective) masked autoregressive flow.
The MAFs use n_layers layers and are parameterized using MADE networks with hidden_sizes neurons per layer. For each dimensionality reducing layer, a conditional Gaussian density is used that uses the same number of layer and nodes per layers as hidden_sizes. The argument n_layer_dimensions determines which layer is dimensionality-preserving or -reducing. For example, for n_layer_dimensions=(5, 5, 3, 3) and n_dimension=5, the third layer would reduce the dimensionality by two and use a surjection layer. THe other layers are dimensionality-preserving.
- Return type:
Transformed- Parameters:
n_dimension (int) – a list of integers that determine the dimensionality of each flow layer
n_layers (int | None) – number of layers
n_layer_dimensions (Iterable[int] | None) – list of integers that determine if a layer is dimensionality-preserving or -reducing
hidden_sizes (Iterable[int]) – sizes of hidden layers for each normalizing flow
activation (Callable) – a jax activation function
Examples
>>> neural_network = make_maf(10, n_layer_dimensions=(10, 10, 5, 5, 5))
- Returns:
a (surjective) normalizing flow model
- Parameters:
n_dimension (int)
n_layers (int | None)
n_layer_dimensions (Iterable[int] | None)
hidden_sizes (Iterable[int])
activation (Callable)
- Return type:
Transformed
- sbijax.nn.make_spf(n_dimension, range_min, range_max, n_layers=5, n_layer_dimensions=None, hidden_sizes=(64, 64), n_params=10, activation=<PjitFunction of <function tanh>>)[source]#
Create a rational-quadratic (surjective) spline coupling flow.
The MAFs use n_layers layers and are parameterized using MADE networks with hidden_sizes neurons per layer. For each dimensionality reducing layer, a conditional Gaussian density is used that uses the same number of layer and nodes per layers as hidden_sizes. The argument n_layer_dimensions determines which layer is dimensionality-preserving or -reducing. For example, for n_layer_dimensions=(5, 5, 3, 3) and n_dimension=5, the third layer would reduce the dimensionality by two and use a surjection layer. THe other layers are dimensionality-preserving.
- Return type:
Transformed- Parameters:
n_dimension (int) – a list of integers that determine the dimensionality of each flow layer
range_min (float) – minimum range on which the spline is defined
range_max (float) – maximum range on which the spline is defined
n_layers (int | None) – number of layers
n_layer_dimensions (Iterable[int] | None) – list of integers that determine if a layer is dimensionality-preserving or -reducing
hidden_sizes (Iterable[int]) – sizes of hidden layers for each normalizing flow
n_params (int) – number of parameters of each spline
activation (Callable) – a jax activation function
Examples
>>> neural_network = make_spf(10, -1.0, 1.0, n_layer_dimensions=(10, 10, 5, 5, 5))
- Returns:
a (surjective) normalizing flow model
- Parameters:
n_dimension (int)
range_min (float)
range_max (float)
n_layers (int | None)
n_layer_dimensions (Iterable[int] | None)
hidden_sizes (Iterable[int])
n_params (int)
activation (Callable)
- Return type:
Transformed
- sbijax.nn.make_cnf(n_dimension, n_layers=2, hidden_size=64, activation=<jax._src.custom_derivatives.custom_jvp object>, dropout_rate=0.1, do_batch_norm=False, batch_norm_decay=0.2, sigma_min=0.001)[source]#
Create a conditional continuous normalizing flow.
The CCNF uses a residual network as transformer which is created automatically.
- Parameters:
n_dimension (int) – dimensionality of modelled space
n_layers (int) – number of resnet blocks
hidden_size (int) – sizes of hidden layers for each resnet block
activation (Callable) – a jax activation function
dropout_rate (float) – dropout rate to use in resnet blocks
do_batch_norm (bool) – use batch normalization or not
batch_norm_decay (float) – decay rate of EMA in batch norm layer
sigma_min (float) – minimal scaling for the vector field
- Returns:
returns a conditional continuous normalizing flow
Classifier networks#
- sbijax.nn.make_mlp(n_layers=2, hidden_size=64, activation=<function gelu>, w_init=<haiku._src.initializers.TruncatedNormal object>, b_init=<function zeros>)[source]#
Create a MLP-based classifier network.
- Parameters:
n_layers (int) – the number of hidden layers to be used
hidden_size (int) – the size of each layer
activation – a JAX activation function
w_init – a haiku initializer
b_init – a haiku initializer
- Returns:
a transformable haiku neural network module
- sbijax.nn.make_resnet(n_layers=2, hidden_size=64, activation=<PjitFunction of <function tanh>>, dropout_rate=0.2, do_batch_norm=False, batch_norm_decay=0.2)[source]#
Create a ResNet-based classifier network.
- Parameters:
n_layers (int) – number of normalizing flow layers
hidden_size (int) – sizes of hidden layers for each normalizing flow
activation (Callable) – a jax activation function
dropout_rate (float) – dropout rate to use in resnet blocks
do_batch_norm (bool) – use batch normalization or not
batch_norm_decay (float) – decay rate of EMA in batch norm layer
- Returns:
a neural network model
Consistency models#
- sbijax.nn.make_cm(n_dimension, n_layers=2, hidden_size=64, activation=<PjitFunction of <function tanh>>, dropout_rate=0.2, do_batch_norm=False, batch_norm_decay=0.2, t_min=0.001, t_max=50.0, sigma_data=1.0)[source]#
Create a consistency model.
The consistency model uses a residual network as score network.
- Parameters:
n_dimension (int) – dimensionality of modelled space
n_layers (int) – number of resnet blocks
hidden_size (int) – sizes of hidden layers for each resnet block
activation (Callable) – a jax activation function
dropout_rate (float) – dropout rate to use in resnet blocks
do_batch_norm (bool) – use batch normalization or not
batch_norm_decay (float) – decay rate of EMA in batch norm layer
t_min (float) – minimal time point for ODE integration
t_max (float) – maximal time point for ODE integration
sigma_data (float) – the standard deviation of the data :)
- Returns:
a consistency model
Summary statistics networks#
- sbijax.nn.make_nass_net(embedding_dim, hidden_sizes, activation=<jax._src.custom_derivatives.custom_jvp object>)[source]#
Create a critic network for SNASS.
- Parameters:
embedding_dim (int) – dimensionality of the summary statistic
hidden_sizes (Iterable[int]) – list of integers specifying the hidden dimensions of the networks
activation (Callable[[Array], Array]) – a jax activation function
- Returns:
a network that can be used within a NASS posterior estimator
- sbijax.nn.make_nasss_net(embedding_dim, sec_embedding_dim, hidden_sizes, activation=<jax._src.custom_derivatives.custom_jvp object>)[source]#
Create a critic network for SNASSS.
- Parameters:
embedding_dim (int) – dimensionality of the summary statistic
sec_embedding_dim (int) – dimensionality of the secondary summary statistic
hidden_sizes (Iterable[int]) – list of integers specifying the hidden dimensions of the networks
activation (Callable[[Array], Array]) – a jax activation function
- Returns:
a network that can be used within a SNASSS posterior estimator