Multi-level caching (C path)

Phasic uses a four-layer caching system to avoid repeating expensive computations. Each layer targets a different bottleneck in the pipeline from model definition to inference:

Layer What is cached Location When it helps
Graph cache Fully constructed Graph objects ~/.phasic_cache/graphs/ Avoids callback-based construction
Parameterised reward-compute cache C-level parameterized_reward_compute_graph (symbolic Gaussian elimination output), theta-independent ~/.phasic_cache/parameterized_reward_compute/<hash>.bin Skips the O(n³) C-side elimination on fresh process startup
Hierarchical SCC composer cache Per-SCC parameterised reward-compute artefacts ~/.phasic_cache/parameterized_reward_compute/scc_<hash>.bin Parallel SCC elimination, cross-graph SCC reuse (opt-in via configure(parallel_elimination=True))
JAX compilation cache JIT-compiled XLA code ~/.jax_cache/ Avoids recompilation on restart

All four caches are persistent — they survive across Python sessions and restarts. Cache correctness is ensured by SHA-256 content hashing: the same graph structure always produces the same hash, and any structural change automatically invalidates the entry. The two parameterised-reward-compute caches are theta-independent — running the same parameterised model with new theta values does not produce new entries.

The parameterised reward-compute cache (covered in this notebook) is what Graph.expectation(), Graph.variance(), Graph.moments(), the FFI / pybind11 forward path, and SVGD use under the hood. There is no opt-in: it is consulted automatically.

from phasic import (
    Graph, Property, StateIndexer, set_log_level,
    clear_caches, clear_model_cache,
    cache_info, get_all_cache_stats, print_all_cache_info,
    get_graph_cache_stats, print_graph_cache_info,
    configure,
)
import phasic.cache as cache
import numpy as np
import time
from vscodenb import set_vscode_theme

set_vscode_theme()

We use the ARG with two parameters as example model. I have added a dummy keyword arg (that does nothing) for demonstration purposes:

nr_samples = 6
indexer = StateIndexer(descendants=[
    Property('loc1', max_value=nr_samples),
    Property('loc2', max_value=nr_samples)
])

initial = [0] * indexer.state_length
initial[indexer.props_to_index(loc1=1, loc2=1)] = nr_samples

def two_locus_arg_2param(state, indexer=None, dummy=None):

    transitions = []
    if state.sum() <= 1: return transitions

    for i in range(indexer.state_length):
        if state[i] == 0: continue
        pi = indexer.index_to_props(i)

        for j in range(i, indexer.state_length):
            if state[j] == 0: continue
            pj = indexer.index_to_props(j)
            
            same = int(i == j)
            if same and state[i] < 2:
                continue
            if not same and (state[i] < 1 or state[j] < 1):
                continue 
            child = state.copy()
            child[i] -= 1
            child[j] -= 1
            loc1 = pi.descendants.loc1 + pj.descendants.loc1
            loc2 = pi.descendants.loc2 + pj.descendants.loc2
            if loc1 <= nr_samples and loc2 <= nr_samples:
                child[indexer.props_to_index(loc1=loc1, loc2=loc2)] += 1
                transitions.append([child, [state[i]*(state[j]-same)/(1+same), 0]]) 

        if state[i] > 0 and pi.descendants.loc1 > 0 and pi.descendants.loc2 > 0:
            child = state.copy()
            child[i] -= 1
            child[indexer.props_to_index(loc1=pi.descendants.loc1, loc2=0)] += 1
            child[indexer.props_to_index(loc1=0, loc2=pi.descendants.loc2)] += 1
            transitions.append([child, [0, 1]])                                

    return transitions

Start from a clean slate and enable info logging so the examples below show cache misses and hits clearly:

set_log_level('INFO')
clear_caches(verbose=True)
[INFO] phasic.graph_cache: Cleared 0 cached graphs
  Removed 199 file(s), preserved directory structure
  Removed 1 file(s), preserved directory structure

Graph cache

Building a graph from a callback function requires exploring the full state space, creating vertices and edges, and can take seconds to minutes for large models. The graph cache stores fully constructed Graph objects on disk so that the same model can be loaded instantly on subsequent calls.

The cache key is a SHA-256 hash of:

  • The callback function’s AST (abstract syntax tree), so whitespace/comment changes are ignored but code changes invalidate the cache
  • All construction parameters (ipv, nr_samples, keyword arguments)

