Examples

Examples#

In the following, we demonstrate several sbijax methods using the complex β€œSimple Liklelihood Complex Posterior” model.

[1]:
import arviz as az
import jax
import optax
import os
import sbijax
import seaborn as sns
%matplotlib inline
import matplotlib.pyplot as plt

from matplotlib.ticker import AutoLocator, MaxNLocator
from jax import numpy as jnp, random as jr
from tensorflow_probability.substrates.jax import distributions as tfd

We remove some warnings that TFP is emitting, when using 64-bit arithmetic instead of 32-bit.

[2]:
import warnings
warnings.filterwarnings("ignore")

We implement a custom function to visualize posterior pairs.

[3]:
def plot_posteriors(obj):
    cmap = sns.color_palette("rocket", as_cmap=False, desat=0.6, n_colors=10)
    cmap = sns.blend_palette(cmap, as_cmap=True)

    _, axes = plt.subplots(figsize=(12, 10), nrows=5, ncols=5)
    with az.style.context(["arviz-doc"], after_reset=True):
        for i in range(0, 5):
            for j in range(0, 5):
                ax = axes[i, j]
                if i < j:
                    ax.axis('off')
                else:
                    ax.hexbin(obj[..., j], obj[..., i], gridsize=50, bins='log', cmap=cmap)
                ax.spines.left.set_linewidth(.5)
                ax.spines.bottom.set_linewidth(.5)
                ax.spines.right.set_linewidth(.5)
                ax.spines.top.set_linewidth(.5)
                ax.xaxis.set_major_locator(MaxNLocator(2))
                ax.yaxis.set_major_locator(MaxNLocator(2))
                ax.xaxis.set_tick_params(width=1, length=2, labelsize=25)
                ax.yaxis.set_tick_params(width=1, length=2, labelsize=25)
                if i != j:
                    ax.set_yticks([-3, 0, 3])
                    ax.set_xticks([-3, 0, 3])
                else:
                    ax.set_yticklabels([])
                if i < 4:
                    ax.set_xticklabels([])
                    ax.xaxis.set_tick_params(width=0., length=0)
                if j != 0:
                    ax.set_yticklabels([])
                    ax.yaxis.set_tick_params(width=0., length=0)
                ax.grid(which='major', axis='both', alpha=0.5)
        for i in range(5):
            axes[i, i].hist(obj[..., i], color="black")
    return axes
[4]:
def plot_ess_and_trace(inference_results):
    _, axes = plt.subplots(figsize=(10, 8), nrows=5, ncols=3)

    with az.style.context(["arviz-doc"], after_reset=True):
        plt.rcParams["font.family"] = "Times New Roman"
        sns.color_palette("rocket_r", as_cmap=False, desat=0.6, n_colors=10)
        ax = az.plot_ess(
            inference_results,
            ax=[axes[i, 2] for i in range(5)],
            color="#777777",
            extra_kwargs={"color": "#98375d"},
            kind="evolution"
        )
        for ax in [axes[i, 2] for i in range(5)]:
            ax.axline((0, 0), slope=1, color="black", ls="--")
        colors = sns.color_palette("rocket_r", as_cmap=False, desat=0.6, n_colors=10)
        az.plot_rank(inference_results, ax=[axes[i, 1] for i in range(5)],  kind='vlines', colors=colors, vlines_kwargs={"alpha":0.15}, marker_vlines_kwargs={"linestyle":'None', "marker": "o", "ms":3, "alpha": 0.75})
        for i in range(5):
            for j in range(10):
                axes[i, 0].plot(slice_samples.reshape(10, 5000, 5)[j, :, i], color=colors[j], alpha=0.15)
                axes[i, 0].set_ylabel(rf"$\theta_{i}$", fontsize=15)
                axes[i, 1].set_ylabel(None)
                axes[i, 2].set_ylabel(None)
        for i, ax in enumerate(axes.flatten()):
            ax.set_title(None)
            ax.spines[['right', 'top']].set_visible(False)
            ax.spines.left.set_linewidth(.5)
            ax.spines.bottom.set_linewidth(.5)
            ax.yaxis.set_major_locator(AutoLocator())
            ax.set_xlabel(None)
            if i in [13, 14]:
                ax.set_xlabel("Total number of draws", fontsize=15)
            if i == 12:
                ax.set_xlabel("Number of draws per chain", fontsize=15)
            if ax.get_legend() is not None:
                ax.get_legend().remove()

            ax.yaxis.set_tick_params(labelsize=12)
            ax.xaxis.set_tick_params(labelsize=12)
            ax.xaxis.set_tick_params(width=0.5, length=2)
            ax.yaxis.set_tick_params(width=0.5, length=2)
            ax.grid(which='major', axis='both', alpha=0.5)
            if i != [12, 13, 14]:
                ax.set_xticklabels([])
            if i in [2, 5, 8, 11, 14]:
                ax.set_yticklabels([])
            axes[4, 2].legend(["Bulk ESS", "Tail ESS"], fontsize=12)
        return axes

