Distributed computing

Phasic supports two layers of parallelism: multi-CPU compute (via JAX’s vmap/pmap for SVGD particles, and OpenMP for SCC-decomposed elimination) and multi-node distributed JAX on SLURM clusters. The parallelization strategy for SVGD depends on the number of available JAX devices:

Devices Strategy Description
>1 pmap Particles distributed across devices
1 (multi-CPU) vmap Vectorized computation on single device
1 (single-CPU) none Sequential execution

JAX is not imported eagerly by import phasic. It is loaded on first use — either when you access phasic.SVGD/phasic.MCMC/phasic.BFFG, when you call phasic.configure(compute='jax-cpu'), or when you call phasic.init_parallel(). This makes pure-C++ pipelines (no JAX dep) cheaper to import.

from phasic import (
    Graph, with_ipv, configure, get_config,
    init_parallel,
    detect_environment,
    EnvironmentInfo, ParallelConfig,
    set_log_level,
)
import numpy as np
import time, sys
from vscodenb import set_vscode_theme
set_vscode_theme()


# JAX is imported lazily on first use. To force activation now
# (e.g. so the next cell can call jax.devices()), call any of
# init_parallel(), phasic.configure(compute='jax-cpu'), or just
# access a JAX-dependent attribute like phasic.SVGD.
init_parallel()
import jax
[WARNING] phasic.auto_parallel: JAX already initialized with 1 device(s), but 10 CPU(s) available.
  To use all CPUs, restart kernel and import phasic first:
    import phasic as pta
    pta.init_parallel(cpus=10)

We use a simple parameterized coalescent model throughout:

nr_samples = 4

@with_ipv([nr_samples] + [0] * (nr_samples - 1))
def coalescent(state):
    transitions = []
    for i in range(state.size):
        for j in range(i, state.size):
            same = int(i == j)
            if same and state[i] < 2:
                continue
            if not same and (state[i] < 1 or state[j] < 1):
                continue
            new = state.copy()
            new[i] -= 1
            new[j] -= 1
            new[i + j + 1] += 1
            transitions.append((new, [state[i] * (state[j] - same) / (1 + same)])) 
    return transitions

graph = Graph(coalescent)


graph = Graph(coalescent)
graph.update_weights([7.0])
observed_data = graph.sample(500)

Automatic multi-CPU setup

JAX initialisation runs the first time it is needed. It performs the following steps:

  1. Detects available CPUs — on Apple Silicon, it uses performance cores; otherwise os.cpu_count() (and respects SLURM_CPUS_PER_TASK when present).
  2. Sets XLA_FLAGS — configures --xla_force_host_platform_device_count so JAX creates one virtual device per CPU.
  3. Enables 64-bit precision — sets JAX_ENABLE_X64=1 for numerical accuracy.
  4. Sets platform to CPU — via JAX_PLATFORMS=cpu.
  5. Imports JAX.

The trigger is one of:

  • Calling phasic.init_parallel() explicitly (recommended in notebooks where you want JAX ready early).
  • Calling phasic.configure(compute='jax-cpu') or compute='jax-gpu'.
  • Accessing any JAX-dependent attribute (phasic.SVGD, phasic.MCMC, Graph.svgd(...), etc.) — this lazily fires JAX activation.

You can check the current device setup after activation:

devices = jax.devices()
print(f"Number of JAX devices: {len(devices)}")
print(f"Device type: {devices[0].platform}")
Number of JAX devices: 1
Device type: cpu

Overriding CPU count

To change the number of CPUs phasic uses for JAX, set the PTDALG_CPUS environment variable before JAX activation (i.e. before any call that triggers _ensure_jax_active()):

export PTDALG_CPUS=4
python my_script.py

Or in a notebook (before any imports or JAX-touching code):

import os
os.environ['PTDALG_CPUS'] = '4'
from phasic import Graph, init_parallel
init_parallel()  # picks up PTDALG_CPUS=4

Inspecting the detected environment

The detect_environment() function returns an EnvironmentInfo object describing the current execution context:

env = detect_environment()
print(f"Environment type:  {env.env_type}")
print(f"Interactive:       {env.is_interactive}")
print(f"Available CPUs:    {env.available_cpus}")
print(f"SLURM detected:   {env.slurm_info is not None}")
print(f"JAX imported:      {env.jax_already_imported}")
Environment type:  jupyter
Interactive:       True
Available CPUs:    10
SLURM detected:   False
JAX imported:      True

Explicit initialization with init_parallel()

For more control, call init_parallel() to explicitly configure parallelism. This is especially useful when you want to set a specific CPU count or when running on a SLURM cluster where automatic detection should be triggered:

