Design philosophy#
sbijax is written in the function-first, low-level style of
Haiku. There are no estimator
classes and no hidden state: every method is a factory that returns a record
of pure functions, and the training and sampling loops are free functions
that operate on those records. This page explains the design and how it is
implemented.
Overview#
import optax
from jax import numpy as jnp, random as jr
from sbijax import npe, nle, train, sample, simulate
from sbijax.mcmc import make_sampler, nuts
from sbijax.nn import make_maf
# a factory takes only the network
obj = npe(make_maf(2))
# the prior is used only to generate data
data = simulate(jr.key(0), prior, simulator, n=10_000)
# `train` is a free driver; the optimizer is injected here
params, info = train(jr.key(1), obj, data, optimizer=optax.adam(3e-4))
# `sample` is a free driver; for amortized posteriors nothing else is needed
samples, sinfo = sample(jr.key(2), obj, params, y_obs)
# for likelihood/ratio methods the sampler (kernel + prior) is injected here
lik = nle(make_maf(2))
params, _ = train(jr.key(1), lik, data, optimizer=optax.adam(3e-4))
samples, _ = sample(
jr.key(2), lik, params, y_obs, sampler=make_sampler(nuts, prior=prior)
)
Principles#
Factories take only the network. npe(net), nle(net), nass(net)
and friends carry no prior, no optimizer, and no sampler. Every choice about
how is injected at the driver that owns it. This mirrors hk.transform(f),
which knows nothing about optax or your data.
Two free, symmetric drivers. train(rng, obj, data, *, optimizer=...) and
sample(rng, obj, params, observable, *, sampler=...) both take the objective
first and their βhowβ as a keyword. This is exactly BlackJAXβs split: the record
carries the bound primitives (SamplingAlgorithm.init/step) while the
loop is a free driver (run_inference_algorithm).
The prior enters only where a posterior is formed. Neural posterior methods
(npe/fmpe/npse) learn the posterior directly β the prior is baked
into the training data (theta ~ prior), so sampling just draws from the
network, prior-free. Neural likelihood/ratio methods (nle/nre) learn a
prior-independent object; the posterior p(y|theta) p(theta) is only formed
at sample time, so the prior travels inside the sampler. A trained likelihood
can therefore be reused under different priors without retraining β the point
of the method.
One generic training loop. Because train is a single function, there is no
per-method training code and no per-method Info record. Each objective
contributes only its loss (via step_fn/eval_fn), its parameter init, and
its sample_fn.
Components#
A factory returns a record of bound pure functions; two free drivers operate on it. The prior, optimizer, and sampler enter at the driver that uses each.
flowchart TB
subgraph factories["factories (network only)"]
npe["npe / fmpe / npse"]
nle["nle / snle / nre"]
nass["nass / nasss"]
abc["sabc / smcabc"]
end
subgraph records["records (bound pure fns)"]
OF["ObjectiveFns(train, sample_fn, extra)"]
SF["SummaryFns(train, summarize_fn)"]
AB["ABCSampler(sample)"]
end
subgraph tf["TrainFns primitives (bound)"]
initf["init_fn(optimizer, rng, batch)"]
stepf["step_fn(optimizer, rng, state, batch)"]
evalf["eval_fn(rng, state, batch)"]
end
subgraph drivers["free generic drivers"]
fitd["train(rng, obj, data, *, optimizer)"]
sampd["sample(rng, obj, params, y, *, sampler)"]
seqd["run_sequential(rng, obj, prior, simulator, y)"]
end
mk["make_sampler(kernel, *, prior)"]
npe --> OF
nle --> OF
nass --> SF
abc --> AB
OF --> tf
SF --> tf
tf --> fitd
OF --> sampd
mk --> sampd
fitd --> seqd
sampd --> seqd
The records#
Record |
Returned by |
Fields |
|---|---|---|
|
|
|
|
inside every trainable record |
|
|
|
|
|
|
|
|
|
|
|
|
|
These records are return values β you receive instances from the factories and drivers but never construct them yourself, so they are not part of the public namespace.
Primitive contracts#
The bound primitives on a trainable record share these signatures. The optimizer
is a leading argument of init_fn/step_fn; train binds it by closure
before jitting (a GradientTransformation closed over jits fine β only
storing it in the threaded TrainingState would not).
init_fn(optimizer, rng_key, batch) -> TrainingState(params, opt_state)
step_fn(optimizer, rng_key, state, batch) -> (metrics, TrainingState) # one optim step
eval_fn(rng_key, state, batch) -> metrics # validation, no update
sample_fn(rng_key, params, observable, *, sampler=None) -> (samples, info)
summarize_fn(params, data) -> summaries # SummaryFns only
metrics is a dict with at least {"loss": ...}. Amortized posterior
methods draw from the flow and ignore sampler; likelihood/ratio methods
require it. Summary networks reuse the TrainFns seam and are trained by the
same train, exposing summarize_fn instead of sample_fn. ABC samplers
do no training and expose only sample.
classDiagram
class TrainingState { +params; +opt_state }
class TrainFns {
+init_fn(optimizer, rng, batch) TrainingState
+step_fn(optimizer, rng, state, batch) tuple
+eval_fn(rng, state, batch) metrics
}
class ObjectiveFns { +train TrainFns; +sample_fn; +extra }
class SummaryFns { +train TrainFns; +summarize_fn }
class ABCSampler { +sample }
class Info { +round int; +losses }
TrainFns --> TrainingState
ObjectiveFns --> TrainFns
SummaryFns --> TrainFns
Training flow#
train reads obj.train, builds a TrainingState with init_fn, and
threads it through step_fn/eval_fn across epochs with early stopping and
best-parameter tracking. TrainingState never escapes train; the returned
artifact is params.
sequenceDiagram
participant U as caller
participant F as train (free)
participant T as obj.train
U->>F: train(rng, obj, data, optimizer)
F->>T: init_fn(optimizer, rng, batch)
T-->>F: state
loop epochs / batches
F->>T: step_fn(optimizer, rng, state, batch)
T-->>F: metrics, state
F->>T: eval_fn(rng, state, val_batch)
T-->>F: metrics
end
F-->>U: params, Info(round, losses)
Sampling flow#
For amortized posteriors, sample_fn draws directly from the trained flow.
For likelihood/ratio methods it builds the likelihood log-density from
(params, observable) and hands it to the injected sampler, which forms the
posterior target and runs the chains.
sequenceDiagram
participant U as caller
participant S as sample (free)
participant O as obj.sample_fn
participant K as sampler = make_sampler(nuts, prior=prior)
U->>S: sample(rng, obj, params, y, sampler=K)
S->>O: sample_fn(rng, params, y, sampler=K)
Note over O: loglik_fn(theta) = net.log_prob(y, theta)
O->>K: K(rng, loglik_fn, n_chains, n_samples, n_warmup)
Note over K: target = loglik + prior.log_prob; init ~ N(0, I)
K-->>O: samples, MCMCSampleInfo
O-->>U: (samples, info)
make_sampler(kernel, *, prior) bundles the MCMC kernel, the prior (which
forms the target loglik + log p(theta)), and N(0, I) chain
initialisation. Because the prior lives in the sampler, one trained likelihood is
reusable under different priors and kernels.
Samples and diagnostics#
sample returns (samples, info) where samples is the named prior
pytree ({"theta": array}, leaves of shape (n_chains, n_draws, dim)) and
info is a small sampling record (mean acceptance and, for multi-chain MCMC,
rhat/ess). There is no arviz/InferenceData and no plotting in the
library β build figures from the returned arrays and check convergence with
sbijax.ess() / sbijax.rhat() (thin re-exports of BlackJAX
diagnostics). Calibration is available through sbijax.sbc().
Sequential inference#
Multi-round inference is the free driver sbijax.run_sequential(), which
simulates from the current posterior each round, appends, and refits. It stays
out of the estimator: NPE switches to its atomic proposal-posterior loss in
rounds > 0 via an extra(prior) hook, and proposal-invariant methods reuse the
same objective.
sequenceDiagram
participant U as caller
participant R as run_sequential
participant F as train
U->>R: run_sequential(rng, obj, prior, simulator, y, n_rounds, sampler)
loop round r
Note over R: obj_r = obj (r==0) or obj.extra(prior) (r>0, npe atomic)
R->>R: simulate(prior, simulator, proposal)
R->>F: train(rng, obj_r, all_data, info=info, optimizer)
F-->>R: params, info
Note over R: proposal := sample(rng, obj, params, y, sampler)
end
R-->>U: params, info
See Migration guide: 0.3 β 0.4 for moving code from the class-based API.