We then define the generative model.

[5]:
def prior_fn():
    prior = tfd.JointDistributionNamed(dict(
        theta=tfd.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0))
    ), batch_ndims=0)
    return prior


def simulator_fn(seed, theta):
    theta = theta["theta"]
    theta = theta[:, None, :]
    us_key, noise_key = jr.split(seed)

    def _unpack_params(ps):
        m0 = ps[..., [0]]
        m1 = ps[..., [1]]
        s0 = ps[..., [2]] ** 2
        s1 = ps[..., [3]] ** 2
        r = jnp.tanh(ps[..., [4]])
        return m0, m1, s0, s1, r

    m0, m1, s0, s1, r = _unpack_params(theta)
    us = tfd.Normal(0.0, 1.0).sample(
        seed=us_key, sample_shape=(theta.shape[0], theta.shape[1], 4, 2)
    )
    xs = jnp.empty_like(us)
    xs = xs.at[:, :, :, 0].set(s0 * us[:, :, :, 0] + m0)
    y = xs.at[:, :, :, 1].set(
        s1 * (r * us[:, :, :, 0] + jnp.sqrt(1.0 - r**2) * us[:, :, :, 1]) + m1
    )
    y = y.reshape((*theta.shape[:1], 8))
    return y
[6]:
y_obs = jnp.array([[
    -0.9707123,
    -2.9461224,
    -0.4494722,
    -3.4231849,
    -0.13285634,
    -3.364017,
    -0.85367596,
    -2.4271638,
]])

MCMC#

We first sample from the β€œtrue” posterior using MCMC, specifically a slice sampler.

[7]:
from functools import partial
from jax import scipy as jsp
from sbijax import as_inference_data
from sbijax.mcmc import sample_with_nuts, sample_with_slice
[8]:
def likelihood_fn(theta, y):
    mu = jnp.tile(theta[:2], 4)
    s1, s2 = theta[2] ** 2, theta[3] ** 2
    corr = s1 * s2 * jnp.tanh(theta[4])
    cov = jnp.array([[s1**2, corr], [corr, s2**2]])
    cov = jsp.linalg.block_diag(*[cov for _ in range(4)])
    p = tfd.MultivariateNormalFullCovariance(mu, cov)
    return p.log_prob(y)


def log_density_fn(theta, y):
    prior_lp = tfd.JointDistributionNamed(dict(
        theta=tfd.Uniform(jnp.full(5, -3.0), jnp.full(5, 3.0))
    )).log_prob(theta)
    likelihood_lp = likelihood_fn(theta, y)
    lp = jnp.sum(prior_lp) + jnp.sum(likelihood_lp)
    return lp
[9]:
log_density = partial(log_density_fn, y=y_obs)

def lp(theta):
    return jax.vmap(log_density)(theta)

slice_samples = sample_with_slice(
    jr.PRNGKey(0),
    lp,
    prior_fn().sample,
    n_chains=10,
    n_samples=10_000,
    n_warmup=5_000
)

We then compute model diagnostics using Arviz.

[10]:
slice_inference_data = as_inference_data({"theta": slice_samples.reshape(10, 5000, 5)}, y_obs)
[11]:
_, axes = plt.subplots(figsize=(9, 3), ncols=2)
sbijax.plot_rhat_and_ress(slice_inference_data, axes=axes)
for ax in axes:
    ax.xaxis.label.set_size(16)
    ax.yaxis.label.set_size(20)
    ax.tick_params(axis='both',labelsize=13)
