import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib_inline
import pymc as pm
import arviz as az
import pytensor.tensor as pt
from scipy.stats import norm, t
from pprint import pprint
## Use a custom style for the plots
'smaller.mplstyle')
plt.style.use('retina')
matplotlib_inline.backend_inline.set_matplotlib_formats(#%config InlineBackend.figure_formats = ['retina']
Comparing GRW models for A/B Clustering
Quantifying differences in GRW models with varying parameters
Here, we compare probabilistic modeling of eigenvectors into A/B compartment calls for Hi-C data.
We compare different models, all using the same underlying structure, but varying smaller aspects of the model.
Goals
To quantify differences in model performance, we use the following metrics:
WAIC: Widely Applicable Information Criterion, a measure of model fit that penalizes complexity.
LOO: Leave-One-Out Cross-Validation, a method for estimating the predictive accuracy of a model.
We also try to assign uncertainty to the compartment calls, ultimately defining credible intervals for the compartment transitions.
Imports and setup
Load/show data
Here, we load the Hi-C contact matrix and the first eigenvector (E1) of the contact matrix. The E1 is used as a predictor for compartment assignment.
The eigendecomposition of the Hi-C matrix was performed with the workflow in the main folder (workflow.py
), using the Open2C ecosystem
# Load the data
= 100000
resolution
= pd.Series(pd.read_csv(f"../data/eigs/fibroblast.eigs.{resolution}.cis.vecs.tsv", sep="\t")['E1'].values.flatten())
y = pd.Series(np.arange(0,y.shape[0])*resolution)
x
# Make a DataFrame object
= pd.DataFrame({"start": x, "e1": y})
df
# Plot the data
= plt.subplots(figsize=(10, 3))
fig, ax =df.e1 > 0, color='tab:red', ec='None', label='E1', step='pre')
ax.fill_between(df.start, df.e1, where=df.e1 < 0, color='tab:blue', ec='None', label='E1', step='pre')
ax.fill_between(df.start, df.e1, where
=True) df.dropna(inplace
# Histogram of distribution of E1 values (A and B)
= df["start"].values
x = df["e1"].values
y = np.linspace(y.min(), y.max(), 1000)
dist_x_values = t.pdf(dist_x_values, loc=np.mean(y[y>0]), scale=y[y>0].std(), df=3)
a_dist = t.pdf(dist_x_values, loc=np.mean(y[y<0]), scale=y[y<0].std(), df=3)
b_dist
= plt.subplots()
f, ax
# A compartments
>0], bins=25, color="tab:red", alpha=0.5, label="A", density=True)
ax.hist(y[y# B compartments
<0], bins=25, color="tab:blue", alpha=0.5, label="B", density=True)
ax.hist(y[y# Mean values (vline)
>0].mean(), color="tab:red", linestyle="--", label="Mean A")
ax.axvline(y[y<0].mean(), color="tab:blue", linestyle="--", label="Mean B")
ax.axvline(y[y
# Plot the normal distributions
="tab:red", label="Stud A")
ax.plot(dist_x_values, a_dist, color="tab:blue", label="Stud B")
ax.plot(dist_x_values, b_dist, color+b_dist, color="tab:purple", label="Stacked StudT")
ax.plot(dist_x_values, a_dist
# Final touches
"E1")
plt.xlabel("Density")
plt.ylabel(='best')
plt.legend(loc plt.tight_layout()
Here, the observed distributions are compared to StudenT distributions of 3 degrees of freedom and the same mean and standard deviation as the observed data. The Student’s t-distribution is used because it is more robust with its heavier tails than the normal.
Models
Here, we expand on the results from 03_PyMC_GRW_logit.ipynb
, and use the same model structure, but with different priors and likelihoods. Overall, the following thoughts are implemented:
- We assume that the eigenvectors can be drawn from a mixture of two distributions, one for each compartment (A and B). It could be any mixture, but the means should be negative for B and positive for A. We use either to model
mu
:- Gaussians, as a simple starting point, or
- Student’s t-distributions, as a heavier-tailed alternative.
- The mixtures has to be ordered in PyMC to ensure that the A compartment is always positive and the B compartment is always negative.
- The mixture components have their own (learned) standard deviation,
sigma
. - The mixture components use
mu
andsigma
(each of shape 2) in either a,- Normal distribution, or
- Student’s t-distribution, with 10 degrees of freedom.
- We model the mixture weight as a non-centered (Gaussian) random walk prior, as a method of probabilistic smoothing of the PCA compartment calls.
eps
is the (learned) step size of the random walk, and- either Normal or Student’s t-distributed
grw_sigma
is the (learned) standard deviation of the random walk- either HalfNormal or HalfStudentT-distributed
- The logit-space (of the mixture weight) is then modeled deterministically as the cummulative sum of the random walk steps.
- The logit-space is then transformed to the probability space by squishing it through a sigmoid function.
- \(\hat{y}\) (
y_hat
) is drawing the observations from the mixture distribution, using the mixture weight and the two components (the learned distributions of E1 values for A and B-compartments).
Model 1: T-dist in logit space, T-dist in mixture components
import pytensor.tensor as pt
with pm.Model(coords={"pos": df.index.values}) as latentT_Tmix_model:
= pm.Data("e1", df.e1.values, dims="pos")
e1
# Ordered parameters for the means of the two components
# Note: The ordered transform ensures that mu_a < mu_b
= pm.Normal("mu", mu=[-0.5, 0.5], sigma=0.3,
mu =pm.distributions.transforms.ordered, # IMPORTANT
transform=2,
shape
)= pm.HalfNormal("sigma", 0.3, shape=2)
sigma
# GRW over logit space; non-centered reparameterization
= pm.HalfNormal("grw_sigma", 0.05)
grw_sigma
# StudentT distribution for eps (got alot of divergences with nu=3)
= pm.StudentT("eps", nu=10, mu=0.0, sigma=1.0, dims="pos")
eps
# Cumulative sum to create a Gaussian Random Walk
= pm.Deterministic("logit_w", pt.cumsum(eps * grw_sigma), dims="pos")
logit_w = pm.Deterministic("w", pm.math.sigmoid(logit_w), dims="pos")
w
# Components of the mixture model
= pm.StudentT.dist(nu=10, mu=mu, sigma=sigma, shape=2)
components
# Mixture model
# The observed data is modeled as a mixture of the two components
= pm.Mixture("y_hat", w=pm.math.stack([w,1-w], axis=1), comp_dists=components, observed=e1, dims='pos') y_hat
Model 1: T-dist in logit space, Normal in mixture components
with pm.Model(coords={"pos": df.index.values}) as latentT_Nmix_model:
= pm.Data("e1", df.e1.values, dims="pos")
e1
# Ordered parameters for the means of the two components
# Note: The ordered transform ensures that mu_a < mu_b
= pm.Normal("mu", mu=[-0.5, 0.5], sigma=0.3,
mu =pm.distributions.transforms.ordered, # IMPORTANT
transform=2,
shape
)= pm.HalfNormal("sigma", 0.3, shape=2)
sigma
# GRW over logit space; non-centered reparameterization
= pm.HalfNormal("grw_sigma", 0.05)
grw_sigma
# StudentT distribution for eps (got alot of divergences with nu=3)
= pm.StudentT("eps", nu=10, mu=0.0, sigma=1.0, dims="pos")
eps
# Cumulative sum to create a Gaussian Random Walk
= pm.Deterministic("logit_w", pt.cumsum(eps * grw_sigma), dims="pos")
logit_w = pm.Deterministic("w", pm.math.sigmoid(logit_w), dims="pos")
w
# Components of the mixture model
= pm.StudentT.dist(nu=10, mu=mu, sigma=sigma, shape=2)
components
# Mixture model
# The observed data is modeled as a mixture of the two components
= pm.Mixture("y_hat", w=pm.math.stack([w,1-w], axis=1), comp_dists=components, observed=e1, dims='pos') y_hat
Parameter sweep configuration
Here is an attempt to sweep through parameters in a reproducible and deterministic way. The idea is to use a grid search over the parameters, and then run the models in parallel, and finally collect the results and analyze them.
First, let’s create a function that create a grid of parameters to sweep over using itertools.
Create parameter grid
import itertools as it
def create_model_grid():
"""
Create a grid of models for the different configurations as defined below.
Name the model according to the configuration for easier identification:
- comp_dist: Distribution of the components (Normal or StudentT)
- eps_dist: Distribution of the GRW steps (Normal or StudentT)
- grw_sigma_dist: Distribution of the GRW noise (HalfNormal or HalfStudentT)
name = f"{comp_dist}_{eps_dist}_{grw_sigma_dist}"
- If any distribution is StudentT, append the nu parameter to the name.
"""
def config_namer(comp_dist, comp_kwargs, eps_dist, eps_kwargs, grw_sigma_dist, grw_sigma_kwargs):
= {
abbr 'Normal': 'N',
'StudentT': 'T',
'HalfNormal': 'HN',
'HalfStudentT': 'HT'
}
= f"{abbr[comp_dist]}_{abbr[eps_dist]}_{abbr[grw_sigma_dist]}"
name if comp_dist == 'StudentT':
+= f"_{comp_kwargs['nu']}"
name if eps_dist == 'StudentT':
+= f"_{eps_kwargs['nu']}"
name if grw_sigma_dist == 'HalfStudentT':
+= f"_{grw_sigma_kwargs['nu']}"
name return name
= [
comp_kwargs_list 'Normal', {}),
('StudentT', {'nu': 10}),
('StudentT', {'nu': 5})
(
]
= [
eps_kwargs_list 'Normal', {'mu': 0.0, 'sigma': 1.0}),
('StudentT', {'nu': 10, 'mu': 0.0, 'sigma': 1.0}),
('StudentT', {'nu': 5, 'mu': 0.0, 'sigma': 0.5})
(
]
= [
grw_sigma_kwargs_list 'HalfNormal', {'sigma': 0.05}),
('HalfStudentT', {'nu': 10, 'sigma': 0.05}),
('HalfStudentT', {'nu': 5, 'sigma': 0.01})
(
]
# Prior configurations for E1 mixture
= [-0.5, 0.5]
mu_mu = 0.3
mu_sigma = 0.3
sigma
= it.product(comp_kwargs_list, eps_kwargs_list, grw_sigma_kwargs_list)
grid = []
configs for (comp_dist, comp_kwargs), (eps_dist, eps_kwargs), (grw_sigma_dist, grw_sigma_kwargs) in grid:
= config_namer(
name
comp_dist, comp_kwargs,
eps_dist, eps_kwargs,
grw_sigma_dist, grw_sigma_kwargs
)= {
config 'name': name,
'comp_dist': comp_dist,
'comp_kwargs': comp_kwargs,
'eps_dist': eps_dist,
'eps_kwargs': eps_kwargs,
'grw_sigma_dist': grw_sigma_dist,
'grw_sigma_kwargs': grw_sigma_kwargs,
'mu_mu': mu_mu,
'mu_sigma': mu_sigma,
'sigma': sigma
}
configs.append(config)return configs
from pprint import pprint
import json
# Instantiate the model grid
= create_model_grid()
configs
pprint(configs)
# Save the model grid to a JSON file
with open("../results/model_grid.json", "w") as f:
=4) json.dump(configs, f, indent
[{'comp_dist': 'Normal',
'comp_kwargs': {},
'eps_dist': 'Normal',
'eps_kwargs': {'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfNormal',
'grw_sigma_kwargs': {'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'N_N_HN',
'sigma': 0.3},
{'comp_dist': 'Normal',
'comp_kwargs': {},
'eps_dist': 'Normal',
'eps_kwargs': {'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 10, 'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'N_N_HT_10',
'sigma': 0.3},
{'comp_dist': 'Normal',
'comp_kwargs': {},
'eps_dist': 'Normal',
'eps_kwargs': {'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 5, 'sigma': 0.01},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'N_N_HT_5',
'sigma': 0.3},
{'comp_dist': 'Normal',
'comp_kwargs': {},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 10, 'sigma': 1.0},
'grw_sigma_dist': 'HalfNormal',
'grw_sigma_kwargs': {'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'N_T_HN_10',
'sigma': 0.3},
{'comp_dist': 'Normal',
'comp_kwargs': {},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 10, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 10, 'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'N_T_HT_10_10',
'sigma': 0.3},
{'comp_dist': 'Normal',
'comp_kwargs': {},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 10, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 5, 'sigma': 0.01},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'N_T_HT_10_5',
'sigma': 0.3},
{'comp_dist': 'Normal',
'comp_kwargs': {},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 5, 'sigma': 0.5},
'grw_sigma_dist': 'HalfNormal',
'grw_sigma_kwargs': {'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'N_T_HN_5',
'sigma': 0.3},
{'comp_dist': 'Normal',
'comp_kwargs': {},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 5, 'sigma': 0.5},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 10, 'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'N_T_HT_5_10',
'sigma': 0.3},
{'comp_dist': 'Normal',
'comp_kwargs': {},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 5, 'sigma': 0.5},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 5, 'sigma': 0.01},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'N_T_HT_5_5',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 10},
'eps_dist': 'Normal',
'eps_kwargs': {'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfNormal',
'grw_sigma_kwargs': {'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_N_HN_10',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 10},
'eps_dist': 'Normal',
'eps_kwargs': {'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 10, 'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_N_HT_10_10',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 10},
'eps_dist': 'Normal',
'eps_kwargs': {'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 5, 'sigma': 0.01},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_N_HT_10_5',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 10},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 10, 'sigma': 1.0},
'grw_sigma_dist': 'HalfNormal',
'grw_sigma_kwargs': {'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_T_HN_10_10',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 10},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 10, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 10, 'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_T_HT_10_10_10',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 10},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 10, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 5, 'sigma': 0.01},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_T_HT_10_10_5',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 10},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 5, 'sigma': 0.5},
'grw_sigma_dist': 'HalfNormal',
'grw_sigma_kwargs': {'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_T_HN_10_5',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 10},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 5, 'sigma': 0.5},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 10, 'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_T_HT_10_5_10',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 10},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 5, 'sigma': 0.5},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 5, 'sigma': 0.01},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_T_HT_10_5_5',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 5},
'eps_dist': 'Normal',
'eps_kwargs': {'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfNormal',
'grw_sigma_kwargs': {'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_N_HN_5',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 5},
'eps_dist': 'Normal',
'eps_kwargs': {'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 10, 'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_N_HT_5_10',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 5},
'eps_dist': 'Normal',
'eps_kwargs': {'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 5, 'sigma': 0.01},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_N_HT_5_5',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 5},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 10, 'sigma': 1.0},
'grw_sigma_dist': 'HalfNormal',
'grw_sigma_kwargs': {'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_T_HN_5_10',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 5},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 10, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 10, 'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_T_HT_5_10_10',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 5},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 10, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 5, 'sigma': 0.01},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_T_HT_5_10_5',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 5},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 5, 'sigma': 0.5},
'grw_sigma_dist': 'HalfNormal',
'grw_sigma_kwargs': {'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_T_HN_5_5',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 5},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 5, 'sigma': 0.5},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 10, 'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_T_HT_5_5_10',
'sigma': 0.3},
{'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 5},
'eps_dist': 'StudentT',
'eps_kwargs': {'mu': 0.0, 'nu': 5, 'sigma': 0.5},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 5, 'sigma': 0.01},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'name': 'T_T_HT_5_5_5',
'sigma': 0.3}]
Create model builder function
Then, we want to create a function (model_builder) that takes the parameters and returns a PyMC model. This function will be used to build the models in parallel.
NOTE: accidentally, I switched the order of mixture weights, so the posterior mean of
w
now determines the probability of being in a B compartment (coming from the distribution with negative mean)
def build_model_from_config(df, cfg):
"""
Build a PyMC model based on the provided configuration dictionary.
"""
with pm.Model(coords={"pos": df.index.values}) as model:
= pm.Data("e1", df.e1.values, dims="pos")
e1
= pm.Normal("mu", mu=cfg['mu_mu'], sigma=cfg['mu_sigma'],
mu =pm.distributions.transforms.ordered, shape=2)
transform= pm.HalfNormal("sigma", sigma=cfg['sigma'], shape=2)
sigma
# components: Normal or StudentT
if cfg['comp_dist'] == "Normal":
= pm.Normal.dist(mu=mu, sigma=sigma, shape=2)
components else:
= pm.StudentT.dist(mu=mu, sigma=sigma, **cfg['comp_kwargs'], shape=2)
components
# grw_sigma: HalfNormal or HalfStudentT
if cfg['grw_sigma_dist'] == "HalfNormal":
= pm.HalfNormal("grw_sigma", **cfg['grw_sigma_kwargs'])
grw_sigma else:
= pm.HalfStudentT("grw_sigma", **cfg['grw_sigma_kwargs'])
grw_sigma
# eps: Normal or StudentT
if cfg['eps_dist'] == "Normal":
= pm.Normal("eps", dims="pos", **cfg['eps_kwargs'])
eps else:
= pm.StudentT("eps", dims="pos", **cfg['eps_kwargs'])
eps
= pm.Deterministic("logit_w", pt.cumsum(eps * grw_sigma), dims="pos")
logit_w = pm.Deterministic("w", pm.math.sigmoid(logit_w), dims="pos")
w
= pm.Mixture("y_hat", w=pt.stack([w, 1-w], axis=1), comp_dists=components,
y_hat =e1, dims='pos')
observed
return model
Sampling loop
Now, we create the sampling loop that should run the models in sequence.
We will name the models by joining the list [number, comp_dist, eps_dist, grw_sigma_dist][eps_nu, comp_nu]
.
All the traces (InferenceData) will be saved in a dictionary, model_traces
.
Sample the models
from datetime import datetime
def log(msg, path=".04_PyMC_GRW_compare.log", overwrite=False):
print(f"LOG: {msg}")
if overwrite:
with open(path, "w") as f:
+ "\n")
f.write(msg else:
with open(path, "a") as f:
+ "\n")
f.write(msg
# Reset the log file
"Resetting log... \n", overwrite=True)
log(
= {}
model_traces
for i, cfg in enumerate(configs):
= cfg['name']
name
if name in model_traces:
continue
else:
= {'model': None, 'trace': None, 'scores': None}
model_traces[name]
f"Sampling {name}...")
log(try:
= datetime.now()
start f"Started at {start.strftime('%H:%M:%S')}")
log('model'] = build_model_from_config(df, cfg)
model_traces[name][= model_traces[name]['model']
model = pm.sample(draws=1000, tune=1000, chains=7, cores=7, progressbar=False, model=model)
trace 'trace'] = trace
model_traces[name][f"--> Took {(datetime.now()-start).total_seconds()} seconds!\n")
log(except Exception as e:
f"Model {name} failed: {e}\n") log(
LOG: Resetting log...
LOG: Sampling N_N_HN...
LOG: Started at 16:23:36
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 205 seconds.
LOG: --> Took 226.266818 seconds!
LOG: Sampling N_N_HT_10...
LOG: Started at 16:27:22
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 42 seconds.
There were 6999 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 56.923738 seconds!
LOG: Sampling N_N_HT_5...
LOG: Started at 16:28:19
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 40 seconds.
There were 6993 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 48.845641 seconds!
LOG: Sampling N_T_HN_10...
LOG: Started at 16:29:08
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 218 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 224.320325 seconds!
LOG: Sampling N_T_HT_10_10...
LOG: Started at 16:32:53
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 66 seconds.
There were 6992 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 73.587468 seconds!
LOG: Sampling N_T_HT_10_5...
LOG: Started at 16:34:06
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 53 seconds.
There were 7000 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 60.553168 seconds!
LOG: Sampling N_T_HN_5...
LOG: Started at 16:35:07
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 200 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 206.580966 seconds!
LOG: Sampling N_T_HT_5_10...
LOG: Started at 16:38:33
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 73 seconds.
There were 6963 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 80.68908 seconds!
LOG: Sampling N_T_HT_5_5...
LOG: Started at 16:39:54
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 219 seconds.
There were 6258 divergences after tuning. Increase `target_accept` or reparameterize.
Chain 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 226.273568 seconds!
LOG: Sampling T_N_HN_10...
LOG: Started at 16:43:40
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 238 seconds.
LOG: --> Took 243.337909 seconds!
LOG: Sampling T_N_HT_10_10...
LOG: Started at 16:47:44
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 55 seconds.
There were 7000 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 62.482996 seconds!
LOG: Sampling T_N_HT_10_5...
LOG: Started at 16:48:46
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 66 seconds.
There were 6984 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 73.402269 seconds!
LOG: Sampling T_T_HN_10_10...
LOG: Started at 16:49:59
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 251 seconds.
LOG: --> Took 255.851731 seconds!
LOG: Sampling T_T_HT_10_10_10...
LOG: Started at 16:54:15
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 62 seconds.
There were 7000 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 68.849916 seconds!
LOG: Sampling T_T_HT_10_10_5...
LOG: Started at 16:55:24
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 47 seconds.
There were 7000 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 54.206634 seconds!
LOG: Sampling T_T_HN_10_5...
LOG: Started at 16:56:18
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 225 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 229.865992 seconds!
LOG: Sampling T_T_HT_10_5_10...
LOG: Started at 17:00:08
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 97 seconds.
There were 6942 divergences after tuning. Increase `target_accept` or reparameterize.
Chain 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 104.442058 seconds!
LOG: Sampling T_T_HT_10_5_5...
LOG: Started at 17:01:53
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 97 seconds.
There were 6944 divergences after tuning. Increase `target_accept` or reparameterize.
Chain 6 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 103.842859 seconds!
LOG: Sampling T_N_HN_5...
LOG: Started at 17:03:36
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 237 seconds.
LOG: --> Took 245.707629 seconds!
LOG: Sampling T_N_HT_5_10...
LOG: Started at 17:07:42
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 74 seconds.
There were 6976 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 81.628761 seconds!
LOG: Sampling T_N_HT_5_5...
LOG: Started at 17:09:04
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 47 seconds.
There were 7000 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 54.098754 seconds!
LOG: Sampling T_T_HN_5_10...
LOG: Started at 17:09:58
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 254 seconds.
LOG: --> Took 259.284411 seconds!
LOG: Sampling T_T_HT_5_10_10...
LOG: Started at 17:14:17
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 61 seconds.
There were 7000 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 67.761221 seconds!
LOG: Sampling T_T_HT_5_10_5...
LOG: Started at 17:15:25
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 52 seconds.
There were 7000 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 59.315111 seconds!
LOG: Sampling T_T_HN_5_5...
LOG: Started at 17:16:24
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 230 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 234.735501 seconds!
LOG: Sampling T_T_HT_5_5_10...
LOG: Started at 17:20:19
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 63 seconds.
There were 6999 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 70.738131 seconds!
LOG: Sampling T_T_HT_5_5_5...
LOG: Started at 17:21:30
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (7 chains in 7 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Sampling 7 chains for 1_000 tune and 1_000 draw iterations (7_000 + 7_000 draws total) took 59 seconds.
There were 6998 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 66.229913 seconds!
model_traces.items()
dict_items([('N_N_HN', {'model': <pymc.model.core.Model object at 0x147c288d5a90>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('N_N_HT_10', {'model': <pymc.model.core.Model object at 0x147c22fb29e0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('N_N_HT_5', {'model': <pymc.model.core.Model object at 0x147c22fb2b10>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('N_T_HN_10', {'model': <pymc.model.core.Model object at 0x147c22fb1e00>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('N_T_HT_10_10', {'model': <pymc.model.core.Model object at 0x147c4efccfc0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('N_T_HT_10_5', {'model': <pymc.model.core.Model object at 0x147c4efcd0f0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('N_T_HN_5', {'model': <pymc.model.core.Model object at 0x147c22fb1ba0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('N_T_HT_5_10', {'model': <pymc.model.core.Model object at 0x147c22fb3360>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('N_T_HT_5_5', {'model': <pymc.model.core.Model object at 0x147c22fb23f0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_N_HN_10', {'model': <pymc.model.core.Model object at 0x147c22fb28b0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_N_HT_10_10', {'model': <pymc.model.core.Model object at 0x147c22fb2c40>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_N_HT_10_5', {'model': <pymc.model.core.Model object at 0x147c22fb3490>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_T_HN_10_10', {'model': <pymc.model.core.Model object at 0x147c22fb1940>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_T_HT_10_10_10', {'model': <pymc.model.core.Model object at 0x147c22fb2d70>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_T_HT_10_10_5', {'model': <pymc.model.core.Model object at 0x147c22fb2650>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_T_HN_10_5', {'model': <pymc.model.core.Model object at 0x147c22fb2fd0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_T_HT_10_5_10', {'model': <pymc.model.core.Model object at 0x147c22fb2ea0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_T_HT_10_5_5', {'model': <pymc.model.core.Model object at 0x147c22fb22c0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_N_HN_5', {'model': <pymc.model.core.Model object at 0x147c22fb1f30>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_N_HT_5_10', {'model': <pymc.model.core.Model object at 0x147c22fb36f0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_N_HT_5_5', {'model': <pymc.model.core.Model object at 0x147c22fb3100>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_T_HN_5_10', {'model': <pymc.model.core.Model object at 0x147c22fb35c0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_T_HT_5_10_10', {'model': <pymc.model.core.Model object at 0x147c22fb2190>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_T_HT_5_10_5', {'model': <pymc.model.core.Model object at 0x147c22fb2780>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_T_HN_5_5', {'model': <pymc.model.core.Model object at 0x147c22fb3820>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_T_HT_5_5_10', {'model': <pymc.model.core.Model object at 0x147c100515b0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None}), ('T_T_HT_5_5_5', {'model': <pymc.model.core.Model object at 0x147c100508a0>, 'trace': Inference data with groups:
> posterior
> sample_stats
> observed_data, 'scores': None})])
Log-lik, WAIC and PPC
for name, model_dict in model_traces.items():
# model_dict.keys(): {'model', 'trace', 'scores'}
try:
print(name)
pm.compute_log_likelihood('trace'],
model_dict[=model_dict['model'],
model=False)
progressbarexcept Exception as e:
print(f"{name} failed computing log_likelihood: {e}")
try:
= pm.sample_prior_predictive(
prior_predictive 1000,
=model_dict['model'])
model'trace'].extend(prior_predictive)
model_dict[except Exception as e:
print(f"{name} failed sampling prior predictive: {e}")
try:
= pm.sample_posterior_predictive(
posterior_predictive 'trace'],
model_dict[=model_dict['model'],
model=False)
progressbar'trace'].extend(posterior_predictive)
model_dict[except Exception as e:
print(f"{name} failed sampling posterior predictive: {e}")
try:
= az.loo(model_dict['trace'])
loo = az.waic(model_dict['trace'])
waic 'scores'] = {
model_dict["loo": loo,
"waic": waic
}except Exception as e:
print(f"{name} failed during loo/waic: {e}")
N_N_HN
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
N_N_HT_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
N_N_HT_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
N_T_HN_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
N_T_HT_10_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
N_T_HT_10_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
N_T_HN_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
N_T_HT_5_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
N_T_HT_5_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_N_HN_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_N_HT_10_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_N_HT_10_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HN_10_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HT_10_10_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HT_10_10_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HN_10_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HT_10_5_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HT_10_5_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_N_HN_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_N_HT_5_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_N_HT_5_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HN_5_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HT_5_10_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HT_5_10_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HN_5_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HT_5_5_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HT_5_5_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
Save the results
import os.path as op
import os
import cloudpickle
# Save the traces to file (netcdf)
dir = "../results"
for name, d in model_traces.items():
# Save the model.Model as a cloudpickle.dumps
dir, "models"), exist_ok=True)
os.makedirs(op.join(= op.join(dir, "models", f"{name}_model.pkl")
model_path # print(model_path)
with open(model_path, "wb") as f:
'model'], f)
cloudpickle.dump(d[
# Save the trace as a netcdf file
= op.join(dir, "traces", f"{name}_trace.nc")
trace_path =True)
os.makedirs(op.dirname(trace_path), exist_ok'trace'], trace_path)
az.to_netcdf(d[
# Load model
with open("../results/models/N_N_HN_model.pkl", "rb") as f:
= cloudpickle.load(f)
test_model
# Load trace
= az.from_netcdf("../results/traces/N_N_HN_trace.nc") test_trace
from math import ceil
# Plot the energy for all traces
= 3
ncol = ceil(len(model_traces) / ncol)
nrows = plt.subplots(nrows=nrows, ncols=ncol, figsize=(15, 3*nrows))
f, axs = axs.flatten()
axs
for i, (name, d) in enumerate(model_traces.items()):
= axs[i]
ax 'trace'], ax=ax)
az.plot_energy(d[+ "\n" +
ax.set_title(name f"Diverging draws: {d['trace'].sample_stats.diverging.sum().values}\n" +
f"Sample time: {d['trace'].sample_stats.sampling_time:.2f} sec"
)
# Remove empty subplots
for ax in axs if not ax.has_data()]
[f.delaxes(ax)
plt.tight_layout()
# az.plot_energy(test_trace);
# Combine into InferenceData comparison object
= {name: model_traces[name]['scores']['loo'] for name in model_traces.keys()}
loo_dict = az.compare(loo_dict, method="stacking", ic="loo")
cmp_df
# Save the comparison DataFrame
"../results/model_comparison.csv", )
cmp_df.to_csv(# Print the comparison DataFrame
print("\nModel Comparison Results:")
cmp_df
Model Comparison Results:
rank | elpd_loo | p_loo | elpd_diff | weight | se | dse | warning | scale | |
---|---|---|---|---|---|---|---|---|---|
N_T_HT_5_10 | 0 | -358.452110 | 62.930324 | 0.000000 | 4.765703e-01 | 26.860618 | 0.000000 | True | log |
N_N_HT_10 | 1 | -359.459620 | 65.393423 | 1.007510 | 2.252832e-01 | 26.771997 | 3.069303 | True | log |
N_T_HT_10_5 | 2 | -361.000076 | 66.023840 | 2.547967 | 2.724386e-08 | 26.918152 | 3.289561 | True | log |
N_T_HT_10_10 | 3 | -361.222374 | 65.823139 | 2.770264 | 9.981287e-02 | 27.320468 | 3.897161 | True | log |
T_T_HT_10_5_5 | 4 | -363.641294 | 57.862940 | 5.189184 | 8.081423e-02 | 27.381287 | 4.840141 | True | log |
T_T_HT_10_5_10 | 5 | -364.153489 | 61.305255 | 5.701379 | 1.175195e-01 | 27.850611 | 5.006629 | True | log |
T_N_HT_10_5 | 6 | -368.051247 | 64.387889 | 9.599137 | 3.166623e-08 | 27.337736 | 4.930719 | True | log |
T_N_HT_10_10 | 7 | -368.332451 | 65.708143 | 9.880341 | 3.003514e-08 | 27.584315 | 4.323734 | True | log |
T_T_HT_10_10_10 | 8 | -368.842565 | 61.994153 | 10.390455 | 3.140541e-08 | 27.010677 | 4.743943 | True | log |
N_N_HT_5 | 9 | -373.748497 | 75.827076 | 15.296387 | 2.680213e-08 | 27.183989 | 4.619605 | True | log |
T_T_HT_10_10_5 | 10 | -375.301755 | 67.138740 | 16.849646 | 2.974943e-08 | 27.421577 | 4.743063 | True | log |
N_T_HT_5_5 | 11 | -377.913791 | 83.800530 | 19.461681 | 2.676479e-08 | 27.503237 | 6.329224 | True | log |
T_T_HT_5_5_5 | 12 | -386.325316 | 63.302957 | 27.873206 | 3.534710e-08 | 27.516315 | 6.724489 | True | log |
T_T_HT_5_10_5 | 13 | -386.815659 | 62.094507 | 28.363549 | 3.348353e-08 | 27.577744 | 6.048645 | True | log |
T_T_HT_5_5_10 | 14 | -387.753509 | 61.805375 | 29.301399 | 3.266521e-08 | 27.954686 | 5.986390 | True | log |
T_N_HT_5_10 | 15 | -387.863099 | 60.859364 | 29.410990 | 3.359195e-08 | 27.788176 | 6.217155 | True | log |
T_T_HT_5_10_10 | 16 | -388.062134 | 62.580083 | 29.610024 | 3.282912e-08 | 28.113280 | 6.137160 | True | log |
T_N_HT_5_5 | 17 | -395.524849 | 67.629619 | 37.072739 | 3.317884e-08 | 27.457881 | 6.544416 | True | log |
T_T_HN_10_10 | 18 | -484.641065 | 45.381785 | 126.188955 | 1.850862e-08 | 27.387257 | 6.016426 | False | log |
N_N_HN | 19 | -489.660761 | 44.196325 | 131.208651 | 1.396381e-08 | 27.082763 | 5.372256 | True | log |
T_N_HN_10 | 20 | -498.162518 | 43.544155 | 139.710408 | 1.722162e-08 | 27.387247 | 6.285019 | False | log |
T_T_HN_5_10 | 21 | -503.006963 | 44.417449 | 144.554853 | 2.125441e-08 | 27.781687 | 7.448034 | False | log |
T_N_HN_5 | 22 | -516.300853 | 42.496101 | 157.848743 | 2.001360e-08 | 27.763489 | 7.684781 | False | log |
N_T_HN_5 | 23 | -530.303872 | 43.506794 | 171.851762 | 1.092973e-08 | 27.132629 | 6.870209 | False | log |
T_T_HN_10_5 | 24 | -539.210197 | 42.654365 | 180.758088 | 1.413144e-08 | 27.420135 | 7.628495 | False | log |
T_T_HN_5_5 | 25 | -556.981749 | 41.920557 | 198.529639 | 1.692740e-08 | 27.799341 | 8.819701 | False | log |
N_T_HN_10 | 26 | -621.109350 | 163.667842 | 262.657240 | 0.000000e+00 | 26.780468 | 8.132667 | True | log |
Import and analyze the results
import pandas as pd
import arviz as az
from pathlib import Path
from cloudpickle import load
from glob import glob
# Load models
= {}
model_traces for pkl in glob("../results/models/*.pkl"):
= Path(pkl).name.removesuffix("_model.pkl")
name = f"../results/traces/{name}_trace.nc"
trace_path = {}
model_traces[name] with open(pkl, "rb") as f:
'model'] = load(f)
model_traces[name]['trace'] = az.from_netcdf(trace_path)
model_traces[name][
# Load the comparison DataFrame
= pd.read_csv("../results/model_comparison.csv", index_col=0)
cmp_df
import matplotlib.pyplot as plt
= 8
top_n
for name in cmp_df.index[:top_n]:
= plt.subplots(1,2, figsize=(12, 3))
f, axs = model_traces[name]['trace']
trace ='y_hat', ax=axs[0])
az.plot_loo_pit(trace, y='y_hat', ecdf=True, ax=axs[1])
az.plot_loo_pit(trace, yf"PIT for {name}")
f.suptitle( f.tight_layout()
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
Test: Post-hoc decoding of mixture weight
import arviz as az
import cloudpickle as cp
= "N_N_HN"
name
= az.from_netcdf(f"../results/traces/{name}_trace.nc")
trace
with open(f"../results/models/{name}_model.pkl", "rb") as f:
= cp.load(f) model
trace
-
<xarray.Dataset> Size: 246MB Dimensions: (chain: 7, draw: 1000, pos: 1460, mu_dim_0: 2, sigma_dim_0: 2) Coordinates: * chain (chain) int64 56B 0 1 2 3 4 5 6 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 * mu_dim_0 (mu_dim_0) int64 16B 0 1 * sigma_dim_0 (sigma_dim_0) int64 16B 0 1 Data variables: eps (chain, draw, pos) float64 82MB ... mu (chain, draw, mu_dim_0) float64 112kB ... sigma (chain, draw, sigma_dim_0) float64 112kB ... grw_sigma (chain, draw) float64 56kB ... logit_w (chain, draw, pos) float64 82MB ... w (chain, draw, pos) float64 82MB ... Attributes: created_at: 2025-06-26T14:27:19.174308+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0 sampling_time: 205.36568427085876 tuning_steps: 1000
-
<xarray.Dataset> Size: 82MB Dimensions: (chain: 7, draw: 1000, pos: 1460) Coordinates: * chain (chain) int64 56B 0 1 2 3 4 5 6 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 Data variables: y_hat (chain, draw, pos) float64 82MB ... Attributes: created_at: 2025-06-26T15:23:03.760307+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 82MB Dimensions: (chain: 7, draw: 1000, pos: 1460) Coordinates: * chain (chain) int64 56B 0 1 2 3 4 5 6 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 Data variables: y_hat (chain, draw, pos) float64 82MB ... Attributes: created_at: 2025-06-26T15:22:37.706829+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 862kB Dimensions: (chain: 7, draw: 1000) Coordinates: * chain (chain) int64 56B 0 1 2 3 4 5 6 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: (12/17) step_size (chain, draw) float64 56kB ... reached_max_treedepth (chain, draw) bool 7kB ... process_time_diff (chain, draw) float64 56kB ... smallest_eigval (chain, draw) float64 56kB ... largest_eigval (chain, draw) float64 56kB ... perf_counter_diff (chain, draw) float64 56kB ... ... ... n_steps (chain, draw) float64 56kB ... max_energy_error (chain, draw) float64 56kB ... tree_depth (chain, draw) int64 56kB ... energy (chain, draw) float64 56kB ... index_in_trajectory (chain, draw) int64 56kB ... perf_counter_start (chain, draw) float64 56kB ... Attributes: created_at: 2025-06-26T14:27:19.389497+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0 sampling_time: 205.36568427085876 tuning_steps: 1000
-
<xarray.Dataset> Size: 35MB Dimensions: (chain: 1, draw: 1000, pos: 1460, sigma_dim_0: 2, mu_dim_0: 2) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 * sigma_dim_0 (sigma_dim_0) int64 16B 0 1 * mu_dim_0 (mu_dim_0) int64 16B 0 1 Data variables: eps (chain, draw, pos) float64 12MB ... sigma (chain, draw, sigma_dim_0) float64 16kB ... w (chain, draw, pos) float64 12MB ... grw_sigma (chain, draw) float64 8kB ... logit_w (chain, draw, pos) float64 12MB ... mu (chain, draw, mu_dim_0) float64 16kB ... Attributes: created_at: 2025-06-26T15:22:41.210842+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 12MB Dimensions: (chain: 1, draw: 1000, pos: 1460) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 Data variables: y_hat (chain, draw, pos) float64 12MB ... Attributes: created_at: 2025-06-26T15:22:41.214288+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 23kB Dimensions: (pos: 1460) Coordinates: * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 Data variables: y_hat (pos) float64 12kB ... Attributes: created_at: 2025-06-26T14:27:19.392559+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
model
\[ \begin{array}{rcl} \text{mu} &\sim & \operatorname{Normal}(\text{<constant>},~0.3)\\\text{sigma} &\sim & \operatorname{HalfNormal}(0,~0.3)\\\text{grw\_sigma} &\sim & \operatorname{HalfNormal}(0,~0.05)\\\text{eps} &\sim & \operatorname{Normal}(0,~1)\\\text{logit\_w} &\sim & \operatorname{Deterministic}(f(\text{eps},~\text{grw\_sigma}))\\\text{w} &\sim & \operatorname{Deterministic}(f(\text{eps},~\text{grw\_sigma}))\\\text{y\_hat} &\sim & \operatorname{MarginalMixture}(f(\text{eps},~\text{grw\_sigma}),~\operatorname{Normal}(\text{mu},~\text{sigma}))\\\text{y\_hat} &\sim & \operatorname{MarginalMixture}(f(\text{eps},~\text{grw\_sigma}),~\operatorname{Normal}(\text{mu},~\text{sigma})) \end{array} \]
# Just check what the xarray dims look like
'w'].stack(dims=['chain', 'draw']) trace.posterior[
<xarray.DataArray 'w' (pos: 1460, dims: 7000)> Size: 82MB array([[3.53584100e-01, 3.79749740e-01, 5.75997133e-01, ..., 3.58219618e-01, 3.92414290e-01, 4.04981431e-01], [3.08873962e-01, 3.06603339e-01, 4.77109398e-01, ..., 2.45296767e-01, 3.77326354e-01, 4.73423191e-01], [3.29618247e-01, 2.10090449e-01, 4.32590937e-01, ..., 1.85445932e-01, 3.15633458e-01, 4.09261421e-01], ..., [6.92862267e-03, 1.11724170e-02, 1.95219389e-03, ..., 6.54501695e-03, 8.27288062e-03, 3.88749947e-04], [3.12414028e-03, 6.00391701e-02, 3.92143850e-03, ..., 1.52737606e-02, 1.43993321e-02, 4.41491093e-04], [2.21762399e-03, 9.51342900e-02, 1.14675355e-03, ..., 1.98311143e-02, 1.89055557e-02, 5.39107876e-04]], shape=(1460, 7000)) Coordinates: * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 * dims (dims) object 56kB MultiIndex * chain (dims) int64 56kB 0 0 0 0 0 0 0 0 0 0 0 0 ... 6 6 6 6 6 6 6 6 6 6 6 * draw (dims) int64 56kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
Now, we can flatten the weights to a n_chains*ndraws X n_bins
array and assign an underlying distribution with a Bernoulli trial on each draw.
Note, that the mean of the Bernoulli samples from the posterior draws will converge towards the posterior mean of
w
. It is thus only an unnecessary step that adds noise and not addition information.
import numpy as np
# Extract posterior samples of w
= trace.posterior['w'].stack(dims=['chain', 'draw'])
w_flat = trace.posterior['w'].mean(dim=['chain', 'draw'])
w_mean
# Sample z draws
= 42
seed = np.random.default_rng(seed)
rng = rng.binomial(1, w_flat) # shape (samples, pos)
z_samples
# Compute posterior probability of z=1
= z_samples.mean(axis=1) # per-position probability z_probs
z_samples
array([[1, 0, 0, ..., 0, 1, 1],
[1, 1, 1, ..., 0, 0, 0],
[0, 0, 1, ..., 0, 1, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], shape=(1460, 7000))
z_probs
# Plots as a quick curve
='black', label='Posterior Probability of z=1')
plt.plot(z_probs, color='tab:orange', label='Posterior Mean of w', alpha = 0.5)
plt.plot(w_mean, color0.95,0.05], xmin=0, xmax=len(z_probs), colors='gray', linestyles='--', label='95% CI'); plt.hlines([
Try some other logic:
# Logic: We define the threshold for a certain assignment of either
# A: z_posterior > A_threshold, or B: z_posterior < B_threshold
# The rest will be considered in-transition and could be visualized as a gradient or with shaded regions
= 0.05 # A = 0 in the binary assignment
A_threshold = 0.95 # B = 1 in the binary assignment
B_threshold
# Assign compartments based on the threshold
= np.where(z_probs < A_threshold, 0,
z_assignment > B_threshold, 1, np.nan))
np.where(z_probs
# Turn around the assignment to have 1 for A and -1 for B
= np.where(z_assignment == 0, 1,
z_assignment == 1, -1, np.nan))
np.where(z_assignment
# Plot as quick curve
='tab:purple', label='Posterior Probability of z=1')
plt.plot(z_assignment, color0.95,0.05], xmin=0, xmax=len(z_assignment), colors='gray', linestyles='--', label='95% CI');
plt.hlines([
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib_inline
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
'svg')
matplotlib_inline.backend_inline.set_matplotlib_formats(
# Load the data
= 100000
resolution
= pd.Series(pd.read_csv(f"../data/eigs/fibroblast.eigs.{resolution}.cis.vecs.tsv", sep="\t")['E1'].values.flatten())
y = pd.Series(np.arange(0,y.shape[0])*resolution)
x
# Make a DataFrame object
= pd.DataFrame({"start": x, "e1": y})
df =True)
df.dropna(inplace
# Quick Plot:
= plt.subplots(figsize=(10, 3))
fig, ax =df.e1 > 0, color='tab:red', ec='None', label='A', step='pre')
ax.fill_between(df.start, df.e1, where=df.e1 < 0, color='tab:blue', ec='None', label='B', step='pre')
ax.fill_between(df.start, df.e1, where='tab:green', lw=1, label='z assignment')
ax.plot(df.start, z_assignment, color
# Add shaded areas for the nans
= []
patches = False
in_nan = None
start
for i in range(len(z_assignment)):
if np.isnan(z_assignment[i]) and not in_nan:
# Start of NaN block
= True
in_nan = df.start.iloc[i]
start elif not np.isnan(z_assignment[i]) and in_nan:
# End of NaN block
= df.start.iloc[i]
end -1), end - start, 2, alpha=0.2, color='gray'))
patches.append(Rectangle((start, = False
in_nan
# If ending on NaNs
if in_nan:
= df.start.iloc[-1] + resolution
end -1), end - start, 2, alpha=1, color='gray'))
patches.append(Rectangle((start,
=True))
ax.add_collection(PatchCollection(patches, match_original
"Genomic Position")
ax.set_xlabel("E1")
ax.set_ylabel(='lower right')
ax.legend(loc
plt.tight_layout()