Source code for sbijax._src.nn.make_resnet

from collections.abc import Callable

import haiku as hk
import jax
from jax import numpy as jnp
from jax._src.nn.functions import glu


# pylint: disable=too-many-arguments
class _ResnetBlock(hk.Module):
  """A block for a 1d residual network."""

  def __init__(
    self,
    hidden_size: int,
    activation: Callable = jax.nn.relu,
    dropout_rate: float = 0.2,
    do_batch_norm: bool = False,
    batch_norm_decay: float = 0.1,
  ):
    super().__init__()
    self.hidden_size = hidden_size
    self.activation = activation
    self.do_batch_norm = do_batch_norm
    self.dropout_rate = dropout_rate
    self.batch_norm_decay = batch_norm_decay

  def __call__(self, inputs, context=None, is_training=False):
    outputs = inputs
    if self.do_batch_norm:
      outputs = hk.BatchNorm(True, True, self.batch_norm_decay)(
        outputs, is_training=is_training
      )
    outputs = hk.Linear(self.hidden_size)(outputs)
    outputs = self.activation(outputs)
    if is_training:
      outputs = hk.dropout(
        rng=hk.next_rng_key(), rate=self.dropout_rate, x=outputs
      )
    outputs = hk.Linear(self.hidden_size)(outputs)
    if context is not None:
      context_proj = hk.Linear(inputs.shape[-1])(context)
      outputs = glu(jnp.concatenate([outputs, context_proj], axis=-1))
    return outputs + inputs


# ruff: noqa: PLR0913
class _Resnet(hk.Module):
  """A simplified 1-d residual network."""

  def __init__(
    self,
    n_layers: int,
    hidden_size: int,
    activation: Callable = jax.nn.relu,
    dropout_rate: float = 0.1,
    do_batch_norm: bool = True,
    batch_norm_decay: float = 0.1,
  ):
    super().__init__()
    self.n_layers = n_layers
    self.hidden_size = hidden_size
    self.activation = activation
    self.do_batch_norm = do_batch_norm
    self.dropout_rate = dropout_rate
    self.batch_norm_decay = batch_norm_decay

  def __call__(self, inputs, is_training=False, **kwargs):
    outputs = inputs
    outputs = hk.Linear(self.hidden_size)(outputs)
    outputs = self.activation(outputs)
    for _ in range(self.n_layers):
      outputs = _ResnetBlock(
        hidden_size=self.hidden_size,
        activation=self.activation,
        dropout_rate=self.dropout_rate,
        do_batch_norm=self.do_batch_norm,
        batch_norm_decay=self.batch_norm_decay,
      )(outputs, is_training=is_training)
    outputs = self.activation(outputs)
    outputs = hk.Linear(1)(outputs)
    return outputs


[docs] def make_resnet( n_layers: int = 2, hidden_size: int = 64, activation: Callable = jax.nn.tanh, dropout_rate: float = 0.2, do_batch_norm: bool = False, batch_norm_decay: float = 0.2, ): """Create a ResNet-based classifier network. Args: n_layers: number of normalizing flow layers hidden_size: sizes of hidden layers for each normalizing flow activation: a jax activation function dropout_rate: dropout rate to use in resnet blocks do_batch_norm: use batch normalization or not batch_norm_decay: decay rate of EMA in batch norm layer Returns: a neural network model """ @hk.without_apply_rng @hk.transform def _net(inputs, is_training=False): nn = _Resnet( n_layers=n_layers, hidden_size=hidden_size, activation=activation, do_batch_norm=do_batch_norm, dropout_rate=dropout_rate, batch_norm_decay=batch_norm_decay, ) return nn(inputs, is_training=is_training) return _net