Custom training and sampling loops#
sbijax.train() and sbijax.sample() are convenience drivers over the
low-level primitives that every objective carries. When you need full control –
a custom schedule, gradient accumulation, your own early-stopping rule, or a
bespoke sampling loop – drive the primitives yourself. This is the same escape
hatch dm-haiku (init/apply) and BlackJAX (init/step) provide.
The primitives#
A trainable objective obj exposes bound pure functions:
obj.train.init_fn(optimizer, rng, batch) -> TrainingState(params, opt_state)
obj.train.step_fn(optimizer, rng, state, batch) -> (metrics, TrainingState)
obj.train.eval_fn(rng, state, batch) -> metrics
obj.sample_fn(rng, params, observable, *, sampler=None) -> (samples, info)
TrainingState is an opaque carry; params is what you extract at the end.
A batch is a dict {"y": array, "theta": array} with the parameters
flattened to a (batch_size, dim) array – sbijax.simulate() returns
theta as the prior pytree, so flatten it once before batching.
A custom training loop#
import jax
import optax
from jax import random as jr
from jax._src.flatten_util import ravel_pytree
from sbijax import npe, simulate
from sbijax.nn import make_maf
obj = npe(make_maf(2))
optimizer = optax.adam(3e-4)
data = simulate(jr.key(0), prior, simulator, n=10_000)
# flatten the prior-pytree parameters to a (n, d) array
theta = jax.vmap(lambda t: ravel_pytree(t)[0])(data["theta"])
xy = {"y": data["y"], "theta": theta}
def get_batch(xy, batch_size, key):
n = xy["y"].shape[0]
perm = jr.permutation(key, n)
for i in range(0, n - batch_size + 1, batch_size):
idx = perm[i : i + batch_size]
yield {k: v[idx] for k, v in xy.items()}
# build the carry from a first batch, then jit the step
first = next(get_batch(xy, 128, jr.key(1)))
state = obj.train.init_fn(optimizer, jr.key(2), first)
step = jax.jit(lambda rng, s, b: obj.train.step_fn(optimizer, rng, s, b))
key = jr.key(3)
for epoch in range(100):
for batch in get_batch(xy, 128, jr.fold_in(key, epoch)):
key, step_key = jr.split(key)
metrics, state = step(step_key, state, batch)
# plug in your own validation / early stopping, e.g.:
# val = obj.train.eval_fn(jr.fold_in(key, epoch), state, val_batch)
params = state.params
You now hold params exactly as sbijax.train() would have returned them.
Custom sampling#
For amortized posteriors, sample_fn is a single draw from the flow –
call it directly:
samples, info = obj.sample_fn(jr.key(4), params, y_obs)
For likelihood/ratio methods the posterior is formed at sample time. The
usual path is to pass a sampler from sbijax.mcmc.make_sampler(), but you can
also drive a kernel yourself with the low-level sbijax.mcmc routines by
supplying your own log-density. For example, with the slice sampler:
from jax import numpy as jnp
from sbijax.mcmc import sample_with_slice
def log_posterior(theta):
# theta is the named prior pytree; `network` is the flow nle wraps
theta_flat = ravel_pytree(theta)[0]
lp_lik = network.apply(
params, method="log_prob", y=jnp.atleast_2d(y_obs),
x=theta_flat[None],
)
return jnp.sum(lp_lik) + jnp.sum(prior.log_prob(theta))
samples, info = sample_with_slice(
jr.key(5), log_posterior, prior,
n_chains=4, n_samples=2_000, n_warmup=1_000,
)
Any of the algorithms – nuts, mala, rmh, imh – can be passed
to sbijax.mcmc.make_sampler(), which wraps exactly this pattern (target
loglik + log p(theta), N(0, I) initialisation, chosen kernel) behind
sbijax.sample(). Rolling it by hand lets you swap the kernel, change the
initialisation, or condition on a different prior without retraining.