config = init_parallel()
print(f"Device count:       {config.device_count}")
print(f"Local device count: {config.local_device_count}")
print(f"Strategy:           {config.strategy}")
print(f"Environment:        {config.env_info.env_type}")
[WARNING] phasic.auto_parallel: JAX already initialized with 1 device(s), but 10 CPU(s) available.
  To use all CPUs, restart kernel and import phasic first:
    import phasic as pta
    pta.init_parallel(cpus=10)
Device count:       1
Local device count: 1
Strategy:           vmap
Environment:        jupyter

You can pass an explicit CPU count:

config = init_parallel(cpus=8)  # Use exactly 8 devices

The returned ParallelConfig is stored globally and used by graph.svgd() to select its parallelization strategy.

Parallelism in SVGD

The graph.svgd() method automatically parallelizes particle updates across available devices. The parallel parameter controls the strategy:

Value Behavior
None (default) Auto-select: pmap if multiple devices, vmap otherwise
'pmap' Distribute particles across devices (multiple CPUs/GPUs)
'vmap' Vectorize particles on a single device
'none' Sequential execution (useful for debugging)

With pmap, particles are split evenly across devices. Each device computes log-likelihood and gradients for its particles independently. The SVGD kernel computation and particle updates are also parallelized.

The n_devices parameter limits how many devices pmap uses (default: all available).

To explicitly control the strategy:

# Force pmap with specific device count
svgd = graph.svgd(
    observed_data,
    n_particles=50,
    n_iterations=50,
    parallel='pmap',
    n_devices=len(jax.devices()),
    progress=False
)
result_pmap = svgd.get_results()
print(f"pmap posterior mean: {result_pmap['theta_mean']}")