Enable the graph cache by passing graph_cache=True to Graph(). First build constructs graph from callback and saves to cache:

%%time 
graph = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
    graph_cache=True)
[INFO] phasic.graph_cache: Saved graph to cache: 0d365121a03c3302... (1044 vertices)
[INFO] phasic: Saved graph to cache: 1044 vertices
CPU times: user 6.28 s, sys: 135 ms, total: 6.42 s
Wall time: 9.61 s

Second build is loaded from cache:

%%time
graph = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
    graph_cache=True, dummy=42)
[INFO] phasic.graph_cache: Saved graph to cache: e6d9bf26c611c388... (1044 vertices)
[INFO] phasic: Saved graph to cache: 1044 vertices
CPU times: user 6.09 s, sys: 113 ms, total: 6.2 s
Wall time: 9.28 s

If you modify the callback function or pass different parameters, the cache misses and the graph is rebuilt. Even though our dummy keyword arg does nothing, passing a new value triggers a rebuild of the graph:

%%time
graph = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer,
    graph_cache=True, dummy=99)
[INFO] phasic.graph_cache: Saved graph to cache: 4feffc61791d2516... (1044 vertices)
[INFO] phasic: Saved graph to cache: 1044 vertices
CPU times: user 5.92 s, sys: 133 ms, total: 6.05 s
Wall time: 8.13 s
clear_caches(verbose=True)
[INFO] phasic.graph_cache: Cleared 0 cached graphs
  Removed 3 file(s), preserved directory structure

Parameterised reward-compute cache

When computing moments or running SVGD inference, phasic performs Gaussian elimination on the graph to build a symbolic parameterised reward compute graph (PRC) — a linear sequence of operations that can be replayed cheaply with different parameter values. The elimination is O(n³) and is the most expensive step for large models.

The PRC cache stores the elimination result on disk, keyed by a SHA-256 hash of the graph structure (vertices, edges, coefficients — but not the specific theta values). The cache is consulted transparently by every forward call (expectation, variance, pdf, compute_pmf, compute_moments, …). You never call it directly — but you can inspect and clear it via the phasic.cache module.

Build a graph and run an expectation; the C runtime builds the PRC and writes it to disk:

# Start from a clean param-compute cache so the demo's effects are visible.
cache.clear_param_compute_cache()

graph = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer)
graph.update_weights([2, 5])
graph.expectation()
[INFO] phasic.cache: Cleared 0 parameterised compute graph cache files
[INFO] phasic.graph_cache: Saved graph to cache: 0d365121a03c3302... (1044 vertices)
[INFO] phasic: Saved graph to cache: 1044 vertices
[INFO] phasic.c: Auto-activating MPFR for moment computation (condition 1.33e+12 > threshold 1.00e+12)
[INFO] phasic.c: MPFR-A: consuming double PRC at 128-bit precision
[INFO] phasic.c: MPFR computation successful - returning high-precision results
1.3007983196759683
graph2 = Graph(two_locus_arg_2param, ipv=initial, indexer=indexer)
graph2.update_weights([2, 5])
[INFO] phasic.graph_cache: Cache hit: 0d365121a03c3302... (1044 vertices)
[INFO] phasic: Loaded graph from cache: 1044 vertices

The PRC is built lazily on the first forward call and reused thereafter. Because we just cleared the cache and ran one expectation, the on-disk cache now has one entry — calling expectation() again on a fresh graph object with the same structure loads the cached PRC from disk instead of running another O(n³) elimination:

graph2.vertices_length()
1044
graph2.expectation()
[INFO] phasic.c: Auto-activating MPFR for moment computation (condition 1.33e+12 > threshold 1.00e+12)
[INFO] phasic.c: MPFR-A: consuming double PRC at 128-bit precision
[INFO] phasic.c: MPFR computation successful - returning high-precision results
1.3007983196759683

From now on, repeated forward calls on the same graph object reuse the in-memory persistent compute graph (Stage A1) and skip the disk read; new graph objects with the same structure hit the on-disk cache (Stage A2):

%%time
graph2.variance()
[INFO] phasic.c: Auto-activating MPFR for moment computation (condition 1.33e+12 > threshold 1.00e+12)
[INFO] phasic.c: MPFR-A: consuming double PRC at 128-bit precision
[INFO] phasic.c: MPFR computation successful - returning high-precision results
[INFO] phasic.c: Auto-activating MPFR for moment computation (condition 1.33e+12 > threshold 1.00e+12)
[INFO] phasic.c: MPFR-A: consuming double PRC at 128-bit precision
[INFO] phasic.c: MPFR computation successful - returning high-precision results
CPU times: user 11.5 ms, sys: 2.29 ms, total: 13.8 ms
Wall time: 28.1 ms
0.5844898758255803

