Source code for sbijax._src.experimental.nn.make_simformer
import dataclasses
from collections.abc import Callable
import haiku as hk
import jax
from einops import rearrange
from jax import numpy as jnp
__all__ = ["make_simformer_based_score_model"]
from sbijax._src.experimental.nn.make_score_network import (
ScoreModel,
timestep_embedding,
)
@dataclasses.dataclass
class _Encoder(hk.Module):
num_heads: int
num_layers: int
head_size: int
dropout_rate: float
widening_factor: int = 4
initializer: Callable = hk.initializers.TruncatedNormal(stddev=0.01)
activation: Callable = jax.nn.gelu
def __call__(self, inputs, time, mask, *, is_training):
dropout_rate = self.dropout_rate if is_training else 0.0
mask = mask[None, None, ...] if mask is not None else None
hidden = inputs
for _ in range(self.num_layers):
intr = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(
hidden
)
intr = hk.MultiHeadAttention(
num_heads=self.num_heads,
key_size=self.head_size or (intr.shape[-1] // self.num_heads),
w_init=self.initializer,
)(intr, intr, intr, mask=mask)
intr = hk.dropout(hk.next_rng_key(), dropout_rate, intr)
hidden = hidden + intr
intr = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(
hidden
)
intr = hk.nets.MLP(
[self.widening_factor * intr.shape[-1], intr.shape[-1]],
w_init=self.initializer,
activation=jax.nn.gelu,
)(intr)
intr = hk.dropout(hk.next_rng_key(), dropout_rate, intr)
hidden = hidden + intr
hidden = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(
hidden
)
return hidden
@dataclasses.dataclass
class _SimFormer(hk.Module):
mask: jax.Array
n_heads: int = 4
n_layers: int = 4
head_size: int | None = None
embedding_dim_values: int = 32
embedding_dim_ids: int = 32
embedding_dim_conditioning: int = 10
time_embedding_layers: tuple[int, ...] = (128, 128)
dropout_rate: float = 0.1
activation: Callable = jax.nn.relu
def __call__(self, inputs, time, context, *, is_training=True):
n_inputs, n_context = inputs.shape[-1], context.shape[-1]
inputs = jnp.concatenate([inputs, context], axis=-1)
time = hk.Sequential(
[
lambda x: timestep_embedding(x, self.time_embedding_layers[0]),
hk.nets.MLP(self.time_embedding_layers, activation=self.activation),
]
)(time)
ids = jnp.arange(inputs.shape[-1], dtype=jnp.int32).reshape(1, -1)
condition_mask = jnp.concatenate(
[
jnp.ones(n_inputs, dtype=jnp.int32),
jnp.zeros(n_context, dtype=jnp.int32),
]
).reshape(1, -1)
ids, condition_mask, inputs = jnp.broadcast_arrays(
ids, condition_mask, inputs
)
inputs_embedding = jnp.tile(
inputs.reshape(*inputs.shape, 1), [1, 1, self.embedding_dim_values]
)
id_embedding = hk.Embed(inputs.shape[-1], self.embedding_dim_ids)(ids)
condition_mask_embedding = hk.Embed(2, self.embedding_dim_conditioning)(
condition_mask
)
inputs = jnp.concatenate(
[inputs_embedding, id_embedding, condition_mask_embedding], axis=-1
)
hidden = _Encoder(
num_heads=self.n_heads,
num_layers=self.n_layers,
head_size=self.head_size,
dropout_rate=self.dropout_rate,
activation=self.activation,
)(inputs, time, self.mask, is_training=is_training)
hidden = hk.Linear(1)(hidden)
outputs = rearrange(hidden, "b l d -> b (l d)")
outputs = outputs[..., :n_inputs]
return outputs
# ruff: noqa: PLR0913,E501
[docs]
def make_simformer_based_score_model(
n_dimension: int,
mask: jax.Array,
n_heads: int = 4,
n_layers: int = 4,
head_size: int | None = None,
embedding_dim_values: int = 32,
embedding_dim_ids: int = 32,
embedding_dim_conditioning: int = 8,
time_embedding_layers: tuple[int, ...] = (
128,
128,
),
dropout_rate: float = 0.1,
activation: Callable = jax.nn.gelu,
sde: str = "vp",
beta_min: float = 0.1,
beta_max: float = 10.0,
time_eps: float = 0.001,
time_max: float = 1,
):
"""Create a score network for AiO.
The score model uses a transformer as a score estimator.
Args:
n_dimension: dimensionality of modelled space
mask: a binary matrix of conditional dependencies
n_heads: number of attention heads
n_layers: number of attention layers
head_size: size of an attention head
embedding_dim_values: dimensionality of the embedding for the values
embedding_dim_ids: dimensionality of the embedding for the ids
of the variables
embedding_dim_conditioning: dimensionality of the binary
conditioning labels
time_embedding_layers: a tuple if ints determining the output sizes of
the data embedding network
dropout_rate: a tuple if ints determining the output sizes of
the data embedding network
activation: activation function to be used for
sde: can be either of 'vp' and 've'. Defines the type of SDE to be used
as a forward process. See the original publication and references
therein for details.
beta_min: beta min. Again, see the paper please.
beta_max: beta max. Again, see the paper please.
time_eps: some small number to use as minimum time point for the
forward process. Used for numerical stability.
time_max: maximum integration time. 1 is good, but so is 5 or 10.
Returns:
returns a score model that can be used for posterior inference using
AiO.
References:
Gloeckler, Manuel, et al. "All-in-one simulation-based inference." International Conference on Machine Learning, 2024.
"""
@hk.transform
def _score_model(method, **kwargs):
nn = _SimFormer(
mask=mask,
n_heads=n_heads,
n_layers=n_layers,
head_size=head_size,
embedding_dim_conditioning=embedding_dim_conditioning,
embedding_dim_values=embedding_dim_values,
embedding_dim_ids=embedding_dim_ids,
time_embedding_layers=time_embedding_layers,
dropout_rate=dropout_rate,
activation=activation,
)
net = ScoreModel(
n_dimension, nn, sde, beta_min, beta_max, time_eps, time_max
)
return net(method, **kwargs)
return _score_model