22  Backward Filtering Forward Guiding

Introduction

This file formalizes Backward Filtering Forward Guiding [BFFG; Schauer et al. (2017)], an importance-weighting framework for inference with phase-type distributions under inhomogeneous (time-varying) parameters. The core problem is computing the likelihood p(\mathbf{y} \mid \boldsymbol{\theta}_{\text{target}}) when the target parameter vector varies over time, making direct forward-algorithm computation intractable. BFFG addresses this by sampling paths under a tractable homogeneous proposal model (with fixed parameters \boldsymbol{\theta}_{\text{prop}}) and reweighting those paths to account for the target model. The importance weights are derived from the ratio of path densities under the target and proposal continuous-time Markov chains.

In the phasic pipeline, BFFG enables Bayesian inference for models where transition rates depend on time (e.g., population-size changes affecting coalescent rates). The proposal model is evaluated via the fast forward algorithm and trace cache ([14], [11]), while the importance weight correction is computed from sampled paths ([17]). The combined log-probability feeds into MCMC ([19]) or SVGD ([18]) for posterior inference.

Prerequisites: [01], [14], [17]

Source files:

  • src/phasic/bffg.py (functions: path_to_rewards, path_exit_rates, path_exit_rates_by_param, importance_log_weight_from_rates, importance_weighted_log_likelihood, bffg_log_prob)

Definitions

Definition 22.1 (Proposal and Target Distributions) Let G = (V, E) be a parameterized phase-type graph with parameter vector \boldsymbol{\theta} \in \mathbb{R}^d_{>0} and coefficient vectors \mathbf{c}_e \in \mathbb{R}^d on each edge e \in E. The proposal distribution is the phase-type distribution \operatorname{PH}(\boldsymbol{\alpha}, \mathbf{S}(\boldsymbol{\theta}_{\text{prop}})) with fixed parameter vector \boldsymbol{\theta}_{\text{prop}}, where each edge weight is w(e) = \mathbf{c}_e^\top \boldsymbol{\theta}_{\text{prop}}. The target distribution is determined by a time-varying parameter function \boldsymbol{\theta}_{\text{target}}: \mathbb{R}_{\geq 0} \to \mathbb{R}^d_{>0}, so that at time t the edge weight under the target is w_{\text{target}}(e, t) = \mathbf{c}_e^\top \boldsymbol{\theta}_{\text{target}}(t).

Intuition The proposal model is a standard homogeneous phase-type distribution that can be efficiently evaluated using the forward algorithm. The target model has time-varying rates, which prevents direct use of the forward algorithm (since \mathbf{S} changes over time). By sampling paths from the proposal and reweighting, we avoid the need to solve the time-inhomogeneous Kolmogorov equations directly.
Example In a coalescent model with 5 lineages, the coalescent rate at state n is \binom{n}{2} \theta^{-1} where \theta is the effective population size. The proposal uses a fixed \theta_{\text{prop}}, while the target has \theta_{\text{target}}(t) varying (e.g., an exponential growth model). Each edge coefficient encodes \binom{n}{2}, and the parameter \theta^{-1} gives the rate scaling.

Definition 22.2 (Path Exit Rate Decomposition) Let \omega = (v_0, v_1, \ldots, v_K, v_{\text{abs}}) be a path through the graph with sojourn times s_1, \ldots, s_K at the transient vertices v_1, \ldots, v_K. The exit rate at vertex v_k under parameter vector \boldsymbol{\theta} is

\lambda_{v_k}(\boldsymbol{\theta}) = \sum_{e \in \operatorname{out}(v_k)} \mathbf{c}_e^\top \boldsymbol{\theta}, \tag{22.1}

where \operatorname{out}(v_k) denotes the set of outgoing edges from v_k. This decomposes per parameter as

\lambda_{v_k}(\boldsymbol{\theta}) = \sum_{j=1}^{d} \theta_j \sum_{e \in \operatorname{out}(v_k)} c_{e,j}, \tag{22.2}

where c_{e,j} is the j-th coefficient of edge e.

Intuition The exit rate from a vertex is the total rate at which the process leaves that state. For parameterized graphs, this rate is a linear function of the parameters, and the per-parameter decomposition (Equation 22.2) allows efficient computation of exit rates under different parameter vectors without re-traversing the graph structure.