You can inspect the cache directly to see how many model entries are persisted and how big the cache is on disk:

cache.param_compute_cache_info()
{'cache_dir': '/Users/kmt/.phasic_cache/parameterized_reward_compute',
 'n_files': 0,
 'n_parent_files': 0,
 'n_scc_files': 0,
 'total_size': 0,
 'disabled': True}

The cache key is theta-independent — it covers structure + coefficients only. So updating the weights to a different theta and recomputing does not produce a new cache entry:

graph2.update_weights([1, 7])
graph2.expectation()  # uses the same cached PRC
cache.param_compute_cache_info()['n_files']  # same number of files
0

Because the cache is persistent on disk (~/.phasic_cache/parameterized_reward_compute/), the cached PRC is available even if you restart Python and construct the same graph structure again. This is especially valuable for iterative development, repeated SVGD runs on the same model, and SLURM workers sharing a network filesystem.

Like all phasic on-disk caches, the PRC cache honours PHASIC_DISABLE_CACHE=1 (or phasic.configure(cache_enabled=False)) to skip both reads and writes — useful in CI or when measuring un-cached baselines. The cache directory is controlled by PHASIC_CACHE_DIR (or cache_dir=...).

Format-version mismatches (after a phasic upgrade that changes the on-disk layout) are detected by a magic-string + version header; the loader returns NULL for mismatched files and the caller falls back to a fresh elimination that overwrites the bad file. No user action needed.

JAX compilation cache

When running SVGD inference, JAX JIT-compiles the log-likelihood, kernel, and update functions the first time they are called. This compilation can take 1–10 seconds. The JAX compilation cache stores the compiled XLA code on disk so that subsequent Python sessions skip recompilation entirely.

This cache is managed by JAX itself and is enabled automatically by phasic at import time. The cache key is based on the function structure and input shapes (not values), so different parameter vectors reuse the same compiled code.

Configuration

The default cache directory is ~/.jax_cache/. You can change it via environment variable before importing JAX:

export JAX_COMPILATION_CACHE_DIR=/fast/ssd/jax_cache

Or programmatically with the CompilationConfig class:

from phasic import CompilationConfig

config = CompilationConfig.balanced()   # sensible defaults
config.apply()




# ## JAX Compilation Cache

# ### What It Caches

# JAX caches compiled XLA code based on:
# - Function structure (HLO graph)
# - Input shapes
# - Device configuration

# ### Basic Configuration


from phasic.jax_config import CompilationConfig

# Balanced preset (default)
config = CompilationConfig.balanced()
config.apply()

# Maximum performance
config = CompilationConfig.max_performance()
config.apply()

# Fast compilation (for development)
config = CompilationConfig.fast_compile()
config.apply()

Inspecting caches

Phasic provides a unified API for inspecting all three cache layers.

Overview of all caches

print_all_cache_info()
Cache directory: /Users/kmt/.phasic_cache
Cached compilations: 1
Total size: 1.3 MB

Most recent files (showing up to 10):
  2026-05-25T23:00:00.639138 |   1347.9 KB | graphs/0d365121a03c3302f04a1d5437bdaecf.json

Cache directory: /Users/kmt/.phasic_cache/graphs
Cached graphs: 1
Total size: 1.32 MB

Cache directory: /Users/kmt/.phasic_cache/traces
Status: No cached traces

Individual cache layers

Each layer has its own inspection functions:

# Graph cache
print_graph_cache_info()
Cache directory: /Users/kmt/.phasic_cache/graphs
Cached graphs: 1
Total size: 1.32 MB
# The Python EliminationTrace cache (`cache_trace=True`) is
# deprecated; the analogous information for the C-path PRC
# cache lives under cache.param_compute_cache_info().
# Keeping the call below for reference, but it will typically
# return zero files in current versions.
print_trace_cache_info() if False else None
# JAX compilation cache
jax_info = cache_info()
print(f"JAX cache: {jax_info['num_files']} files, {jax_info['total_size_mb']:.1f} MB")
JAX cache: 0 files, 0.0 MB
# Symbolic compute graph cache (the C-side parameterized_reward_compute_graph cache).
# This is the cache we populated above by calling graph.expectation().
info = cache.param_compute_cache_info()
print(f"param_compute: {info['n_files']} files, {info['total_size'] / 1024:.1f} KB")
param_compute: 0 files, 0.0 KB

