Source code for sbijax._src.nn.continuous_normalizing_flow

from typing import Callable

import distrax
import haiku as hk
import jax
from jax import numpy as jnp
from jax.nn import glu
from scipy import integrate

__all__ = ["CCNF", "make_ccnf"]


# ruff: noqa: PLR0913,D417
class CCNF(hk.Module):
    """Conditional continuous normalizing flow.

    Args:
        n_dimension: the dimensionality of the modelled space
        transform: a haiku module. The transform is a callable that has to
            take as input arguments named 'theta', 'time', 'context' and
            **kwargs. Theta, time and context are two-dimensional arrays
            with the same batch dimensions.
    """

    def __init__(self, n_dimension: int, transform: Callable):
        """Conditional continuous normalizing flow.

        Args:
            n_dimension: the dimensionality of the modelled space
            transform: a haiku module. The transform is a callable that has to
                take as input arguments named 'theta', 'time', 'context' and
                **kwargs. Theta, time and context are two-dimensional arrays
                with the same batch dimensions.
        """
        super().__init__()
        self._n_dimension = n_dimension
        self._network = transform
        self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0)

    def __call__(self, method, **kwargs):
        """Aplpy the flow.

        Args:
            method (str): method to call

        Keyword Args:
            keyword arguments for the called method:
        """
        return getattr(self, method)(**kwargs)

    def sample(self, context, **kwargs):
        """Sample from the pushforward.

        Args:
            context: array of conditioning variables
        """
        theta_0 = self._base_distribution.sample(
            seed=hk.next_rng_key(), sample_shape=(context.shape[0],)
        )

        def ode_func(time, theta_t):
            theta_t = theta_t.reshape(-1, self._n_dimension)
            time = jnp.full((theta_t.shape[0], 1), time)
            ret = self.vector_field(
                theta=theta_t, time=time, context=context, **kwargs
            )
            return ret.reshape(-1)

        res = integrate.solve_ivp(
            ode_func,
            (0.0, 1.0),
            theta_0.reshape(-1),
            rtol=1e-5,
            atol=1e-5,
            method="RK45",
        )

        ret = res.y[:, -1].reshape(-1, self._n_dimension)
        return ret

    def vector_field(self, theta, time, context, **kwargs):
        """Compute the vector field.

        Args:
            theta: array of parameters
            time: time variables
            context: array of conditioning variables

        Keyword Args:
            keyword arguments that aer passed tothe neural network
        """
        time = jnp.full((theta.shape[0], 1), time)
        return self._network(theta=theta, time=time, context=context, **kwargs)


# pylint: disable=too-many-arguments
class _ResnetBlock(hk.Module):
    """A block for a 1d residual network."""

    def __init__(
        self,
        hidden_size: int,
        activation: Callable = jax.nn.relu,
        dropout_rate: float = 0.2,
        do_batch_norm: bool = False,
        batch_norm_decay: float = 0.1,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.activation = activation
        self.do_batch_norm = do_batch_norm
        self.dropout_rate = dropout_rate
        self.batch_norm_decay = batch_norm_decay

    def __call__(self, inputs, context, is_training=False):
        outputs = inputs
        if self.do_batch_norm:
            outputs = hk.BatchNorm(True, True, self.batch_norm_decay)(
                outputs, is_training=is_training
            )
        outputs = hk.Linear(self.hidden_size)(outputs)
        outputs = self.activation(outputs)
        if is_training:
            outputs = hk.dropout(
                rng=hk.next_rng_key(), rate=self.dropout_rate, x=outputs
            )
        outputs = hk.Linear(self.hidden_size)(outputs)
        context_proj = hk.Linear(inputs.shape[-1])(context)
        outputs = glu(jnp.concatenate([outputs, context_proj], axis=-1))
        return outputs + inputs


# pylint: disable=too-many-arguments
class _CCNFResnet(hk.Module):
    """A simplified 1-d residual network."""

    def __init__(
        self,
        n_layers: int,
        n_dimension: int,
        hidden_size: int,
        activation: Callable = jax.nn.relu,
        dropout_rate: float = 0.2,
        do_batch_norm: bool = True,
        batch_norm_decay: float = 0.1,
    ):
        super().__init__()
        self.n_layers = n_layers
        self.n_dimension = n_dimension
        self.hidden_size = hidden_size
        self.activation = activation
        self.do_batch_norm = do_batch_norm
        self.dropout_rate = dropout_rate
        self.batch_norm_decay = batch_norm_decay

    def __call__(self, theta, time, context, is_training=False, **kwargs):
        outputs = context
        # this is a bit weird, but what the paper suggests:
        # instead of using times and context (i.e., y) as conditioning variables
        # it suggests using times and theta and use y in the resnet blocks,
        # since theta is typically low-dim and y is typically high-dime
        t_theta_embedding = jnp.concatenate(
            [
                hk.Linear(self.n_dimension)(theta),
                hk.Linear(self.n_dimension)(time),
            ],
            axis=-1,
        )
        outputs = hk.Linear(self.hidden_size)(outputs)
        outputs = self.activation(outputs)
        for _ in range(self.n_layers):
            outputs = _ResnetBlock(
                hidden_size=self.hidden_size,
                activation=self.activation,
                dropout_rate=self.dropout_rate,
                do_batch_norm=self.do_batch_norm,
                batch_norm_decay=self.batch_norm_decay,
            )(outputs, context=t_theta_embedding, is_training=is_training)
        outputs = self.activation(outputs)
        outputs = hk.Linear(self.n_dimension)(outputs)
        return outputs


# ruff: noqa: PLR0913
[docs] def make_ccnf( n_dimension: int, n_layers: int = 2, hidden_size: int = 64, activation: Callable = jax.nn.tanh, dropout_rate: float = 0.2, do_batch_norm: bool = False, batch_norm_decay: float = 0.2, ): """Create a conditional continuous normalizing flow. The CCNF uses a residual network as transformer which is created automatically. Args: n_dimension: dimensionality of modelled space n_layers: number of resnet blocks hidden_size: sizes of hidden layers for each resnet block activation: a jax activation function dropout_rate: dropout rate to use in resnet blocks do_batch_norm: use batch normalization or not batch_norm_decay: decay rate of EMA in batch norm layer Returns: returns a conditional continuous normalizing flow """ @hk.transform def _flow(method, **kwargs): nn = _CCNFResnet( n_layers=n_layers, n_dimension=n_dimension, hidden_size=hidden_size, activation=activation, do_batch_norm=do_batch_norm, dropout_rate=dropout_rate, batch_norm_decay=batch_norm_decay, ) ccnf = CCNF(n_dimension, nn) return ccnf(method, **kwargs) return _flow