import arviz as az
import cloudpickle as cp
= "N_N_HN"
name
= az.from_netcdf(f"../results/traces/{name}_trace.nc")
trace
with open(f"../results/models/{name}_model.pkl", "rb") as f:
= cp.load(f) model
Reintroducing latent z in GRW
Reintroducing the latent z assignment in the GRW model
Goals
Here, we re-introduce the latent binwise assignment of compartments A or B from the eigenvector track.
First, we will try to do a simple post-hoc simulation of the \(\hat{y}\) values estimated from the previous notebook. There, no latent state was assigned each bin, but only a mixture weight, \(w\), of the two compartment distribitutions.
trace
-
<xarray.Dataset> Size: 246MB Dimensions: (chain: 7, draw: 1000, pos: 1460, mu_dim_0: 2, sigma_dim_0: 2) Coordinates: * chain (chain) int64 56B 0 1 2 3 4 5 6 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 * mu_dim_0 (mu_dim_0) int64 16B 0 1 * sigma_dim_0 (sigma_dim_0) int64 16B 0 1 Data variables: eps (chain, draw, pos) float64 82MB ... mu (chain, draw, mu_dim_0) float64 112kB ... sigma (chain, draw, sigma_dim_0) float64 112kB ... grw_sigma (chain, draw) float64 56kB ... logit_w (chain, draw, pos) float64 82MB ... w (chain, draw, pos) float64 82MB ... Attributes: created_at: 2025-06-26T14:27:19.174308+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0 sampling_time: 205.36568427085876 tuning_steps: 1000
-
<xarray.Dataset> Size: 82MB Dimensions: (chain: 7, draw: 1000, pos: 1460) Coordinates: * chain (chain) int64 56B 0 1 2 3 4 5 6 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 Data variables: y_hat (chain, draw, pos) float64 82MB ... Attributes: created_at: 2025-06-26T15:23:03.760307+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 82MB Dimensions: (chain: 7, draw: 1000, pos: 1460) Coordinates: * chain (chain) int64 56B 0 1 2 3 4 5 6 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 Data variables: y_hat (chain, draw, pos) float64 82MB ... Attributes: created_at: 2025-06-26T15:22:37.706829+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 862kB Dimensions: (chain: 7, draw: 1000) Coordinates: * chain (chain) int64 56B 0 1 2 3 4 5 6 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: (12/17) step_size (chain, draw) float64 56kB ... reached_max_treedepth (chain, draw) bool 7kB ... process_time_diff (chain, draw) float64 56kB ... smallest_eigval (chain, draw) float64 56kB ... largest_eigval (chain, draw) float64 56kB ... perf_counter_diff (chain, draw) float64 56kB ... ... ... n_steps (chain, draw) float64 56kB ... max_energy_error (chain, draw) float64 56kB ... tree_depth (chain, draw) int64 56kB ... energy (chain, draw) float64 56kB ... index_in_trajectory (chain, draw) int64 56kB ... perf_counter_start (chain, draw) float64 56kB ... Attributes: created_at: 2025-06-26T14:27:19.389497+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0 sampling_time: 205.36568427085876 tuning_steps: 1000
-
<xarray.Dataset> Size: 35MB Dimensions: (chain: 1, draw: 1000, pos: 1460, sigma_dim_0: 2, mu_dim_0: 2) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 * sigma_dim_0 (sigma_dim_0) int64 16B 0 1 * mu_dim_0 (mu_dim_0) int64 16B 0 1 Data variables: eps (chain, draw, pos) float64 12MB ... sigma (chain, draw, sigma_dim_0) float64 16kB ... w (chain, draw, pos) float64 12MB ... grw_sigma (chain, draw) float64 8kB ... logit_w (chain, draw, pos) float64 12MB ... mu (chain, draw, mu_dim_0) float64 16kB ... Attributes: created_at: 2025-06-26T15:22:41.210842+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 12MB Dimensions: (chain: 1, draw: 1000, pos: 1460) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 Data variables: y_hat (chain, draw, pos) float64 12MB ... Attributes: created_at: 2025-06-26T15:22:41.214288+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
-
<xarray.Dataset> Size: 23kB Dimensions: (pos: 1460) Coordinates: * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 Data variables: y_hat (pos) float64 12kB ... Attributes: created_at: 2025-06-26T14:27:19.392559+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
model
\[ \begin{array}{rcl} \text{mu} &\sim & \operatorname{Normal}(\text{<constant>},~0.3)\\\text{sigma} &\sim & \operatorname{HalfNormal}(0,~0.3)\\\text{grw\_sigma} &\sim & \operatorname{HalfNormal}(0,~0.05)\\\text{eps} &\sim & \operatorname{Normal}(0,~1)\\\text{logit\_w} &\sim & \operatorname{Deterministic}(f(\text{eps},~\text{grw\_sigma}))\\\text{w} &\sim & \operatorname{Deterministic}(f(\text{eps},~\text{grw\_sigma}))\\\text{y\_hat} &\sim & \operatorname{MarginalMixture}(f(\text{eps},~\text{grw\_sigma}),~\operatorname{Normal}(\text{mu},~\text{sigma}))\\\text{y\_hat} &\sim & \operatorname{MarginalMixture}(f(\text{eps},~\text{grw\_sigma}),~\operatorname{Normal}(\text{mu},~\text{sigma})) \end{array} \]
Post-hoc decoding of mixture weight
import numpy as np
# Extract posterior samples of w
= trace.posterior["w"].values # shape (chains, draws, pos)
w_samples = w_samples.reshape(-1, w_samples.shape[-1])
w_flat
# Sample z draws
= np.random.default_rng()
rng = rng.binomial(1, w_flat) # shape (samples, pos)
z_samples
# Compute posterior probability of z=1
= z_samples.mean(axis=0) # per-position probability z_probs
z_probs
array([0.393 , 0.31385714, 0.27742857, ..., 0.05357143, 0.05442857,
0.05885714], shape=(1460,))
z_samples
array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[1, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], shape=(7000, 1460))
Try some other logic:
= 0.5
threshold # w_samples: shape (n_draws, n_positions)
# How many of the samples are above the threshold?
= z_samples < threshold # shape (samples, n_positions)
above_thresh
# Proportion of Trues in each row:
= above_thresh.sum(axis=0) # shape: (n_positions,)
n_above # Proportion of samples above threshold for each position
= n_above / above_thresh.shape[0] # shape: (n_positions,)
transition_prop
transition_prop
array([0.607 , 0.68614286, 0.72257143, ..., 0.94642857, 0.94557143,
0.94114286], shape=(1460,))
= 0.05
lower_bound = 0.95
upper_bound
= (transition_prop > lower_bound) & (transition_prop < upper_bound)
in_transition
= trace.observed_data.pos.values
pos # Get positions where transition_prop is in the specified range
= pos[in_transition]
idx_in_transition # transition_positions
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib_inline
'svg')
matplotlib_inline.backend_inline.set_matplotlib_formats(
# Load the data
= 100000
resolution
= pd.Series(pd.read_csv(f"../data/eigs/fibroblast.eigs.{resolution}.cis.vecs.tsv", sep="\t")['E1'].values.flatten())
y = pd.Series(np.arange(0,y.shape[0])*resolution)
x
# Make a DataFrame object
= pd.DataFrame({"start": x, "e1": y})
df =True)
df.dropna(inplace
# Plot the data
= plt.subplots(figsize=(10, 3))
fig, ax =df.e1 > 0, color='tab:red', ec='None', step='pre')
ax.fill_between(df.start, df.e1, where=df.e1 < 0, color='tab:blue', ec='None',step='pre')
ax.fill_between(df.start, df.e1, where/0.5-1, color='black', lw=0.5, label='Transition Propensity')
ax.plot(df.start, transition_prop
#Shade the credible interval
plt.fill_between(
idx_in_transition,-1,
1,
='orange',
color=0.3,
alpha=f'{int(lower_bound*100)}–{int(upper_bound*100)}% Transition Zone'
label
)
="upper left")
plt.legend(loc plt.tight_layout()
Include z as s rv in the model
Let’s fetch the results from the stats of the previous notebook. Then, we will use some of the configs that 1) did not have any divergences or other warnings, and 2) otherwise had good results.
import pandas as pd
import arviz as az
= pd.read_csv("../results/model_comparison.csv", index_col=0)
model_comparison = model_comparison.query("warning == False")
valid
= model_comparison[:3]
top3 = pd.concat([top3, valid])
configs configs
rank | elpd_loo | p_loo | elpd_diff | weight | se | dse | warning | scale | |
---|---|---|---|---|---|---|---|---|---|
N_T_HT_5_10 | 0 | -358.452110 | 62.930324 | 0.000000 | 4.765703e-01 | 26.860618 | 0.000000 | True | log |
N_N_HT_10 | 1 | -359.459620 | 65.393423 | 1.007510 | 2.252832e-01 | 26.771997 | 3.069303 | True | log |
N_T_HT_10_5 | 2 | -361.000076 | 66.023840 | 2.547967 | 2.724386e-08 | 26.918152 | 3.289561 | True | log |
T_T_HN_10_10 | 18 | -484.641065 | 45.381785 | 126.188955 | 1.850862e-08 | 27.387257 | 6.016426 | False | log |
T_N_HN_10 | 20 | -498.162518 | 43.544155 | 139.710408 | 1.722162e-08 | 27.387247 | 6.285019 | False | log |
T_T_HN_5_10 | 21 | -503.006963 | 44.417449 | 144.554853 | 2.125441e-08 | 27.781687 | 7.448034 | False | log |
T_N_HN_5 | 22 | -516.300853 | 42.496101 | 157.848743 | 2.001360e-08 | 27.763489 | 7.684781 | False | log |
N_T_HN_5 | 23 | -530.303872 | 43.506794 | 171.851762 | 1.092973e-08 | 27.132629 | 6.870209 | False | log |
T_T_HN_10_5 | 24 | -539.210197 | 42.654365 | 180.758088 | 1.413144e-08 | 27.420135 | 7.628495 | False | log |
T_T_HN_5_5 | 25 | -556.981749 | 41.920557 | 198.529639 | 1.692740e-08 | 27.799341 | 8.819701 | False | log |
Make a model that includes the latent state \(z\) as a stochastic random variable (rv) in the model on a couple of these configs.
= configs[:6].copy()
top_n top_n.index.to_list()
['N_T_HT_5_10',
'N_N_HT_10',
'N_T_HT_10_5',
'T_T_HN_10_10',
'T_N_HN_10',
'T_T_HN_5_10']
import json
# Fetch the configs from JSON
with open("../results/model_grid.json", "r") as f:
= json.load(f) model_grid
# Extract the configs for the top 3 models
= [config
configs for config in model_grid
if config['name'] in top_n.index]
configs
[{'name': 'N_N_HT_10',
'comp_dist': 'Normal',
'comp_kwargs': {},
'eps_dist': 'Normal',
'eps_kwargs': {'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 10, 'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'sigma': 0.3},
{'name': 'N_T_HT_10_5',
'comp_dist': 'Normal',
'comp_kwargs': {},
'eps_dist': 'StudentT',
'eps_kwargs': {'nu': 10, 'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 5, 'sigma': 0.01},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'sigma': 0.3},
{'name': 'N_T_HT_5_10',
'comp_dist': 'Normal',
'comp_kwargs': {},
'eps_dist': 'StudentT',
'eps_kwargs': {'nu': 5, 'mu': 0.0, 'sigma': 0.5},
'grw_sigma_dist': 'HalfStudentT',
'grw_sigma_kwargs': {'nu': 10, 'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'sigma': 0.3},
{'name': 'T_N_HN_10',
'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 10},
'eps_dist': 'Normal',
'eps_kwargs': {'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfNormal',
'grw_sigma_kwargs': {'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'sigma': 0.3},
{'name': 'T_T_HN_10_10',
'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 10},
'eps_dist': 'StudentT',
'eps_kwargs': {'nu': 10, 'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfNormal',
'grw_sigma_kwargs': {'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'sigma': 0.3},
{'name': 'T_T_HN_5_10',
'comp_dist': 'StudentT',
'comp_kwargs': {'nu': 5},
'eps_dist': 'StudentT',
'eps_kwargs': {'nu': 10, 'mu': 0.0, 'sigma': 1.0},
'grw_sigma_dist': 'HalfNormal',
'grw_sigma_kwargs': {'sigma': 0.05},
'mu_mu': [-0.5, 0.5],
'mu_sigma': 0.3,
'sigma': 0.3}]
Model builder
import pymc as pm
import pytensor.tensor as pt
def build_z_model_from_config(df, cfg):
"""
Build a PyMC model based on the provided configuration dictionary.
"""
with pm.Model(coords={"pos": df.index.values}) as model:
= pm.Data("e1", df.e1.values, dims="pos")
e1
= pm.Normal("mu", mu=cfg['mu_mu'], sigma=cfg['mu_sigma'],
mu =pm.distributions.transforms.ordered, shape=2)
transform= pm.HalfNormal("sigma", sigma=cfg['sigma'], shape=2)
sigma
# components: Normal or StudentT
if cfg['comp_dist'] == "Normal":
= pm.Normal.dist(mu=mu, sigma=sigma, shape=2)
components else:
= pm.StudentT.dist(mu=mu, sigma=sigma, **cfg['comp_kwargs'], shape=2)
components
# grw_sigma: HalfNormal or HalfStudentT
if cfg['grw_sigma_dist'] == "HalfNormal":
= pm.HalfNormal("grw_sigma", **cfg['grw_sigma_kwargs'])
grw_sigma else:
= pm.HalfStudentT("grw_sigma", **cfg['grw_sigma_kwargs'])
grw_sigma
# eps: Normal or StudentT
if cfg['eps_dist'] == "Normal":
= pm.Normal("eps", dims="pos", **cfg['eps_kwargs'])
eps else:
= pm.StudentT("eps", dims="pos", **cfg['eps_kwargs'])
eps
= pm.Deterministic("logit_w", pt.cumsum(eps * grw_sigma), dims="pos")
logit_w = pm.Deterministic("w", pm.math.sigmoid(logit_w), dims="pos")
w
# Make a latent variable for the mixture model
= pm.Bernoulli("z", p=w, dims="pos")
z
= pm.Mixture("y_hat", w=pt.stack([z, 1-z], axis=1), comp_dists=components,
y_hat =e1, dims='pos')
observed
return model
Model sampler
from datetime import datetime
def log(msg, path=".05_PyMC_bernoulli_z.log", overwrite=False):
print(f"LOG: {msg}")
if overwrite:
with open(path, "w") as f:
+ "\n")
f.write(msg else:
with open(path, "a") as f:
+ "\n")
f.write(msg
# Reset the log file
"Resetting log... \n", overwrite=True)
log(
= {}
model_traces
for i, cfg in enumerate(configs):
= cfg['name']
name
if name in model_traces:
continue
else:
= {'model': None, 'trace': None, 'scores': None}
model_traces[name]
f"Sampling {name}...")
log(try:
= datetime.now()
start f"Started at {start.strftime('%H:%M:%S')}")
log('model'] = build_z_model_from_config(df, cfg)
model_traces[name][= model_traces[name]['model']
model = pm.sample(draws=1500, tune=1500, chains=7, cores=7, progressbar=False, model=model)
trace 'trace'] = trace
model_traces[name][f"--> Took {(datetime.now()-start).total_seconds()} seconds!\n")
log(except Exception as e:
f"Model {name} failed: {e}\n") log(
LOG: Resetting log...
LOG: Sampling N_N_HT_10...
LOG: Started at 10:01:32
Multiprocess sampling (7 chains in 7 jobs)
CompoundStep
>NUTS: [mu, sigma, grw_sigma, eps]
>BinaryGibbsMetropolis: [z]
Sampling 7 chains for 1_500 tune and 1_500 draw iterations (10_500 + 10_500 draws total) took 1921 seconds.
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
There were 681 divergences after tuning. Increase `target_accept` or reparameterize.
Chain 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 4 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 5 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 6 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 1948.941466 seconds!
LOG: Sampling N_T_HT_10_5...
LOG: Started at 10:34:01
Multiprocess sampling (7 chains in 7 jobs)
CompoundStep
>NUTS: [mu, sigma, grw_sigma, eps]
>BinaryGibbsMetropolis: [z]
Sampling 7 chains for 1_500 tune and 1_500 draw iterations (10_500 + 10_500 draws total) took 1904 seconds.
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
There were 1501 divergences after tuning. Increase `target_accept` or reparameterize.
Chain 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 4 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 5 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 6 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 1915.577769 seconds!
LOG: Sampling N_T_HT_5_10...
LOG: Started at 11:05:57
Multiprocess sampling (7 chains in 7 jobs)
CompoundStep
>NUTS: [mu, sigma, grw_sigma, eps]
>BinaryGibbsMetropolis: [z]
Sampling 7 chains for 1_500 tune and 1_500 draw iterations (10_500 + 10_500 draws total) took 2001 seconds.
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
There were 735 divergences after tuning. Increase `target_accept` or reparameterize.
Chain 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 4 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 5 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 6 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
LOG: --> Took 2012.403881 seconds!
LOG: Sampling T_N_HN_10...
LOG: Started at 11:39:29
Multiprocess sampling (7 chains in 7 jobs)
CompoundStep
>NUTS: [mu, sigma, grw_sigma, eps]
>BinaryGibbsMetropolis: [z]
Sampling 7 chains for 1_500 tune and 1_500 draw iterations (10_500 + 10_500 draws total) took 1696 seconds.
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
LOG: --> Took 1707.43467 seconds!
LOG: Sampling T_T_HN_10_10...
LOG: Started at 12:07:57
Multiprocess sampling (7 chains in 7 jobs)
CompoundStep
>NUTS: [mu, sigma, grw_sigma, eps]
>BinaryGibbsMetropolis: [z]
Sampling 7 chains for 1_500 tune and 1_500 draw iterations (10_500 + 10_500 draws total) took 1769 seconds.
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
LOG: --> Took 1780.576117 seconds!
LOG: Sampling T_T_HN_5_10...
LOG: Started at 12:37:37
Multiprocess sampling (7 chains in 7 jobs)
CompoundStep
>NUTS: [mu, sigma, grw_sigma, eps]
>BinaryGibbsMetropolis: [z]
Sampling 7 chains for 1_500 tune and 1_500 draw iterations (10_500 + 10_500 draws total) took 1776 seconds.
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
LOG: --> Took 1788.028742 seconds!
for name, model_dict in model_traces.items():
# model_dict.keys(): {'model', 'trace', 'scores'}
try:
print(name)
pm.compute_log_likelihood('trace'],
model_dict[=model_dict['model'],
model=False)
progressbarexcept Exception as e:
print(f"{name} failed computing log_likelihood: {e}")
try:
= pm.sample_prior_predictive(
prior_predictive 1000,
=model_dict['model'])
model'trace'].extend(prior_predictive)
model_dict[except Exception as e:
print(f"{name} failed sampling prior predictive: {e}")
try:
= pm.sample_posterior_predictive(
posterior_predictive 'trace'],
model_dict[=model_dict['model'],
model=False)
progressbar'trace'].extend(posterior_predictive)
model_dict[except Exception as e:
print(f"{name} failed sampling posterior predictive: {e}")
try:
= az.loo(model_dict['trace'])
loo = az.waic(model_dict['trace'])
waic 'scores'] = {
model_dict["loo": loo,
"waic": waic
}except Exception as e:
print(f"{name} failed during loo/waic: {e}")
N_N_HT_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat, z]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
N_T_HT_10_5
Sampling: [eps, grw_sigma, mu, sigma, y_hat, z]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
N_T_HT_5_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat, z]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_N_HN_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat, z]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HN_10_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat, z]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
T_T_HN_5_10
Sampling: [eps, grw_sigma, mu, sigma, y_hat, z]
Sampling: [y_hat]
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:797: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
warnings.warn(
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1655: UserWarning: For one or more samples the posterior variance of the log predictive densities exceeds 0.4. This could be indication of WAIC starting to fail.
See http://arxiv.org/abs/1507.04544 for details
warnings.warn(
'retina')
matplotlib_inline.backend_inline.set_matplotlib_formats(
for name,d in model_traces.items():
= d['trace']
trace = d['scores']['loo']
score
= plt.subplots(3, 2, figsize=(10, 5))
f, axs =["mu", "sigma", "grw_sigma"], compact=True, axes = axs)
az.plot_trace(trace, var_namesf"Trace plot for model: {name}")
plt.suptitle(
plt.tight_layout()
= plt.subplots(2, 2, figsize=(10, 5))
f,axs = axs.flatten()
axs
=["y_hat"], ax=axs[0])
az.plot_ppc(trace, var_names="y_hat", ax=axs[1])
az.plot_loo_pit(trace, y="y_hat", ax=axs[2], ecdf=True)
az.plot_loo_pit(trace, y= axs[3])
az.plot_energy(trace, ax f"PPC and stats for model: {name}")
plt.suptitle( plt.tight_layout()
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/tmp/64285303/ipykernel_3601446/245460086.py:25: UserWarning: Creating legend with loc="best" can be slow with large amounts of data.
plt.tight_layout()
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/tmp/64285303/ipykernel_3601446/245460086.py:25: UserWarning: Creating legend with loc="best" can be slow with large amounts of data.
plt.tight_layout()
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/tmp/64285303/ipykernel_3601446/245460086.py:25: UserWarning: Creating legend with loc="best" can be slow with large amounts of data.
plt.tight_layout()
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/tmp/64285303/ipykernel_3601446/245460086.py:25: UserWarning: Creating legend with loc="best" can be slow with large amounts of data.
plt.tight_layout()
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/tmp/64285303/ipykernel_3601446/245460086.py:25: UserWarning: Creating legend with loc="best" can be slow with large amounts of data.
plt.tight_layout()
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/stats/stats.py:1045: RuntimeWarning: overflow encountered in exp
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/numpy/_core/_methods.py:52: RuntimeWarning: overflow encountered in reduce
return umr_sum(a, axis, dtype, out, keepdims, initial, where)
/tmp/64285303/ipykernel_3601446/245460086.py:25: UserWarning: Creating legend with loc="best" can be slow with large amounts of data.
plt.tight_layout()
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/IPython/core/events.py:82: UserWarning: Creating legend with loc="best" can be slow with large amounts of data.
func(*args, **kwargs)
/home/sojern/miniconda3/envs/pymc/lib/python3.13/site-packages/IPython/core/pylabtools.py:170: UserWarning: Creating legend with loc="best" can be slow with large amounts of data.
fig.canvas.print_figure(bytes_io, **kw)
Save the models and traces
import os.path as op
import os
import cloudpickle
# Save the traces to file (netcdf)
dir = "../results"
for name, d in model_traces.items():
= name+ "_z"
name # Save the model.Model as a cloudpickle.dumps
dir, "models"), exist_ok=True)
os.makedirs(op.join(= op.join(dir, "models", f"{name}_model.pkl")
model_path # print(model_path)
with open(model_path, "wb") as f:
'model'], f)
cloudpickle.dump(d[
# Save the trace as a netcdf file
= op.join(dir, "traces", f"{name}_trace.nc")
trace_path =True)
os.makedirs(op.dirname(trace_path), exist_ok'trace'], trace_path) az.to_netcdf(d[