27 JAX Integration
Introduction
This file formalizes the integration of phasic’s C++ graph algorithms with the JAX automatic differentiation framework (Bradbury et al. 2018). The core challenge is that phasic’s distribution computations (PDF, PMF, moments) are implemented in C++ for performance, but JAX requires all operations in a computation graph to be JAX-traceable for transformations such as jit (compilation), grad (differentiation), vmap (batching), and pmap (parallelization). The solution uses two complementary mechanisms:
JAX Foreign Function Interface (FFI): C++ computations are registered as XLA custom calls, enabling zero-copy data transfer and OpenMP-parallelized batching. The graph structure (topology and coefficients) is passed as a static string attribute, while the parameter vector \boldsymbol{\theta} and evaluation points \mathbf{t} are passed as dynamic XLA buffers that JAX can trace and batch.
Custom Vector-Jacobian Products (VJPs): Since C++ functions are opaque to JAX’s autodiff (Baydin et al. 2018), gradient computation is enabled by registering custom VJP rules. The forward pass calls the C++ implementation; the backward pass computes \nabla_{\boldsymbol{\theta}} f via finite differences through the same C++ code path, or (for the C-level forward algorithm) via the exact gradient computation formalized in [16].
These mechanisms allow phasic models to participate in JAX-based inference workflows (SVGD [18], MCMC [19]) while retaining the performance of C++ graph algorithms.
Prerequisites: [01], [09], [11], [14], [16]
Source files:
src/phasic/ffi_wrappers.py(functions:compute_pmf_ffi,compute_moments_ffi,compute_pmf_and_moments_ffi,_register_ffi_targets)src/phasic/__init__.py(functions:pmf_from_graph,pmf_and_moments_from_graph, custom VJP definitions)
Definitions
Recall (Definition 13.2). A trace \mathcal{T} is a sequence of linear operations recorded during graph elimination, enabling fast re-evaluation with different parameter vectors.
Recall (Definition 16.1). Uniformization of a CPH graph with granularity g produces the sub-transition matrix \mathbf{T}_g = \mathbf{I} + \frac{1}{g}\mathbf{S}, connecting discrete steps to continuous time via the Poisson distribution.
Recall (Definition 11.1). A symbolic expression represents an edge weight as a function of the parameter vector \boldsymbol{\theta} with stored coefficients.
Definition 27.1 (JAX Pure Callback) Let h : \mathbb{R}^d \times \mathbb{R}^n \to \mathbb{R}^n be a function implemented in C++ that computes the PDF or PMF of a phase-type distribution:
h(\boldsymbol{\theta}, \mathbf{t}) = (f(t_1; \boldsymbol{\theta}), f(t_2; \boldsymbol{\theta}), \ldots, f(t_n; \boldsymbol{\theta})), \tag{27.1}
where f(t; \boldsymbol{\theta}) is the PDF (or PMF for DPH) of the phase-type distribution parameterized by \boldsymbol{\theta}. A JAX pure callback wraps h as a JAX-traceable operation \tilde{h} such that:
- \tilde{h}(\boldsymbol{\theta}, \mathbf{t}) = h(\boldsymbol{\theta}, \mathbf{t}) for all inputs (semantic equivalence),
- \tilde{h} is compatible with
jax.jit(the callback is invoked at runtime, not trace time), - \tilde{h} is compatible with
jax.vmapvia sequential or batched evaluation (thevmap_methodparameter controls whether batches are processed sequentially or in parallel via OpenMP).
The callback converts JAX arrays to NumPy arrays at the boundary, calls the C++ implementation, and converts results back to JAX arrays.
Intuition
A pure callback is a “black box” from JAX’s perspective: JAX knows the input and output shapes and dtypes but cannot trace into the function body. This is sufficient forjit and vmap but not for grad, which requires either a custom VJP rule or a traceable implementation. The term “pure” indicates the callback has no side effects and always returns the same output for the same input.
Example
The functioncompute_pmf_ffi wraps the C++ GraphBuilder::compute_pmf method. Given a serialized graph structure J (a JSON string), it computes h(\boldsymbol{\theta}, \mathbf{t}) by: (i) constructing a GraphBuilder from J, (ii) calling build(theta) to instantiate a concrete graph, and (iii) evaluating the forward algorithm (Algorithm 16.2) at each time point.
Definition 27.2 (Custom VJP for Phase-Type PDF) Let h(\boldsymbol{\theta}, \mathbf{t}) be as in Definition 27.1. A custom VJP for h consists of two functions:
The forward function h_{\mathrm{fwd}} : (\boldsymbol{\theta}, \mathbf{t}) \mapsto (h(\boldsymbol{\theta}, \mathbf{t}), (\boldsymbol{\theta}, \mathbf{t})), which computes the output and saves residuals (the inputs) for use in the backward pass.
The backward function h_{\mathrm{bwd}} : ((\boldsymbol{\theta}, \mathbf{t}), \bar{\mathbf{y}}) \mapsto (\bar{\boldsymbol{\theta}}, \mathbf{0}), which computes the vector-Jacobian product. Given the upstream gradient \bar{\mathbf{y}} \in \mathbb{R}^n (the cotangent vector), it returns:
\bar{\theta}_j = \bar{\mathbf{y}}^\top \frac{\partial h}{\partial \theta_j}(\boldsymbol{\theta}, \mathbf{t}) = \sum_{i=1}^{n} \bar{y}_i \frac{\partial f(t_i; \boldsymbol{\theta})}{\partial \theta_j}, \quad j = 1, \ldots, d. \tag{27.2}
The partial derivatives \frac{\partial f(t_i; \boldsymbol{\theta})}{\partial \theta_j} are computed via central finite differences (Nocedal and Wright 2006):
\frac{\partial f(t_i; \boldsymbol{\theta})}{\partial \theta_j} \approx \frac{h(\boldsymbol{\theta} + \epsilon \mathbf{e}_j, \mathbf{t})_i - h(\boldsymbol{\theta} - \epsilon \mathbf{e}_j, \mathbf{t})_i}{2\epsilon}, \tag{27.3}
with step size \epsilon = 10^{-7}. The gradient with respect to \mathbf{t} is set to zero (time points are not differentiated).
Intuition
The VJP computes how a scalar loss L = L(h(\boldsymbol{\theta}, \mathbf{t})) changes with each parameter \theta_j, given the chain rule factor \bar{\mathbf{y}} = \nabla_{\mathbf{y}} L. Finite differences require 2d evaluations of the full C++ forward algorithm (two per parameter), making the cost O(d) times the cost of a single forward pass. For models with few parameters (typical: d \leq 10) this is efficient; for high-dimensional \boldsymbol{\theta}, the exact C gradient from [16] is preferred.Example
For SVGD with a two-parameter coalescent model (d = 2), the backward pass evaluates the C++ forward algorithm 4 times (two perturbations per parameter). With a 67-vertex graph and 100 time points, this takes approximately 20 ms, well within SVGD iteration budgets.Definition 27.3 (Static vs. Dynamic Arguments) In the JAX FFI interface for phase-type computation, arguments are classified as:
Static arguments: the graph structure JSON string J, the
discreteflag, thegranularityparameter, andnr_moments. These are passed as XLA attributes and are fixed at JIT compilation time. Changing a static argument triggers recompilation.Dynamic arguments: the parameter vector \boldsymbol{\theta} \in \mathbb{R}^d, the time points \mathbf{t} \in \mathbb{R}^n, and optional reward vectors \mathbf{r}. These are passed as XLA buffers and can vary across calls without recompilation. JAX’s
vmapadds a batch dimension to dynamic arguments.
Formally, the FFI call has signature:
\texttt{ffi\_call}(J, \texttt{granularity}, \texttt{discrete}; \boldsymbol{\theta}, \mathbf{t}) \mapsto \mathbf{y} \in \mathbb{R}^n, \tag{27.4}
where arguments before the semicolon are static (attributes) and arguments after are dynamic (buffers).
Intuition
The static/dynamic distinction reflects the separation of structure from parameters that is central to phasic’s design. The graph topology and edge coefficients (encoded in J) change rarely (once per model), while the parameter vector \boldsymbol{\theta} changes at every SVGD or MCMC iteration. By marking J as static, JAX compiles specialized code for each graph structure, and the C++GraphBuilder can cache the parsed JSON across calls.
Example
Injax.jit(compute_pmf_ffi, static_argnums=(0, 3, 4)), argument 0 (structure_json), 3 (discrete), and 4 (granularity) are static. The JIT-compiled function accepts different \boldsymbol{\theta} and \mathbf{t} values without recompilation.
Theorems and Proofs
Theorem 27.1 (Pure Callback Preserves Semantics) Let G = (V, E, W) be a parameterized phase-type graph with serialized structure J = \operatorname{serialize}(G). Let h_{\mathrm{C++}}(\boldsymbol{\theta}, \mathbf{t}) denote the C++ implementation of PDF computation via the forward algorithm (Algorithm 16.2), and let \tilde{h}(\boldsymbol{\theta}, \mathbf{t}) denote the JAX pure callback wrapper (Definition 27.1). Then
\tilde{h}(\boldsymbol{\theta}, \mathbf{t}) = h_{\mathrm{C++}}(\boldsymbol{\theta}, \mathbf{t}) \tag{27.5}
for all valid \boldsymbol{\theta} \in \mathbb{R}^d_{>0} and \mathbf{t} \in \mathbb{R}^n_{\geq 0}, up to floating-point representation (i.e., bitwise identical results).
Proof. The pure callback wrapper performs three operations: (1) convert JAX arrays to NumPy arrays via np.asarray(), (2) call the C++ implementation, and (3) return the result as a JAX array. Step (1) is a zero-copy view when the JAX array is backed by a CPU buffer (which is the case for platform="cpu"), so no numerical transformation occurs. Step (2) invokes h_{\mathrm{C++}} with identical data. Step (3) wraps the result buffer without copying. Since no arithmetic is performed by the wrapper, the output is bitwise identical to the C++ result.
For the FFI path (as opposed to pure_callback), the argument is analogous: the FFI handler receives pointers to XLA buffers containing the same IEEE 754 double-precision values, passes them to the same C++ GraphBuilder code, and writes results directly into the output XLA buffer. The data path is: XLA buffer \to C++ pointer \to GraphBuilder::compute_pmf \to output pointer \to XLA buffer. No intermediate conversion alters the values. \square
Theorem 27.2 (Custom VJP Computes Correct Gradients) Let h(\boldsymbol{\theta}, \mathbf{t}) be the phase-type PDF function (Definition 27.1), and let h_{\mathrm{bwd}} be the backward function (Definition 27.2) using central finite differences with step size \epsilon. Then for each j = 1, \ldots, d:
\left| \bar{\theta}_j - \bar{\mathbf{y}}^\top \frac{\partial h}{\partial \theta_j}(\boldsymbol{\theta}, \mathbf{t}) \right| \leq \frac{\epsilon^2}{6} \sum_{i=1}^{n} |\bar{y}_i| \cdot \max_{\xi \in [\theta_j - \epsilon, \theta_j + \epsilon]} \left| \frac{\partial^3 f(t_i; \boldsymbol{\theta})}{\partial \theta_j^3} \right| + O(\epsilon_{\mathrm{mach}} / \epsilon), \tag{27.6}
where \epsilon_{\mathrm{mach}} \approx 2.2 \times 10^{-16} is machine epsilon. With \epsilon = 10^{-7}, the finite difference error is O(10^{-14}) and the cancellation error is O(10^{-9}), giving a total error of O(10^{-9}) per gradient component.
Proof. Central finite differences have truncation error O(\epsilon^2). By Taylor expansion of f(t_i; \boldsymbol{\theta} \pm \epsilon \mathbf{e}_j) around \boldsymbol{\theta}:
\frac{f(t_i; \boldsymbol{\theta} + \epsilon \mathbf{e}_j) - f(t_i; \boldsymbol{\theta} - \epsilon \mathbf{e}_j)}{2\epsilon} = \frac{\partial f}{\partial \theta_j}(t_i; \boldsymbol{\theta}) + \frac{\epsilon^2}{6} \frac{\partial^3 f}{\partial \theta_j^3}(t_i; \boldsymbol{\xi})
for some \boldsymbol{\xi} between \boldsymbol{\theta} - \epsilon \mathbf{e}_j and \boldsymbol{\theta} + \epsilon \mathbf{e}_j (by the mean value theorem applied to the Taylor remainder). The truncation error per component is bounded by \frac{\epsilon^2}{6} |f'''|.
The cancellation error arises from subtracting two nearly equal floating-point numbers: |f(\boldsymbol{\theta} + \epsilon \mathbf{e}_j) - f(\boldsymbol{\theta} - \epsilon \mathbf{e}_j)| \approx 2\epsilon |f'|, so the relative error from cancellation is \epsilon_{\mathrm{mach}} |f| / (2\epsilon |f'|), giving absolute error O(\epsilon_{\mathrm{mach}} / \epsilon).
The VJP \bar{\theta}_j = \sum_i \bar{y}_i \cdot \widehat{\partial f / \partial \theta_j} accumulates these per-component errors with weights \bar{y}_i. The bound (Equation 27.6) follows by the triangle inequality. For phase-type PDFs with bounded third derivatives (which holds for all finite-state models with bounded rates), the O(\epsilon^2) truncation term dominates for \epsilon = 10^{-7}, giving total error O(10^{-9}) per gradient component. \square
Algorithms
27.0.0.1 JAX-Compatible PMF Computation
Description. This algorithm describes the end-to-end computation of phase-type PDF/PMF values through the JAX FFI interface. It separates the one-time setup (FFI registration, JSON parsing) from the per-call computation (graph construction, forward algorithm). The algorithm supports three JAX transformations: jit (static arguments trigger recompilation), vmap (batch dimension added to dynamic arguments, processed in parallel via OpenMP), and grad (custom VJP invoked for backward pass).
JAX-Compatible PMF Computation
1: Let J be a JSON string encoding graph structure (static)
2: Let theta = (theta_1, ..., theta_d) be the parameter vector (dynamic)
3: Let t = (t_1, ..., t_n) be evaluation points (dynamic)
4: Let discrete be a boolean flag and g the granularity (static)
5:
6: function ComputePMF_FFI(J, theta, t, discrete, g)
7: ▷ Phase 1: Ensure FFI targets are registered (once per process)
8: if not FFI_REGISTERED then
9: capsule ← cpp_module.get_compute_pmf_ffi_capsule()
10: jax.ffi.register_ffi_target("ptd_compute_pmf", capsule)
11: FFI_REGISTERED ← true
12: end if
13:
14: ▷ Phase 2: Construct FFI call with static/dynamic separation
15: ffi_fn ← jax.ffi.ffi_call("ptd_compute_pmf",
16: output_shape = (n,),
17: vmap_method = "expand_dims")
18:
19: ▷ Phase 3: Invoke FFI (dispatches to C++ handler)
20: y ← ffi_fn(theta, t,
21: structure_json = J, ▷ Static attribute
22: granularity = g, ▷ Static attribute
23: discrete = discrete) ▷ Static attribute
24:
25: return y ▷ y_i = f(t_i; theta)
26: end function
27:
28: ▷ C++ FFI Handler (executed by XLA runtime)
29: function FFI_Handler(theta_buf, times_buf; J, g, discrete)
30: builder ← GraphBuilder(J) ▷ Parse JSON, cache structure
31: graph ← builder.build(theta_buf) ▷ Instantiate with parameters
32: for i ← 1 to n do
33: if discrete then
34: y_i ← graph.dph_pmf(t_i) ▷ DPH PMF (DPH Forward Step)
35: else
36: y_i ← graph.pdf(t_i, g) ▷ CPH PDF (CPH Uniformization)
37: end if
38: end for
39: return (y_1, ..., y_n)
40: end function
41:
42: ▷ Custom VJP for gradient computation
43: function ComputePMF_fwd(theta, t)
44: y ← ComputePMF_FFI(J, theta, t, discrete, g)
45: return (y, (theta, t)) ▷ Save residuals
46: end function
47:
48: function ComputePMF_bwd((theta, t), y_bar)
49: epsilon ← 10^{-7}
50: theta_bar ← zero vector of length d
51: for j ← 1 to d do
52: e_j ← standard basis vector j
53: y_plus ← ComputePMF_FFI(J, theta + epsilon * e_j, t, discrete, g)
54: y_minus ← ComputePMF_FFI(J, theta - epsilon * e_j, t, discrete, g)
55: grad_j ← (y_plus - y_minus) / (2 * epsilon) ▷ Central differences
56: theta_bar_j ← y_bar^T * grad_j ▷ VJP contraction
57: end for
58: return (theta_bar, 0) ▷ No gradient w.r.t. t
59: end function
Correspondence table:
| Pseudocode variable | Math symbol | Code variable (file:function) |
|---|---|---|
J |
J | structure_json (ffi_wrappers.py:compute_pmf_ffi) |
theta |
\boldsymbol{\theta} | theta (ffi_wrappers.py:compute_pmf_ffi) |
t |
\mathbf{t} | times (ffi_wrappers.py:compute_pmf_ffi) |
y |
\mathbf{y} = h(\boldsymbol{\theta}, \mathbf{t}) | result (ffi_wrappers.py:compute_pmf_ffi) |
builder |
— | builder (graph_builder.cpp:GraphBuilder) |
graph |
G(\boldsymbol{\theta}) | g (graph_builder.cpp:build) |
y_bar |
\bar{\mathbf{y}} | g (cotangent in __init__.py:jax_model_bwd) |
theta_bar |
\bar{\boldsymbol{\theta}} | theta_bar (__init__.py:jax_model_bwd) |
epsilon |
\epsilon | eps (__init__.py:jax_model_bwd) |
grad_j |
\partial h / \partial \theta_j | grad_i (__init__.py:jax_model_bwd) |
FFI_REGISTERED |
— | _FFI_REGISTERED (ffi_wrappers.py) |
ffi_fn |
— | ffi_fn (ffi_wrappers.py:compute_pmf_ffi) |
Complexity. Let p = |V| (vertices), m = |E| (edges), n = |\mathbf{t}| (evaluation points), d = |\boldsymbol{\theta}| (parameters), and K the number of uniformization steps (typically K = O(g \cdot t_{\max})).
- Forward pass: Time O(n \cdot K \cdot m) (one forward algorithm per time point). Space O(p + m).
- Backward pass (finite differences): Time O(d \cdot n \cdot K \cdot m) (2d forward evaluations). Space O(p + m + n).
- FFI registration: Time O(1) (amortized, once per process).
- JSON parsing: Time O(|J|) per
GraphBuilderconstruction.
vmap transformation adds a batch dimension processed by the FFI handler’s OpenMP loop, producing identical results to sequential evaluation (floating-point associativity may cause O(\epsilon_{\mathrm{mach}}) differences).
Numerical Considerations
Finite difference step size. The step size \epsilon = 10^{-7} is chosen to balance truncation error (O(\epsilon^2) \approx 10^{-14}) against cancellation error (O(\epsilon_{\mathrm{mach}} / \epsilon) \approx 10^{-9}). For parameters \theta_j of very different magnitudes, relative step sizes \epsilon_j = \epsilon \cdot |\theta_j| would be more appropriate, but the current implementation uses a fixed absolute step size. This is adequate for phasic’s typical use cases where parameters are O(1).
Float64 enforcement. All FFI calls enforce float64 output dtype. Phase-type PDF values can span many orders of magnitude (especially in the tails), and float32 precision is insufficient for reliable likelihood computation in SVGD.
GIL management. The C++ FFI handler releases the Python GIL during computation, enabling true multi-threaded parallelism when vmap dispatches multiple parameter vectors to OpenMP threads. This is critical for SVGD, where 100+ particles are evaluated per iteration.
Implementation Notes
Source code mapping:
| Algorithm | File | Function | Lines |
|---|---|---|---|
| Algorithm 27.1 (FFI path) | src/phasic/ffi_wrappers.py |
compute_pmf_ffi |
L496–L573 |
| Algorithm 27.1 (registration) | src/phasic/ffi_wrappers.py |
_register_ffi_targets |
L158–L298 |
| Algorithm 27.1 (VJP) | src/phasic/__init__.py |
jax_model_fwd, jax_model_bwd |
L3861–L3883 |
| Algorithm 27.1 (C++ handler) | src/cpp/parameterized/graph_builder.cpp |
GraphBuilder::compute_pmf |
L309–L361 |
Deviations from pseudocode:
- The pseudocode shows a single
ComputePMF_FFIfunction; the implementation has two code paths: the FFI path (usingjax.ffi.ffi_callwith XLA capsules) and a commented-out fallback path (usingjax.pure_callback). The FFI path is the production path. - The VJP backward pass in the pseudocode uses a loop over parameters; the implementation accumulates gradients into a list and converts to a JAX array at the end.
- The implementation supports both
pmf_from_graph(which usespure_callbackwith VJP) andcompute_pmf_ffi(which uses direct FFI). The custom VJP is defined on thepmf_from_graphpath; the FFI path currently relies on JAX’s numerical differentiation or external gradient computation.
Symbol Index
| Symbol | Name | First appearance |
|---|---|---|
| h(\boldsymbol{\theta}, \mathbf{t}) | C++ PDF/PMF computation function | Definition 27.1 |
| \tilde{h} | JAX-wrapped version of h | Definition 27.1 |
| h_{\mathrm{fwd}} | Custom VJP forward function | Definition 27.2 |
| h_{\mathrm{bwd}} | Custom VJP backward function | Definition 27.2 |
| \bar{\mathbf{y}} | Upstream cotangent vector | Definition 27.2 |
| \bar{\boldsymbol{\theta}} | Parameter cotangent (VJP output) | Definition 27.2 |
| \epsilon | Finite difference step size | Definition 27.2 |
| J | Serialized graph structure (JSON) | Definition 27.3 |