Examples#
In the following, we demonstrate several sbijax methods using the complex βSimple Liklelihood Complex Posteriorβ model.
[1]:
import arviz as az
import jax
import optax
import os
import sbijax
import seaborn as sns
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoLocator, MaxNLocator
from jax import numpy as jnp, random as jr
from tensorflow_probability.substrates.jax import distributions as tfd
We remove some warnings that TFP is emitting, when using 64-bit arithmetic instead of 32-bit.
[2]:
import warnings
warnings.filterwarnings("ignore")
We implement a custom function to visualize posterior pairs.
[3]:
def plot_posteriors(obj):
cmap = sns.color_palette("rocket", as_cmap=False, desat=0.6, n_colors=10)
cmap = sns.blend_palette(cmap, as_cmap=True)
_, axes = plt.subplots(figsize=(12, 10), nrows=5, ncols=5)
with az.style.context(["arviz-doc"], after_reset=True):
for i in range(0, 5):
for j in range(0, 5):
ax = axes[i, j]
if i < j:
ax.axis('off')
else:
ax.hexbin(obj[..., j], obj[..., i], gridsize=50, bins='log', cmap=cmap)
ax.spines.left.set_linewidth(.5)
ax.spines.bottom.set_linewidth(.5)
ax.spines.right.set_linewidth(.5)
ax.spines.top.set_linewidth(.5)
ax.xaxis.set_major_locator(MaxNLocator(2))
ax.yaxis.set_major_locator(MaxNLocator(2))
ax.xaxis.set_tick_params(width=1, length=2, labelsize=25)
ax.yaxis.set_tick_params(width=1, length=2, labelsize=25)
if i != j:
ax.set_yticks([-3, 0, 3])
ax.set_xticks([-3, 0, 3])
else:
ax.set_yticklabels([])
if i < 4:
ax.set_xticklabels([])
ax.xaxis.set_tick_params(width=0., length=0)
if j != 0:
ax.set_yticklabels([])
ax.yaxis.set_tick_params(width=0., length=0)
ax.grid(which='major', axis='both', alpha=0.5)
for i in range(5):
axes[i, i].hist(obj[..., i], color="black")
return axes
[4]:
def plot_ess_and_trace(inference_results):
_, axes = plt.subplots(figsize=(10, 8), nrows=5, ncols=3)
with az.style.context(["arviz-doc"], after_reset=True):
plt.rcParams["font.family"] = "Times New Roman"
sns.color_palette("rocket_r", as_cmap=False, desat=0.6, n_colors=10)
ax = az.plot_ess(
inference_results,
ax=[axes[i, 2] for i in range(5)],
color="#777777",
extra_kwargs={"color": "#98375d"},
kind="evolution"
)
for ax in [axes[i, 2] for i in range(5)]:
ax.axline((0, 0), slope=1, color="black", ls="--")
colors = sns.color_palette("rocket_r", as_cmap=False, desat=0.6, n_colors=10)
az.plot_rank(inference_results, ax=[axes[i, 1] for i in range(5)], kind='vlines', colors=colors, vlines_kwargs={"alpha":0.15}, marker_vlines_kwargs={"linestyle":'None', "marker": "o", "ms":3, "alpha": 0.75})
for i in range(5):
for j in range(10):
axes[i, 0].plot(slice_samples.reshape(10, 5000, 5)[j, :, i], color=colors[j], alpha=0.15)
axes[i, 0].set_ylabel(rf"$\theta_{i}$", fontsize=15)
axes[i, 1].set_ylabel(None)
axes[i, 2].set_ylabel(None)
for i, ax in enumerate(axes.flatten()):
ax.set_title(None)
ax.spines[['right', 'top']].set_visible(False)
ax.spines.left.set_linewidth(.5)
ax.spines.bottom.set_linewidth(.5)
ax.yaxis.set_major_locator(AutoLocator())
ax.set_xlabel(None)
if i in [13, 14]:
ax.set_xlabel("Total number of draws", fontsize=15)
if i == 12:
ax.set_xlabel("Number of draws per chain", fontsize=15)
if ax.get_legend() is not None:
ax.get_legend().remove()
ax.yaxis.set_tick_params(labelsize=12)
ax.xaxis.set_tick_params(labelsize=12)
ax.xaxis.set_tick_params(width=0.5, length=2)
ax.yaxis.set_tick_params(width=0.5, length=2)
ax.grid(which='major', axis='both', alpha=0.5)
if i != [12, 13, 14]:
ax.set_xticklabels([])
if i in [2, 5, 8, 11, 14]:
ax.set_yticklabels([])
axes[4, 2].legend(["Bulk ESS", "Tail ESS"], fontsize=12)
return axes
We then define the generative model.
[5]:
def prior_fn():
prior = tfd.JointDistributionNamed(dict(
theta=tfd.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0))
), batch_ndims=0)
return prior
def simulator_fn(seed, theta):
theta = theta["theta"]
theta = theta[:, None, :]
us_key, noise_key = jr.split(seed)
def _unpack_params(ps):
m0 = ps[..., [0]]
m1 = ps[..., [1]]
s0 = ps[..., [2]] ** 2
s1 = ps[..., [3]] ** 2
r = jnp.tanh(ps[..., [4]])
return m0, m1, s0, s1, r
m0, m1, s0, s1, r = _unpack_params(theta)
us = tfd.Normal(0.0, 1.0).sample(
seed=us_key, sample_shape=(theta.shape[0], theta.shape[1], 4, 2)
)
xs = jnp.empty_like(us)
xs = xs.at[:, :, :, 0].set(s0 * us[:, :, :, 0] + m0)
y = xs.at[:, :, :, 1].set(
s1 * (r * us[:, :, :, 0] + jnp.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1
)
y = y.reshape((*theta.shape[:1], 8))
return y
[6]:
y_obs = jnp.array([[
-0.9707123,
-2.9461224,
-0.4494722,
-3.4231849,
-0.13285634,
-3.364017,
-0.85367596,
-2.4271638,
]])
MCMC#
We first sample from the βtrueβ posterior using MCMC, specifically a slice sampler.
[7]:
from functools import partial
from jax import scipy as jsp
from sbijax import as_inference_data
from sbijax.mcmc import sample_with_nuts, sample_with_slice
[8]:
def likelihood_fn(theta, y):
mu = jnp.tile(theta[:2], 4)
s1, s2 = theta[2] ** 2, theta[3] ** 2
corr = s1 * s2 * jnp.tanh(theta[4])
cov = jnp.array([[s1**2, corr], [corr, s2**2]])
cov = jsp.linalg.block_diag(*[cov for _ in range(4)])
p = tfd.MultivariateNormalFullCovariance(mu, cov)
return p.log_prob(y)
def log_density_fn(theta, y):
prior_lp = tfd.JointDistributionNamed(dict(
theta=tfd.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0))
)).log_prob(theta)
likelihood_lp = likelihood_fn(theta, y)
lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp)
return lp
[9]:
log_density = partial(log_density_fn, y=y_obs)
def lp(theta):
return jax.vmap(log_density)(theta)
slice_samples = sample_with_slice(
jr.PRNGKey(0),
lp,
prior_fn().sample,
n_chains=10,
n_samples=10_000,
n_warmup=5_000
)
We then compute model diagnostics using Arviz.
[10]:
slice_inference_data = as_inference_data({"theta": slice_samples.reshape(10, 5000, 5)}, y_obs)
[11]:
_, axes = plt.subplots(figsize=(9, 3), ncols=2)
sbijax.plot_rhat_and_ress(slice_inference_data, axes=axes)
for ax in axes:
ax.xaxis.label.set_size(16)
ax.yaxis.label.set_size(20)
ax.tick_params(axis='both',labelsize=13)
plt.tight_layout()
plt.show()
[12]:
def plot_ess_and_trace(inference_results):
_, axes = plt.subplots(figsize=(10, 8), nrows=5, ncols=3)
with az.style.context(["arviz-doc"], after_reset=True):
plt.rcParams["font.family"] = "Times New Roman"
cols = sns.color_palette("rocket_r", as_cmap=False, desat=0.4, n_colors=10).as_hex()
ax = az.plot_ess(
inference_results,
ax=[axes[i, 2] for i in range(5)],
color="#777777",
extra_kwargs={"color": "#884761"},
kind="evolution"
)
for ax in [axes[i, 2] for i in range(5)]:
ax.axline((0, 0), slope=1, color="black", ls="--")
colors = sns.color_palette("rocket_r", as_cmap=False, desat=0.6, n_colors=10)
az.plot_rank(inference_results, ax=[axes[i, 1] for i in range(5)], kind='vlines', colors=colors, vlines_kwargs={"alpha":0.15}, marker_vlines_kwargs={"linestyle":'None', "marker": "o", "ms":3, "alpha": 0.75})
for i in range(5):
for j in range(10):
axes[i, 0].plot(slice_samples.reshape(10, 5000, 5)[j, :, i], color=colors[j], alpha=0.15)
axes[i, 0].set_ylabel(rf"$\theta_{i}$", fontsize=16)
axes[i, 1].set_ylabel(None)
axes[i, 2].set_ylabel(None)
for i, ax in enumerate(axes.flatten()):
ax.set_title(None)
ax.spines[['right', 'top']].set_visible(False)
ax.spines.left.set_linewidth(.5)
ax.spines.bottom.set_linewidth(.5)
ax.yaxis.set_major_locator(AutoLocator())
ax.set_xlabel(None)
if i in [13, 14]:
ax.set_xlabel("Total number of draws", fontsize=16)
if i == 12:
ax.set_xlabel("Number of draws per chain", fontsize=16)
if ax.get_legend() is not None:
ax.get_legend().remove()
ax.yaxis.set_tick_params(labelsize=12)
ax.xaxis.set_tick_params(labelsize=12)
ax.xaxis.set_tick_params(width=0.5, length=2)
ax.yaxis.set_tick_params(width=0.5, length=2)
ax.grid(which='major', axis='both', alpha=0.5)
if i != [12, 13, 14]:
ax.set_xticklabels([])
if i in [2, 5, 8, 11, 14]:
ax.set_yticklabels([])
axes[4, 2].legend(["Bulk ESS", "Tail ESS"], fontsize=12)
return axes
plot_ess_and_trace(slice_inference_data)
plt.tight_layout()
plt.show()
[13]:
plot_posteriors(slice_samples.reshape(-1, 5))
plt.tight_layout()
plt.show()
SNLE#
Next, we use surjective neural likelihood estimation to compute a posterior distribution.
[14]:
from sbijax import SNLE, inference_data_as_dictionary
from sbijax.nn import make_maf
[15]:
n_dim_data = 8
n_layer_dimensions, hidden_sizes = (8, 8, 5, 5, 5), (64, 64)
neural_network = make_maf(
n_dim_data,
n_layer_dimensions=n_layer_dimensions,
hidden_sizes=hidden_sizes
)
fns = prior_fn, simulator_fn
snle = SNLE(fns, neural_network)
[16]:
data, snle_params = None, {}
for i in range(15):
data, _ = snle.simulate_data_and_possibly_append(
jr.fold_in(jr.PRNGKey(1), i),
params=snle_params,
observable=y_obs,
data=data,
)
snle_params, info = snle.fit(
jr.fold_in(jr.PRNGKey(2), i), data=data
)
40%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 401/1000 [01:14<01:51, 5.36it/s]
27%|βββββββββββββββββββββββββββββββββββββββ | 274/1000 [00:55<02:27, 4.93it/s]
27%|ββββββββββββββββββββββββββββββββββββββ | 270/1000 [00:55<02:28, 4.91it/s]
20%|βββββββββββββββββββββββββββββ | 203/1000 [00:44<02:55, 4.55it/s]
28%|ββββββββββββββββββββββββββββββββββββββββ | 279/1000 [01:05<02:49, 4.25it/s]
23%|βββββββββββββββββββββββββββββββββ | 231/1000 [00:54<03:02, 4.22it/s]
22%|ββββββββββββββββββββββββββββββββ | 224/1000 [00:55<03:12, 4.02it/s]
15%|ββββββββββββββββββββββ | 151/1000 [00:39<03:41, 3.84it/s]
41%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 409/1000 [01:52<02:42, 3.64it/s]
10%|ββββββββββββββ | 98/1000 [00:28<04:19, 3.47it/s]
34%|ββββββββββββββββββββββββββββββββββββββββββββββββ | 339/1000 [01:43<03:21, 3.28it/s]
22%|ββββββββββββββββββββββββββββββββ | 224/1000 [01:10<04:03, 3.19it/s]
34%|ββββββββββββββββββββββββββββββββββββββββββββββββ | 336/1000 [01:47<03:32, 3.13it/s]
21%|ββββββββββββββββββββββββββββββ | 207/1000 [01:14<04:44, 2.79it/s]
22%|βββββββββββββββββββββββββββββββ | 218/1000 [01:20<04:48, 2.71it/s]
[17]:
snle_inference_results, diagnostics = snle.sample_posterior(
jr.PRNGKey(5), snle_params, y_obs, n_samples=5_000, n_warmup=2_500, n_chains=10
)
[18]:
plot_posteriors(
inference_data_as_dictionary(snle_inference_results.posterior)["theta"],
)
plt.tight_layout()
plt.show()
FMPE#
As a comparison, we use flow matching posterior estimation.
[19]:
from sbijax import FMPE
from sbijax.nn import make_cnf
[20]:
n_dim_theta = 5
n_layers, hidden_size = 5, 128
neural_network = make_cnf(n_dim_theta, n_layers, hidden_size)
fns = prior_fn, simulator_fn
fmpe = FMPE(fns, neural_network)
[21]:
data, _ = fmpe.simulate_data(
jr.PRNGKey(1),
n_simulations=20_000,
)
fmpe_params, info = fmpe.fit(
jr.PRNGKey(2),
data=data,
optimizer=optax.adam(0.001),
n_early_stopping_delta=0.00001,
n_early_stopping_patience=30
)
10%|ββββββββββββββ | 96/1000 [01:49<17:09, 1.14s/it]
[22]:
fmpe_inference_results, diagnostics = fmpe.sample_posterior(
jr.PRNGKey(5), fmpe_params, y_obs, n_samples=25_000
)
[23]:
plot_posteriors(
inference_data_as_dictionary(fmpe_inference_results.posterior)["theta"],
)
plt.tight_layout()
plt.show()
SMC-ABC#
Finally, we evaluate SMC-ABC using neural sufficient statistics.
[24]:
from sbijax import NASS, SMCABC, inference_data_as_dictionary
from sbijax.nn import make_nass_net
[25]:
n_embedding_dim, hidden_sizes = 5, (64, 64)
neural_network = make_nass_net(n_embedding_dim, hidden_sizes)
fns = prior_fn, simulator_fn
model_nass = NASS(fns, neural_network)
data, _ = model_nass.simulate_data(jr.PRNGKey(1), n_simulations=20_000)
params_nass, _ = model_nass.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25)
17%|ββββββββββββββββββββββββ | 170/1000 [02:18<11:18, 1.22it/s]
[26]:
def summary_fn(y):
s = model_nass.summarize(params_nass, y)
return s
def distance_fn(y_simulated, y_observed):
diff = y_simulated - y_observed
dist = jax.vmap(lambda el: jnp.linalg.norm(el))(diff)
return dist
[27]:
model_smc = SMCABC(fns, summary_fn, distance_fn)
smc_inference_results, _ = model_smc.sample_posterior(
jr.PRNGKey(5),
y_obs,
n_rounds=10,
n_particles=5_000,
eps_step=0.825,
ess_min=2_000
)
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 10/10 [04:58<00:00, 29.89s/it]
[28]:
plot_posteriors(
inference_data_as_dictionary(smc_inference_results.posterior)["theta"],
)
plt.tight_layout()
plt.show()
Session info#
[29]:
import session_info
session_info.show(html=False)
-----
arviz 0.19.0
haiku 0.0.12
jax 0.4.31
jaxlib 0.4.31
matplotlib 3.9.2
optax 0.2.4
sbijax 0.3.3
seaborn 0.12.2
session_info 1.0.0
tensorflow_probability 0.25.0
-----
IPython 8.31.0
jupyter_client 8.6.3
jupyter_core 5.7.2
jupyterlab 4.3.1
-----
Python 3.11.7 (main, Dec 9 2023, 06:06:18) [Clang 14.0.3 (clang-1403.0.22.14.1)]
macOS-13.0.1-arm64-arm-64bit
-----
Session information updated at 2025-01-04 20:33