Definition 22.3 (Importance Log-Weight) For a path \omega with transient vertices v_1, \ldots, v_K, sojourn times s_1, \ldots, s_K, and taken transitions e_1, \ldots, e_K, the importance log-weight is the log-ratio of the path density under the target to the path density under the proposal:

\log w(\omega) = \sum_{k=1}^{K} \left[ \log \lambda_{v_k}^{\text{tgt}} - \log \lambda_{v_k}^{\text{prop}} - \left(\lambda_{v_k}^{\text{tgt}} - \lambda_{v_k}^{\text{prop}}\right) s_k + \log \frac{p_k^{\text{tgt}}}{p_k^{\text{prop}}} \right], \tag{22.3}

where \lambda_{v_k}^{\text{tgt}} = \lambda_{v_k}(\boldsymbol{\theta}_{\text{target}}(t_k)) and \lambda_{v_k}^{\text{prop}} = \lambda_{v_k}(\boldsymbol{\theta}_{\text{prop}}) are the exit rates, and p_k^{\text{tgt}} = w_{\text{target}}(e_k, t_k) / \lambda_{v_k}^{\text{tgt}}, p_k^{\text{prop}} = w_{\text{prop}}(e_k) / \lambda_{v_k}^{\text{prop}} are the transition probabilities of the taken edge under target and proposal respectively.

Intuition In a continuous-time Markov chain, the density contribution from visiting vertex v_k for time s_k and then taking edge e_k is \lambda_{v_k} \exp(-\lambda_{v_k} s_k) \cdot p_k. The importance weight is the ratio of this quantity under the target and proposal. Taking logarithms decomposes the ratio into an exit-rate log-ratio, a sojourn-time correction (the rate difference times the sojourn time), and a transition probability log-ratio. The first two terms come from the exponential sojourn time density; the third from the embedded chain transition probability.
Example If at vertex v_k the proposal has exit rate \lambda^{\text{prop}} = 3.0 and the target has \lambda^{\text{tgt}} = 4.5, with sojourn time s_k = 0.2, and the taken edge has proposal probability p^{\text{prop}} = 0.5 and target probability p^{\text{tgt}} = 0.6, the contribution to the log-weight is \log(4.5) - \log(3.0) - (4.5 - 3.0)(0.2) + \log(0.6) - \log(0.5) \approx 0.405 - 0.300 + 0.182 = 0.287.

Definition 22.4 (Importance-Weighted Log-Likelihood) Given M paths \omega_1, \ldots, \omega_M sampled independently from the proposal distribution, with log-likelihoods \ell_1, \ldots, \ell_M (of observed data given each path) and log importance weights \log w_1, \ldots, \log w_M, the importance-weighted log-likelihood estimator is

\hat{\ell} = -\log M + \operatorname{logsumexp}(\ell_1 + \log w_1, \; \ldots, \; \ell_M + \log w_M). \tag{22.4}

When the path-conditional likelihoods are all equal (as in the homogeneous case where we only need the weight correction), set \ell_m = 0 for all m and Equation 22.4 reduces to

\hat{\ell}_{\text{correction}} = -\log M + \operatorname{logsumexp}(\log w_1, \ldots, \log w_M). \tag{22.5}

Intuition This is the standard self-normalized importance sampling estimator in log-space. The \operatorname{logsumexp} function provides numerical stability by factoring out the maximum log-weight before exponentiating. Each path contributes its data likelihood weighted by how much more (or less) likely that path is under the target than under the proposal. The -\log M term normalizes by the number of samples.

Theorems and Proofs

Theorem 22.1 (Unbiasedness of Importance Weights) Let \omega \sim \operatorname{PH}(\boldsymbol{\alpha}, \mathbf{S}(\boldsymbol{\theta}_{\text{prop}})) be a path sampled from the proposal distribution, and let w(\omega) = \exp(\log w(\omega)) with \log w(\omega) as in Definition 22.3. If \boldsymbol{\theta}_{\text{target}} is constant (homogeneous target), then \mathbb{E}_{\text{prop}}[w(\omega)] = 1.

