Source code for sbijax._src.simulators.solar_dynamo

import numpy as np
from jax import lax
from jax import numpy as jnp
from jax.scipy.special import erf
from scipy.signal.windows import hann
from tensorflow_probability.substrates.jax import distributions as tfd

__all__ = ["solar_dynamo"]


def _sample_timeseries(
  seed, y0, alpha_min, alpha_max, epsilon_max, len_timeseries=200
):
  a = tfd.Uniform(alpha_min, alpha_max).sample(
    seed=seed, sample_shape=(len_timeseries,)
  )
  noise = tfd.Uniform(0.0, epsilon_max).sample(
    seed=seed, sample_shape=(len_timeseries,)
  )

  def _fn(fs, arrays):
    alpha, epsilon = arrays
    f, pn = fs
    f = _babcock_leighton_fn(pn)
    pn = _babcock_leighton(pn, alpha, epsilon)
    return (f, pn), (f, pn)

  _, (f, y) = lax.scan(_fn, (y0, y0), (a, noise))
  return f.T, y.T, a.T, noise.T


def _babcock_leighton_fn(p, b_1=0.6, w_1=0.2, b_2=1.0, w_2=0.8):
  f = 0.5 * (1.0 + erf((p - b_1) / w_1)) * (1.0 - erf((p - b_2) / w_2))
  return f


def _babcock_leighton(p, alpha, epsilon):
  p = alpha * _babcock_leighton_fn(p) * p + epsilon
  return p


def _simulate(seed, theta):
  orig_shape = theta.shape
  if theta.ndim == 2:
    theta = theta[None, :, :]

  alpha_min = theta[..., 0]
  alpha_max = alpha_min + theta[..., 1]
  epsilon_max = theta[..., 2]
  y0 = jnp.ones(theta.shape[:-1])

  _, y, _, _ = _sample_timeseries(
    seed, y0, alpha_min, alpha_max, epsilon_max, 100
  )

  y = jnp.swapaxes(y, 1, 0)
  if len(orig_shape) == 2:
    y = y.reshape((*orig_shape[:1], 100))
  return y


# ruff: noqa: PLR0913, E501
[docs] def solar_dynamo(summarize_data=False): """Solar dynamo model. Constructs prior and simulator functions Returns: returns a tuple of three objects. The first is a tfd.JointDistributionNamed serving as a prior distribution. The second is a simulator function that can be used to generate data. The third is None (since the likelihood is intractable and to be consistent with other models). References: Albert, Carlo, et al., Learning summary statistics for Bayesian inference with autoencoders, 2022 """ def prior_fn(): return tfd.JointDistributionNamed( dict( theta=tfd.Independent( tfd.Uniform( jnp.array([0.9, 0.05, 0.02]), jnp.array([1.4, 0.25, 0.15]), ), reinterpreted_batch_ndims=1, ) ) ) def summarize(ys): window = hann(ys.shape[1]) window = window.reshape(1, -1) fourier_range = np.arange(0, ys.shape[1], 6) fs = np.fft.ifft(window * ys, axis=1) ss = np.abs(fs[:, fourier_range]) return fourier_range, ss def simulator(seed, theta): theta = theta["theta"].reshape(-1, 3) ys = _simulate(seed, theta) if summarize_data: _, summ = summarize(ys) return summ return ys return prior_fn(), simulator, None