plt.tight_layout()
plt.show()
../_images/notebooks_examples_16_0.png
[12]:
def plot_ess_and_trace(inference_results):
    _, axes = plt.subplots(figsize=(10, 8), nrows=5, ncols=3)

    with az.style.context(["arviz-doc"], after_reset=True):
        plt.rcParams["font.family"] = "Times New Roman"
        cols = sns.color_palette("rocket_r", as_cmap=False, desat=0.4, n_colors=10).as_hex()
        ax = az.plot_ess(
            inference_results,
            ax=[axes[i, 2] for i in range(5)],
            color="#777777",
            extra_kwargs={"color": "#884761"},
            kind="evolution"
        )
        for ax in [axes[i, 2] for i in range(5)]:
            ax.axline((0, 0), slope=1, color="black", ls="--")
        colors = sns.color_palette("rocket_r", as_cmap=False, desat=0.6, n_colors=10)
        az.plot_rank(inference_results, ax=[axes[i, 1] for i in range(5)],  kind='vlines', colors=colors, vlines_kwargs={"alpha":0.15}, marker_vlines_kwargs={"linestyle":'None', "marker": "o", "ms":3, "alpha": 0.75})
        for i in range(5):
            for j in range(10):
                axes[i, 0].plot(slice_samples.reshape(10, 5000, 5)[j, :, i], color=colors[j], alpha=0.15)
                axes[i, 0].set_ylabel(rf"$\theta_{i}$", fontsize=16)
                axes[i, 1].set_ylabel(None)
                axes[i, 2].set_ylabel(None)
        for i, ax in enumerate(axes.flatten()):
            ax.set_title(None)
            ax.spines[['right', 'top']].set_visible(False)
            ax.spines.left.set_linewidth(.5)
            ax.spines.bottom.set_linewidth(.5)
            ax.yaxis.set_major_locator(AutoLocator())
            ax.set_xlabel(None)
            if i in [13, 14]:
                ax.set_xlabel("Total number of draws", fontsize=16)
            if i == 12:
                ax.set_xlabel("Number of draws per chain", fontsize=16)
            if ax.get_legend() is not None:
                ax.get_legend().remove()

            ax.yaxis.set_tick_params(labelsize=12)
            ax.xaxis.set_tick_params(labelsize=12)
            ax.xaxis.set_tick_params(width=0.5, length=2)
            ax.yaxis.set_tick_params(width=0.5, length=2)
            ax.grid(which='major', axis='both', alpha=0.5)
            if i != [12, 13, 14]:
                ax.set_xticklabels([])
            if i in [2, 5, 8, 11, 14]:
                ax.set_yticklabels([])
            axes[4, 2].legend(["Bulk ESS", "Tail ESS"], fontsize=12)
        return axes

plot_ess_and_trace(slice_inference_data)
plt.tight_layout()
plt.show()
../_images/notebooks_examples_17_0.png
[13]:
plot_posteriors(slice_samples.reshape(-1, 5))
plt.tight_layout()
plt.show()
../_images/notebooks_examples_18_0.png

SNLE#

Next, we use surjective neural likelihood estimation to compute a posterior distribution.

[14]:
from sbijax import SNLE, inference_data_as_dictionary
from sbijax.nn import make_maf
[15]:
n_dim_data = 8
n_layer_dimensions, hidden_sizes = (8, 8, 5, 5, 5), (64, 64)
neural_network = make_maf(
    n_dim_data,
    n_layer_dimensions=n_layer_dimensions,
    hidden_sizes=hidden_sizes
)

fns = prior_fn, simulator_fn
snle = SNLE(fns, neural_network)
[16]:
data, snle_params = None, {}
for i in range(15):
    data, _ = snle.simulate_data_and_possibly_append(
        jr.fold_in(jr.PRNGKey(1), i),
        params=snle_params,
        observable=y_obs,
        data=data,
    )
    snle_params, info = snle.fit(
        jr.fold_in(jr.PRNGKey(2), i), data=data
    )
 40%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ                                                                                    | 401/1000 [01:14<01:51,  5.36it/s]
 27%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹                                                                                                      | 274/1000 [00:55<02:27,  4.93it/s]
 27%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ                                                                                                       | 270/1000 [00:55<02:28,  4.91it/s]
 20%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ                                                                                                                | 203/1000 [00:44<02:55,  4.55it/s]
 28%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž                                                                                                     | 279/1000 [01:05<02:49,  4.25it/s]
 23%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ                                                                                                            | 231/1000 [00:54<03:02,  4.22it/s]
 22%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ                                                                                                             | 224/1000 [00:55<03:12,  4.02it/s]
 15%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž                                                                                                                       | 151/1000 [00:39<03:41,  3.84it/s]
 41%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹                                                                                   | 409/1000 [01:52<02:42,  3.64it/s]
 10%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰                                                                                                                                | 98/1000 [00:28<04:19,  3.47it/s]
 34%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š                                                                                             | 339/1000 [01:43<03:21,  3.28it/s]
 22%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ                                                                                                             | 224/1000 [01:10<04:03,  3.19it/s]
 34%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–                                                                                             | 336/1000 [01:47<03:32,  3.13it/s]
 21%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–                                                                                                               | 207/1000 [01:14<04:44,  2.79it/s]
 22%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹                                                                                                              | 218/1000 [01:20<04:48,  2.71it/s]