Proof. Write the proposal path density as q(\omega) = \prod_{k=1}^{K} \lambda_{v_k}^{\text{prop}} \exp(-\lambda_{v_k}^{\text{prop}} s_k) \cdot p_k^{\text{prop}} and the target path density as p(\omega) = \prod_{k=1}^{K} \lambda_{v_k}^{\text{tgt}} \exp(-\lambda_{v_k}^{\text{tgt}} s_k) \cdot p_k^{\text{tgt}}. By Definition 22.3, w(\omega) = p(\omega)/q(\omega). Then

\mathbb{E}_{\text{prop}}[w(\omega)] = \int w(\omega) \, q(\omega) \, d\omega = \int \frac{p(\omega)}{q(\omega)} \, q(\omega) \, d\omega = \int p(\omega) \, d\omega = 1,

where the last equality holds because p(\omega) integrates to 1 over all paths (it is a valid density for a continuous-time Markov chain on a finite state space with absorption). \square

Theorem 22.2 (Convergence of Importance-Weighted Likelihood) Let \omega_1, \ldots, \omega_M be i.i.d. paths from the proposal, let L(\omega_m) = \exp(\ell_m) be the data likelihood given path \omega_m, and let w_m = w(\omega_m) be the importance weight. Then

\hat{L} = \frac{1}{M} \sum_{m=1}^{M} L(\omega_m) \, w(\omega_m) \xrightarrow{M \to \infty} \mathbb{E}_{\text{target}}[L(\omega)] = P(\mathbf{y} \mid \boldsymbol{\theta}_{\text{target}}) \quad \text{almost surely.}

Proof. By the strong law of large numbers applied to the i.i.d. sequence L(\omega_m) w(\omega_m):

\frac{1}{M} \sum_{m=1}^{M} L(\omega_m) w(\omega_m) \xrightarrow{\text{a.s.}} \mathbb{E}_{\text{prop}}[L(\omega) w(\omega)].

Expanding w(\omega) = p(\omega)/q(\omega):

\mathbb{E}_{\text{prop}}[L(\omega) w(\omega)] = \int L(\omega) \frac{p(\omega)}{q(\omega)} q(\omega) \, d\omega = \int L(\omega) \, p(\omega) \, d\omega = \mathbb{E}_{\text{target}}[L(\omega)].

The final expectation equals P(\mathbf{y} \mid \boldsymbol{\theta}_{\text{target}}) by the law of total probability (marginalizing the likelihood over the latent path). The estimator \hat{L} is finite with probability 1 because both L and w are bounded on a finite-state continuous-time Markov chain with bounded rates. \square

Algorithms

22.0.0.1 BFFG Log-Probability Estimation

Description. Computes an estimate of the log-probability \log P(\mathbf{y} \mid \boldsymbol{\theta}_{\text{target}}) by combining a proposal model evaluation with a stochastic importance weight correction. For each observed data point (locus), the proposal log-probability is obtained via the forward algorithm on the homogeneous proposal graph, and a set of M conditioned paths are sampled from the proposal. Each path’s importance log-weight is computed by traversing the path vertices, computing exit rates and transition probabilities under both proposal and target, and summing the log-ratios. The per-locus correction is the logsumexp of the log-weights minus \log M. The total log-probability is the sum of per-locus proposal log-probabilities and corrections.

BFFG Log-Probability Estimation
1: Let G = (V, E) be a parameterized phase-type graph
2: Let theta_prop be the proposal parameter vector
3: Let theta_target(t) be the time-varying target parameter function
4: Let y = (y_1, ..., y_L) be the observed data (L loci)
5: Let M be the number of importance sampling paths per locus
6:
7: function BFFGLogProb(G, theta_prop, theta_target, y, M)
8:   total <- 0
9:   for l = 1, ..., L do                           triangleright Loop over loci
10:    log_p_model <- ForwardAlgorithm(G, theta_prop, y_l)
11:                                                  triangleright Proposal log-prob via [14]
12:    for m = 1, ..., M do
13:      omega_m <- SamplePathConditioned(G, theta_prop, y_l)
14:                                                  triangleright Conditioned path via [17]
15:      log_w_m <- 0
16:      for k = 1, ..., K_m do                      triangleright K_m transient vertices in path
17:        lambda_prop <- sum of w(e) for e in out(v_k) under theta_prop
18:        lambda_tgt <- sum of c_e^T theta_target(t_k) for e in out(v_k)
19:        p_prop <- w(e_k) / lambda_prop             triangleright Proposal transition prob
20:        p_tgt <- c_{e_k}^T theta_target(t_k) / lambda_tgt
21:                                                  triangleright Target transition prob
22:        log_w_m <- log_w_m + log(lambda_tgt) - log(lambda_prop)
23:                  - (lambda_tgt - lambda_prop) * s_k
24:                  + log(p_tgt) - log(p_prop)       triangleright Eq. (3)
25:      end for
26:    end for
27:    correction <- logsumexp(log_w_1, ..., log_w_M) - log(M)
28:                                                  triangleright Eq. (5)
29:    total <- total + log_p_model + correction
30:  end for
31:  return total
32: end function

