from collections.abc import Callable, Iterable
import distrax
import haiku as hk
import jax
import surjectors
from jax import numpy as jnp
from surjectors import (
AffineMaskedAutoregressiveInferenceFunnel,
Chain,
MaskedAutoregressive,
MaskedCoupling,
MaskedCouplingInferenceFunnel,
Permutation,
TransformedDistribution,
)
from surjectors.nn import MADE
from surjectors.nn import make_mlp as surjectors_mlp
from surjectors.util import make_alternating_binary_mask, unstack
from tensorflow_probability.substrates.jax import distributions as tfd
# ruff: noqa: PLR0913, E501
[docs]
def make_maf(
n_dimension: int,
n_layers: int | None = 5,
n_layer_dimensions: Iterable[int] | None = None,
hidden_sizes: Iterable[int] = (64, 64),
activation: Callable = jax.nn.tanh,
) -> hk.Transformed:
"""Create an affine (surjective) masked autoregressive flow.
The MAFs use `n_layers` layers and are parameterized using MADE networks
with `hidden_sizes` neurons per layer. For each dimensionality reducing
layer, a conditional Gaussian density is used that uses the same number of
layer and nodes per layers as `hidden_sizes`. The argument
`n_layer_dimensions` determines which layer is dimensionality-preserving
or -reducing. For example, for `n_layer_dimensions=(5, 5, 3, 3)` and
`n_dimension=5`, the third layer would reduce the dimensionality by two
and use a surjection layer. THe other layers are dimensionality-preserving.
Args:
n_dimension: a list of integers that determine the dimensionality
of each flow layer
n_layers: number of layers
n_layer_dimensions: list of integers that determine if a layer is
dimensionality-preserving or -reducing
hidden_sizes: sizes of hidden layers for each normalizing flow
activation: a jax activation function
Examples:
>>> neural_network = make_maf(10, n_layer_dimensions=(10, 10, 5, 5, 5))
Returns:
a (surjective) normalizing flow model
"""
if isinstance(n_layers, int) and n_layer_dimensions is not None:
assert n_layers == len(list(n_layer_dimensions))
elif isinstance(n_layers, int):
n_layer_dimensions = [n_dimension] * n_layers
return _make_maf(
n_dimension=n_dimension,
n_layer_dimensions=n_layer_dimensions,
hidden_sizes=hidden_sizes,
activation=activation,
)
def _make_maf(
n_dimension,
n_layer_dimensions,
hidden_sizes,
activation,
):
def _bijector_fn(params):
means, log_scales = unstack(params, -1)
return surjectors.ScalarAffine(means, jnp.exp(log_scales))
def _decoder_fn(n_dim, hidden_sizes):
decoder_net = surjectors_mlp(
hidden_sizes + [n_dim * 2],
w_init=hk.initializers.TruncatedNormal(stddev=0.001),
)
def _fn(z):
params = decoder_net(z)
mu, log_scale = jnp.split(params, 2, -1)
return tfd.Independent(tfd.Normal(mu, jnp.exp(log_scale)), 1)
return _fn
@hk.transform
def _flow(method, **kwargs):
layers = []
order = jnp.arange(n_dimension)
curr_dim = n_dimension
for i, n_dim_curr_layer in enumerate(n_layer_dimensions):
# layer is dimensionality preserving
if n_dim_curr_layer == curr_dim:
layer = MaskedAutoregressive(
bijector_fn=_bijector_fn,
conditioner=MADE(
n_dim_curr_layer,
list(hidden_sizes),
2,
w_init=hk.initializers.TruncatedNormal(0.001),
b_init=jnp.zeros,
activation=activation,
),
)
order = order[::-1]
elif n_dim_curr_layer < curr_dim:
n_latent = n_dim_curr_layer
layer = AffineMaskedAutoregressiveInferenceFunnel(
n_latent,
_decoder_fn(curr_dim - n_latent, list(hidden_sizes)),
conditioner=MADE(
n_latent,
list(hidden_sizes),
2,
w_init=hk.initializers.TruncatedNormal(0.001),
b_init=jnp.zeros,
activation=jax.nn.tanh,
),
)
curr_dim = n_latent
order = order[::-1]
order = order[:curr_dim] - jnp.min(order[:curr_dim])
else:
raise ValueError(
f"n_dimension at layer {i} is layer than the dimension of"
f" the following layer {i + 1}"
)
layers.append(layer)
layers.append(Permutation(order, 1))
chain = Chain(layers[:-1])
base_distribution = tfd.Independent(
tfd.Normal(jnp.zeros(curr_dim), jnp.ones(curr_dim)),
1,
)
td = TransformedDistribution(base_distribution, chain)
return td(method, **kwargs)
return _flow
# ruff: noqa: PLR0913, E501
[docs]
def make_spf(
n_dimension: int,
range_min: float,
range_max: float,
n_layers: int | None = 5,
n_layer_dimensions: Iterable[int] | None = None,
hidden_sizes: Iterable[int] = (64, 64),
n_params: int = 10,
activation: Callable = jax.nn.tanh,
) -> hk.Transformed:
"""Create a rational-quadratic (surjective) spline coupling flow.
The MAFs use `n_layers` layers and are parameterized using MADE networks
with `hidden_sizes` neurons per layer. For each dimensionality reducing
layer, a conditional Gaussian density is used that uses the same number of
layer and nodes per layers as `hidden_sizes`. The argument
`n_layer_dimensions` determines which layer is dimensionality-preserving
or -reducing. For example, for `n_layer_dimensions=(5, 5, 3, 3)` and
`n_dimension=5`, the third layer would reduce the dimensionality by two
and use a surjection layer. THe other layers are dimensionality-preserving.
Args:
n_dimension: a list of integers that determine the dimensionality
of each flow layer
range_min: minimum range on which the spline is defined
range_max: maximum range on which the spline is defined
n_layers: number of layers
n_layer_dimensions: list of integers that determine if a layer is
dimensionality-preserving or -reducing
hidden_sizes: sizes of hidden layers for each normalizing flow
n_params: number of parameters of each spline
activation: a jax activation function
Examples:
>>> neural_network = make_spf(10, -1.0, 1.0, n_layer_dimensions=(10, 10, 5, 5, 5))
Returns:
a (surjective) normalizing flow model
"""
if isinstance(n_layers, int) and n_layer_dimensions is not None:
assert n_layers == len(list(n_layer_dimensions))
if isinstance(n_layers, int):
n_layer_dimensions = [n_dimension] * n_layers
return _make_spf(
n_dimension=n_dimension,
range_min=range_min,
range_max=range_max,
n_layer_dimensions=n_layer_dimensions,
hidden_sizes=hidden_sizes,
n_params=n_params,
activation=activation,
)
def _make_spf(
n_dimension,
n_layer_dimensions,
range_min,
range_max,
n_params,
hidden_sizes,
activation,
):
def _bijector_fn(params):
return distrax.RationalQuadraticSpline(
params, range_min=range_min, range_max=range_max
)
def _decoder_fn(dims):
def fn(z):
params = surjectors_mlp(dims, activation=activation)(z)
mu, log_scale = jnp.split(params, 2, -1)
return tfd.Independent(tfd.Normal(mu, jnp.exp(log_scale)))
return fn
def _conditioner(n_dim):
return hk.Sequential(
[
surjectors_mlp(
list(hidden_sizes) + [n_params * n_dim],
activation=activation,
),
hk.Reshape((n_dimension, n_params)),
]
)
@hk.transform
def _flow(method, **kwargs):
layers = []
curr_dim = n_dimension
for i, n_dim_curr_layer in enumerate(n_layer_dimensions):
# layer is dimensionality preserving
if n_dim_curr_layer == curr_dim:
layer = MaskedCoupling(
mask=make_alternating_binary_mask(curr_dim, i % 2 == 0),
conditioner=_conditioner(curr_dim),
bijector_fn=_bijector_fn,
)
# layer is dimensionality reducing
elif n_dim_curr_layer < curr_dim:
n_latent = n_dim_curr_layer
layer = MaskedCouplingInferenceFunnel(
n_keep=n_latent,
decoder=_decoder_fn(list(hidden_sizes) + [2 * (curr_dim - n_latent)]),
conditioner=surjectors_mlp(
list(hidden_sizes) + [2 * curr_dim],
activation=activation,
),
bijector_fn=_bijector_fn,
)
curr_dim = n_latent
else:
raise ValueError(
f"n_dimension at layer {i} is layer than the dimension of"
f" the following layer {i + 1}"
)
layers.append(layer)
chain = Chain(layers)
base_distribution = tfd.Independent(
tfd.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)),
1,
)
td = TransformedDistribution(base_distribution, chain)
return td(method, **kwargs)
return _flow