Source code for sbijax._src.plot.plot

import arviz as az
import arviz_plots
import jax
import numpy as np
import xarray
from matplotlib import pyplot
from matplotlib.ticker import MaxNLocator


[docs] def plot_trace(inference_data: xarray.DataTree): """MCMC trace plot. Args: inference_data: an inference data object received from calling `sample_posterior` of an SBI algorithm axes: an array of matplotlib axes **kwargs: additional parameters passed to Arviz Returns: the same array of matplotlib axes with added plots """ pl = az.plot_trace(inference_data) return pl
[docs] def plot_posterior(inference_data: xarray.DataTree): """Posterior histogram plot. Args: inference_data: an inference data object received from calling `sample_posterior` of an SBI algorithm axes: an array of matplotlib axes Returns: the same array of matplotlib axes with added plots """ pl = arviz_plots.plot_dist(inference_data) return pl
[docs] def plot_loss_profile( losses: jax.Array, axes: pyplot.Axes = None ) -> pyplot.Axes: """Visualize the training and validation loss profile. Args: losses: a jax.Array of training and validation losses axes: a matplotlib axes Returns: the same array of matplotlib axes with added plots """ if axes is None: _, axes = pyplot.subplots(figsize=(5, 3), sharey=False, sharex=False) axes.plot(losses[:, 0], label="Training loss", linestyle=(0, (3, 1, 1, 1))) axes.plot(losses[:, 1], label="Validation loss", linestyle=(0, (5, 1))) axes.yaxis.set_major_locator(MaxNLocator(5)) axes.legend() return axes
[docs] def plot_rank(inference_data: xarray.DataTree): """Rank statistics plots. Args: inference_data: an inference data object received from calling `sample_posterior` of an SBI algorithm Returns: the same array of matplotlib axes with added plots """ pl = az.plot_rank(inference_data) return pl
# ruff: noqa: PLR2004
[docs] def plot_ess(inference_data: xarray.DataTree): """Effective sample size plot. Args: inference_data: an inference data object received from calling `sample_posterior` of an SBI algorithm Returns: the same array of matplotlib axes with added plots """ pl = az.plot_ess(inference_data) return pl
[docs] def plot_rhat_and_ress( inference_data: xarray.DataTree, axes: np.ndarray[pyplot.Axes] = None ) -> np.ndarray[pyplot.Axes]: r"""Split-$\hat{R}$ and relative effective sample size plot. Args: inference_data: an inference data object received from calling `sample_posterior` of an SBI algorithm axes: an array of matplotlib axes Returns: the same array of matplotlib axes with added plots """ rhats = az.rhat(inference_data) rhats = np.concatenate([np.array(v) for k, v in rhats.data_vars.items()]) rhats = np.squeeze(rhats) ress = az.ess(inference_data, relative=True) ress = np.concatenate([np.array(v) for k, v in ress.data_vars.items()]) ress = np.squeeze(ress) if axes is None: _, axes = pyplot.subplots(ncols=2) axes[0].plot( rhats, range(len(rhats)), marker="o", linestyle="None", color="black" ) axes[0].hlines(range(len(rhats)), np.ones(len(rhats)), rhats, color="black") axes[0].axvline(1.05, color="darkgrey", alpha=0.5, linestyle="dashed") axes[0].axvline(1.1, color="darkgrey", alpha=0.5, linestyle="dashed") axes[0].axvline(1.0, color="black", alpha=0.5) axes[0].set_yticks(list(range(len(rhats)))) if np.any(rhats < 1.0): axes[0].set_xlim(0.95) else: axes[0].set_xlim(0.99) if np.any(rhats >= 1.3): axes[0].axvline(1.3, color="dimgrey", alpha=0.5, linestyle="dashed") axes[0].set_xticks([1.0, 1.05, 1.1, 1.3]) else: axes[0].set_xticks([1.0, 1.05, 1.1]) axes[0].set_yticklabels([]) axes[0].set_ylabel(r"$\theta$") axes[0].set_xlabel(r"Split-$\hat{R}$") axes[1].plot( ress, range(len(ress)), marker="o", linestyle="None", color="black" ) axes[1].hlines(range(len(ress)), np.zeros(len(ress)), ress, color="black") axes[1].axvline(0.0, color="black", alpha=0.5) axes[1].axvline(0.1, color="darkgrey", alpha=0.5, linestyle="dashed") axes[1].axvline(0.5, color="darkgrey", alpha=0.5, linestyle="dashed") axes[1].axvline(1.0, color="darkgrey", alpha=0.5, linestyle="dashed") axes[1].set_xticks([0.0, 0.1, 0.5, 1.0]) axes[1].set_yticks(list(range(len(ress)))) axes[1].set_yticklabels([]) axes[1].set_xlabel(r"Relative ESS") for i, ax in enumerate(axes): if i > 0: ax.set_ylabel(None) return axes