Programmatic access

For scripting, get_all_cache_stats() returns a dictionary with statistics for each layer:

stats = get_all_cache_stats()
for name, layer_stats in stats.items():
    print(f"{name}: {layer_stats}")

# get_all_cache_stats() covers the legacy three layers (graph, trace, jax).
# The C-path parameterised reward-compute cache is reported via phasic.cache:
print(f"param_compute: {cache.param_compute_cache_info()}")
jax: {'exists': True, 'path': '/Users/kmt/.jax_cache', 'num_files': 0, 'total_size_mb': 0.0, 'files': []}
graph: {'num_graphs': 1, 'total_size_mb': 1.3162736892700195, 'cache_dir': '/Users/kmt/.phasic_cache/graphs'}
trace: {'total_files': 0, 'total_bytes': 0, 'cache_dir': '/Users/kmt/.phasic_cache/traces'}
param_compute: {'cache_dir': '/Users/kmt/.phasic_cache/parameterized_reward_compute', 'n_files': 0, 'n_parent_files': 0, 'n_scc_files': 0, 'total_size': 0, 'disabled': True}

Listing individual trace entries

You can list the cached traces with metadata about each entry:

# Listing entries of the deprecated Python trace cache is
# no longer useful. The PRC cache on the C path doesn't
# expose per-entry metadata (the on-disk files are
# content-addressed by SHA-256 hash); inspect aggregate
# state via cache.param_compute_cache_info() above.

Clearing caches

Function What it clears
clear_caches() Legacy umbrella: graph, trace, JAX caches
clear_model_cache() Graph cache + (deprecated) Python trace cache
clear_jax_cache() JAX compilation cache only
phasic.cache.clear_param_compute_cache() Parameterised reward-compute cache (monolithic + per-SCC entries) — the cache that this notebook covers
phasic.cache.clear_all_caches() Both phasic on-disk caches; does not touch the JAX cache

The phasic.cache helpers are the modern entry points. For the C path covered in this notebook, clear_param_compute_cache() is the relevant one.

# Clear the C-path symbolic compute graph cache.
n_removed = cache.clear_param_compute_cache()
print(f"Removed {n_removed} cache files")

# Verify
print(f"param_compute now: {cache.param_compute_cache_info()['n_files']} files")
[INFO] phasic.cache: Cleared 0 parameterised compute graph cache files
Removed 0 cache files
param_compute now: 0 files

Cache size

The PRC cache directory accumulates one file per distinct graph structure. For long-running development loops you may want to purge it occasionally:

import phasic.cache as cache

info = cache.param_compute_cache_info()
print(f"{info['n_files']} files, {info['total_size'] / 1024 / 1024:.1f} MB")

# Clear everything:
cache.clear_param_compute_cache()
info = cache.param_compute_cache_info()
print(f"{info['n_files']} files, {info['total_size'] / 1024 / 1024:.1f} MB")
0 files, 0.0 MB

You can also clear from the command line:

# Clear all phasic caches
rm -rf ~/.phasic_cache/ ~/.jax_cache/

Caching with composed graphs

Graphs built through composition methods — add_epoch(), discretize(), laplace_transform(), joint_prob_graph() — fully support the parameterised reward-compute cache. The hash is structure-based (SHA-256 over vertices, edges, and coefficients), so it works identically regardless of how the graph was constructed.

The same composition pipeline always produces the same hash, enabling cache hits across sessions:

# Session 1: builds and caches the PRC
graph = Graph(coalescent)
graph.update_weights([1/N0])
g1 = graph.add_epoch(t1)
g1.update_weights([1/N0, 1/N1, 1])
g2 = g1.add_epoch(t2)
g2.expectation()  # records and caches the PRC

# Session 2: same pipeline, cache hit
graph = Graph(coalescent)
graph.update_weights([1/N0])
g1 = graph.add_epoch(t1)
g1.update_weights([1/N0, 1/N1, 1])
g2 = g1.add_epoch(t2)
g2.expectation()  # cache hit — instant

This works because the resulting graph structure (vertices, edges, coefficient layout) is deterministic for a given composition sequence.