Source code for sbijax._src.nn.make_mdn
from collections.abc import Callable, Iterable
import haiku as hk
import jax
from jax import numpy as jnp
from tensorflow_probability.substrates.jax import distributions as tfd
# pylint: disable=too-many-arguments
[docs]
def make_mdn(
n_dimension: int,
n_components: int,
hidden_sizes: Iterable[int] = (64, 64),
activation: Callable = jax.nn.relu,
):
"""Create a mixture density network.
The MDN uses `n_components` mixture components each modelling the
distribution of a `n_dimension`al data point.
Args:
n_dimension: dimensionality of data
n_components: number of mixture components
hidden_sizes: sizes of hidden layers for each normalizing flow. E.g.,
when the hidden sizes are a tuple (64, 64), then each maf layer
uses a MADE with two layers of size 64 each
activation: a jax activation function
Returns:
a mixture density network
"""
@hk.transform
def mdn(method, **kwargs):
n = kwargs["x"].shape[0]
hidden = hk.nets.MLP(
hidden_sizes, activation=activation, activate_final=True
)(kwargs["x"])
logits = hk.Linear(n_components)(hidden)
mu_sigma = hk.Linear(n_components * n_dimension * 2)(hidden)
mu, sigma = jnp.split(mu_sigma, 2, axis=-1)
mixture = tfd.MixtureSameFamily(
tfd.Categorical(logits=logits),
tfd.MultivariateNormalDiag(
mu.reshape(n, n_components, n_dimension),
jnp.exp(sigma.reshape(n, n_components, n_dimension)),
),
)
if method == "sample":
return mixture.sample(seed=hk.next_rng_key())
else:
return mixture.log_prob(kwargs["y"])
return mdn