Source code for sbijax._src.simulators.sir

import jax
from diffrax import ODETerm, SaveAt, Tsit5, diffeqsolve
from jax import numpy as jnp
from tensorflow_probability.substrates.jax import distributions as tfd


# ruff: noqa: PLR0913, E501
[docs] def sir( population_size=1_000_000, binomial_count=1_000, initial_conditions=(1_000_000 - 1, 1, 0), t_end=160, summarize_data=False, ): """SIR model. Construct prior, simulator, and likelihood functions. Args: population_size: the size of the population for the SIR model binomial_count: the number of Bernoulli trials for the Binomial likelihood initial_conditions: tuple of three integers that should sum to population_size t_end: end time of the ODE summarize_data: if true returns the data from the simulator in a summarized version of 5 values. Otherwise returns the infection counts of the ODE. Returns: returns a tuple of three objects. The first is a tfd.JointDistributionNamed serving as a prior distribution. The second is a simulator function that can be used to generate data. The third is the likelihood function. References: Lueckmann, Jan-Matthis, et al., "Benchmarking Simulation-Based Inference", 2021. """ dt = 0.1 ts = jnp.linspace(0, t_end, t_end) def prior_fn(): prior = tfd.JointDistributionNamed( dict( beta=tfd.LogNormal(jnp.log(jnp.array([0.4])), 0.5), gamma=tfd.LogNormal(jnp.log(jnp.array([0.125])), 0.2), ) ) return prior def _solve_ode(theta): def f(t, y, args): susceptible, infected, _ = y beta, gamma = jnp.squeeze(args["beta"]), jnp.squeeze(args["gamma"]) d_s = -beta * susceptible * infected / population_size d_i = beta * susceptible * infected / population_size - gamma * infected d_r = gamma * infected d_y = d_s, d_i, d_r return d_y term = ODETerm(f) solver = Tsit5() saveat = SaveAt(ts=ts) def _solve(args): sol = diffeqsolve( term, solver, 0, t_end, dt, initial_conditions, args=args, saveat=saveat, ) return sol.ys[1] infections = jax.vmap(_solve)(theta) return infections def summarize(ys): I_max = jnp.max(ys) t_peak = ts[jnp.argmax(ys)] R_final = jnp.sum(ys) mean_I = jnp.mean(ys) var_I = jnp.var(ys) return jnp.array([I_max, t_peak, R_final, mean_I, var_I]) def simulator(seed, theta): infections = _solve_ode(theta) probs = jnp.clip(infections / population_size, 0, 1) ys = tfd.Binomial(total_count=binomial_count, probs=probs).sample(seed=seed) if summarize_data: return jax.vmap(summarize)(ys) return ys def likelihood(y, theta): infections = _solve_ode(theta) probs = jnp.clip(infections / population_size, 0, 1) lik_fn = tfd.Independent( tfd.Binomial(total_count=binomial_count, probs=probs), reinterpreted_batch_ndims=1, ) log_lik = lik_fn.log_prob(y) return log_lik return prior_fn(), simulator, likelihood