Source code for sbijax._src.util.data

import arviz as az
import jax
import numpy as np
import xarray
from jax import numpy as jnp
from jax.tree_util import tree_flatten

from sbijax._src.util.types import PyTree


def _tree_stack(trees):
  leaves_list = []
  treedef_list = []
  for tree in trees:
    leaves, treedef = tree_flatten(tree)
    leaves_list.append(leaves)
    treedef_list.append(treedef)

  grouped_leaves = zip(*leaves_list)
  result_leaves = [jnp.vstack(leave) for leave in grouped_leaves]
  return treedef_list[0].unflatten(result_leaves)


[docs] def stack_data(data: PyTree, also_data: PyTree) -> PyTree: """Stack two data sets. Args: data: one data set also_data: another data set Returns: returns the stack of the two data sets """ if data is None: return also_data if also_data is None: return data stacked = _tree_stack([data, also_data]) return stacked
[docs] def as_inference_data(samples: PyTree, observed: jax.Array) -> xarray.DataTree: """Convert a PyTree to an inference data object. Args: samples: a PyTree of posterior samples observed: a jax.Array representing the observed data Returns: an inference data object """ d_ds = {} d_ds["posterior"] = az.dict_to_dataset( samples, coords={f"{k}_dim": np.arange(v.shape[-1]) for k, v in samples.items()}, dims={k: [f"{k}_dim"] for k in samples.keys()}, ) d_ds["observed_data"] = az.dict_to_dataset( {"y": observed}, skip_event_dims=True ) dt = xarray.DataTree.from_dict(d_ds, name=None) return dt
[docs] def inference_data_as_dictionary(inference_data: xarray.DataTree) -> PyTree: """Convert inference data to a PyTree. Args: inference_data: the `posterior` variable of an inference data object Returns: a PyTree """ posterior = inference_data["/posterior"] posterior = {k: v.data for k, v in posterior.data_vars.items()} posterior = {k: v.reshape(-1, v.shape[-1]) for k, v in posterior.items()} return posterior