[17]:
snle_inference_results, diagnostics = snle.sample_posterior(
    jr.PRNGKey(5), snle_params, y_obs, n_samples=5_000, n_warmup=2_500, n_chains=10
)
[18]:
plot_posteriors(
    inference_data_as_dictionary(snle_inference_results.posterior)["theta"],
)
plt.tight_layout()
plt.show()
../_images/notebooks_examples_24_0.png

FMPE#

As a comparison, we use flow matching posterior estimation.

[19]:
from sbijax import FMPE
from sbijax.nn import make_cnf
[20]:
n_dim_theta = 5
n_layers, hidden_size = 5, 128
neural_network = make_cnf(n_dim_theta, n_layers, hidden_size)

fns = prior_fn, simulator_fn
fmpe = FMPE(fns, neural_network)
[21]:
data, _ = fmpe.simulate_data(
    jr.PRNGKey(1),
    n_simulations=20_000,
)
fmpe_params, info = fmpe.fit(
    jr.PRNGKey(2),
    data=data,
    optimizer=optax.adam(0.001),
    n_early_stopping_delta=0.00001,
    n_early_stopping_patience=30
)
 10%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹                                                                                                                                | 96/1000 [01:49<17:09,  1.14s/it]
[22]:
fmpe_inference_results, diagnostics = fmpe.sample_posterior(
    jr.PRNGKey(5), fmpe_params, y_obs, n_samples=25_000
)
[23]:
plot_posteriors(
    inference_data_as_dictionary(fmpe_inference_results.posterior)["theta"],
)
plt.tight_layout()
plt.show()
../_images/notebooks_examples_30_0.png

SMC-ABC#

Finally, we evaluate SMC-ABC using neural sufficient statistics.

[24]:
from sbijax import NASS, SMCABC, inference_data_as_dictionary
from sbijax.nn import make_nass_net
[25]:
n_embedding_dim, hidden_sizes = 5, (64, 64)
neural_network = make_nass_net(n_embedding_dim, hidden_sizes)

fns = prior_fn, simulator_fn
model_nass = NASS(fns, neural_network)

data, _ = model_nass.simulate_data(jr.PRNGKey(1), n_simulations=20_000)
params_nass, _ = model_nass.fit(jr.PRNGKey(2), data=data, n_early_stopping_patience=25)
 17%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰                                                                                                                     | 170/1000 [02:18<11:18,  1.22it/s]
[26]:
def summary_fn(y):
    s = model_nass.summarize(params_nass, y)
    return s

def distance_fn(y_simulated, y_observed):
    diff = y_simulated - y_observed
    dist = jax.vmap(lambda el: jnp.linalg.norm(el))(diff)
    return dist
[27]:
model_smc = SMCABC(fns, summary_fn, distance_fn)

smc_inference_results, _ = model_smc.sample_posterior(
    jr.PRNGKey(5),
    y_obs,
    n_rounds=10,
    n_particles=5_000,
    eps_step=0.825,
    ess_min=2_000
)
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 10/10 [04:58<00:00, 29.89s/it]
[28]:
plot_posteriors(
    inference_data_as_dictionary(smc_inference_results.posterior)["theta"],
)
plt.tight_layout()
plt.show()
../_images/notebooks_examples_36_0.png

Session info#

[29]:
import session_info

session_info.show(html=False)
-----
arviz                       0.19.0
haiku                       0.0.12
jax                         0.4.31
jaxlib                      0.4.31
matplotlib                  3.9.2
optax                       0.2.4
sbijax                      0.3.3
seaborn                     0.12.2
session_info                1.0.0
tensorflow_probability      0.25.0
-----
IPython             8.31.0
jupyter_client      8.6.3
jupyter_core        5.7.2
jupyterlab          4.3.1
-----
Python 3.11.7 (main, Dec  9 2023, 06:06:18) [Clang 14.0.3 (clang-1403.0.22.14.1)]
macOS-13.0.1-arm64-arm-64bit
-----
Session information updated at 2025-01-04 20:33