Phasic supports two layers of parallelism: multi-CPU compute (via JAX’s vmap/pmap for SVGD particles, and OpenMP for SCC-decomposed elimination) and multi-node distributed JAX on SLURM clusters. The parallelization strategy for SVGD depends on the number of available JAX devices:
Devices
Strategy
Description
>1
pmap
Particles distributed across devices
1 (multi-CPU)
vmap
Vectorized computation on single device
1 (single-CPU)
none
Sequential execution
JAX is not imported eagerly by import phasic. It is loaded on first use — either when you access phasic.SVGD/phasic.MCMC/phasic.BFFG, when you call phasic.configure(compute='jax-cpu'), or when you call phasic.init_parallel(). This makes pure-C++ pipelines (no JAX dep) cheaper to import.
from phasic import ( Graph, with_ipv, configure, get_config, init_parallel, detect_environment, EnvironmentInfo, ParallelConfig, set_log_level,)import numpy as npimport time, sysfrom vscodenb import set_vscode_themeset_vscode_theme()# JAX is imported lazily on first use. To force activation now# (e.g. so the next cell can call jax.devices()), call any of# init_parallel(), phasic.configure(compute='jax-cpu'), or just# access a JAX-dependent attribute like phasic.SVGD.init_parallel()import jax
[WARNING] phasic.auto_parallel: JAX already initialized with 1 device(s), but 10 CPU(s) available.
To use all CPUs, restart kernel and import phasic first:
import phasic as pta
pta.init_parallel(cpus=10)
We use a simple parameterized coalescent model throughout:
nr_samples =4@with_ipv([nr_samples] + [0] * (nr_samples -1))def coalescent(state): transitions = []for i inrange(state.size):for j inrange(i, state.size): same =int(i == j)if same and state[i] <2:continueifnot same and (state[i] <1or state[j] <1):continue new = state.copy() new[i] -=1 new[j] -=1 new[i + j +1] +=1 transitions.append((new, [state[i] * (state[j] - same) / (1+ same)])) return transitionsgraph = Graph(coalescent)graph = Graph(coalescent)graph.update_weights([7.0])observed_data = graph.sample(500)
Automatic multi-CPU setup
JAX initialisation runs the first time it is needed. It performs the following steps:
Detects available CPUs — on Apple Silicon, it uses performance cores; otherwise os.cpu_count() (and respects SLURM_CPUS_PER_TASK when present).
Sets XLA_FLAGS — configures --xla_force_host_platform_device_count so JAX creates one virtual device per CPU.
Enables 64-bit precision — sets JAX_ENABLE_X64=1 for numerical accuracy.
Sets platform to CPU — via JAX_PLATFORMS=cpu.
Imports JAX.
The trigger is one of:
Calling phasic.init_parallel() explicitly (recommended in notebooks where you want JAX ready early).
Calling phasic.configure(compute='jax-cpu') or compute='jax-gpu'.
Accessing any JAX-dependent attribute (phasic.SVGD, phasic.MCMC, Graph.svgd(...), etc.) — this lazily fires JAX activation.
You can check the current device setup after activation:
devices = jax.devices()print(f"Number of JAX devices: {len(devices)}")print(f"Device type: {devices[0].platform}")
Number of JAX devices: 1
Device type: cpu
Overriding CPU count
To change the number of CPUs phasic uses for JAX, set the PTDALG_CPUS environment variable before JAX activation (i.e. before any call that triggers _ensure_jax_active()):
exportPTDALG_CPUS=4python my_script.py
Or in a notebook (before any imports or JAX-touching code):
For more control, call init_parallel() to explicitly configure parallelism. This is especially useful when you want to set a specific CPU count or when running on a SLURM cluster where automatic detection should be triggered:
[WARNING] phasic.auto_parallel: JAX already initialized with 1 device(s), but 10 CPU(s) available.
To use all CPUs, restart kernel and import phasic first:
import phasic as pta
pta.init_parallel(cpus=10)
config = init_parallel(cpus=8) # Use exactly 8 devices
The returned ParallelConfig is stored globally and used by graph.svgd() to select its parallelization strategy.
Parallelism in SVGD
The graph.svgd() method automatically parallelizes particle updates across available devices. The parallel parameter controls the strategy:
Value
Behavior
None (default)
Auto-select: pmap if multiple devices, vmap otherwise
'pmap'
Distribute particles across devices (multiple CPUs/GPUs)
'vmap'
Vectorize particles on a single device
'none'
Sequential execution (useful for debugging)
With pmap, particles are split evenly across devices. Each device computes log-likelihood and gradients for its particles independently. The SVGD kernel computation and particle updates are also parallelized.
The n_devices parameter limits how many devices pmap uses (default: all available).
To explicitly control the strategy:
# Force pmap with specific device countsvgd = graph.svgd( observed_data, n_particles=50, n_iterations=50, parallel='pmap', n_devices=len(jax.devices()), progress=False)result_pmap = svgd.get_results()print(f"pmap posterior mean: {result_pmap['theta_mean']}")# Force vmap (single-device vectorization)svgd = graph.svgd( observed_data, n_particles=50, n_iterations=50, parallel='vmap', progress=False)result_vmap = svgd.get_results()print(f"vmap posterior mean: {result_vmap['theta_mean']}")
/Users/kmt/phasic/.pixi/envs/default/lib/python3.13/site-packages/phasic/__init__.py:7039: UserWarning: parallel='pmap' requested but only 1 JAX device available. Using 'vmap' instead. To use pmap, configure more devices via PTDALG_CPUS environment variable or initialize_distributed().
svgd = SVGD(
Using the %%slurm magic from the vscodenb package, you can allocate SLURM redources for a single cell. Phasic will automatically pick up the allocated resources.
Graph construction is inherently sequential (state-space exploration). Moment / expectation computation defaults to a single CPU: the monolithic parameterised reward compute graph (PRC) is built once with a single-threaded Gaussian elimination and cached on disk for subsequent runs.
For large graphs, opt into the hierarchical SCC composer to parallelise the elimination across CPUs. Sibling SCCs within a level run concurrently under OpenMP, and per-SCC PRC artefacts are content-hashed so identical SCCs across different parent graphs share cache entries:
import phasicphasic.configure( parallel_elimination=True, parallel_elimination_max_concurrent=8, # cap simultaneous SCC computes per level)# Or via env var (e.g. in a SLURM job script):# export PHASIC_HIERAR_ELIMINATION=1# export PHASIC_MAX_PARALLEL_SCCS=8
For evaluating a model at many parameter values (outside SVGD), use JAX’s vmap or pmap directly:
import jaximport jax.numpy as jnpgraph = Graph(coalescent)model = graph.pmf_from_graph()# Evaluate at many parameter values in paralleltheta_grid = jnp.linspace(0.5, 10.0, 100).reshape(-1, 1)times = jnp.array([1.0, 2.0, 3.0])batch_model = jax.vmap(lambda theta: model(theta, times))all_pdfs = batch_model(theta_grid) # Shape: (100, 3)
For multi-device parallelism, replace vmap with pmap and reshape the input to (n_devices, batch_per_device, ...).
Context managers
All phasic settings — including the SVGD-particle parallelism strategy (pmap / vmap / none) — can be changed temporarily for a block of code via phasic.configure(...) used as a context manager. The settings (and the corresponding env vars and dataclass fields) are restored at the end of the block. See the Configuration tutorial for the full surface.
For SVGD specifically, the relevant field is svgd_strategy:
'auto' (default) — pick 'pmap' if multiple JAX devices are visible, else 'vmap'.
Graph.svgd()’s own parallel= kwarg overrides this on a per-call basis; the config field controls only the default when parallel=None.
Disabling parallelism
Use with configure(svgd_strategy='none'): to force sequential SVGD execution inside a block — useful for debugging or when you want predictable serial behaviour:
with configure(svgd_strategy='none'): svgd = graph.svgd( observed_data, n_particles=20, n_iterations=10, progress=False ) result = svgd.get_results()print(f"Sequential posterior mean: {result['theta_mean']}")# Outside the context, the default (auto / pmap / vmap) is restored.print(f"Strategy outside block: {get_config().svgd_strategy}")
Sequential posterior mean: [7.08906581]
Strategy outside block: auto
Switching strategy for one block
Same pattern, any strategy:
# Force vmap (single-device vectorisation) for this block only.with configure(svgd_strategy='vmap'): svgd = graph.svgd( observed_data, n_particles=20, n_iterations=10, progress=False ) result = svgd.get_results()print(f"vmap posterior mean: {result['theta_mean']}")
vmap posterior mean: [7.23017806]
Temporary library-wide configuration
phasic.configure(...) is itself a context manager. Use it to toggle SCC-composer parallelism, MPFR precision, the cache directory, or any other PhasicConfig field for a single block of code. Settings (and the backing env vars) are restored when the block exits.
# Temporarily enable parallel SCC elimination just for this block.# Outside the `with`, the previous configuration is restored.with configure(parallel_elimination=True, parallel_elimination_max_concurrent=4):# Any moment / expectation call here uses the SCC composer. graph_big = Graph(coalescent) graph_big.update_weights([7.0])print('expectation (SCC composer):', graph_big.expectation())# Outside: parallel_elimination is back to its prior value.print('parallel_elimination after block:', get_config().parallel_elimination)
expectation (SCC composer): 0.2142857142857143
parallel_elimination after block: False
SLURM clusters
Phasic detects SLURM environments automatically and configures JAX accordingly. There are two modes:
Mode
SLURM config
Use case
Single-node
--cpus-per-task=N
Multiple CPUs on one machine
Multi-node
--nodes=N --ntasks-per-node=1
Distribute across machines
Single-node SLURM
For single-node jobs, phasic reads SLURM_CPUS_PER_TASK and creates that many JAX devices. The SLURM script is straightforward:
For multi-node jobs, phasic uses jax.distributed.initialize() to coordinate computation across machines. The SLURM batch script sets up the coordinator address and launches one process per node via srun:
#!/bin/bash#SBATCH --nodes=4#SBATCH --ntasks-per-node=1#SBATCH --cpus-per-task=8#SBATCH --time=01:00:00# First node becomes the coordinatorCOORDINATOR_NODE=$(scontrol show hostnames $SLURM_JOB_NODELIST|head-n 1)exportSLURM_COORDINATOR_ADDRESS=$COORDINATOR_NODEexportJAX_COORDINATOR_PORT=12345# Configure XLA for local CPUsexportXLA_FLAGS="--xla_force_host_platform_device_count=$SLURM_CPUS_PER_TASK"exportJAX_PLATFORMS=cpuexportJAX_ENABLE_X64=1# Launch one process per nodesrun python my_distributed_inference.py
With 4 nodes x 8 CPUs each, this creates 32 global JAX devices.
Multi-node Python script
The Python script uses phasic’s distributed utilities to initialize JAX across nodes:
from phasic import ( Graph, init_parallel, detect_slurm_environment, get_coordinator_address, initialize_jax_distributed)import jax# Step 1: Detect SLURM environmentslurm_env = detect_slurm_environment()print(f"Process {slurm_env['process_id']}/{slurm_env['num_processes']}")print(f"CPUs per task: {slurm_env['cpus_per_task']}")print(f"Nodes: {slurm_env['node_count']}")# Step 2: Get coordinator address (first node)coordinator = get_coordinator_address(slurm_env)print(f"Coordinator: {coordinator}")# Step 3: Initialize JAX distributedinitialize_jax_distributed( coordinator_address=coordinator, num_processes=slurm_env['num_processes'], process_id=slurm_env['process_id'])print(f"Local devices: {len(jax.local_devices())}")print(f"Global devices: {len(jax.devices())}")# Step 4: Build model and run SVGDgraph = Graph(my_model)graph.update_weights([7.0])observed_data = graph.sample(1000)# SVGD distributes particles across all 32 global devicesresult = graph.svgd( observed_data, n_particles=320, # Must be divisible by global device count parallel='pmap')
SLURM detection API
The detect_slurm_environment() function parses SLURM environment variables and returns a dictionary with the allocation details:
from phasic import detect_slurm_environmentslurm_env = detect_slurm_environment()print(f"Running under SLURM: {slurm_env['is_slurm']}")if slurm_env['is_slurm']:print(f"Job ID: {slurm_env['job_id']}")print(f"Process ID: {slurm_env['process_id']}")print(f"Total processes: {slurm_env['num_processes']}")print(f"CPUs per task: {slurm_env['cpus_per_task']}")print(f"Nodes: {slurm_env['node_count']}")
Running under SLURM: False
Parallelism and graph construction
Graph construction explores the state space by calling the user-provided callback function repeatedly. This is inherently sequential — there is no automatic parallelization of graph construction.
However, if you need to build multiple graphs (e.g., for different sample sizes), you can parallelize at the Python level:
from concurrent.futures import ProcessPoolExecutordef build_graph(n):@with_ipv([n] + [0] * (n -1))def coalescent(state):# ... callback ...return transitionsreturn Graph(coalescent, graph_cache=True)# Build graphs for different sample sizes in parallelwith ProcessPoolExecutor(max_workers=4) as pool: graphs =list(pool.map(build_graph, [4, 6, 8, 10]))
Using graph_cache=True ensures that each graph is built only once and loaded from cache on subsequent runs.
Parallelism and parameterised PRC evaluation
Gaussian elimination is O(n³) and, by default, runs as a single dense computation on one CPU — producing a parameterised reward compute graph (PRC) that is theta-independent and disk-cached at ~/.phasic_cache/parameterized_reward_compute/. Once the PRC is built, evaluating it for a specific parameter vector is O(n) and can run in parallel across particles using JAX’s vmap or pmap.
For large graphs, you can also parallelise the elimination itself. Setting phasic.configure(parallel_elimination=True) (or PHASIC_HIERAR_ELIMINATION=1 in the shell) routes the elimination through the hierarchical SCC composer: the graph is decomposed into strongly-connected components and sibling SCCs within a level are eliminated concurrently under OpenMP. Per-SCC artefacts are content-hashed and shared across graphs and processes, so they can also be precomputed on a SLURM cluster via phasic.distributed_scc.precompute_distributed.
from phasic import Graph, configureconfigure(parallel_elimination=True)def coalescent(state): transitions = []for i inrange(state.size):for j inrange(i, state.size): same =int(i == j)if same and state[i] <2:continueifnot same and (state[i] <1or state[j] <1):continue new = state.copy() new[i] -=1 new[j] -=1 new[i + j +1] +=1 transitions.append((new, [state[i] * (state[j] - same) / (1+ same)]))return transitionsnr_samples =30graph = Graph(coalescent, ipv=[nr_samples] + [0] * (nr_samples -1))graph.update_weights([7.0])graph.expectation()
0.2761904761904762
When you call graph.svgd(), phasic automatically:
Builds the PRC once (sequential, or per-SCC parallel if the hierarchical composer is enabled)
Creates a JIT-compiled function that evaluates the PRC (O(n) per evaluation)
Uses pmap or vmap to evaluate this function across particles in parallel
This means the expensive elimination happens only once per structural model, and all subsequent evaluations during SVGD iterations benefit from parallelism.
Computing expectations in parallel
Moment and expectation computations (graph.moments(), graph.expected_waiting_time()) operate on a single graph with fixed parameters and are not parallelized.
To evaluate expectations at many parameter values in parallel, use the model function with vmap:
Phasic recognizes these environment variables. The PHASIC_* ones are also reachable via phasic.configure(...); the JAX_* / XLA_* / SLURM_* ones are read but not written by configure.
Variable
Description
PTDALG_CPUS
Override the number of CPUs JAX uses (read at JAX activation time)
OMP_NUM_THREADS
OpenMP thread count for SCC composer (auto-set by phasic at import if unset)
PHASIC_HIERAR_ELIMINATION
1 enables the hierarchical SCC composer (equivalent to configure(parallel_elimination=True))
PHASIC_MAX_PARALLEL_SCCS
Cap simultaneous SCC computes per level (equivalent to parallel_elimination_max_concurrent)
PHASIC_MIN_SCC_SIZE_TO_CACHE
Threshold for caching per-SCC PRCs (equivalent to parallel_elimination_min_subgraph)
PHASIC_CACHE_DIR
Cache root directory (equivalent to cache_dir)
PHASIC_DISABLE_CACHE
1 disables on-disk caching (equivalent to cache_enabled=False)
XLA_FLAGS
JAX/XLA flags; phasic sets --xla_force_host_platform_device_count on JAX activation
JAX_PLATFORMS
Platform selection; phasic sets to cpu by default on JAX activation
JAX_ENABLE_X64
Enable 64-bit precision; phasic enables on JAX activation
SLURM_CPUS_PER_TASK
Read by phasic on SLURM clusters to set OMP and JAX device count
SLURM_COORDINATOR_ADDRESS
Manual override for coordinator node address (multi-node JAX)
JAX_COORDINATOR_PORT
Port for JAX distributed coordinator (default: 12345)
Tips and troubleshooting
Particle count: When using pmap, the number of particles must be divisible by the number of devices. If not, SVGD will raise a ValueError.
JAX import timing: import phasic no longer imports JAX. The first JAX-touching operation (any of init_parallel(), phasic.configure(compute='jax-*'), phasic.SVGD, Graph.svgd(...), etc.) triggers activation. If you want JAX ready before timing-sensitive code, call phasic.init_parallel() near the top of your script or notebook.
Debugging: Pass parallel='none' to Graph.svgd() directly, or wrap the call in with phasic.configure(svgd_strategy='none'):. Either gives sequential execution and readable error messages — parallel execution can obscure the source of errors.
Performance: For small models (< 100 vertices), the overhead of pmap may outweigh the benefit. Use parallel='vmap' or let the auto-detection choose.
SLURM proxies: On some HPC systems, HTTP proxy variables can interfere with JAX distributed initialization. Phasic’s initialize_jax_distributed() temporarily unsets proxy variables during initialization.
Graph construction scaling: Graph construction is CPU-bound and single-threaded. For very large state spaces, consider using graph_cache=True to avoid rebuilding across sessions.
SLURM cell magic
This feature is part of the independent vscodenb library that contains various utilities for working with jupyter notebooks in vscode. Install like this:
If you use the “Remote Development” extension to work on notebooks on the frontend (or single node) on a SLURM cluster, you can use the %%slurm call magic to run a single cell as a SLURM job with as many resources as you need. Once completed, cell variables are created/updated and output appears in the cell output as if you had executed the cell locally.