Correspondence table:

Pseudocode variable Math symbol Code variable (file:function)
G G = (V, E) jg_continuous (bffg.py:bffg_log_prob)
theta_prop \boldsymbol{\theta}_{\text{prop}} theta_proposal (bffg.py:bffg_log_prob)
theta_target \boldsymbol{\theta}_{\text{target}}(t) theta_target_fn (bffg.py:bffg_log_prob)
y \mathbf{y} observed_data (bffg.py:bffg_log_prob)
M M n_paths (bffg.py:bffg_log_prob)
omega_m \omega_m path (bffg.py:_full_importance_log_weight)
log_w_m \log w(\omega_m) log_w (bffg.py:_full_importance_log_weight)
lambda_prop \lambda_{v_k}^{\text{prop}} r_prop (bffg.py:_full_importance_log_weight)
lambda_tgt \lambda_{v_k}^{\text{tgt}} r_tgt (bffg.py:_full_importance_log_weight)
p_prop p_k^{\text{prop}} p_prop (bffg.py:_full_importance_log_weight)
p_tgt p_k^{\text{tgt}} p_tgt (bffg.py:_full_importance_log_weight)
s_k s_k s_k (bffg.py:_full_importance_log_weight)
log_p_model \log P_{\text{model}} log_p_model (bffg.py:log_prob_fn)
correction \hat{\ell}_{\text{correction}} log_ratio (bffg.py:log_prob_fn)
total \hat{\ell} total (bffg.py:log_prob_fn)
logsumexp \operatorname{logsumexp} logsumexp (jax.scipy.special)

Complexity. Time: O(L \cdot (C_{\text{fwd}} + M \cdot K_{\max} \cdot d)), where C_{\text{fwd}} is the cost of one forward algorithm evaluation ([14]), K_{\max} is the maximum path length, d is the parameter dimension, and the inner loop over edges at each vertex costs O(|\operatorname{out}(v_k)| \cdot d) for computing \mathbf{c}_e^\top \boldsymbol{\theta}. The conditioned path sampling costs O(K_{\max} \cdot |\operatorname{out}(v)|) per path. Space: O(|V| \cdot E_{\max} \cdot d) for the precomputed edge coefficient arrays, plus O(M) for the log-weight storage.

Algorithm 22.1: Correctness. By Theorem 22.1, each importance weight w(\omega_m) is unbiased (\mathbb{E}[w] = 1) when the target is homogeneous. By Theorem 22.2, the estimator \hat{L} = M^{-1} \sum_m w_m converges almost surely to the true likelihood ratio as M \to \infty. For the inhomogeneous case, the importance weight formula (Equation 22.3) remains a valid density ratio because it computes the exact ratio of the continuous-time path densities evaluated at the time-specific target parameters. The log-space computation via \operatorname{logsumexp} (line 27) is numerically equivalent to \log(M^{-1} \sum_m \exp(\log w_m)).

Numerical Considerations

Log-space computation. All weight computations are performed in log-space to avoid overflow and underflow. The importance log-weight (Equation 22.3) involves differences of log-rates and products of sojourn times with rate differences, both of which remain in a numerically tractable range. The final aggregation uses \operatorname{logsumexp}, which subtracts the maximum log-weight before exponentiating.

