import arviz as az
= "T_N_HN_10"
name = az.from_netcdf(f"../results/traces/{name}_z_trace.nc") trace
Bootstrapping traces
Bootstrapping latent bin assignment from the trace to create CI for compartment transitions
Goal
Here, we finally try to create a credible interval (CI) for the compartment transitions, getting a confidence estimate on the exact position of the transitions.
Method: bootstrapping
We will use bootstrapping to resample the ‘z’ variable from the trace of a model (across chains and draws), and calculate the mean each time.
Load a trace
trace.posterior
<xarray.Dataset> Size: 491MB Dimensions: (chain: 7, draw: 1500, pos: 1460, mu_dim_0: 2, sigma_dim_0: 2) Coordinates: * chain (chain) int64 56B 0 1 2 3 4 5 6 * draw (draw) int64 12kB 0 1 2 3 4 5 ... 1494 1495 1496 1497 1498 1499 * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 * mu_dim_0 (mu_dim_0) int64 16B 0 1 * sigma_dim_0 (sigma_dim_0) int64 16B 0 1 Data variables: eps (chain, draw, pos) float64 123MB ... z (chain, draw, pos) int64 123MB ... mu (chain, draw, mu_dim_0) float64 168kB ... sigma (chain, draw, sigma_dim_0) float64 168kB ... grw_sigma (chain, draw) float64 84kB ... logit_w (chain, draw, pos) float64 123MB ... w (chain, draw, pos) float64 123MB ... Attributes: created_at: 2025-07-02T10:07:48.044236+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0 sampling_time: 1695.637930393219 tuning_steps: 1500
First, we will stack the chains and draws per pos
, so we have chains*draws
samples per position. Also, as PyMC assigns the default dtype
(int64) even though z
is binary, we will convert it to int8
to save memory (increase efficiency).
import numpy as np
= trace.posterior.z.stack(
stacked =["chain", "draw"])
sample= stacked.data.astype(np.int8)
stacked.data stacked
<xarray.DataArray 'z' (pos: 1460, sample: 10500)> Size: 15MB array([[0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], ..., [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0]], shape=(1460, 10500), dtype=int8) Coordinates: * pos (pos) int64 12kB 24 25 26 27 28 29 ... 1617 1618 1619 1620 1621 * sample (sample) object 84kB MultiIndex * chain (sample) int64 84kB 0 0 0 0 0 0 0 0 0 0 0 ... 6 6 6 6 6 6 6 6 6 6 6 * draw (sample) int64 84kB 0 1 2 3 4 5 6 ... 1494 1495 1496 1497 1498 1499
Then, we define a wrapper function to bootstrap the mean for each position. scipy.stats.bootstrap
is vectorized and is therefore very efficient for resampling each position. However, it will allocate an array of size n_resamples * n_samples * pos
(1.2 TB for 10,000 resamples), unless we use the batch
, where it allocates batch * n_samples * pos
arrays in memory at at time. Here, we define a function to compute the batch value to maximally allocate a 16Gb array in memory at a time.
from scipy.stats import bootstrap
import numpy as np
def bootstrap_mean(data, **kwargs):
"""Compute the confidence interval for the mean."""
= bootstrap(
res
(data,),= np.mean,
statistic = True,
vectorized =1,
axis= "basic",
method **kwargs
)
return res
def compute_batch(arr, max_bytes=16 * 1024**3):
"""
Compute max batch size so that each batch of resamples fits in max_bytes.
"""
# Elements in a single resample
= arr.shape[0] * arr.shape[1]
n_elements # Bytes per element
= np.dtype(arr.dtype).itemsize
bytes_per_elem # Bytes per resample
= n_elements * bytes_per_elem
bytes_per_resample # Max batches
= max(1, int(max_bytes // bytes_per_resample))
batch return batch
compute_batch(stacked.values)
1120
Then, try it out:
= bootstrap_mean(
bootstrap_results **{
stacked, "confidence_level": 0.99,
"n_resamples": 10000,
"batch": compute_batch(stacked.values)
}
)
bootstrap_results
BootstrapResult(confidence_interval=ConfidenceInterval(low=array([ 1.90476190e-04, -9.52380952e-05, 1.90476190e-04, ...,
5.92380952e-02, 7.21904762e-02, 2.85714286e-03], shape=(1460,)), high=array([0.00171429, 0.00095238, 0.00171429, ..., 0.07171429, 0.08590476,
0.00638095], shape=(1460,))), bootstrap_distribution=array([[0.00095238, 0.00152381, 0.00209524, ..., 0.00066667, 0.0007619 ,
0.00104762],
[0.00066667, 0.00038095, 0.00028571, ..., 0.00047619, 0.00038095,
0.00028571],
[0.00095238, 0.00085714, 0.00057143, ..., 0.00104762, 0.00114286,
0.0007619 ],
...,
[0.06609524, 0.05914286, 0.06295238, ..., 0.06685714, 0.0627619 ,
0.06838095],
[0.07904762, 0.07971429, 0.07714286, ..., 0.0807619 , 0.08161905,
0.0772381 ],
[0.00438095, 0.00428571, 0.00447619, ..., 0.00447619, 0.00514286,
0.00428571]], shape=(1460, 10000)), standard_error=array([0.00031488, 0.00021414, 0.00031289, ..., 0.00242412, 0.00264991,
0.00067497], shape=(1460,)))
import matplotlib.pyplot as plt
import matplotlib_inline
'svg')
matplotlib_inline.backend_inline.set_matplotlib_formats(
= stacked.pos.values
x = bootstrap_results.confidence_interval.low
lower = bootstrap_results.confidence_interval.high
upper
= plt.subplots(figsize=(10, 3))
fig, ax
ax.fill_between(
x,
lower,
upper,="95% CI",
label# ec = "black",
)195,204)
ax.set_xlim( plt.tight_layout()
= bootstrap_mean(
bootstrap_99 **{
stacked, "n_resamples": 0,
#"batch": compute_batch(stacked.values),
"confidence_level": 0.99,
"bootstrap_result": bootstrap_results
} )
= bootstrap_99.confidence_interval.low
lower_99 = bootstrap_99.confidence_interval.high
upper_99
= plt.subplots(figsize=(10, 4))
fig, ax
ax.fill_between(
x,
lower_99,
upper_99,
)= [(1360,1420), (1175, 1190)]
zooms
0, 1, color="red", linestyle="--", linewidth=0.5)
ax.vlines(zooms,
= plt.subplots(len(zooms), figsize = (10, 6*len(zooms)))
f,axs
for i,zoom in enumerate(zooms):
= axs[i]
zax
zax.fill_between(
x,
lower_99,
upper_99,
)
zax.set_xlim(zoom)
# ax[1].fill_between(
# x,
# lower_99,
# upper_99,
# )
# ax[1].set_xlim(zoom)
fig.tight_layout() plt.show()