Source code for sbijax._src.nn.make_nass_network

from collections.abc import Callable, Iterable

import haiku as hk
import jax

from sbijax._src.nn.nass_net import NASSNet
from sbijax._src.nn.nasss_net import NASSSNet


[docs] def make_nass_net( embedding_dim: int, hidden_sizes: Iterable[int], activation: Callable[[jax.Array], jax.Array] = jax.nn.relu, ): """Create a critic network for SNASS. Args: embedding_dim: dimensionality of the summary statistic hidden_sizes: list of integers specifying the hidden dimensions of the networks activation: a jax activation function Returns: a network that can be used within a NASS posterior estimator """ @hk.without_apply_rng @hk.transform def _net(method, **kwargs): summary_net = hk.nets.MLP( output_sizes=list(hidden_sizes) + [embedding_dim], activation=activation, ) critic_net = hk.nets.MLP( output_sizes=list(hidden_sizes) + [1], activation=activation ) net = NASSNet(summary_net=summary_net, critic_net=critic_net) return net(method, **kwargs) return _net
[docs] def make_nasss_net( embedding_dim: int, sec_embedding_dim: int, hidden_sizes: Iterable[int], activation: Callable[[jax.Array], jax.Array] = jax.nn.relu, ): """Create a critic network for SNASSS. Args: embedding_dim: dimensionality of the summary statistic sec_embedding_dim: dimensionality of the secondary summary statistic hidden_sizes: list of integers specifying the hidden dimensions of the networks activation: a jax activation function Returns: a network that can be used within a SNASSS posterior estimator """ @hk.without_apply_rng @hk.transform def _net(method, **kwargs): summary_net = hk.nets.MLP( output_sizes=list(hidden_sizes) + [embedding_dim], activation=activation, ) sec_summary_net = hk.nets.MLP( output_sizes=list(hidden_sizes) + [sec_embedding_dim], activation=activation, ) critic_net = hk.nets.MLP( output_sizes=list(hidden_sizes) + [1], activation=activation ) net = NASSSNet( summary_net=summary_net, sec_summary_net=sec_summary_net, critic_net=critic_net, ) return net(method, **kwargs) return _net