Degenerate rates. When exit rates are zero (absorbing vertex reached prematurely) or the taken edge cannot be identified, the step contribution is set to zero. Guard conditions r_prop > 0, r_tgt > 0, and p_prop > 0, p_tgt > 0 prevent \log 0 evaluations. In the JAX-compatible path (_importance_weight_one_path), a small constant 10^{-30} is added to rates and probabilities before taking logarithms.

Precomputed coefficients. Edge coefficients are extracted once from the graph and stored in dense arrays (dense_edge_coeffs, shape (|V|, E_{\max}, d)) for efficient JAX vectorization. This trades memory for speed, enabling vmap over all paths and loci simultaneously.

JIT and non-JIT paths. The implementation provides two execution paths: a non-JIT fallback (_full_importance_log_weight) that directly queries the graph data structure for each path step, and a JIT-compatible path (_importance_weight_one_path) that uses precomputed dense arrays and jax.lax.scan for fixed-size iteration. The return_model=True flag selects the JIT path.

Implementation Notes

Source code mapping:

Algorithm File Function Lines
Algorithm 22.1 (entry, backward compat) src/phasic/bffg.py bffg_log_prob / log_prob_fn L570–L612
Algorithm 22.1 (JIT model) src/phasic/bffg.py model L396–L427
Algorithm 22.1 (JIT correction) src/phasic/bffg.py likelihood_correction_jit L546–L563
Algorithm 22.1 (non-JIT correction) src/phasic/bffg.py likelihood_correction L431–L456
Importance log-weight (non-JIT) src/phasic/bffg.py _full_importance_log_weight L353–L392
Importance log-weight (JIT) src/phasic/bffg.py _importance_weight_one_path L490–L522
Path exit rates src/phasic/bffg.py path_exit_rates L77–L118
Path exit rates by param src/phasic/bffg.py path_exit_rates_by_param L121–L176
Importance log-weight from rates src/phasic/bffg.py importance_log_weight_from_rates L179–L215
Importance-weighted log-likelihood src/phasic/bffg.py importance_weighted_log_likelihood L218–L250
Path to rewards src/phasic/bffg.py path_to_rewards L22–L74

Deviations from pseudocode:

  • The pseudocode shows a sequential loop over loci and paths. The JIT implementation (likelihood_correction_jit) uses jax.vmap to batch all L \times M path samples and weight computations into a single vectorized call, reshaping the results to (L, M) for per-locus aggregation.
  • The non-JIT implementation (_full_importance_log_weight) includes both exit-rate and transition-probability log-ratios as shown in Equation 22.3. The utility function importance_log_weight_from_rates computes only the exit-rate terms (without transition probability ratios), suitable for cases where proposal and target share the same embedded chain structure.
  • The bffg_log_prob function returns either a single log_prob_fn (backward compatible) or a (model, likelihood_correction) tuple for integration with MCMC ([19]). The pseudocode shows only the combined log-probability for simplicity.
  • Edge coefficients are precomputed and cached in closure variables (_vertex_edge_coeffs, _vertex_prop_rates) at graph construction time, not recomputed per path traversal.

Symbol Index

Symbol Name First appearance
\mathbf{c}_e Edge coefficient vector Definition 22.1
c_{e,j} j-th coefficient of edge e Definition 22.2
e_k Taken edge at step k Definition 22.3
\hat{\ell} Importance-weighted log-likelihood estimator Definition 22.4
\hat{\ell}_{\text{correction}} Importance weight correction Definition 22.4
\lambda_{v_k}^{\text{prop}} Exit rate under proposal at vertex v_k Definition 22.2
\lambda_{v_k}^{\text{tgt}} Exit rate under target at vertex v_k Definition 22.2
L Number of loci (observations) Algorithm 22.1
M Number of importance sampling paths Definition 22.4
\omega Path through the graph Definition 22.2
p_k^{\text{prop}} Proposal transition probability at step k Definition 22.3
p_k^{\text{tgt}} Target transition probability at step k Definition 22.3
s_k Sojourn time at vertex v_k Definition 22.2
\boldsymbol{\theta}_{\text{prop}} Proposal parameter vector Definition 22.1
\boldsymbol{\theta}_{\text{target}}(t) Time-varying target parameter function Definition 22.1
w(\omega) Importance weight for path \omega Definition 22.3