📚TheoryAdvanced

Variational Inference Theory

Key Points

  • Variational Inference (VI) replaces an intractable posterior with a simpler distribution and optimizes it by minimizing KL divergence, which is equivalent to maximizing the ELBO.
  • The ELBO equals expected log joint minus expected log variational density; it lower-bounds the log evidence and provides a tractable optimization target.
  • Mean-field VI assumes independence across latent variables, turning a hard global problem into coordinate-wise updates or stochastic gradients.
  • The reparameterization trick (z = μ + with ε ~ N(0,1)) enables low-variance Monte Carlo gradients by moving randomness outside parameters.
  • Amortized inference uses a learned encoder to predict variational parameters for each datapoint, making inference fast at test time and enabling VAEs.
  • VI underestimates uncertainty when the variational family is too simple or when using KL(q||p); choose the family carefully and monitor calibration.
  • Black-box VI uses Monte Carlo estimates of gradients using only log p(x,z) and log q(z), allowing VI on non-conjugate models and deep generative models.
  • VAEs marry amortized VI with neural decoders, optimizing a stochastic ELBO that balances reconstruction quality and KL regularization.

Prerequisites

  • Probability distributions and Bayes’ ruleUnderstanding priors, likelihoods, and posteriors is essential to interpret p(x,z) and p(z|x).
  • Expectation and varianceELBO and KL are expectations; VI manipulates and estimates them.
  • Multivariate Gaussians and quadratic formsCommon variational families and many models use Gaussians and precisions.
  • Convexity and Jensen’s inequalityThe ELBO is derived using Jensen’s inequality and lower-bounding arguments.
  • Gradient-based optimizationBlack-box VI and VAEs rely on stochastic gradients and learning rates.
  • Monte Carlo estimationELBO and its gradients are estimated via sampling for non-conjugate models.
  • Linear algebraCoordinate updates in mean-field for Gaussians involve matrix operations.
  • Automatic differentiation/backpropagationReparameterization requires differentiating through sampling operations.

Detailed Explanation

Tap terms for definitions

01Overview

Imagine you want to understand hidden causes (latents z) behind visible data (x), but computing the exact Bayesian posterior p(z|x) is either astronomically expensive or outright impossible. Variational Inference (VI) solves this by turning inference into optimization: pick a family of tractable distributions q(z; λ) and tune its parameters λ so that q(z; λ) is as close as possible to the true posterior. The closeness is typically measured by KL divergence KL(q||p), which measures how much probability mass q assigns where p does not. This converts integration problems into optimization problems, opening the door to scalable, gradient-based methods and parallel hardware.

02Intuition & Analogies

Hook: Think of trying to fit a simple cardboard template over a complex, bumpy sculpture. You can’t reproduce every curve, but you can adjust the template’s position and bending so it hugs the sculpture as tightly as possible where it matters. Concept: In VI, the complicated sculpture is the true posterior p(z|x). The template is a simpler, tractable family q(z; λ)—maybe factorized Gaussians. We ‘press’ the template onto the sculpture by minimizing a distance (KL divergence) so the template covers high-probability regions of the sculpture. Example: Suppose you need uncertainty over a hidden temperature reading z from a noisy thermometer x. The exact posterior might be messy; a Gaussian q(z; μ, σ^2) is the cardboard template. You adjust μ and σ so q sits over the most plausible temperatures given x, using gradients that measure how to push μ and σ to hug the posterior better. Reparameterization (z = μ + σε) is like moving the template in a controlled way: the randomness ε is fixed in shape, so you can smoothly adjust μ and σ and feel how the fit changes, enabling backpropagation.

03Formal Definition

