20 Stein Variational Gradient Descent
Introduction
This file formalizes Stein Variational Gradient Descent [SVGD; Liu and Wang (2016)], a deterministic particle-based method for approximate Bayesian inference. Given a target posterior distribution p(\boldsymbol{\theta} \mid \mathbf{y}) over the parameter vector \boldsymbol{\theta} \in \mathbb{R}^d of a phase-type model and observed data \mathbf{y}, SVGD maintains a finite set of n particles \{\boldsymbol{\theta}^{(i)}\}_{i=1}^{n} and iteratively transports them toward the posterior by following the direction that maximally decreases the KL divergence from the particle distribution to the target. The key insight is that this optimal direction can be computed in closed form using a reproducing kernel Hilbert space (RKHS), requiring only evaluations of \nabla_{\boldsymbol{\theta}} \log p(\boldsymbol{\theta} \mid \mathbf{y}) at the particle locations and kernel evaluations between particles.
In the phasic pipeline, SVGD serves as the primary inference engine for estimating posterior distributions of phase-type model parameters. The model likelihood is computed via the forward algorithm ([14]) and, optionally, moment-based regularization uses the moment computation machinery from [15]. SVGD replaces traditional MCMC sampling with a deterministic optimization that converges faster for moderate-dimensional problems (d \leq 50) typical of phase-type models.
Prerequisites: [01], [14], [15]
Source files:
src/phasic/svgd.py(classes:SVGD,SVGDKernel,Prior,GaussPrior,HalfCauchyPrior,DataPrior,SparseObservations; functions:rbf_kernel,batch_median_heuristic,batch_median_heuristic_per_dim,svgd_step,run_svgd,_svgd_update_jitted,dense_to_sparse)
Definitions
Definition 20.1 (Stein Operator and Stein Discrepancy) Let p be a smooth density on \mathbb{R}^d and \boldsymbol{\phi} \colon \mathbb{R}^d \to \mathbb{R}^d a smooth vector field. The Stein operator of p applied to \boldsymbol{\phi} is
\mathcal{A}_p \boldsymbol{\phi}(\boldsymbol{\theta}) = \nabla_{\boldsymbol{\theta}} \log p(\boldsymbol{\theta})^\top \boldsymbol{\phi}(\boldsymbol{\theta}) + \nabla_{\boldsymbol{\theta}} \cdot \boldsymbol{\phi}(\boldsymbol{\theta}), \tag{20.1}
where \nabla_{\boldsymbol{\theta}} \cdot \boldsymbol{\phi} = \sum_{j=1}^{d} \frac{\partial \phi_j}{\partial \theta_j} is the divergence. The Stein discrepancy between a distribution q and p over a function class \mathcal{F} is
\mathbb{S}(q, p) = \sup_{\boldsymbol{\phi} \in \mathcal{F}} \left( \mathbb{E}_q [\mathcal{A}_p \boldsymbol{\phi}(\boldsymbol{\theta})] \right)^2. \tag{20.2}
Intuition
The Stein operator (Stein 1972) measures how much a test function \boldsymbol{\phi} distinguishes q from p. By Stein’s identity, \mathbb{E}_p[\mathcal{A}_p \boldsymbol{\phi}(\boldsymbol{\theta})] = 0 for all suitable \boldsymbol{\phi} (under mild boundary conditions). Therefore \mathbb{E}_q[\mathcal{A}_p \boldsymbol{\phi}] is nonzero precisely when q \neq p, and the supremum over \mathcal{F} quantifies the discrepancy. The Stein discrepancy depends on p only through \nabla \log p, so the normalizing constant of p is not needed.Example
For a univariate Gaussian target p(\theta) \propto \exp(-\theta^2 / 2), we have \nabla \log p(\theta) = -\theta. The Stein operator applied to \phi(\theta) = 1 gives \mathcal{A}_p \phi(\theta) = -\theta, and \mathbb{E}_q[-\theta] = -\mathbb{E}_q[\theta], which is zero only when q has mean zero.Definition 20.2 (RBF Kernel and Bandwidth Selection) The radial basis function (RBF) kernel is
k(\boldsymbol{\theta}, \boldsymbol{\theta}') = \exp\!\left( -\frac{\lVert \boldsymbol{\theta} - \boldsymbol{\theta}' \rVert^2}{2h^2} \right), \tag{20.3}
where h > 0 is the bandwidth parameter. The isotropic median heuristic (Liu and Wang 2016) sets
h = \frac{\operatorname{median}\{\lVert \boldsymbol{\theta}^{(i)} - \boldsymbol{\theta}^{(j)} \rVert : i < j\}}{\sqrt{\log(n + 1)}}, \tag{20.4}
where n is the number of particles. The anisotropic (per-dimension) median heuristic computes a separate bandwidth h_l for each dimension l = 1, \ldots, d:
h_l = \sqrt{\frac{\operatorname{median}\{(\theta^{(i)}_l - \theta^{(j)}_l)^2 : i < j\}}{\log(n + 1)}}, \tag{20.5}
with the anisotropic kernel k(\boldsymbol{\theta}, \boldsymbol{\theta}') = \exp\!\left( -\sum_{l=1}^{d} \frac{(\theta_l - \theta'_l)^2}{2h_l^2} \right).
Intuition
The bandwidth h controls the interaction range between particles. The median heuristic adapts h to the current particle spread: when particles are concentrated, h shrinks to provide fine-grained repulsion; when particles are dispersed, h grows to maintain long-range interactions. The division by \sqrt{\log(n+1)} ensures h decreases as the number of particles increases, consistent with kernel density estimation theory. The per-dimension variant allows the kernel to adapt independently to each parameter’s scale, which is essential when parameters have different magnitudes (e.g., a rate parameter of order 1 versus a population size of order 10^3).Example
For n = 100 particles in \mathbb{R}^2, if the median pairwise distance is 0.5, then h = 0.5 / \sqrt{\log 101} \approx 0.233.Definition 20.3 (SVGD Update Rule) Given n particles \{\boldsymbol{\theta}^{(i)}\}_{i=1}^{n}, a kernel k, and the target log-density \log p, the SVGD perturbation direction (Liu and Wang 2016) for particle i is
\hat{\boldsymbol{\phi}}^*(\boldsymbol{\theta}^{(i)}) = \frac{1}{n} \sum_{j=1}^{n} \Big[ k(\boldsymbol{\theta}^{(j)}, \boldsymbol{\theta}^{(i)}) \, \nabla_{\boldsymbol{\theta}^{(j)}} \log p(\boldsymbol{\theta}^{(j)}) + \nabla_{\boldsymbol{\theta}^{(j)}} k(\boldsymbol{\theta}^{(j)}, \boldsymbol{\theta}^{(i)}) \Big]. \tag{20.6}
The particle update is \boldsymbol{\theta}^{(i)} \leftarrow \boldsymbol{\theta}^{(i)} + \epsilon \, \hat{\boldsymbol{\phi}}^*(\boldsymbol{\theta}^{(i)}), where \epsilon > 0 is the step size.
Intuition
The update has two terms. The first term k(\boldsymbol{\theta}^{(j)}, \boldsymbol{\theta}^{(i)}) \nabla \log p(\boldsymbol{\theta}^{(j)}) drives particle i toward high-probability regions, weighted by the kernel similarity to particle j (nearby particles exert more influence). The second term \nabla_{\boldsymbol{\theta}^{(j)}} k(\boldsymbol{\theta}^{(j)}, \boldsymbol{\theta}^{(i)}) acts as a repulsive force that prevents particles from collapsing to a single mode. The balance between these terms yields a particle system that approximates the full posterior, not just its mode.Example
With a single particle (n = 1), the repulsive term vanishes (since \nabla_{\boldsymbol{\theta}} k(\boldsymbol{\theta}, \boldsymbol{\theta}) = \mathbf{0}), and SVGD reduces to gradient ascent on \log p, recovering maximum a posteriori (MAP) estimation.Definition 20.4 (Particle System) An SVGD particle system is a tuple (\{\boldsymbol{\theta}^{(i)}\}_{i=1}^{n}, k, p, \epsilon) where:
- \boldsymbol{\theta}^{(i)} \in \mathbb{R}^d for i = 1, \ldots, n are the particle positions,
- k is a positive-definite kernel (Definition 20.2),
- p is the target posterior density,
- \epsilon > 0 is the step size (or step size schedule, Definition 20.5).
The particle positions at iteration t are denoted \{\boldsymbol{\theta}^{(i)}_t\}_{i=1}^{n}, with \boldsymbol{\theta}^{(i)}_0 drawn from an initial distribution q_0 (typically prior samples or a perturbed point estimate). The empirical measure \hat{q}_t = \frac{1}{n} \sum_{i=1}^{n} \delta_{\boldsymbol{\theta}^{(i)}_t} approximates the posterior as t \to \infty.
Intuition
The particle system represents the posterior nonparametrically: each particle is one plausible parameter vector, and the density of particles in a region reflects the posterior probability of that region. Unlike MCMC, all particles evolve simultaneously and interact through the kernel, enabling cooperative exploration of the parameter space.Definition 20.5 (Step Size Schedules) A step size schedule is a function \epsilon \colon \mathbb{N}_0 \to \mathbb{R}_{>0} mapping iteration t to step size \epsilon_t. The implementations are:
- Constant: \epsilon_t = \epsilon_0 for all t.
- Exponential decay: \epsilon_t = \epsilon_{\text{first}} \, e^{-t/\tau} + \epsilon_{\text{last}} \, (1 - e^{-t/\tau}), with time constant \tau > 0.
- Adaptive: \epsilon_t is adjusted multiplicatively based on a particle-spread proxy for the KL divergence. If \widehat{\text{KL}}_t > \text{target}, set \epsilon_{t+1} = (1 - \rho)\epsilon_t; otherwise \epsilon_{t+1} = (1 + \rho)\epsilon_t, where \rho \in (0,1) is the adjustment rate.
- Warmup with decay: \epsilon_t = \epsilon_{\text{peak}} \cdot t / t_w for t \leq t_w (linear ramp), followed by exponential decay \epsilon_t = \epsilon_{\text{peak}} \, e^{-(t - t_w)/\tau} + \epsilon_{\text{last}}(1 - e^{-(t - t_w)/\tau}) for t > t_w.
Intuition
Constant step sizes are simplest but may require manual tuning. Exponential decay starts with large exploratory steps and refines as convergence approaches. The adaptive schedule monitors particle spread as a proxy for how well the particles cover the posterior: if they are too spread, the step size decreases to prevent divergence; if too concentrated, it increases to encourage exploration. Warmup is particularly useful with Adam optimization, where moment estimates are unreliable in early iterations.Definition 20.6 (Prior Distributions) A prior on the parameter vector \boldsymbol{\theta} \in \mathbb{R}^d_{>0} specifies a log-density \log \pi(\boldsymbol{\theta}) and a sampling method. The implementations are:
- Gaussian prior: \log \pi(\boldsymbol{\theta}) = -\frac{1}{2} \sum_{l=1}^{d} \left(\frac{\theta_l - \mu_l}{\sigma_l}\right)^2 + C, where \mu_l, \sigma_l are per-dimension mean and standard deviation.
- Half-Cauchy prior: \log \pi(\boldsymbol{\theta}) = \sum_{l=1}^{d} \left[\log 2 - \log \pi - \log s_l - \log\!\left(1 + (\theta_l / s_l)^2\right)\right] for \theta_l > 0, with scale s_l > 0.
- Data-driven prior: Uses method-of-moments or probability matching on the observed data to estimate \mu_l and \sigma_l, producing a Gaussian prior centered on the data-informed point estimate.
When parameters are constrained to be positive via the softplus transformation \theta_l = \log(1 + e^{\phi_l}), the prior is evaluated in the constrained space \boldsymbol{\theta} with a Jacobian correction: \log \pi(\boldsymbol{\phi}) = \log \pi(\boldsymbol{\theta}(\boldsymbol{\phi})) + \sum_{l=1}^{d} \log \sigma(\phi_l), where \sigma(\phi_l) = 1/(1 + e^{-\phi_l}) is the sigmoid function.
Intuition
The Gaussian prior encodes a belief that each parameter is near some value with known uncertainty. The half-Cauchy prior has heavier tails and is a standard choice for scale parameters (Gelman 2006) that could be much larger than expected. The data-driven prior automates prior specification by extracting information from the data using method-of-moments, which is especially useful when no strong domain knowledge is available. The Jacobian correction ensures that the change of variables from constrained (\boldsymbol{\theta}) to unconstrained (\boldsymbol{\phi}) space preserves the correct probability density.Definition 20.7 (Sparse Observations) For a multivariate phase-type model with m feature dimensions, a sparse observation format is a tuple (\mathbf{v}, \mathbf{f}, m, \mathcal{S}) where:
- \mathbf{v} = (v_1, \ldots, v_N) \in \mathbb{R}^N is the vector of all valid observation values (no missing entries),
- \mathbf{f} = (f_1, \ldots, f_N) \in \{0, \ldots, m-1\}^N is the feature index for each observation,
- m is the number of feature dimensions,
- \mathcal{S} = \{(s_j, e_j)\}_{j=0}^{m-1} are pre-computed slice indices such that \mathbf{v}[s_j : e_j] are the observations for feature j.
The log-likelihood under a model with per-feature PMF f_j(\cdot \mid \boldsymbol{\theta}) is
\ell(\boldsymbol{\theta}; \mathbf{v}, \mathbf{f}) = \sum_{j=0}^{m-1} \sum_{k=s_j}^{e_j - 1} \log f_j(v_k \mid \boldsymbol{\theta}). \tag{20.7}
Intuition
In multivariate phase-type models, different marginal distributions may have different numbers of observations (e.g., some features are observed more frequently than others). The dense representation pads missing observations with NaN, which propagates through JAX gradient computations and causes numerical issues. The sparse format stores only valid observations grouped by feature, with pre-computed index slices enabling efficient JIT-compatible access without boolean indexing.Example
For 3 features with 10, 5, and 8 observations respectively, N = 23, \mathbf{v} contains all 23 values sorted by feature, \mathbf{f} = (0, \ldots, 0, 1, \ldots, 1, 2, \ldots, 2), and \mathcal{S} = \{(0, 10), (10, 15), (15, 23)\}.Theorems and Proofs
Theorem 20.1 (SVGD Update Decreases KL Divergence (Liu and Wang 2016)) Let q be a distribution on \mathbb{R}^d, p the target distribution, and \mathcal{H} the RKHS of a positive-definite kernel k. Define the perturbation \boldsymbol{\phi}^* = \arg\max_{\boldsymbol{\phi} \in \mathcal{H}^d, \, \lVert \boldsymbol{\phi} \rVert_{\mathcal{H}} \leq 1} \{-\frac{d}{d\epsilon}\operatorname{KL}(q_{[\epsilon \boldsymbol{\phi}]} \| p) \big|_{\epsilon=0}\}, where q_{[\epsilon \boldsymbol{\phi}]} is the distribution of \boldsymbol{\theta} + \epsilon \boldsymbol{\phi}(\boldsymbol{\theta}) when \boldsymbol{\theta} \sim q. Then:
(i) The steepest descent direction in KL divergence is given by \boldsymbol{\phi}^*(\cdot) = \mathbb{E}_q[\mathcal{A}_p k(\boldsymbol{\theta}, \cdot)], where \mathcal{A}_p is the Stein operator (Definition 20.1).
(ii) The rate of KL decrease along this direction is -\frac{d}{d\epsilon}\operatorname{KL}(q_{[\epsilon \boldsymbol{\phi}^*]} \| p)\big|_{\epsilon=0} = \lVert \boldsymbol{\phi}^* \rVert_{\mathcal{H}}^2 \geq 0, with equality if and only if q = p.
Proof. The KL divergence under the perturbation T_\epsilon(\boldsymbol{\theta}) = \boldsymbol{\theta} + \epsilon \boldsymbol{\phi}(\boldsymbol{\theta}) satisfies, by a change of variables and first-order expansion of \log \det(\mathbf{I} + \epsilon \nabla \boldsymbol{\phi}):
-\frac{d}{d\epsilon} \operatorname{KL}(q_{[\epsilon \boldsymbol{\phi}]} \| p)\Big|_{\epsilon=0} = \mathbb{E}_q \Big[ \nabla_{\boldsymbol{\theta}} \log p(\boldsymbol{\theta})^\top \boldsymbol{\phi}(\boldsymbol{\theta}) + \nabla_{\boldsymbol{\theta}} \cdot \boldsymbol{\phi}(\boldsymbol{\theta}) \Big] = \mathbb{E}_q [\mathcal{A}_p \boldsymbol{\phi}(\boldsymbol{\theta})].
For \boldsymbol{\phi} in the unit ball of \mathcal{H}^d, by the reproducing property \phi_l(\boldsymbol{\theta}) = \langle \phi_l, k(\boldsymbol{\theta}, \cdot) \rangle_{\mathcal{H}}. Applying Stein’s operator componentwise:
\mathbb{E}_q[\mathcal{A}_p \boldsymbol{\phi}(\boldsymbol{\theta})] = \sum_{l=1}^{d} \langle \phi_l, \, \mathbb{E}_q[\nabla_{\theta_l} \log p(\boldsymbol{\theta}) \, k(\boldsymbol{\theta}, \cdot) + \nabla_{\theta_l} k(\boldsymbol{\theta}, \cdot)] \rangle_{\mathcal{H}}.
By the Cauchy–Schwarz inequality in \mathcal{H}, the supremum over \lVert \boldsymbol{\phi} \rVert_{\mathcal{H}} \leq 1 is achieved when \boldsymbol{\phi}^* is proportional to the function \boldsymbol{\theta}' \mapsto \mathbb{E}_q[\nabla \log p(\boldsymbol{\theta}) \, k(\boldsymbol{\theta}, \boldsymbol{\theta}') + \nabla_{\boldsymbol{\theta}} k(\boldsymbol{\theta}, \boldsymbol{\theta}')], which proves (i). The maximum value equals \lVert \boldsymbol{\phi}^* \rVert_{\mathcal{H}}^2. By Stein’s identity, \mathbb{E}_p[\mathcal{A}_p \boldsymbol{\phi}] = 0 for all \boldsymbol{\phi} (assuming p satisfies mild boundary conditions), so \boldsymbol{\phi}^* = \mathbf{0} if and only if q = p, proving (ii). \square
Algorithms
20.0.0.1 SVGD Optimization
Description. Performs Bayesian posterior inference by evolving n particles under the SVGD dynamics. At each iteration, the algorithm: (1) computes the kernel matrix and its gradient using the median heuristic bandwidth; (2) evaluates the log-posterior gradient at each particle; (3) combines these via the SVGD update rule (Definition 20.3); and (4) applies the step size (possibly via an adaptive optimizer such as Adam (Kingma and Ba 2015)). The algorithm operates in unconstrained space \boldsymbol{\phi} and maps back to constrained space \boldsymbol{\theta} via a parameter transformation (e.g., softplus for positivity).
SVGD Optimization
1: Let {phi_i^(0)}_{i=1}^n be initial particles in unconstrained space R^d
2: Let p(theta | y) be the target posterior with observed data y
3: Let T be the number of iterations, k the RBF kernel
4: Let g: R^d -> R^d_>0 be the parameter transform (e.g., softplus)
5:
6: function SVGDOptimize({phi^(0)}, log_p, T, k, epsilon)
7: for t = 0, 1, ..., T-1 do
8: h <- MedianHeuristic({phi_i^(t)}) triangleright Bandwidth (eq. 4 or 5)
9: K_ij <- k(phi_i^(t), phi_j^(t); h) triangleright Kernel matrix, n x n
10: for i = 1, ..., n do
11: grad_i <- nabla_phi log p(g(phi_i^(t)) | y) + log|det Jg(phi_i^(t))|
12: triangleright Score + Jacobian
13: end for
14: for i = 1, ..., n do
15: psi_i <- (1/n) sum_{j=1}^n [K_ji * grad_j + nabla_{phi_j} K_ji]
16: triangleright SVGD direction (eq. 6)
17: end for
18: epsilon_t <- StepSize(t, {phi^(t)}) triangleright Schedule (Def. 18.5)
19: for i = 1, ..., n do
20: phi_i^(t+1) <- OptimizerUpdate(phi_i^(t), psi_i, epsilon_t)
21: triangleright Adam or SGD step
22: end for
23: end for
24: theta_i <- g(phi_i^(T)) for i = 1, ..., n triangleright Map to constrained space
25: return {theta_i}_{i=1}^n
26: end function
Correspondence table:
| Pseudocode variable | Math symbol | Code variable (file:function) |
|---|---|---|
phi_i^(t) |
\boldsymbol{\phi}^{(i)}_t | particles (svgd.py:svgd_step) |
n |
n | n_particles (svgd.py:SVGD.__init__) |
T |
T | n_iterations / n_steps (svgd.py:SVGD.__init__, run_svgd) |
K_ij |
k(\boldsymbol{\phi}^{(i)}, \boldsymbol{\phi}^{(j)}) | K (svgd.py:SVGDKernel.compute_kernel_grad) |
grad_i |
\nabla_{\boldsymbol{\phi}} \log p | grad_log_p (svgd.py:svgd_step) |
psi_i |
\hat{\boldsymbol{\phi}}^*(\boldsymbol{\phi}^{(i)}) | phi (svgd.py:_svgd_update_jitted) |
epsilon_t |
\epsilon_t | step_size (svgd.py:run_svgd) |
h |
h or \mathbf{h} | bandwidth (svgd.py:batch_median_heuristic, batch_median_heuristic_per_dim) |
g |
parameter transform | self.param_transform (svgd.py:SVGD.__init__) |
OptimizerUpdate |
optimizer step | optimizer.apply_gradients (svgd.py:Adam.update) |
log_p |
\log p(\boldsymbol{\theta} \mid \mathbf{y}) | log_prob_fn (svgd.py:SVGD._log_prob_unified) |
K_ji * grad_j |
k(\boldsymbol{\phi}^{(j)}, \boldsymbol{\phi}^{(i)}) \nabla \log p(\boldsymbol{\phi}^{(j)}) | positive_term (svgd.py:_svgd_update_jitted) |
nabla_{phi_j} K_ji |
\nabla_{\boldsymbol{\phi}^{(j)}} k(\boldsymbol{\phi}^{(j)}, \boldsymbol{\phi}^{(i)}) | negative_term (svgd.py:_svgd_update_jitted) |
Complexity. Time: O(T \cdot (n^2 d + n \cdot C_{\text{model}})), where C_{\text{model}} is the cost of one log-posterior gradient evaluation (dominated by the forward algorithm from [14]). The kernel computation is O(n^2 d) per iteration. Space: O(n^2 + n d) for the kernel matrix and particle storage.
positive_term is the \mathbf{K} \nabla \log p product (via jnp.einsum('ij,jk->ik', K, grad_log_p)) and the negative_term is \sum_j \nabla_{\boldsymbol{\phi}^{(j)}} k (via jnp.sum(grad_K, axis=1)). The kernel gradient for the RBF kernel is \nabla_{\boldsymbol{\theta}} k(\boldsymbol{\theta}, \boldsymbol{\theta}') = -k(\boldsymbol{\theta}, \boldsymbol{\theta}') \cdot (\boldsymbol{\theta} - \boldsymbol{\theta}') / h^2, computed in _compute_kernel_grad_impl.
Numerical Considerations
Positivity constraint. Phase-type distribution parameters (rates) must be positive. SVGD operates in unconstrained space \boldsymbol{\phi} \in \mathbb{R}^d and maps to \boldsymbol{\theta} \in \mathbb{R}^d_{>0} via the softplus transform \theta_l = \log(1 + e^{\phi_l}), which is smooth, monotonic, and has bounded gradient. The Jacobian correction \log \sigma(\phi_l) = -\log(1 + e^{-\phi_l}) is subtracted from the log-density to account for the change of variables. For large \phi_l > 20, \theta_l \approx \phi_l (the inverse softplus uses an identity approximation to avoid overflow).
Bandwidth stability. The median heuristic bandwidth is clamped to a minimum of 10^{-8} per dimension to prevent division by zero when particles collapse. The per-dimension variant (batch_median_heuristic_per_dim) computes h_l = \max(\sqrt{h_l^2}, 10^{-8}).
Log-space likelihood. The model log-likelihood is computed as \sum_k \log(f(\cdot \mid \boldsymbol{\theta}) + 10^{-10}), with the small additive constant preventing \log(0) when the PDF evaluates to zero at some observation points.
Sparse observations. The SparseObservations format avoids NaN propagation through JAX callbacks. Each feature’s observations are accessed via pre-computed integer slices rather than boolean masks, which is compatible with jax.jit (boolean indexing produces variable-length outputs that are not JIT-traceable).
Implementation Notes
Source code mapping:
| Algorithm | File | Function | Lines |
|---|---|---|---|
| Algorithm 20.1 (kernel) | src/phasic/svgd.py |
SVGDKernel.compute_kernel_grad |
L3152–L3213 |
| Algorithm 20.1 (update) | src/phasic/svgd.py |
_svgd_update_jitted |
L3217–L3254 |
| Algorithm 20.1 (step) | src/phasic/svgd.py |
svgd_step |
L3257–L3440 |
| Algorithm 20.1 (loop) | src/phasic/svgd.py |
run_svgd |
L3447–L3750 |
| Algorithm 20.1 (entry) | src/phasic/svgd.py |
SVGD.optimize |
L5176–L5440 |
| Bandwidth (isotropic) | src/phasic/svgd.py |
batch_median_heuristic |
L2481–L2492 |
| Bandwidth (per-dim) | src/phasic/svgd.py |
batch_median_heuristic_per_dim |
L2494–L2525 |
| RBF kernel (scalar) | src/phasic/svgd.py |
rbf_kernel |
L2463–L2466 |
| Prior (Gaussian) | src/phasic/svgd.py |
GaussPrior.__call__ |
L493–L520 |
| Prior (Half-Cauchy) | src/phasic/svgd.py |
HalfCauchyPrior.__call__ |
L655–L691 |
| Prior (Data-driven) | src/phasic/svgd.py |
DataPrior.__init__ |
L815–L856 |
| Sparse observations | src/phasic/svgd.py |
SparseObservations, dense_to_sparse |
L96–L232 |
Deviations from pseudocode:
- The pseudocode shows explicit loops over particles; the implementation vectorizes via
jnp.einsumandjnp.sumfor the SVGD update, and viajax.vmaporjax.pmapfor the gradient computation across particles. - Fixed parameters are handled by projecting particles to a learnable subspace before the SVGD step and expanding back to the full space afterward (not shown in pseudocode for clarity).
- The optional preconditioner normalizes particles before kernel computation and transforms gradients back, making the kernel isotropic in preconditioned space. This is not shown in the pseudocode.
- When an optimizer [e.g., Adam; Kingma and Ba (2015)] is used, line 20 applies the optimizer’s adaptive update (maintaining first and second moment estimates) rather than a simple \boldsymbol{\phi} + \epsilon \boldsymbol{\psi} step.
Symbol Index
| Symbol | Name | First appearance |
|---|---|---|
| \mathcal{A}_p | Stein operator | Definition 20.1 |
| d | Parameter dimension | Definition 20.1 |
| \epsilon_t | Step size at iteration t | Definition 20.3 |
| \mathbf{f} | Feature index vector (sparse obs.) | Definition 20.7 |
| h | Kernel bandwidth | Definition 20.2 |
| h_l | Per-dimension bandwidth | Definition 20.2 |
| k(\cdot, \cdot) | RBF kernel function | Definition 20.2 |
| n | Number of particles | Definition 20.2 |
| \hat{\boldsymbol{\phi}}^* | SVGD perturbation direction | Definition 20.3 |
| \boldsymbol{\phi} | Unconstrained parameter vector | Definition 20.6 |
| \pi(\boldsymbol{\theta}) | Prior density | Definition 20.6 |
| \hat{q}_t | Empirical particle measure at iteration t | Definition 20.4 |
| \rho | Adaptive step adjustment rate | Definition 20.5 |
| \mathcal{S} | Slice indices (sparse obs.) | Definition 20.7 |
| \mathbb{S}(q, p) | Stein discrepancy | Definition 20.1 |
| \tau | Exponential decay time constant | Definition 20.5 |
| T | Number of iterations | Algorithm 20.1 |
| \mathbf{v} | Observation values (sparse obs.) | Definition 20.7 |