# Force vmap (single-device vectorization)
svgd = graph.svgd(
    observed_data,
    n_particles=50,
    n_iterations=50,
    parallel='vmap',
    progress=False
)
result_vmap = svgd.get_results()
print(f"vmap posterior mean: {result_vmap['theta_mean']}")
/Users/kmt/phasic/.pixi/envs/default/lib/python3.13/site-packages/phasic/__init__.py:7039: UserWarning: parallel='pmap' requested but only 1 JAX device available. Using 'vmap' instead. To use pmap, configure more devices via PTDALG_CPUS environment variable or initialize_distributed().
  svgd = SVGD(
pmap posterior mean: [7.06923219]
vmap posterior mean: [7.36114747]
Tip

Using the %%slurm magic from the vscodenb package, you can allocate SLURM redources for a single cell. Phasic will automatically pick up the allocated resources.

%%slurm -m 4G -c 20 -t 00:10:00 -A xy-drive

graph = Graph(coalescent)

# Simulate data
graph.update_weights([7.0])
observed_data = graph.sample(500)

# Auto-parallelized SVGD (uses pmap if multiple devices)
svgd = graph.svgd(
    observed_data,
    n_particles=50,
    n_iterations=50,
    progress=False
)
result = svgd.get_results()

Parallelism outside SVGD

Graph construction is inherently sequential (state-space exploration). Moment / expectation computation defaults to a single CPU: the monolithic parameterised reward compute graph (PRC) is built once with a single-threaded Gaussian elimination and cached on disk for subsequent runs.

For large graphs, opt into the hierarchical SCC composer to parallelise the elimination across CPUs. Sibling SCCs within a level run concurrently under OpenMP, and per-SCC PRC artefacts are content-hashed so identical SCCs across different parent graphs share cache entries:

import phasic

phasic.configure(
    parallel_elimination=True,
    parallel_elimination_max_concurrent=8,  # cap simultaneous SCC computes per level
)

# Or via env var (e.g. in a SLURM job script):
#   export PHASIC_HIERAR_ELIMINATION=1
#   export PHASIC_MAX_PARALLEL_SCCS=8

For evaluating a model at many parameter values (outside SVGD), use JAX’s vmap or pmap directly:

import jax
import jax.numpy as jnp

graph = Graph(coalescent)
model = graph.pmf_from_graph()

# Evaluate at many parameter values in parallel
theta_grid = jnp.linspace(0.5, 10.0, 100).reshape(-1, 1)
times = jnp.array([1.0, 2.0, 3.0])

batch_model = jax.vmap(lambda theta: model(theta, times))
all_pdfs = batch_model(theta_grid)  # Shape: (100, 3)

For multi-device parallelism, replace vmap with pmap and reshape the input to (n_devices, batch_per_device, ...).

Context managers

All phasic settings — including the SVGD-particle parallelism strategy (pmap / vmap / none) — can be changed temporarily for a block of code via phasic.configure(...) used as a context manager. The settings (and the corresponding env vars and dataclass fields) are restored at the end of the block. See the Configuration tutorial for the full surface.

For SVGD specifically, the relevant field is svgd_strategy:

  • 'auto' (default) — pick 'pmap' if multiple JAX devices are visible, else 'vmap'.
  • 'pmap' — distribute particles across devices.
  • 'vmap' — vectorise particles on a single device.
  • 'none' — sequential particle evaluation (debugging).

Graph.svgd()’s own parallel= kwarg overrides this on a per-call basis; the config field controls only the default when parallel=None.

Disabling parallelism

Use with configure(svgd_strategy='none'): to force sequential SVGD execution inside a block — useful for debugging or when you want predictable serial behaviour:

with configure(svgd_strategy='none'):
    svgd = graph.svgd(
        observed_data,
        n_particles=20,
        n_iterations=10,
        progress=False
    )
    result = svgd.get_results()
    print(f"Sequential posterior mean: {result['theta_mean']}")

# Outside the context, the default (auto / pmap / vmap) is restored.
print(f"Strategy outside block: {get_config().svgd_strategy}")
Sequential posterior mean: [7.08906581]
Strategy outside block: auto

Switching strategy for one block

Same pattern, any strategy:

# Force vmap (single-device vectorisation) for this block only.
with configure(svgd_strategy='vmap'):
    svgd = graph.svgd(
        observed_data,
        n_particles=20,
        n_iterations=10,
        progress=False
    )
    result = svgd.get_results()
    print(f"vmap posterior mean: {result['theta_mean']}")
vmap posterior mean: [7.23017806]

Temporary library-wide configuration

phasic.configure(...) is itself a context manager. Use it to toggle SCC-composer parallelism, MPFR precision, the cache directory, or any other PhasicConfig field for a single block of code. Settings (and the backing env vars) are restored when the block exits.

# Temporarily enable parallel SCC elimination just for this block.
# Outside the `with`, the previous configuration is restored.
with configure(parallel_elimination=True,
               parallel_elimination_max_concurrent=4):
    # Any moment / expectation call here uses the SCC composer.
    graph_big = Graph(coalescent)
    graph_big.update_weights([7.0])
    print('expectation (SCC composer):', graph_big.expectation())

# Outside: parallel_elimination is back to its prior value.
print('parallel_elimination after block:',
      get_config().parallel_elimination)
expectation (SCC composer): 0.2142857142857143
parallel_elimination after block: False

SLURM clusters

Phasic detects SLURM environments automatically and configures JAX accordingly. There are two modes:

Mode SLURM config Use case
Single-node --cpus-per-task=N Multiple CPUs on one machine
Multi-node --nodes=N --ntasks-per-node=1 Distribute across machines

Single-node SLURM

For single-node jobs, phasic reads SLURM_CPUS_PER_TASK and creates that many JAX devices. The SLURM script is straightforward:

#!/bin/bash
#SBATCH --job-name=svgd_inference
#SBATCH --cpus-per-task=16
#SBATCH --mem-per-cpu=4G
#SBATCH --time=01:00:00

python my_inference.py

The Python script needs no special configuration — init_parallel() detects the SLURM allocation:

from phasic import Graph, init_parallel

config = init_parallel()  # Detects 16 CPUs from SLURM
print(f"Using {config.device_count} devices")  # 16

graph = Graph(my_model)
graph.update_weights([7.0])
observed_data = graph.sample(1000)

# SVGD automatically uses pmap across 16 devices
result = graph.svgd(observed_data, n_particles=160)

Multi-node SLURM

For multi-node jobs, phasic uses jax.distributed.initialize() to coordinate computation across machines. The SLURM batch script sets up the coordinator address and launches one process per node via srun:

#!/bin/bash
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=8
#SBATCH --time=01:00:00

# First node becomes the coordinator
COORDINATOR_NODE=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export SLURM_COORDINATOR_ADDRESS=$COORDINATOR_NODE
export JAX_COORDINATOR_PORT=12345

# Configure XLA for local CPUs
export XLA_FLAGS="--xla_force_host_platform_device_count=$SLURM_CPUS_PER_TASK"
export JAX_PLATFORMS=cpu
export JAX_ENABLE_X64=1

# Launch one process per node
srun python my_distributed_inference.py

With 4 nodes x 8 CPUs each, this creates 32 global JAX devices.

Multi-node Python script

The Python script uses phasic’s distributed utilities to initialize JAX across nodes:

from phasic import (
    Graph, init_parallel,
    detect_slurm_environment,
    get_coordinator_address,
    initialize_jax_distributed
)
import jax

# Step 1: Detect SLURM environment
slurm_env = detect_slurm_environment()
print(f"Process {slurm_env['process_id']}/{slurm_env['num_processes']}")
print(f"CPUs per task: {slurm_env['cpus_per_task']}")
print(f"Nodes: {slurm_env['node_count']}")

# Step 2: Get coordinator address (first node)
coordinator = get_coordinator_address(slurm_env)
print(f"Coordinator: {coordinator}")

# Step 3: Initialize JAX distributed
initialize_jax_distributed(
    coordinator_address=coordinator,
    num_processes=slurm_env['num_processes'],
    process_id=slurm_env['process_id']
)

print(f"Local devices:  {len(jax.local_devices())}")
print(f"Global devices: {len(jax.devices())}")

# Step 4: Build model and run SVGD
graph = Graph(my_model)
graph.update_weights([7.0])
observed_data = graph.sample(1000)

# SVGD distributes particles across all 32 global devices
result = graph.svgd(
    observed_data,
    n_particles=320,    # Must be divisible by global device count
    parallel='pmap'
)

SLURM detection API

The detect_slurm_environment() function parses SLURM environment variables and returns a dictionary with the allocation details:

from phasic import detect_slurm_environment

slurm_env = detect_slurm_environment()
print(f"Running under SLURM: {slurm_env['is_slurm']}")

if slurm_env['is_slurm']:
    print(f"Job ID:           {slurm_env['job_id']}")
    print(f"Process ID:       {slurm_env['process_id']}")
    print(f"Total processes:  {slurm_env['num_processes']}")
    print(f"CPUs per task:    {slurm_env['cpus_per_task']}")
    print(f"Nodes:            {slurm_env['node_count']}")
Running under SLURM: False

Parallelism and graph construction

Graph construction explores the state space by calling the user-provided callback function repeatedly. This is inherently sequential — there is no automatic parallelization of graph construction.

However, if you need to build multiple graphs (e.g., for different sample sizes), you can parallelize at the Python level:

from concurrent.futures import ProcessPoolExecutor

def build_graph(n):
    @with_ipv([n] + [0] * (n - 1))
    def coalescent(state):
        # ... callback ...
        return transitions
    return Graph(coalescent, graph_cache=True)

# Build graphs for different sample sizes in parallel
with ProcessPoolExecutor(max_workers=4) as pool:
    graphs = list(pool.map(build_graph, [4, 6, 8, 10]))

Using graph_cache=True ensures that each graph is built only once and loaded from cache on subsequent runs.

Parallelism and parameterised PRC evaluation

Gaussian elimination is O(n³) and, by default, runs as a single dense computation on one CPU — producing a parameterised reward compute graph (PRC) that is theta-independent and disk-cached at ~/.phasic_cache/parameterized_reward_compute/. Once the PRC is built, evaluating it for a specific parameter vector is O(n) and can run in parallel across particles using JAX’s vmap or pmap.

For large graphs, you can also parallelise the elimination itself. Setting phasic.configure(parallel_elimination=True) (or PHASIC_HIERAR_ELIMINATION=1 in the shell) routes the elimination through the hierarchical SCC composer: the graph is decomposed into strongly-connected components and sibling SCCs within a level are eliminated concurrently under OpenMP. Per-SCC artefacts are content-hashed and shared across graphs and processes, so they can also be precomputed on a SLURM cluster via phasic.distributed_scc.precompute_distributed.

from phasic import Graph, configure

configure(parallel_elimination=True)

def coalescent(state):
    transitions = []
    for i in range(state.size):
        for j in range(i, state.size):
            same = int(i == j)
            if same and state[i] < 2:
                continue
            if not same and (state[i] < 1 or state[j] < 1):
                continue
            new = state.copy()
            new[i] -= 1
            new[j] -= 1
            new[i + j + 1] += 1
            transitions.append((new, [state[i] * (state[j] - same) / (1 + same)]))
    return transitions

nr_samples = 30
graph = Graph(coalescent, ipv=[nr_samples] + [0] * (nr_samples - 1))
graph.update_weights([7.0])
graph.expectation()
0.2761904761904762

When you call graph.svgd(), phasic automatically:

  1. Builds the PRC once (sequential, or per-SCC parallel if the hierarchical composer is enabled)
  2. Creates a JIT-compiled function that evaluates the PRC (O(n) per evaluation)
  3. Uses pmap or vmap to evaluate this function across particles in parallel

This means the expensive elimination happens only once per structural model, and all subsequent evaluations during SVGD iterations benefit from parallelism.

Computing expectations in parallel

Moment and expectation computations (graph.moments(), graph.expected_waiting_time()) operate on a single graph with fixed parameters and are not parallelized.

To evaluate expectations at many parameter values in parallel, use the model function with vmap:

# Create parameterized model
model = graph.pmf_and_moments_from_graph(nr_moments=2)

# Evaluate at 100 parameter values in parallel
theta_values = jnp.linspace(0.5, 10.0, 100).reshape(-1, 1)
times = jnp.array([1.0, 2.0, 3.0])

def eval_at_theta(theta):
    pmf, moments = model(theta, times)
    return moments

all_moments = jax.vmap(eval_at_theta)(theta_values)  # (100, 2)

Environment variables

Phasic recognizes these environment variables. The PHASIC_* ones are also reachable via phasic.configure(...); the JAX_* / XLA_* / SLURM_* ones are read but not written by configure.

Variable Description
PTDALG_CPUS Override the number of CPUs JAX uses (read at JAX activation time)
OMP_NUM_THREADS OpenMP thread count for SCC composer (auto-set by phasic at import if unset)
PHASIC_HIERAR_ELIMINATION 1 enables the hierarchical SCC composer (equivalent to configure(parallel_elimination=True))
PHASIC_MAX_PARALLEL_SCCS Cap simultaneous SCC computes per level (equivalent to parallel_elimination_max_concurrent)
PHASIC_MIN_SCC_SIZE_TO_CACHE Threshold for caching per-SCC PRCs (equivalent to parallel_elimination_min_subgraph)
PHASIC_CACHE_DIR Cache root directory (equivalent to cache_dir)
PHASIC_DISABLE_CACHE 1 disables on-disk caching (equivalent to cache_enabled=False)
XLA_FLAGS JAX/XLA flags; phasic sets --xla_force_host_platform_device_count on JAX activation
JAX_PLATFORMS Platform selection; phasic sets to cpu by default on JAX activation
JAX_ENABLE_X64 Enable 64-bit precision; phasic enables on JAX activation
SLURM_CPUS_PER_TASK Read by phasic on SLURM clusters to set OMP and JAX device count
SLURM_COORDINATOR_ADDRESS Manual override for coordinator node address (multi-node JAX)
JAX_COORDINATOR_PORT Port for JAX distributed coordinator (default: 12345)

Tips and troubleshooting

Particle count: When using pmap, the number of particles must be divisible by the number of devices. If not, SVGD will raise a ValueError.

JAX import timing: import phasic no longer imports JAX. The first JAX-touching operation (any of init_parallel(), phasic.configure(compute='jax-*'), phasic.SVGD, Graph.svgd(...), etc.) triggers activation. If you want JAX ready before timing-sensitive code, call phasic.init_parallel() near the top of your script or notebook.

Debugging: Pass parallel='none' to Graph.svgd() directly, or wrap the call in with phasic.configure(svgd_strategy='none'):. Either gives sequential execution and readable error messages — parallel execution can obscure the source of errors.

Performance: For small models (< 100 vertices), the overhead of pmap may outweigh the benefit. Use parallel='vmap' or let the auto-detection choose.

SLURM proxies: On some HPC systems, HTTP proxy variables can interfere with JAX distributed initialization. Phasic’s initialize_jax_distributed() temporarily unsets proxy variables during initialization.

Graph construction scaling: Graph construction is CPU-bound and single-threaded. For very large state spaces, consider using graph_cache=True to avoid rebuilding across sessions.

SLURM cell magic

This feature is part of the independent vscodenb library that contains various utilities for working with jupyter notebooks in vscode. Install like this:

pixi workspace channel add munch-group
pixi install vscodenb
conda install -c munch-group -c conda-forge vscodenb
pip install vscodenb

SLURM cell magic with vscodenb

If you use the “Remote Development” extension to work on notebooks on the frontend (or single node) on a SLURM cluster, you can use the %%slurm call magic to run a single cell as a SLURM job with as many resources as you need. Once completed, cell variables are created/updated and output appears in the cell output as if you had executed the cell locally.

%%slurm -m 4G -c 20 -t 00:10:00 -A xy-drive

svgd = graph.svgd(
    observed_data,
    n_particles=100,
    n_iterations=200,
    progress=False
)

This makes the results available to subsequent cells:

result = svgd.get_results()
print(f"Posterior mean: {result['theta_mean']}")
print(f"Posterior std:  {result['theta_std']}")

See the documentation for details on options etc.