We choose a variational family Q = {q(z; λ)} and minimize the forward KL divergence: KL(q(z; |x)) = [ q(z; - p(zx) = p(x,z) - p(x), we obtain KL(q||p) because p(x) does not depend on A common restriction is mean-field VI: q(z; = (; which leads to coordinate-ascent updates q_() E_{}[ p(x, z)]. For non-conjugate models, black-box VI uses Monte Carlo estimates of ( often via the reparameterization trick , ε p( yielding low-variance pathwise gradients. In amortized VI (e.g., VAEs), the variational parameters λ are outputs of an encoder network q_ shared across datapoints, and the ELBO is maximized jointly over generative parameters θ and inference parameters

04When to Use

Use VI when exact posteriors are intractable, but you can evaluate or differentiate log p(x, z). It excels in large-scale Bayesian modeling (topic models, Bayesian regression), probabilistic deep learning (VAEs), and situations requiring fast test-time inference (amortized VI). If your model is conjugate and small, exact inference or Gibbs sampling could be simpler; if gradients are available and data are massive, stochastic VI with mini-batches is ideal. VI is also a good fit when you need differentiable objectives for integration into larger systems (e.g., end-to-end learning). For streaming or online scenarios, stochastic/online VI updates maintain scalability. When uncertainty calibration is critical and the posterior is multimodal, consider richer q families (normalizing flows, mixture variational families) or alternative divergences to avoid underestimating uncertainty.

⚠️Common Mistakes

  • Confusing KL directions: KL(q||p) vs KL(p||q) behave very differently. KL(q||p) avoids placing mass where p has none (mode-seeking) and can miss modes; don’t assume it will capture all posterior structure. - Overly simple q: Mean-field can severely underestimate uncertainty when variables are correlated. Use diagnostics (posterior predictive checks, calibration) and consider richer families. - High-variance gradients: Using score-function (REINFORCE) estimators without control variates can stall learning. Prefer reparameterization when possible and use multiple samples or baselines. - Dropping the entropy term: The ELBO includes -E_q[log q(z)]; forgetting it biases optimization and collapses σ to zero. - Poor initialization and learning rates: Variational parameters can diverge or collapse; use sensible initial scales and adaptive optimizers. - Amortization gap: In VAEs, the encoder may underfit even if per-datapoint optima exist. Consider more expressive encoders, normalizing flows, or semi-amortized VI. - Misinterpreting ELBO: Higher ELBO doesn’t always mean better generative quality (especially with strong decoders); use held-out likelihood estimates or downstream metrics too.

Key Formulas

VI Objective (KL)

Explanation: This is the quantity VI minimizes to make q(z) close to the true posterior. It measures extra bits needed when using q instead of p.

ELBO

Explanation: The ELBO is the tractable lower bound optimized in VI. It balances data fit via log p(x,z) and regularization via the entropy term -[log q].

ELBO–Evidence Decomposition

Explanation: Because KL is non-negative, the ELBO lower-bounds the log evidence. Maximizing the ELBO is equivalent to minimizing the KL.

Jensen’s Inequality Derivation

Explanation: Applying Jensen’s inequality to the concave log function yields the ELBO as a lower bound to log evidence.

Mean-field Factorization

Explanation: The mean-field assumption breaks dependencies among latent variables, making coordinate updates tractable.

CAVI Update

Explanation: At optimum (with other factors fixed), is proportional to the exponential of the expected log-joint. This yields closed-form updates in conjugate models.

Reparameterization Trick

Explanation: Sampling is re-expressed as a deterministic function of parameters and auxiliary noise, enabling backpropagation through z.

Pathwise Gradient

Explanation: Differentiating through the sampling path yields low-variance Monte Carlo gradients when f is differentiable in z.

Monte Carlo ELBO

Explanation: A stochastic estimate of the ELBO computed by sampling from q. Increasing K reduces estimator variance at higher cost.

Gaussian Entropy

Explanation: The entropy term in the ELBO for Gaussian q has a closed form, simplifying gradients and stability analysis.

Gaussian–Gaussian KL

Explanation: Used in VAEs with Gaussian priors and approximate posteriors. It often has a closed form that stabilizes training.

VAE Objective

Explanation: The amortized ELBO used in VAEs trades off reconstruction accuracy and regularization towards the prior.

Per-step Complexity (Stochastic VI/VAEs)

Explanation: For batch size B, K samples per datapoint, and model evaluation cost p, the training step time scales linearly in all three.

Complexity Analysis

The computational complexity of VI depends on the model, variational family, and optimizer. In mean-field coordinate ascent for conjugate-exponential models, each update of a factor typically costs O() for sufficient statistics, and a full sweep over m factors costs O(∑_i ). For dense dependencies, this is often O() per sweep. Memory scales with the number of variational parameters; for diagonal-Gaussian q(z) in d dimensions, storage is O(d) for means and variances. In black-box VI with reparameterization, evaluating a Monte Carlo ELBO estimate with K samples per datapoint requires K evaluations of log p(x,z) and log q(z). For a minibatch of size B, the time per step is O(B K p), where p is the cost to compute the joint (and its gradients). The space cost is dominated by parameters and activations needed for backpropagation (in neural settings), typically O(P + A), where P is parameter count and A is the size of cached activations. Reparameterized gradients often reduce variance, enabling smaller K (e.g., ), thus lowering compute. For VAEs, forward and backward passes through the encoder and decoder dominate complexity; with L layers and width w, dense layers cost roughly O(B K L ) per step. The KL term is often closed-form O(d). Overall memory is O(B K d + P) for latent samples and parameters. Convergence rate depends on curvature and estimator variance; adaptive optimizers (Adam) and control variates can reduce the number of iterations. Compared to MCMC, VI trades asymptotic exactness for speed and parallelism, usually achieving orders-of-magnitude faster wall-clock times on large datasets.

Code Examples

Reparameterized Monte Carlo ELBO for a 1D Latent Gaussian Model
1#include <bits/stdc++.h>
2using namespace std;
3
4// Model: z ~ N(0,1); x | z ~ N(z, sigma_x^2)
5// Variational family: q(z|x) = N(mu, sigma^2)
6// Optimize ELBO(mu, sigma) via reparameterization: z = mu + sigma * eps, eps ~ N(0,1)
7
8struct RNG {
9 mt19937 gen;
10 normal_distribution<double> stdn{0.0, 1.0};
11 RNG(unsigned seed = 42u) : gen(seed) {}
12 double std_normal() { return stdn(gen); }
13};
14
15// Derivative of f(z) = log p(z) + log p(x|z) w.r.t z
16// log p(z) = -0.5*(z^2 + c); d/dz = -z
17// log p(x|z) = -0.5*((x - z)^2 / sigma_x^2 + c); d/dz = (x - z)/sigma_x^2
18inline double df_dz(double z, double x, double sigma_x){
19 return -z + (x - z)/(sigma_x * sigma_x);
20}
21
22// Compute a single-sample stochastic ELBO value for monitoring
23inline double elbo_sample(double z, double mu, double sigma, double x, double sigma_x){
24 // log p(z)
25 double log_pz = -0.5*(z*z + log(2*M_PI));
26 // log p(x|z)
27 double log_px_z = -0.5*(((x - z)*(x - z))/(sigma_x*sigma_x) + log(2*M_PI*sigma_x*sigma_x));
28 // log q(z|x)
29 double log_q = -0.5*(((z - mu)*(z - mu))/(sigma*sigma) + 2.0*log(sigma) + log(2*M_PI));
30 return log_pz + log_px_z - log_q;
31}
32
33int main(){
34 // Observed datum
35 const double x = 2.0; // thermometer reading
36 const double sigma_x = 0.5; // known noise std
37
38 // Variational parameters (initialize)
39 double mu = 0.0; // mean
40 double rho = -0.5; // log sigma (to keep sigma > 0)
41
42 // Optimization hyperparameters
43 const int K = 8; // MC samples per iteration
44 const int iters = 2000; // iterations
45 const double lr = 0.02; // learning rate
46
47 RNG rng(123);
48
49 for(int t=1; t<=iters; ++t){
50 double g_mu = 0.0; // gradient w.r.t mu
51 double g_rho = 0.0; // gradient w.r.t rho
52 double elbo_est = 0.0; // for monitoring
53
54 double sigma = exp(rho);
55 for(int k=0; k<K; ++k){
56 double eps = rng.std_normal();
57 double z = mu + sigma * eps; // reparameterization
58
59 // Pathwise gradients for Gaussian q
60 // dL/dmu = f'(z)
61 double fp = df_dz(z, x, sigma_x);
62 double dL_dmu = fp;
63 // dL/dsigma = f'(z)*eps + 1/sigma => chain to rho: dL/drho = dL/dsigma * sigma
64 double dL_dsigma = fp * eps + 1.0 / sigma;
65 double dL_drho = dL_dsigma * sigma;
66
67 g_mu += dL_dmu;
68 g_rho += dL_drho;
69
70 elbo_est += elbo_sample(z, mu, sigma, x, sigma_x);
71 }
72 g_mu /= K; g_rho /= K; elbo_est /= K;
73
74 // Gradient ascent on ELBO
75 mu += lr * g_mu;
76 rho += lr * g_rho;
77
78 if(t % 200 == 0){
79 cout << "iter " << t
80 << "\tELBO~ " << fixed << setprecision(4) << elbo_est
81 << "\tmu= " << mu
82 << "\tsigma= " << exp(rho) << "\n";
83 }
84 }
85
86 cout << "Final: mu= " << mu << ", sigma= " << exp(rho) << "\n";
87 return 0;
88}
89

We fit a simple latent Gaussian model with one observation and approximate posterior q(z|x) = N(μ, σ^2). Using the reparameterization z = μ + σε with ε ~ N(0,1), we estimate the ELBO’s gradient via pathwise derivatives. The derivative with respect to μ simplifies to f′(z) = d/dz[log p(z) + log p(x|z)]. The derivative with respect to σ includes an entropy term, yielding dL/dσ = f′(z)ε + 1/σ, and we optimize ρ = log σ for positivity. This demonstrates black-box VI with low-variance gradients in a fully self-contained C++ program.

Time: O(T · K), where T is the number of iterations and K is the number of Monte Carlo samples per iteration.Space: O(1) beyond constants; only a few doubles are stored regardless of T and K.
Mean-field Coordinate Ascent VI for a Correlated Gaussian Target
1#include <bits/stdc++.h>
2using namespace std;
3
4// Target posterior (up to normalization): p(z) ∝ exp(-0.5 * (z - m)^T Λ (z - m))
5// We approximate with mean-field q(z) = Π_i N(m_i_hat, s2_i), where s2_i = 1 / Λ_ii
6// Coordinate update (Gaussian case):
7// m_i_hat ← m_i - (1/Λ_ii) * Σ_{j≠i} Λ_ij * m_j_hat
8// Repeat until convergence.
9
10struct MeanFieldGaussian {
11 vector<double> m_true; // target mean m
12 vector<vector<double>> Lambda; // precision matrix Λ (symmetric PD)
13 vector<double> m_hat; // variational means
14 vector<double> s2; // variational variances (fixed): 1 / Λ_ii
15
16 MeanFieldGaussian(const vector<double>& m, const vector<vector<double>>& L)
17 : m_true(m), Lambda(L) {
18 int n = (int)m.size();
19 m_hat.assign(n, 0.0);
20 s2.resize(n);
21 for(int i=0;i<n;++i){
22 s2[i] = 1.0 / Lambda[i][i];
23 }
24 }
25
26 void run(int max_iters = 1000, double tol = 1e-9){
27 int n = (int)m_hat.size();
28 for(int it=1; it<=max_iters; ++it){
29 double max_delta = 0.0;
30 for(int i=0; i<n; ++i){
31 double off = 0.0;
32 for(int j=0; j<n; ++j){
33 if(j==i) continue;
34 off += Lambda[i][j] * m_hat[j];
35 }
36 double new_mi = m_true[i] - off / Lambda[i][i];
37 max_delta = max(max_delta, fabs(new_mi - m_hat[i]));
38 m_hat[i] = new_mi;
39 }
40 if(max_delta < tol){
41 // converged
42 break;
43 }
44 }
45 }
46
47 void print_state() const{
48 int n = (int)m_hat.size();
49 cout << fixed << setprecision(6);
50 cout << "Variational means (m_hat): ";
51 for(int i=0;i<n;++i) cout << m_hat[i] << (i+1==n?"\n":" ");
52 cout << "Variational variances (s2): ";
53 for(int i=0;i<n;++i) cout << s2[i] << (i+1==n?"\n":" ");
54 }
55};
56
57int main(){
58 // 2D example with correlation
59 vector<double> m = {1.0, -2.0};
60 // Precision Λ for covariance Σ = [[1, 0.8],[0.8, 1]] (PD); Λ = Σ^{-1}
61 double rho = 0.8;
62 double denom = 1.0 - rho*rho;
63 vector<vector<double>> Lambda = {
64 { 1.0/denom, -rho/denom },
65 { -rho/denom, 1.0/denom }
66 };
67
68 MeanFieldGaussian mf(m, Lambda);
69 mf.run(10000, 1e-12);
70 mf.print_state();
71
72 return 0;
73}
74

For a correlated Gaussian target with precision Λ and mean m, the optimal mean-field factors are univariate Gaussians whose variances are 1/Λ_ii and whose means satisfy fixed-point equations depending on the off-diagonal entries of Λ. The coordinate ascent update m_i_hat = m_i − (1/Λ_ii) Σ_{j≠i} Λ_ij m_j_hat iteratively accounts for correlations via expectations of the other variables, even though the final q factorizes. This small program converges to the mean-field optimum and prints the factor means and variances.

Time: O(T · n^2) for n dimensions and T coordinate-sweep iterations, due to summing over off-diagonal Λ_ij.Space: O(n^2) to store the dense precision matrix Λ and O(n) for variational parameters.