Source code for sbijax._src.train.train

"""Generic training driver over an objective's TrainFns (design B)."""

import logging

import jax
import numpy as np
import optax
from jax import numpy as jnp
from jax import random as jr
from tqdm import tqdm

from sbijax._src.train._types import Info, next_round
from sbijax._src.util.dataloader import as_batch_iterators
from sbijax._src.util.early_stopping import EarlyStopping

logger = logging.getLogger(__name__)


# ruff: noqa: PLR0913
[docs] def train( rng_key, objective, data, *, optimizer=None, info=None, n_iter=1000, batch_size=100, percentage_data_as_validation_set=0.1, n_early_stopping_patience=10, n_early_stopping_delta=1e-3, ): """Train any objective's ``TrainFns`` with early stopping. Args: rng_key: a jax random key objective: an ``ObjectiveFns``/``SummaryFns`` (anything with ``train``) data: a ``{"y", "theta"}`` dataset pytree optimizer: an optax optimizer, bound into the primitives here info: previous round's ``Info`` (``None`` for round 0) n_iter: number of epochs batch_size: minibatch size percentage_data_as_validation_set: validation split fraction n_early_stopping_patience: early-stopping patience n_early_stopping_delta: minimum early-stopping improvement Returns: a tuple ``(params, Info)`` """ if optimizer is None: optimizer = optax.adam(3e-4) train_fns = objective.train itr_key, rng_key = jr.split(rng_key) train_iter, val_iter = as_batch_iterators( itr_key, data, batch_size, 1.0 - percentage_data_as_validation_set, True ) init_key, rng_key = jr.split(rng_key) state = train_fns.init_fn(optimizer, init_key, next(iter(train_iter))) step_fn = jax.jit(lambda rng, s, b: train_fns.step_fn(optimizer, rng, s, b)) eval_fn = jax.jit(train_fns.eval_fn) def _weighted(metric_fn, itr): total = 0.0 for batch in itr: total += metric_fn(batch) * (batch["y"].shape[0] / itr.num_samples) return total losses = np.zeros([n_iter, 2]) early_stop = EarlyStopping(n_early_stopping_delta, n_early_stopping_patience) best_params, best_loss = state.params, np.inf i = 0 for i in tqdm(range(n_iter)): rng_key = jr.fold_in(rng_key, i) train_loss = 0.0 for batch in train_iter: step_key, rng_key = jr.split(rng_key) metrics, state = step_fn(step_key, state, batch) train_loss += metrics["loss"] * ( batch["y"].shape[0] / train_iter.num_samples ) val_key, rng_key = jr.split(rng_key) val_loss = _weighted( lambda b, s=state, k=val_key: eval_fn(k, s, b)["loss"], val_iter ) losses[i] = jnp.array([train_loss, val_loss]) _, early_stop = early_stop.update(val_loss) if early_stop.should_stop: break if val_loss < best_loss: best_loss, best_params = val_loss, state.params stacked = jnp.vstack(losses)[: (i + 1), :] return best_params, Info(round=next_round(info), losses=stacked)