22 Backward Filtering Forward Guiding
Introduction
This file formalizes Backward Filtering Forward Guiding [BFFG; Schauer et al. (2017)], an importance-weighting framework for inference with phase-type distributions under inhomogeneous (time-varying) parameters. The core problem is computing the likelihood p(\mathbf{y} \mid \boldsymbol{\theta}_{\text{target}}) when the target parameter vector varies over time, making direct forward-algorithm computation intractable. BFFG addresses this by sampling paths under a tractable homogeneous proposal model (with fixed parameters \boldsymbol{\theta}_{\text{prop}}) and reweighting those paths to account for the target model. The importance weights are derived from the ratio of path densities under the target and proposal continuous-time Markov chains.
In the phasic pipeline, BFFG enables Bayesian inference for models where transition rates depend on time (e.g., population-size changes affecting coalescent rates). The proposal model is evaluated via the fast forward algorithm and trace cache ([14], [11]), while the importance weight correction is computed from sampled paths ([17]). The combined log-probability feeds into MCMC ([19]) or SVGD ([18]) for posterior inference.
Prerequisites: [01], [14], [17]
Source files:
src/phasic/bffg.py(functions:path_to_rewards,path_exit_rates,path_exit_rates_by_param,importance_log_weight_from_rates,importance_weighted_log_likelihood,bffg_log_prob)
Definitions
Definition 22.1 (Proposal and Target Distributions) Let G = (V, E) be a parameterized phase-type graph with parameter vector \boldsymbol{\theta} \in \mathbb{R}^d_{>0} and coefficient vectors \mathbf{c}_e \in \mathbb{R}^d on each edge e \in E. The proposal distribution is the phase-type distribution \operatorname{PH}(\boldsymbol{\alpha}, \mathbf{S}(\boldsymbol{\theta}_{\text{prop}})) with fixed parameter vector \boldsymbol{\theta}_{\text{prop}}, where each edge weight is w(e) = \mathbf{c}_e^\top \boldsymbol{\theta}_{\text{prop}}. The target distribution is determined by a time-varying parameter function \boldsymbol{\theta}_{\text{target}}: \mathbb{R}_{\geq 0} \to \mathbb{R}^d_{>0}, so that at time t the edge weight under the target is w_{\text{target}}(e, t) = \mathbf{c}_e^\top \boldsymbol{\theta}_{\text{target}}(t).
Intuition
The proposal model is a standard homogeneous phase-type distribution that can be efficiently evaluated using the forward algorithm. The target model has time-varying rates, which prevents direct use of the forward algorithm (since \mathbf{S} changes over time). By sampling paths from the proposal and reweighting, we avoid the need to solve the time-inhomogeneous Kolmogorov equations directly.Example
In a coalescent model with 5 lineages, the coalescent rate at state n is \binom{n}{2} \theta^{-1} where \theta is the effective population size. The proposal uses a fixed \theta_{\text{prop}}, while the target has \theta_{\text{target}}(t) varying (e.g., an exponential growth model). Each edge coefficient encodes \binom{n}{2}, and the parameter \theta^{-1} gives the rate scaling.Definition 22.2 (Path Exit Rate Decomposition) Let \omega = (v_0, v_1, \ldots, v_K, v_{\text{abs}}) be a path through the graph with sojourn times s_1, \ldots, s_K at the transient vertices v_1, \ldots, v_K. The exit rate at vertex v_k under parameter vector \boldsymbol{\theta} is
\lambda_{v_k}(\boldsymbol{\theta}) = \sum_{e \in \operatorname{out}(v_k)} \mathbf{c}_e^\top \boldsymbol{\theta}, \tag{22.1}
where \operatorname{out}(v_k) denotes the set of outgoing edges from v_k. This decomposes per parameter as
\lambda_{v_k}(\boldsymbol{\theta}) = \sum_{j=1}^{d} \theta_j \sum_{e \in \operatorname{out}(v_k)} c_{e,j}, \tag{22.2}
where c_{e,j} is the j-th coefficient of edge e.
Intuition
The exit rate from a vertex is the total rate at which the process leaves that state. For parameterized graphs, this rate is a linear function of the parameters, and the per-parameter decomposition (Equation 22.2) allows efficient computation of exit rates under different parameter vectors without re-traversing the graph structure.Definition 22.3 (Importance Log-Weight) For a path \omega with transient vertices v_1, \ldots, v_K, sojourn times s_1, \ldots, s_K, and taken transitions e_1, \ldots, e_K, the importance log-weight is the log-ratio of the path density under the target to the path density under the proposal:
\log w(\omega) = \sum_{k=1}^{K} \left[ \log \lambda_{v_k}^{\text{tgt}} - \log \lambda_{v_k}^{\text{prop}} - \left(\lambda_{v_k}^{\text{tgt}} - \lambda_{v_k}^{\text{prop}}\right) s_k + \log \frac{p_k^{\text{tgt}}}{p_k^{\text{prop}}} \right], \tag{22.3}
where \lambda_{v_k}^{\text{tgt}} = \lambda_{v_k}(\boldsymbol{\theta}_{\text{target}}(t_k)) and \lambda_{v_k}^{\text{prop}} = \lambda_{v_k}(\boldsymbol{\theta}_{\text{prop}}) are the exit rates, and p_k^{\text{tgt}} = w_{\text{target}}(e_k, t_k) / \lambda_{v_k}^{\text{tgt}}, p_k^{\text{prop}} = w_{\text{prop}}(e_k) / \lambda_{v_k}^{\text{prop}} are the transition probabilities of the taken edge under target and proposal respectively.
Intuition
In a continuous-time Markov chain, the density contribution from visiting vertex v_k for time s_k and then taking edge e_k is \lambda_{v_k} \exp(-\lambda_{v_k} s_k) \cdot p_k. The importance weight is the ratio of this quantity under the target and proposal. Taking logarithms decomposes the ratio into an exit-rate log-ratio, a sojourn-time correction (the rate difference times the sojourn time), and a transition probability log-ratio. The first two terms come from the exponential sojourn time density; the third from the embedded chain transition probability.Example
If at vertex v_k the proposal has exit rate \lambda^{\text{prop}} = 3.0 and the target has \lambda^{\text{tgt}} = 4.5, with sojourn time s_k = 0.2, and the taken edge has proposal probability p^{\text{prop}} = 0.5 and target probability p^{\text{tgt}} = 0.6, the contribution to the log-weight is \log(4.5) - \log(3.0) - (4.5 - 3.0)(0.2) + \log(0.6) - \log(0.5) \approx 0.405 - 0.300 + 0.182 = 0.287.Definition 22.4 (Importance-Weighted Log-Likelihood) Given M paths \omega_1, \ldots, \omega_M sampled independently from the proposal distribution, with log-likelihoods \ell_1, \ldots, \ell_M (of observed data given each path) and log importance weights \log w_1, \ldots, \log w_M, the importance-weighted log-likelihood estimator is
\hat{\ell} = -\log M + \operatorname{logsumexp}(\ell_1 + \log w_1, \; \ldots, \; \ell_M + \log w_M). \tag{22.4}
When the path-conditional likelihoods are all equal (as in the homogeneous case where we only need the weight correction), set \ell_m = 0 for all m and Equation 22.4 reduces to
\hat{\ell}_{\text{correction}} = -\log M + \operatorname{logsumexp}(\log w_1, \ldots, \log w_M). \tag{22.5}
Intuition
This is the standard self-normalized importance sampling estimator in log-space. The \operatorname{logsumexp} function provides numerical stability by factoring out the maximum log-weight before exponentiating. Each path contributes its data likelihood weighted by how much more (or less) likely that path is under the target than under the proposal. The -\log M term normalizes by the number of samples.Theorems and Proofs
Theorem 22.1 (Unbiasedness of Importance Weights) Let \omega \sim \operatorname{PH}(\boldsymbol{\alpha}, \mathbf{S}(\boldsymbol{\theta}_{\text{prop}})) be a path sampled from the proposal distribution, and let w(\omega) = \exp(\log w(\omega)) with \log w(\omega) as in Definition 22.3. If \boldsymbol{\theta}_{\text{target}} is constant (homogeneous target), then \mathbb{E}_{\text{prop}}[w(\omega)] = 1.
Proof. Write the proposal path density as q(\omega) = \prod_{k=1}^{K} \lambda_{v_k}^{\text{prop}} \exp(-\lambda_{v_k}^{\text{prop}} s_k) \cdot p_k^{\text{prop}} and the target path density as p(\omega) = \prod_{k=1}^{K} \lambda_{v_k}^{\text{tgt}} \exp(-\lambda_{v_k}^{\text{tgt}} s_k) \cdot p_k^{\text{tgt}}. By Definition 22.3, w(\omega) = p(\omega)/q(\omega). Then
\mathbb{E}_{\text{prop}}[w(\omega)] = \int w(\omega) \, q(\omega) \, d\omega = \int \frac{p(\omega)}{q(\omega)} \, q(\omega) \, d\omega = \int p(\omega) \, d\omega = 1,
where the last equality holds because p(\omega) integrates to 1 over all paths (it is a valid density for a continuous-time Markov chain on a finite state space with absorption). \square
Theorem 22.2 (Convergence of Importance-Weighted Likelihood) Let \omega_1, \ldots, \omega_M be i.i.d. paths from the proposal, let L(\omega_m) = \exp(\ell_m) be the data likelihood given path \omega_m, and let w_m = w(\omega_m) be the importance weight. Then
\hat{L} = \frac{1}{M} \sum_{m=1}^{M} L(\omega_m) \, w(\omega_m) \xrightarrow{M \to \infty} \mathbb{E}_{\text{target}}[L(\omega)] = P(\mathbf{y} \mid \boldsymbol{\theta}_{\text{target}}) \quad \text{almost surely.}
Proof. By the strong law of large numbers applied to the i.i.d. sequence L(\omega_m) w(\omega_m):
\frac{1}{M} \sum_{m=1}^{M} L(\omega_m) w(\omega_m) \xrightarrow{\text{a.s.}} \mathbb{E}_{\text{prop}}[L(\omega) w(\omega)].
Expanding w(\omega) = p(\omega)/q(\omega):
\mathbb{E}_{\text{prop}}[L(\omega) w(\omega)] = \int L(\omega) \frac{p(\omega)}{q(\omega)} q(\omega) \, d\omega = \int L(\omega) \, p(\omega) \, d\omega = \mathbb{E}_{\text{target}}[L(\omega)].
The final expectation equals P(\mathbf{y} \mid \boldsymbol{\theta}_{\text{target}}) by the law of total probability (marginalizing the likelihood over the latent path). The estimator \hat{L} is finite with probability 1 because both L and w are bounded on a finite-state continuous-time Markov chain with bounded rates. \square
Algorithms
22.0.0.1 BFFG Log-Probability Estimation
Description. Computes an estimate of the log-probability \log P(\mathbf{y} \mid \boldsymbol{\theta}_{\text{target}}) by combining a proposal model evaluation with a stochastic importance weight correction. For each observed data point (locus), the proposal log-probability is obtained via the forward algorithm on the homogeneous proposal graph, and a set of M conditioned paths are sampled from the proposal. Each path’s importance log-weight is computed by traversing the path vertices, computing exit rates and transition probabilities under both proposal and target, and summing the log-ratios. The per-locus correction is the logsumexp of the log-weights minus \log M. The total log-probability is the sum of per-locus proposal log-probabilities and corrections.
BFFG Log-Probability Estimation
1: Let G = (V, E) be a parameterized phase-type graph
2: Let theta_prop be the proposal parameter vector
3: Let theta_target(t) be the time-varying target parameter function
4: Let y = (y_1, ..., y_L) be the observed data (L loci)
5: Let M be the number of importance sampling paths per locus
6:
7: function BFFGLogProb(G, theta_prop, theta_target, y, M)
8: total <- 0
9: for l = 1, ..., L do triangleright Loop over loci
10: log_p_model <- ForwardAlgorithm(G, theta_prop, y_l)
11: triangleright Proposal log-prob via [14]
12: for m = 1, ..., M do
13: omega_m <- SamplePathConditioned(G, theta_prop, y_l)
14: triangleright Conditioned path via [17]
15: log_w_m <- 0
16: for k = 1, ..., K_m do triangleright K_m transient vertices in path
17: lambda_prop <- sum of w(e) for e in out(v_k) under theta_prop
18: lambda_tgt <- sum of c_e^T theta_target(t_k) for e in out(v_k)
19: p_prop <- w(e_k) / lambda_prop triangleright Proposal transition prob
20: p_tgt <- c_{e_k}^T theta_target(t_k) / lambda_tgt
21: triangleright Target transition prob
22: log_w_m <- log_w_m + log(lambda_tgt) - log(lambda_prop)
23: - (lambda_tgt - lambda_prop) * s_k
24: + log(p_tgt) - log(p_prop) triangleright Eq. (3)
25: end for
26: end for
27: correction <- logsumexp(log_w_1, ..., log_w_M) - log(M)
28: triangleright Eq. (5)
29: total <- total + log_p_model + correction
30: end for
31: return total
32: end function
Correspondence table:
| Pseudocode variable | Math symbol | Code variable (file:function) |
|---|---|---|
G |
G = (V, E) | jg_continuous (bffg.py:bffg_log_prob) |
theta_prop |
\boldsymbol{\theta}_{\text{prop}} | theta_proposal (bffg.py:bffg_log_prob) |
theta_target |
\boldsymbol{\theta}_{\text{target}}(t) | theta_target_fn (bffg.py:bffg_log_prob) |
y |
\mathbf{y} | observed_data (bffg.py:bffg_log_prob) |
M |
M | n_paths (bffg.py:bffg_log_prob) |
omega_m |
\omega_m | path (bffg.py:_full_importance_log_weight) |
log_w_m |
\log w(\omega_m) | log_w (bffg.py:_full_importance_log_weight) |
lambda_prop |
\lambda_{v_k}^{\text{prop}} | r_prop (bffg.py:_full_importance_log_weight) |
lambda_tgt |
\lambda_{v_k}^{\text{tgt}} | r_tgt (bffg.py:_full_importance_log_weight) |
p_prop |
p_k^{\text{prop}} | p_prop (bffg.py:_full_importance_log_weight) |
p_tgt |
p_k^{\text{tgt}} | p_tgt (bffg.py:_full_importance_log_weight) |
s_k |
s_k | s_k (bffg.py:_full_importance_log_weight) |
log_p_model |
\log P_{\text{model}} | log_p_model (bffg.py:log_prob_fn) |
correction |
\hat{\ell}_{\text{correction}} | log_ratio (bffg.py:log_prob_fn) |
total |
\hat{\ell} | total (bffg.py:log_prob_fn) |
logsumexp |
\operatorname{logsumexp} | logsumexp (jax.scipy.special) |
Complexity. Time: O(L \cdot (C_{\text{fwd}} + M \cdot K_{\max} \cdot d)), where C_{\text{fwd}} is the cost of one forward algorithm evaluation ([14]), K_{\max} is the maximum path length, d is the parameter dimension, and the inner loop over edges at each vertex costs O(|\operatorname{out}(v_k)| \cdot d) for computing \mathbf{c}_e^\top \boldsymbol{\theta}. The conditioned path sampling costs O(K_{\max} \cdot |\operatorname{out}(v)|) per path. Space: O(|V| \cdot E_{\max} \cdot d) for the precomputed edge coefficient arrays, plus O(M) for the log-weight storage.
Numerical Considerations
Log-space computation. All weight computations are performed in log-space to avoid overflow and underflow. The importance log-weight (Equation 22.3) involves differences of log-rates and products of sojourn times with rate differences, both of which remain in a numerically tractable range. The final aggregation uses \operatorname{logsumexp}, which subtracts the maximum log-weight before exponentiating.
Degenerate rates. When exit rates are zero (absorbing vertex reached prematurely) or the taken edge cannot be identified, the step contribution is set to zero. Guard conditions r_prop > 0, r_tgt > 0, and p_prop > 0, p_tgt > 0 prevent \log 0 evaluations. In the JAX-compatible path (_importance_weight_one_path), a small constant 10^{-30} is added to rates and probabilities before taking logarithms.
Precomputed coefficients. Edge coefficients are extracted once from the graph and stored in dense arrays (dense_edge_coeffs, shape (|V|, E_{\max}, d)) for efficient JAX vectorization. This trades memory for speed, enabling vmap over all paths and loci simultaneously.
JIT and non-JIT paths. The implementation provides two execution paths: a non-JIT fallback (_full_importance_log_weight) that directly queries the graph data structure for each path step, and a JIT-compatible path (_importance_weight_one_path) that uses precomputed dense arrays and jax.lax.scan for fixed-size iteration. The return_model=True flag selects the JIT path.
Implementation Notes
Source code mapping:
| Algorithm | File | Function | Lines |
|---|---|---|---|
| Algorithm 22.1 (entry, backward compat) | src/phasic/bffg.py |
bffg_log_prob / log_prob_fn |
L570–L612 |
| Algorithm 22.1 (JIT model) | src/phasic/bffg.py |
model |
L396–L427 |
| Algorithm 22.1 (JIT correction) | src/phasic/bffg.py |
likelihood_correction_jit |
L546–L563 |
| Algorithm 22.1 (non-JIT correction) | src/phasic/bffg.py |
likelihood_correction |
L431–L456 |
| Importance log-weight (non-JIT) | src/phasic/bffg.py |
_full_importance_log_weight |
L353–L392 |
| Importance log-weight (JIT) | src/phasic/bffg.py |
_importance_weight_one_path |
L490–L522 |
| Path exit rates | src/phasic/bffg.py |
path_exit_rates |
L77–L118 |
| Path exit rates by param | src/phasic/bffg.py |
path_exit_rates_by_param |
L121–L176 |
| Importance log-weight from rates | src/phasic/bffg.py |
importance_log_weight_from_rates |
L179–L215 |
| Importance-weighted log-likelihood | src/phasic/bffg.py |
importance_weighted_log_likelihood |
L218–L250 |
| Path to rewards | src/phasic/bffg.py |
path_to_rewards |
L22–L74 |
Deviations from pseudocode:
- The pseudocode shows a sequential loop over loci and paths. The JIT implementation (
likelihood_correction_jit) usesjax.vmapto batch all L \times M path samples and weight computations into a single vectorized call, reshaping the results to (L, M) for per-locus aggregation. - The non-JIT implementation (
_full_importance_log_weight) includes both exit-rate and transition-probability log-ratios as shown in Equation 22.3. The utility functionimportance_log_weight_from_ratescomputes only the exit-rate terms (without transition probability ratios), suitable for cases where proposal and target share the same embedded chain structure. - The
bffg_log_probfunction returns either a singlelog_prob_fn(backward compatible) or a(model, likelihood_correction)tuple for integration with MCMC ([19]). The pseudocode shows only the combined log-probability for simplicity. - Edge coefficients are precomputed and cached in closure variables (
_vertex_edge_coeffs,_vertex_prop_rates) at graph construction time, not recomputed per path traversal.
Symbol Index
| Symbol | Name | First appearance |
|---|---|---|
| \mathbf{c}_e | Edge coefficient vector | Definition 22.1 |
| c_{e,j} | j-th coefficient of edge e | Definition 22.2 |
| e_k | Taken edge at step k | Definition 22.3 |
| \hat{\ell} | Importance-weighted log-likelihood estimator | Definition 22.4 |
| \hat{\ell}_{\text{correction}} | Importance weight correction | Definition 22.4 |
| \lambda_{v_k}^{\text{prop}} | Exit rate under proposal at vertex v_k | Definition 22.2 |
| \lambda_{v_k}^{\text{tgt}} | Exit rate under target at vertex v_k | Definition 22.2 |
| L | Number of loci (observations) | Algorithm 22.1 |
| M | Number of importance sampling paths | Definition 22.4 |
| \omega | Path through the graph | Definition 22.2 |
| p_k^{\text{prop}} | Proposal transition probability at step k | Definition 22.3 |
| p_k^{\text{tgt}} | Target transition probability at step k | Definition 22.3 |
| s_k | Sojourn time at vertex v_k | Definition 22.2 |
| \boldsymbol{\theta}_{\text{prop}} | Proposal parameter vector | Definition 22.1 |
| \boldsymbol{\theta}_{\text{target}}(t) | Time-varying target parameter function | Definition 22.1 |
| w(\omega) | Importance weight for path \omega | Definition 22.3 |