Source code for sbijax._src.experimental.cmpe

"""Consistency model posterior estimation (functional objective, design B).

Implements the CMPE algorithm of :cite:t:`schmitt2023con`.  The consistency
loss requires an EMA copy of the network weights; both the live params and the
EMA params are stored together as a dict under ``TrainingState.params`` so the
shared ``fit`` driver needs no changes:

    state.params == {"params": live_params, "ema_params": ema_params}

``sample_fn`` receives this dict and extracts ``params["params"]`` before
calling the flow-rejection sampler.
"""

# ruff: noqa: PLR0913
import jax
import optax
from jax import numpy as jnp
from jax import random as jr

from sbijax._src.inference.posterior._sampling import rejection_sample_flow
from sbijax._src.train._types import ObjectiveFns, TrainFns, TrainingState


def _alpha_t(time):
  return 1.0 / (_time_schedule(time + 1) - _time_schedule(time))


def _time_schedule(n, rho=7, t_min=0.001, t_max=50, n_inters=1000):
  left = t_min ** (1 / rho)
  right = t_max ** (1 / rho) - t_min ** (1 / rho)
  right = (n - 1) / (n_inters - 1) * right
  return (left + right) ** rho


def _discretization_schedule(n_iter, max_iter=1000):
  s0, s1 = 10, 50
  nk = (
    (n_iter / max_iter) * (jnp.square(s1 + 1) - jnp.square(s0))
    + jnp.square(s0)
    - 1
  )
  nk = jnp.ceil(jnp.sqrt(nk)) + 1
  return nk


[docs] def cmpe(network, *, t_min=0.001, t_max=50.0): """Construct a consistency model posterior objective. The returned ``ObjectiveFns`` is trained via the shared :func:`~sbijax.train` driver. EMA params are threaded through ``TrainingState.params`` as a dict: ``{"params": live_params, "ema_params": ema_params}``. Args: network: a consistency model with ``vector_field`` and ``sample`` methods t_min: minimal time point for ODE integration t_max: maximal time point for ODE integration Returns: an ``ObjectiveFns`` """ def _loss(params_dict, rng_key, batch, is_training): params = params_dict["params"] ema_params = params_dict["ema_params"] theta = batch["theta"] # n_iter fixed at reference value so the schedule is consistent across fit. n_iter = 1001 nk = _discretization_schedule(n_iter) t_key, rng_key = jr.split(rng_key) time_idx = jr.randint( t_key, shape=(theta.shape[0],), minval=1, maxval=nk - 1 ) tn = _time_schedule( time_idx, t_min=t_min, t_max=t_max, n_inters=nk ).reshape(-1, 1) tnp1 = _time_schedule( time_idx + 1, t_min=t_min, t_max=t_max, n_inters=nk ).reshape(-1, 1) noise_key, rng_key = jr.split(rng_key) noise = jr.normal(noise_key, shape=(*theta.shape,)) train_rng, rng_key = jr.split(rng_key) fnp1 = network.apply( params, train_rng, method="vector_field", theta=theta + tnp1 * noise, time=tnp1, context=batch["y"], is_training=is_training, ) fn = network.apply( ema_params, train_rng, method="vector_field", theta=theta + tn * noise, time=tn, context=batch["y"], is_training=is_training, ) mse = jnp.sqrt(jnp.mean(jnp.square(fnp1 - fn), axis=1)) loss = _alpha_t(time_idx) * mse return jnp.mean(loss) def init_fn(optimizer, rng_key, batch): times = jr.uniform(rng_key, shape=(batch["y"].shape[0], 1)) params = network.init( rng_key, method="vector_field", theta=batch["theta"], time=times, context=batch["y"], is_training=True, ) params_dict = {"params": params, "ema_params": params} return TrainingState(params=params_dict, opt_state=optimizer.init(params)) def step_fn(optimizer, rng_key, state, batch): loss, grads = jax.value_and_grad(_loss)(state.params, rng_key, batch, True) # grad is only valid for live params; zero out ema_params gradient live_grads = grads["params"] updates, opt_state = optimizer.update( live_grads, state.opt_state, state.params["params"] ) new_params = optax.apply_updates(state.params["params"], updates) new_ema = optax.incremental_update( new_params, state.params["ema_params"], step_size=0.01 ) new_params_dict = {"params": new_params, "ema_params": new_ema} return {"loss": loss}, TrainingState(new_params_dict, opt_state) def eval_fn(rng_key, state, batch): return {"loss": _loss(state.params, rng_key, batch, False)} def sample_fn(rng_key, params, observable, *, n_samples=4_000, **kwargs): # params is the params_dict; sampling uses only the live params. live_params = params["params"] return rejection_sample_flow( rng_key, network, live_params, observable, n_samples ) return ObjectiveFns(TrainFns(init_fn, step_fn, eval_fn), sample_fn)