Source code for sbijax._src.nn.make_mlp
import haiku as hk
import jax
from jax import numpy as jnp
[docs]
def make_mlp(
n_layers: int = 2,
hidden_size: int = 64,
activation=jax.nn.gelu,
w_init=hk.initializers.TruncatedNormal(stddev=0.01),
b_init=jnp.zeros,
):
"""Create a MLP-based classifier network.
Args:
n_layers: the number of hidden layers to be used
hidden_size: the size of each layer
activation: a JAX activation function
w_init: a haiku initializer
b_init: a haiku initializer
Returns:
a transformable haiku neural network module
"""
@hk.without_apply_rng
@hk.transform
def _net(inputs, **kwargs):
nn = hk.nets.MLP(
output_sizes=[hidden_size] * n_layers + [1],
w_init=w_init,
b_init=b_init,
activation=activation,
)
return nn(inputs)
return _net