Variational Inference Theory
Key Points
- •Variational Inference (VI) replaces an intractable posterior with a simpler distribution and optimizes it by minimizing KL divergence, which is equivalent to maximizing the ELBO.
- •The ELBO equals expected log joint minus expected log variational density; it lower-bounds the log evidence and provides a tractable optimization target.
- •Mean-field VI assumes independence across latent variables, turning a hard global problem into coordinate-wise updates or stochastic gradients.
- •The reparameterization trick (z = μ + with ε ~ N(0,1)) enables low-variance Monte Carlo gradients by moving randomness outside parameters.
- •Amortized inference uses a learned encoder to predict variational parameters for each datapoint, making inference fast at test time and enabling VAEs.
- •VI underestimates uncertainty when the variational family is too simple or when using KL(q||p); choose the family carefully and monitor calibration.
- •Black-box VI uses Monte Carlo estimates of gradients using only log p(x,z) and log q(z), allowing VI on non-conjugate models and deep generative models.
- •VAEs marry amortized VI with neural decoders, optimizing a stochastic ELBO that balances reconstruction quality and KL regularization.
Prerequisites
- →Probability distributions and Bayes’ rule — Understanding priors, likelihoods, and posteriors is essential to interpret p(x,z) and p(z|x).
- →Expectation and variance — ELBO and KL are expectations; VI manipulates and estimates them.
- →Multivariate Gaussians and quadratic forms — Common variational families and many models use Gaussians and precisions.
- →Convexity and Jensen’s inequality — The ELBO is derived using Jensen’s inequality and lower-bounding arguments.
- →Gradient-based optimization — Black-box VI and VAEs rely on stochastic gradients and learning rates.
- →Monte Carlo estimation — ELBO and its gradients are estimated via sampling for non-conjugate models.
- →Linear algebra — Coordinate updates in mean-field for Gaussians involve matrix operations.
- →Automatic differentiation/backpropagation — Reparameterization requires differentiating through sampling operations.
Detailed Explanation
Tap terms for definitions01Overview
Imagine you want to understand hidden causes (latents z) behind visible data (x), but computing the exact Bayesian posterior p(z|x) is either astronomically expensive or outright impossible. Variational Inference (VI) solves this by turning inference into optimization: pick a family of tractable distributions q(z; λ) and tune its parameters λ so that q(z; λ) is as close as possible to the true posterior. The closeness is typically measured by KL divergence KL(q||p), which measures how much probability mass q assigns where p does not. This converts integration problems into optimization problems, opening the door to scalable, gradient-based methods and parallel hardware.
02Intuition & Analogies
Hook: Think of trying to fit a simple cardboard template over a complex, bumpy sculpture. You can’t reproduce every curve, but you can adjust the template’s position and bending so it hugs the sculpture as tightly as possible where it matters. Concept: In VI, the complicated sculpture is the true posterior p(z|x). The template is a simpler, tractable family q(z; λ)—maybe factorized Gaussians. We ‘press’ the template onto the sculpture by minimizing a distance (KL divergence) so the template covers high-probability regions of the sculpture. Example: Suppose you need uncertainty over a hidden temperature reading z from a noisy thermometer x. The exact posterior might be messy; a Gaussian q(z; μ, σ^2) is the cardboard template. You adjust μ and σ so q sits over the most plausible temperatures given x, using gradients that measure how to push μ and σ to hug the posterior better. Reparameterization (z = μ + σε) is like moving the template in a controlled way: the randomness ε is fixed in shape, so you can smoothly adjust μ and σ and feel how the fit changes, enabling backpropagation.
03Formal Definition
04When to Use
Use VI when exact posteriors are intractable, but you can evaluate or differentiate log p(x, z). It excels in large-scale Bayesian modeling (topic models, Bayesian regression), probabilistic deep learning (VAEs), and situations requiring fast test-time inference (amortized VI). If your model is conjugate and small, exact inference or Gibbs sampling could be simpler; if gradients are available and data are massive, stochastic VI with mini-batches is ideal. VI is also a good fit when you need differentiable objectives for integration into larger systems (e.g., end-to-end learning). For streaming or online scenarios, stochastic/online VI updates maintain scalability. When uncertainty calibration is critical and the posterior is multimodal, consider richer q families (normalizing flows, mixture variational families) or alternative divergences to avoid underestimating uncertainty.
⚠️Common Mistakes
- Confusing KL directions: KL(q||p) vs KL(p||q) behave very differently. KL(q||p) avoids placing mass where p has none (mode-seeking) and can miss modes; don’t assume it will capture all posterior structure. - Overly simple q: Mean-field can severely underestimate uncertainty when variables are correlated. Use diagnostics (posterior predictive checks, calibration) and consider richer families. - High-variance gradients: Using score-function (REINFORCE) estimators without control variates can stall learning. Prefer reparameterization when possible and use multiple samples or baselines. - Dropping the entropy term: The ELBO includes -E_q[log q(z)]; forgetting it biases optimization and collapses σ to zero. - Poor initialization and learning rates: Variational parameters can diverge or collapse; use sensible initial scales and adaptive optimizers. - Amortization gap: In VAEs, the encoder may underfit even if per-datapoint optima exist. Consider more expressive encoders, normalizing flows, or semi-amortized VI. - Misinterpreting ELBO: Higher ELBO doesn’t always mean better generative quality (especially with strong decoders); use held-out likelihood estimates or downstream metrics too.
Key Formulas
VI Objective (KL)
Explanation: This is the quantity VI minimizes to make q(z) close to the true posterior. It measures extra bits needed when using q instead of p.
ELBO
Explanation: The ELBO is the tractable lower bound optimized in VI. It balances data fit via log p(x,z) and regularization via the entropy term -[log q].
ELBO–Evidence Decomposition
Explanation: Because KL is non-negative, the ELBO lower-bounds the log evidence. Maximizing the ELBO is equivalent to minimizing the KL.
Jensen’s Inequality Derivation
Explanation: Applying Jensen’s inequality to the concave log function yields the ELBO as a lower bound to log evidence.
Mean-field Factorization
Explanation: The mean-field assumption breaks dependencies among latent variables, making coordinate updates tractable.
CAVI Update
Explanation: At optimum (with other factors fixed), is proportional to the exponential of the expected log-joint. This yields closed-form updates in conjugate models.
Reparameterization Trick
Explanation: Sampling is re-expressed as a deterministic function of parameters and auxiliary noise, enabling backpropagation through z.
Pathwise Gradient
Explanation: Differentiating through the sampling path yields low-variance Monte Carlo gradients when f is differentiable in z.
Monte Carlo ELBO
Explanation: A stochastic estimate of the ELBO computed by sampling from q. Increasing K reduces estimator variance at higher cost.
Gaussian Entropy
Explanation: The entropy term in the ELBO for Gaussian q has a closed form, simplifying gradients and stability analysis.
Gaussian–Gaussian KL
Explanation: Used in VAEs with Gaussian priors and approximate posteriors. It often has a closed form that stabilizes training.
VAE Objective
Explanation: The amortized ELBO used in VAEs trades off reconstruction accuracy and regularization towards the prior.
Per-step Complexity (Stochastic VI/VAEs)
Explanation: For batch size B, K samples per datapoint, and model evaluation cost p, the training step time scales linearly in all three.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Model: z ~ N(0,1); x | z ~ N(z, sigma_x^2) 5 // Variational family: q(z|x) = N(mu, sigma^2) 6 // Optimize ELBO(mu, sigma) via reparameterization: z = mu + sigma * eps, eps ~ N(0,1) 7 8 struct RNG { 9 mt19937 gen; 10 normal_distribution<double> stdn{0.0, 1.0}; 11 RNG(unsigned seed = 42u) : gen(seed) {} 12 double std_normal() { return stdn(gen); } 13 }; 14 15 // Derivative of f(z) = log p(z) + log p(x|z) w.r.t z 16 // log p(z) = -0.5*(z^2 + c); d/dz = -z 17 // log p(x|z) = -0.5*((x - z)^2 / sigma_x^2 + c); d/dz = (x - z)/sigma_x^2 18 inline double df_dz(double z, double x, double sigma_x){ 19 return -z + (x - z)/(sigma_x * sigma_x); 20 } 21 22 // Compute a single-sample stochastic ELBO value for monitoring 23 inline double elbo_sample(double z, double mu, double sigma, double x, double sigma_x){ 24 // log p(z) 25 double log_pz = -0.5*(z*z + log(2*M_PI)); 26 // log p(x|z) 27 double log_px_z = -0.5*(((x - z)*(x - z))/(sigma_x*sigma_x) + log(2*M_PI*sigma_x*sigma_x)); 28 // log q(z|x) 29 double log_q = -0.5*(((z - mu)*(z - mu))/(sigma*sigma) + 2.0*log(sigma) + log(2*M_PI)); 30 return log_pz + log_px_z - log_q; 31 } 32 33 int main(){ 34 // Observed datum 35 const double x = 2.0; // thermometer reading 36 const double sigma_x = 0.5; // known noise std 37 38 // Variational parameters (initialize) 39 double mu = 0.0; // mean 40 double rho = -0.5; // log sigma (to keep sigma > 0) 41 42 // Optimization hyperparameters 43 const int K = 8; // MC samples per iteration 44 const int iters = 2000; // iterations 45 const double lr = 0.02; // learning rate 46 47 RNG rng(123); 48 49 for(int t=1; t<=iters; ++t){ 50 double g_mu = 0.0; // gradient w.r.t mu 51 double g_rho = 0.0; // gradient w.r.t rho 52 double elbo_est = 0.0; // for monitoring 53 54 double sigma = exp(rho); 55 for(int k=0; k<K; ++k){ 56 double eps = rng.std_normal(); 57 double z = mu + sigma * eps; // reparameterization 58 59 // Pathwise gradients for Gaussian q 60 // dL/dmu = f'(z) 61 double fp = df_dz(z, x, sigma_x); 62 double dL_dmu = fp; 63 // dL/dsigma = f'(z)*eps + 1/sigma => chain to rho: dL/drho = dL/dsigma * sigma 64 double dL_dsigma = fp * eps + 1.0 / sigma; 65 double dL_drho = dL_dsigma * sigma; 66 67 g_mu += dL_dmu; 68 g_rho += dL_drho; 69 70 elbo_est += elbo_sample(z, mu, sigma, x, sigma_x); 71 } 72 g_mu /= K; g_rho /= K; elbo_est /= K; 73 74 // Gradient ascent on ELBO 75 mu += lr * g_mu; 76 rho += lr * g_rho; 77 78 if(t % 200 == 0){ 79 cout << "iter " << t 80 << "\tELBO~ " << fixed << setprecision(4) << elbo_est 81 << "\tmu= " << mu 82 << "\tsigma= " << exp(rho) << "\n"; 83 } 84 } 85 86 cout << "Final: mu= " << mu << ", sigma= " << exp(rho) << "\n"; 87 return 0; 88 } 89
We fit a simple latent Gaussian model with one observation and approximate posterior q(z|x) = N(μ, σ^2). Using the reparameterization z = μ + σε with ε ~ N(0,1), we estimate the ELBO’s gradient via pathwise derivatives. The derivative with respect to μ simplifies to f′(z) = d/dz[log p(z) + log p(x|z)]. The derivative with respect to σ includes an entropy term, yielding dL/dσ = f′(z)ε + 1/σ, and we optimize ρ = log σ for positivity. This demonstrates black-box VI with low-variance gradients in a fully self-contained C++ program.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Target posterior (up to normalization): p(z) ∝ exp(-0.5 * (z - m)^T Λ (z - m)) 5 // We approximate with mean-field q(z) = Π_i N(m_i_hat, s2_i), where s2_i = 1 / Λ_ii 6 // Coordinate update (Gaussian case): 7 // m_i_hat ← m_i - (1/Λ_ii) * Σ_{j≠i} Λ_ij * m_j_hat 8 // Repeat until convergence. 9 10 struct MeanFieldGaussian { 11 vector<double> m_true; // target mean m 12 vector<vector<double>> Lambda; // precision matrix Λ (symmetric PD) 13 vector<double> m_hat; // variational means 14 vector<double> s2; // variational variances (fixed): 1 / Λ_ii 15 16 MeanFieldGaussian(const vector<double>& m, const vector<vector<double>>& L) 17 : m_true(m), Lambda(L) { 18 int n = (int)m.size(); 19 m_hat.assign(n, 0.0); 20 s2.resize(n); 21 for(int i=0;i<n;++i){ 22 s2[i] = 1.0 / Lambda[i][i]; 23 } 24 } 25 26 void run(int max_iters = 1000, double tol = 1e-9){ 27 int n = (int)m_hat.size(); 28 for(int it=1; it<=max_iters; ++it){ 29 double max_delta = 0.0; 30 for(int i=0; i<n; ++i){ 31 double off = 0.0; 32 for(int j=0; j<n; ++j){ 33 if(j==i) continue; 34 off += Lambda[i][j] * m_hat[j]; 35 } 36 double new_mi = m_true[i] - off / Lambda[i][i]; 37 max_delta = max(max_delta, fabs(new_mi - m_hat[i])); 38 m_hat[i] = new_mi; 39 } 40 if(max_delta < tol){ 41 // converged 42 break; 43 } 44 } 45 } 46 47 void print_state() const{ 48 int n = (int)m_hat.size(); 49 cout << fixed << setprecision(6); 50 cout << "Variational means (m_hat): "; 51 for(int i=0;i<n;++i) cout << m_hat[i] << (i+1==n?"\n":" "); 52 cout << "Variational variances (s2): "; 53 for(int i=0;i<n;++i) cout << s2[i] << (i+1==n?"\n":" "); 54 } 55 }; 56 57 int main(){ 58 // 2D example with correlation 59 vector<double> m = {1.0, -2.0}; 60 // Precision Λ for covariance Σ = [[1, 0.8],[0.8, 1]] (PD); Λ = Σ^{-1} 61 double rho = 0.8; 62 double denom = 1.0 - rho*rho; 63 vector<vector<double>> Lambda = { 64 { 1.0/denom, -rho/denom }, 65 { -rho/denom, 1.0/denom } 66 }; 67 68 MeanFieldGaussian mf(m, Lambda); 69 mf.run(10000, 1e-12); 70 mf.print_state(); 71 72 return 0; 73 } 74
For a correlated Gaussian target with precision Λ and mean m, the optimal mean-field factors are univariate Gaussians whose variances are 1/Λ_ii and whose means satisfy fixed-point equations depending on the off-diagonal entries of Λ. The coordinate ascent update m_i_hat = m_i − (1/Λ_ii) Σ_{j≠i} Λ_ij m_j_hat iteratively accounts for correlations via expectations of the other variables, even though the final q factorizes. This small program converges to the mean-field optimum and prints the factor means and variances.