import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib_inline
import pymc as pm
import arviz as az
from scipy.stats import norm
from pprint import pprint
## Use a custom style for the plots
'smaller.mplstyle')
plt.style.use('svg')
matplotlib_inline.backend_inline.set_matplotlib_formats(# %config InlineBackend.figure_formats = ['retina']
Gaussian Random Walk Model for A/B Clustering
Implementing dependency between bins with a Gaussian Random Walk through a sigmoid
This model infers A/B chromatin compartment structure from the first eigenvector (E1) of the Hi-C contact matrix using a Bayesian framework.
Here we implement dependency between bins with a Gaussian Random Walk (GRW) that is then squished through a sigmoid to obtain probabilities for compartment assignment.
Goals
To treat the eigenvector as a soft predictor in stead of deterministic assignment of compartments. Hopefully, we will salvage some variance and thus be able to obtain uncertainty estimates for compartment assignment.
Implementation: GRW on Latent Space
Imports
Load/show data
Here, we load the Hi-C contact matrix and the first eigenvector (E1) of the contact matrix. The E1 is used as a predictor for compartment assignment.
The eigendecomposition of the Hi-C matrix was performed with the workflow in the main folder (workflow.py
), using the Open2C ecosystem
# Load the data
= 50000
resolution
= pd.Series(pd.read_csv(f"../data/eigs/fibroblast.eigs.{resolution}.cis.vecs.tsv", sep="\t")['E1'].values.flatten())
y = pd.Series(np.arange(0,y.shape[0])*resolution)
x
# Make a DataFrame object
= pd.DataFrame({"start": x, "e1": y})
df =True)
df.dropna(inplace
# # TEST DF
# df = df.iloc[:200]
# df
The data are stored as pd.Series() objects, merged into a pd.DataFrame() that is not used. Maybe it would save some typing (for plots) to have the bin numbers converted into the genomic position, so that might be implemented in the future. It already is…
# Plot the data as a track (stairs for niceity)
= plt.subplots()
fig, ax
= np.zeros(2*df.start.shape[0])
x_stair = np.zeros(2*df.shape[0])
y_stair 0::2] = df.start
x_stair[1::2] = df.start + resolution
x_stair[0::2] = df.e1
y_stair[1::2] = df.e1
y_stair[
=(y_stair<0), color="tab:blue", ec='None')
ax.fill_between(x_stair, y_stair, where=(y_stair>0), color="tab:red", ec='None')
ax.fill_between(x_stair, y_stair, where
0, df["start"].max()+resolution)
ax.set_xlim(= np.arange(0, df["start"].max()+resolution, step=1e7, dtype=int)
ticks
ax.set_xticks(ticks)/1e6).astype(int), rotation=45)
ax.set_xticklabels((ticks"Genomic Position (Mbp)")
ax.set_xlabel("E1")
ax.set_ylabel("E1"); ax.set_title(
# Histogram of distribution of E1 values (A and B)
= df["start"].values
x = df["e1"].values
y
= plt.subplots()
f, ax
# A compartments
>0], bins=25, color="tab:red", alpha=0.5, label="A", density=True)
ax.hist(y[y# B compartments
<0], bins=25, color="tab:blue", alpha=0.5, label="B", density=True)
ax.hist(y[y# Mean values (vline)
>0].mean(), color="tab:red", linestyle="--", label="Mean A")
ax.axvline(y[y<0].mean(), color="tab:blue", linestyle="--", label="Mean B")
ax.axvline(y[y
# Plot the normal distributions
= np.linspace(y.min(), y.max(), 1000)
x_values = np.median(y[y>0])
mean_a = np.median(y[y<0])
mean_b = y[y>0].std()
std_a = y[y<0].std()
std_b
=mean_a, scale=std_a), color="tab:red", label="Normal A")
ax.plot(x_values, norm.pdf(x_values, loc=mean_b, scale=std_b), color="tab:blue", label="Normal B")
ax.plot(x_values, norm.pdf(x_values, loc= norm.pdf(x_values, loc=mean_a, scale=std_a) + norm.pdf(x_values, loc=mean_b, scale=std_b)
stacked ="tab:purple", label="Stacked Normals")
ax.plot(x_values, stacked, color
# Final touches
"E1")
plt.xlabel("Density")
plt.ylabel(='best')
plt.legend(loc plt.tight_layout()
Model Specification
Here, the structure of the model is as follows:
We keep the assumption that the E1 values can be drawn from two Gaussian priors, one for A compartments and one for B compartments. To introduce spatial dependency, we use rolling regression with a Gaussian Random Walk (GRW) prior on the latent state variable (p
) that is then squished through a sigmoid function.
NB We only use a GRW on the latent variable in this model as a simpler alternative GRW on the E1. Guess what comes next.
Notation
Let: - \(y_i\): Observed E1 value at bin \(i\) - $ z_i {0, 1} $: Latent compartment assignment (0 = B, 1 = A) - $ _A, _B $: Mean E1 values for compartments A and B - $ $: Shared standard deviation across bins - $ p $: Prior probability of a bin belonging to compartment A
Model Variables
\[ \begin{aligned} \mu_A &\sim \mathcal{N}(0.5, 0.5) \\ \mu_B &\sim \mathcal{N}(-0.5, 0.5) \\ \sigma &\sim \text{HalfNormal}(1) \\ logit(p) &\sim \text{GaussianRandomWalk}(\sigma) \\ p_i &\sim \text{Sigmoid}(logit(p)) \\ z_i &\sim \text{Bernoulli}(p_i) \\ \mu_i &= z_i \cdot \mu_A + (1 - z_i) \cdot \mu_B \\ y_i &\sim \mathcal{N}(\mu_i, \sigma) \\ \end{aligned} \]
PyMC
Visualize distributions
#params = [(0.5, 0.5), (2, 2), (1, 1)]
= [3, 5, 10, 20]
nus = [-0.5, 0.5]
mus = ["tab:blue", "tab:orange", "tab:green", "tab:purple"]
colors
= plt.subplots((len(nus)), figsize=(10, 4*len(nus)))
fig, axs = axs.flatten()
axs
for nu,col,ax in zip(nus, colors, axs):
with pm.Model() as yhat_priors:
= pm.StudentT("nu_T", nu=nu, mu=mus, sigma=0.3)
nu_T = pm.Normal("nu_norm", mu=mus, sigma=0.3)
nu_norm = pm.sample_prior_predictive(samples=10000)
prior "Comparison of priors for yhat")
ax.set_title("nu_T"].values, label=f"StudT(nu={nu})", ax=ax, fill_kwargs={"alpha": 0.1, "color":colors[0]})
az.plot_kde(prior.prior["Comparison of priors for yhat")
ax.set_title("nu_norm"].values, label="Normal", ax=ax, fill_kwargs={"alpha": 0.1, "color":colors[1]})
az.plot_kde(prior.prior[-2, 2)
ax.set_xlim( plt.tight_layout()
Sampling: [nu_T, nu_norm]
Sampling: [nu_T, nu_norm]
Sampling: [nu_T, nu_norm]
Sampling: [nu_T, nu_norm]
Legacy: Centered GRW
with pm.Model(coords={"pos": df.index.values}) as legacy_model:
"""
Model E1 track as a continous mixture of two normals using a Gaussian Random Walk (GRW) to model the mixing proportions.
The GRW is applied to the logit space to ensure the mixing proportions are between 0 and 1.
"""
= pm.Data("e1", df.e1.values, dims="pos")
e1 = e1.shape[0]
n
= pm.Normal("mu_a", 0.5, 0.3)
mu_a = pm.Normal("mu_b", -0.5, 0.3)
mu_b = pm.HalfNormal("sigma", 0.3)
sigma
# GRW over logit space
= pm.HalfNormal("grw_sigma", 0.1)
grw_sigma = pm.GaussianRandomWalk("logit_w",
logit_w =grw_sigma, shape=n,
sigma=pm.Normal.dist(mu=0.0, sigma=1.0), # or a tighter prior if preferred
init_dist="pos"
dims
)= pm.Deterministic("w", pm.math.sigmoid(logit_w), dims="pos")
w
= pm.Normal.dist(mu = pm.math.stack([mu_a, mu_b]),
components = sigma,
sigma =(2,))
shape
# Likelihood estimate
= pm.Mixture("y_hat", w=pm.math.stack([w,1-w], axis=1), comp_dists=components, observed=e1, dims='pos') y_hat
New: Non-centered GRW
This patch is still a Gaussian Random Walk, but expressed in a non-centered form.
What stays the same:
- The process is still a cumulative sum of Gaussian increments, i.e.,
\[\text{logit}_w[t] = \sum_{i=1}^t \epsilon_i \cdot \sigma\]
where \(\epsilon_i \sim \mathcal{N}(0,1)\).
This is exactly what a GRW is.
What’s different:
- Instead of sampling \(\text{logit}_w\) directly, we sample the standardized increments \(\epsilon\) and reconstruct \(\text{logit}_w\) via
cumsum
. - This reduces posterior curvature and pathologies, especially around \(\sigma\), improving NUTS performance.
Why it helps:
- In the centered form, large values of
grw_sigma
lead to very spread-out walks → hard geometry. - The non-centered form handles the scale explicitly, so NUTS explores \(\epsilon\) space more efficiently.
import pytensor.tensor as pt
with pm.Model(coords={"pos": df.index.values}) as model:
"""
Model E1 track as a continous mixture of two normals using a Gaussian Random Walk (GRW) to model the mixing proportions.
The GRW is applied to the logit space to ensure the mixing proportions are between 0 and 1.
This is a non-centered reparameterization of the model.
The model uses ordered parameters for mu_a and mu_b to ensure they are distinct and do not switch.
"""
= pm.Data("e1", df.e1.values, dims="pos")
e1
# Old version, not ordered
# mu_a = pm.Normal("mu_a", 0.5, 0.3)
# mu_b = pm.Normal("mu_b", -0.5, 0.3)
# New: Make ordered parameters for mu_a and mu_b in stead
= pm.Normal("mu", mu=[-0.5, 0.5], sigma=0.3,
mu =pm.distributions.transforms.ordered, # IMPORTANT
transform=2,
shape
)= pm.HalfNormal("sigma", 0.3, shape=2)
sigma
# GRW over logit space; non-centered reparameterization
= pm.HalfNormal("grw_sigma", 0.05)
grw_sigma # Normal distribution for eps
#eps = pm.Normal("eps", mu=0.0, sigma=1, dims="pos")
# StudentT distribution for eps (got alot of divergences with nu=3)
= pm.StudentT("eps", nu=10, mu=0.0, sigma=1.0, dims="pos")
eps
= pm.Deterministic("logit_w", pt.cumsum(eps * grw_sigma), dims="pos")
logit_w = pm.Deterministic("w", pm.math.sigmoid(logit_w), dims="pos")
w
# Old version, not ordered
# components = pm.Normal.dist(mu = pm.math.stack([mu_a, mu_b]), sigma = sigma, shape=(2,))
# New: Use the ordered mu parameters (cleaner code as well)
# components = pm.Normal.dist(mu=mu, sigma=sigma, shape=2)
= pm.StudentT.dist(nu=10, mu=mu, sigma=sigma, shape=2)
components
# Likelihood estimate
= pm.Mixture("y_hat", w=pm.math.stack([w,1-w], axis=1), comp_dists=components, observed=e1, dims='pos') y_hat
Model graphs
# Visualize
# display(legacy_model.to_graphviz(),
# model.to_graphviz())
= pm.model_to_graphviz(legacy_model, figsize=(5,5))
gv1 = pm.model_to_graphviz(model, figsize=(5,5))
gv2
display(gv1, gv2)
with model:
# Sample from the prior
= pm.sample_prior_predictive(samples=1000)
idata
idata
Sampling: [eps, grw_sigma, mu, sigma, y_hat]
-
<xarray.Dataset> Size: 70MB Dimensions: (chain: 1, draw: 1000, sigma_dim_0: 2, pos: 2913, mu_dim_0: 2) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * sigma_dim_0 (sigma_dim_0) int64 16B 0 1 * pos (pos) int64 23kB 47 48 49 50 51 52 ... 3240 3241 3242 3243 3244 * mu_dim_0 (mu_dim_0) int64 16B 0 1 Data variables: grw_sigma (chain, draw) float64 8kB 0.03994 0.01699 ... 0.1122 0.00508 sigma (chain, draw, sigma_dim_0) float64 16kB 0.4904 ... 0.5388 eps (chain, draw, pos) float64 23MB -1.583 -1.34 ... 1.975 -0.2932 mu (chain, draw, mu_dim_0) float64 16kB -0.7911 0.3249 ... 0.891 w (chain, draw, pos) float64 23MB 0.4842 0.4708 ... 0.5319 0.5315 logit_w (chain, draw, pos) float64 23MB -0.06322 -0.1167 ... 0.1263 Attributes: created_at: 2025-06-23T12:45:34.417740+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 23MB Dimensions: (chain: 1, draw: 1000, pos: 2913) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * pos (pos) int64 23kB 47 48 49 50 51 52 ... 3240 3241 3242 3243 3244 Data variables: y_hat (chain, draw, pos) float64 23MB 0.2319 0.2511 ... 1.004 1.267 Attributes: created_at: 2025-06-23T12:45:34.423663+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 47kB Dimensions: (pos: 2913) Coordinates: * pos (pos) int64 23kB 47 48 49 50 51 52 ... 3240 3241 3242 3243 3244 Data variables: y_hat (pos) float64 23kB 0.9196 1.041 0.8465 ... 0.4994 0.02976 0.0007592 Attributes: created_at: 2025-06-23T12:45:34.426367+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
= az.extract(idata, group="prior_predictive", var_names=["y_hat"])
prior_predictive_e1
print(pd.DataFrame({
"Statistic": ["Min", "Max", "Std"],
"Value": [
min(),
prior_predictive_e1.values.max(),
prior_predictive_e1.values.
prior_predictive_e1.values.std()
]
}))
#y_lik_norm (min, max, std): {prior_predictive_e1.y_lik_norm.values.min(), prior_predictive_e1.y_lik_norm.values.max(), prior_predictive_e1.y_lik_norm.values.std()}
Statistic Value
0 Min -5.779701
1 Max 5.194263
2 Std 0.666041
= df['e1'].values
e1_obs
az.plot_dist(
e1_obs,="hist",
kind="C1",
color={"alpha": 0.6, "bins": 25},
hist_kwargs="observed"
label
)
= idata.prior_predictive["y_hat"].stack(sample=("chain", "draw")).values
prior_pred_obs = prior_pred_obs[prior_pred_obs > e1_obs.min()] # Constrain the range to observed values
prior_pred_obs = prior_pred_obs[prior_pred_obs < e1_obs.max()] # Constrain the range to observed values
prior_pred_obs
az.plot_dist(
prior_pred_obs,="hist",
kind={"alpha": 0.6, "bins":25},
hist_kwargs="simulated",
label
)=90);
plt.xticks(rotation
Posterior sampling
with model:
# Sample from the posterior
= pm.sample(1000, tune=1000, cores=8, chains=8, target_accept=0.95)
trace
idata.extend(trace)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (8 chains in 8 jobs)
NUTS: [mu, sigma, grw_sigma, eps]
Outfrom legacy_model
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (8 chains in 8 jobs)
NUTS: [mu_a, mu_b, sigma, grw_sigma, logit_w]
Sampling 8 chains for 2_000 tune and 1_000 draw iterations (16_000 + 8_000 draws total) took 1377 seconds.
There were 106 divergences after tuning. Increase `target_accept` or reparameterize.
Chain 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 4 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 5 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 6 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 7 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
with model:
# Posterior predictive checks
= pm.sample_posterior_predictive(idata, var_names=["y_hat"], random_seed=42, extend_inferencedata=True) ppc
Sampling: [y_hat]
idata
-
<xarray.Dataset> Size: 560MB Dimensions: (chain: 8, draw: 1000, pos: 2913, mu_dim_0: 2, sigma_dim_0: 2) Coordinates: * chain (chain) int64 64B 0 1 2 3 4 5 6 7 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * pos (pos) int64 23kB 47 48 49 50 51 52 ... 3240 3241 3242 3243 3244 * mu_dim_0 (mu_dim_0) int64 16B 0 1 * sigma_dim_0 (sigma_dim_0) int64 16B 0 1 Data variables: eps (chain, draw, pos) float64 186MB -0.1887 -1.325 ... -0.3492 mu (chain, draw, mu_dim_0) float64 128kB -0.3947 0.506 ... 0.5095 sigma (chain, draw, sigma_dim_0) float64 128kB 0.2407 ... 0.3263 grw_sigma (chain, draw) float64 64kB 0.6574 0.6738 ... 0.5938 0.6325 logit_w (chain, draw, pos) float64 186MB -0.1241 -0.9952 ... -8.285 w (chain, draw, pos) float64 186MB 0.469 0.2699 ... 0.0002523 Attributes: created_at: 2025-06-23T13:00:29.435038+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0 sampling_time: 892.3874158859253 tuning_steps: 1000
-
<xarray.Dataset> Size: 186MB Dimensions: (chain: 8, draw: 1000, pos: 2913) Coordinates: * chain (chain) int64 64B 0 1 2 3 4 5 6 7 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * pos (pos) int64 23kB 47 48 49 50 51 52 ... 3240 3241 3242 3243 3244 Data variables: y_hat (chain, draw, pos) float64 186MB 0.2372 1.074 1.03 ... -0.14 0.5012 Attributes: created_at: 2025-06-23T13:02:28.738804+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 984kB Dimensions: (chain: 8, draw: 1000) Coordinates: * chain (chain) int64 64B 0 1 2 3 4 5 6 7 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: (12/17) step_size (chain, draw) float64 64kB 0.003715 ... 0.00323 index_in_trajectory (chain, draw) int64 64kB 134 74 -609 ... 319 -385 energy (chain, draw) float64 64kB 6.848e+03 ... 6.794e+03 diverging (chain, draw) bool 8kB False False ... False False reached_max_treedepth (chain, draw) bool 8kB False False ... False False tree_depth (chain, draw) int64 64kB 10 10 10 10 ... 10 10 10 10 ... ... step_size_bar (chain, draw) float64 64kB 0.003645 ... 0.003552 perf_counter_diff (chain, draw) float64 64kB 0.4756 0.4761 ... 0.4752 perf_counter_start (chain, draw) float64 64kB 1.745e+06 ... 1.745e+06 largest_eigval (chain, draw) float64 64kB nan nan nan ... nan nan acceptance_rate (chain, draw) float64 64kB 0.9843 0.9901 ... 0.9775 max_energy_error (chain, draw) float64 64kB -0.1797 ... 0.08001 Attributes: created_at: 2025-06-23T13:00:29.450055+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0 sampling_time: 892.3874158859253 tuning_steps: 1000
-
<xarray.Dataset> Size: 70MB Dimensions: (chain: 1, draw: 1000, sigma_dim_0: 2, pos: 2913, mu_dim_0: 2) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * sigma_dim_0 (sigma_dim_0) int64 16B 0 1 * pos (pos) int64 23kB 47 48 49 50 51 52 ... 3240 3241 3242 3243 3244 * mu_dim_0 (mu_dim_0) int64 16B 0 1 Data variables: grw_sigma (chain, draw) float64 8kB 0.03994 0.01699 ... 0.1122 0.00508 sigma (chain, draw, sigma_dim_0) float64 16kB 0.4904 ... 0.5388 eps (chain, draw, pos) float64 23MB -1.583 -1.34 ... 1.975 -0.2932 mu (chain, draw, mu_dim_0) float64 16kB -0.7911 0.3249 ... 0.891 w (chain, draw, pos) float64 23MB 0.4842 0.4708 ... 0.5319 0.5315 logit_w (chain, draw, pos) float64 23MB -0.06322 -0.1167 ... 0.1263 Attributes: created_at: 2025-06-23T12:45:34.417740+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 23MB Dimensions: (chain: 1, draw: 1000, pos: 2913) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * pos (pos) int64 23kB 47 48 49 50 51 52 ... 3240 3241 3242 3243 3244 Data variables: y_hat (chain, draw, pos) float64 23MB 0.2319 0.2511 ... 1.004 1.267 Attributes: created_at: 2025-06-23T12:45:34.423663+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 47kB Dimensions: (pos: 2913) Coordinates: * pos (pos) int64 23kB 47 48 49 50 51 52 ... 3240 3241 3242 3243 3244 Data variables: y_hat (pos) float64 23kB 0.9196 1.041 0.8465 ... 0.4994 0.02976 0.0007592 Attributes: created_at: 2025-06-23T12:45:34.426367+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
# az.plot_trace(idata, var_names=['mu_a','mu_b', 'grw_sigma'],combined=True, backend_kwargs={'layout': 'tight'})
=['mu', 'sigma', 'grw_sigma'], compact=True, backend_kwargs={'layout': 'tight'}); az.plot_trace(idata, var_names
# az.summary(idata, var_names=['mu_a','mu_b', 'grw_sigma'], kind='stats', round_to=2)
=['mu', 'sigma', 'grw_sigma'], kind='stats', round_to=2) az.summary(idata, var_names
mean | sd | hdi_3% | hdi_97% | |
---|---|---|---|---|
mu[0] | -0.40 | 0.01 | -0.41 | -0.39 |
mu[1] | 0.51 | 0.01 | 0.49 | 0.54 |
sigma[0] | 0.24 | 0.01 | 0.23 | 0.25 |
sigma[1] | 0.33 | 0.01 | 0.31 | 0.34 |
grw_sigma | 0.64 | 0.03 | 0.59 | 0.70 |
=["y_hat"], num_pp_samples=200, group='prior');
az.plot_ppc(idata, var_names=["y_hat"], num_pp_samples=200); az.plot_ppc(idata, var_names
with model:
= pm.compute_log_likelihood(idata, var_names=["y_hat"]) ll
az.loo(idata)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
Computed from 8000 posterior samples and 2913 observations log-likelihood matrix.
Estimate SE
elpd_loo -765.31 38.32
p_loo 52.31 -
There has been a warning during the calculation. Please check the results.
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.70] (good) 2898 99.5%
(0.70, 1] (bad) 15 0.5%
(1, Inf) (very bad) 0 0.0%
= az.summary(idata, var_names=["y_hat"], kind="stats", group="posterior_predictive")
trace_yhat trace_yhat
mean | sd | hdi_3% | hdi_97% | |
---|---|---|---|---|
y_hat[47] | 0.206 | 0.551 | -0.768 | 1.117 |
y_hat[48] | 0.300 | 0.517 | -0.715 | 1.115 |
y_hat[49] | 0.359 | 0.493 | -0.633 | 1.187 |
y_hat[50] | 0.384 | 0.474 | -0.615 | 1.168 |
y_hat[51] | 0.418 | 0.451 | -0.565 | 1.157 |
... | ... | ... | ... | ... |
y_hat[3240] | 0.482 | 0.401 | -0.375 | 1.198 |
y_hat[3241] | 0.476 | 0.402 | -0.324 | 1.227 |
y_hat[3242] | 0.470 | 0.400 | -0.374 | 1.188 |
y_hat[3243] | 0.482 | 0.407 | -0.313 | 1.278 |
y_hat[3244] | 0.467 | 0.412 | -0.428 | 1.190 |
2913 rows × 4 columns
print(pd.DataFrame({
"mean": [trace_yhat['mean'].mean(), trace_yhat['mean'].std()],
"sd": [trace_yhat['sd'].mean(), trace_yhat['sd'].std()]
'means', 'sds'])).round(4))
}).set_index(pd.Index([
= plt.subplots(1,2, figsize=(10, 5))
f,ax = ax.flatten()
ax 0].set_title("Mean of y_hat estimate")
ax[0].hist(trace_yhat['mean'], bins=100, color="tab:grey", alpha=0.5)
ax[#ax[0].set_xlim(-0.05,0.05) # zoom-in on the near-zero peak
1].set_title("SD of y_hat")
ax[1].hist(trace_yhat['sd'], bins=30, color="tab:grey", alpha=0.5)
ax[ plt.tight_layout()
mean sd
means -0.0161 0.3699
sds 0.3939 0.0828
= plt.subplots(figsize=(12,4))
f,ax 'start'], trace_yhat['mean'], lw=0.3)
ax.plot(df['start'], trace_yhat['hdi_97%'], trace_yhat['hdi_3%'], color="tab:grey", alpha=0.3, label="95% HDI")
ax.fill_between(df[
plt.legend() plt.tight_layout()
=("chain", "draw")) idata.posterior_predictive.stack(sample
<xarray.Dataset> Size: 187MB Dimensions: (pos: 2913, sample: 8000) Coordinates: * pos (pos) int64 23kB 47 48 49 50 51 52 ... 3240 3241 3242 3243 3244 * sample (sample) object 64kB MultiIndex * chain (sample) int64 64kB 0 0 0 0 0 0 0 0 0 0 0 ... 7 7 7 7 7 7 7 7 7 7 7 * draw (sample) int64 64kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: y_hat (pos, sample) float64 186MB 0.2372 0.9097 -0.4309 ... 0.5997 0.5012 Attributes: created_at: 2025-06-23T13:02:28.738804+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
= idata.posterior_predictive["y_hat"]
y_ppc
= y_ppc.stack(sample=("chain", "draw"))
y_ppc_flat
= y_ppc_flat.mean(dim="sample").values # mean of predictions per bin
ppc_mean = y_ppc_flat.std(dim="sample").values # predictive std (uncertainty)
ppc_std = az.hdi(y_ppc, hdi_prob=0.95) ppc_hdi
= plt.subplots(1,2, figsize=(10, 3))
f,ax0].set_title("Posterior Predictive Mean of y_hat")
ax[0].set_ylabel("Counts")
ax[0].hist(ppc_mean, bins=100, alpha=0.7, label="Mean(y_hat)")
ax[1].set_title("Posterior Predictive Std of y_lik")
ax[1].hist(ppc_std, bins=100, alpha=0.7, label="Std(y_hat)")
ax[ plt.tight_layout()
# ppc_mean: shape (bins,)
# mean_ll: shape (bins,)
= idata.posterior_predictive["y_hat"].mean(dim=["chain", "draw"]).values
posterior_mean = idata.log_likelihood.y_hat.mean(dim=["chain", "draw"]).values
mean_ll = df.start.values
x = df.e1.values
y
# Plot the data as a track (stairs for niceity)
= plt.subplots(figsize=(20, 5))
fig, ax
# Stairs
= np.zeros(2*y.shape[0])
x_stair = np.zeros(2*y.shape[0])
y_stair 0::2] = x
x_stair[1::2] = x + resolution
x_stair[0::2] = y
y_stair[1::2] = y
y_stair[
=(y_stair<0), color="tab:blue", alpha=0.5, ec='None')
ax.fill_between(x_stair, y_stair, where=(y_stair>0), color="tab:red", alpha=0.5, ec='None')
ax.fill_between(x_stair, y_stair, where
# plot the y_ppc_flat mean
= np.zeros(2*y.shape[0])
post_mean_stairs 0::2] = posterior_mean
post_mean_stairs[1::2] = posterior_mean
post_mean_stairs[=1, label="Posterior Predictive Mean", color="tab:orange")
ax.plot(x_stair, post_mean_stairs, lw
# Plot the point-wise log-likelihood
= np.zeros(2*mean_ll.shape[0])
ll_stair 0::2] = mean_ll
ll_stair[1::2] = mean_ll
ll_stair[=0.5, label="Binwise Log-Likelihood", color="tab:green")
ax.plot(x_stair, ll_stair, lw
0, 57_984_683 + resolution)
ax.set_xlim(= np.arange(0, 57_984_683, step=5_000_000)
ticks
ax.set_xticks(ticks)/ 1e6).astype(int), rotation=45)
ax.set_xticklabels((ticks "Genomic Position (Mbp)")
ax.set_xlabel("E1")
ax.set_ylabel("E1")
ax.set_title(="best"); ax.legend(loc
= np.arange(len(y))
bins
=(15, 4))
plt.figure(figsize
# Observed data
="Observed E1", color="black", alpha=0.5)
plt.plot(bins, y, label
# Posterior predictive mean
="Posterior Mean", color="tab:red")
plt.plot(bins, ppc_mean, label
# Credible interval shading
plt.fill_between(
bins,='lower').values,
ppc_hdi.y_hat.sel(hdi='higher').values,
ppc_hdi.y_hat.sel(hdi="tab:red",
color=0.3,
alpha="95% Credible Interval",
label='pre'
step
)
0, xmin=0, xmax=bins.shape[0], color='black', linestyle='--', linewidth=0.5)
plt.hlines(
"Genomic bin")
plt.xlabel("E1 value")
plt.ylabel("Posterior Predictive Mean with 95% CI")
plt.title(
plt.legend()
plt.tight_layout() plt.show()