Mean Field Variational Family
Key Points
- •Mean field variational family assumes the joint posterior over latent variables factorizes into independent pieces q(z) = ∏ ().
- •It turns Bayesian inference into optimization by maximizing the Evidence Lower BOund (ELBO) instead of integrating intractable posteriors.
- •Coordinate Ascent Variational Inference (CAVI) updates each factor by taking an expectation of the log joint with respect to the other factors.
- •This approach is fast and scalable, especially for conjugate-exponential models where closed-form updates exist.
- •The trade-off is bias: independence assumptions can under-estimate posterior correlations and uncertainty.
- •Monitoring the ELBO or parameter changes is crucial to ensure convergence and detect numerical issues.
- •Mean field VI is widely used in mixture models, topic models (LDA), Bayesian regression, and matrix factorization.
- •Careful selection of variational families, correct handling of expectations, and numerical stability (log-sum-exp, matrix conditioning) are key to good performance.
Prerequisites
- →Probability distributions and expectations — Understanding how to compute expectations E_q[...] and work with common distributions (Gaussian, Gamma) is fundamental to variational updates.
- →Bayesian inference and priors/posteriors — Mean field VI approximates posteriors; knowing priors, likelihoods, and Bayes’ rule provides context.
- →KL divergence and entropy — The ELBO is defined via expectations and entropy; KL divergence links ELBO to evidence.
- →Exponential family and conjugacy — Closed-form mean field updates rely on conjugacy and exponential-family structure.
- →Linear algebra basics — Updates often involve matrices (X^T X, inverses, traces); numerical stability depends on conditioning.
- →Optimization basics — CAVI is coordinate ascent; concepts like convergence criteria and monotonic ascent are essential.
- →Numerical stability and conditioning — Matrix inversion and log-sum-exp tricks prevent overflow/underflow and improve robustness.
Detailed Explanation
Tap terms for definitions01Overview
Mean field variational family is a simplifying assumption in variational inference where a complicated posterior distribution p(z | x) over many latent variables z is approximated by a product of simpler factors q(z) = ∏_i q_i(z_i). Instead of computing an exact posterior (often intractable due to integrals or sums over huge spaces), we reframe inference as an optimization problem: find the q within this factorized family that is closest to the true posterior in Kullback–Leibler (KL) divergence. The optimization target is the Evidence Lower BOund (ELBO), a tractable surrogate objective whose maximization is equivalent to minimizing KL(q || p). In many probabilistic models—especially those in the conjugate-exponential family—this yields closed-form coordinate updates for each factor q_i. For nonconjugate models, one may use gradient-based or stochastic variational methods while still relying on the factorization structure. The mean field approach is attractive because it is conceptually simple, computationally efficient, and scales well to high-dimensional problems or large datasets. However, the independence assumption can limit accuracy by ignoring posterior correlations, leading to underestimation of uncertainty and biased credible intervals. Understanding both its strengths and limitations helps in deciding when to deploy mean field VI or when to consider richer variational families.
02Intuition & Analogies
Imagine trying to understand a crowded conversation in a noisy room. Instead of decoding every overlapping voice together (which is hard), you ask each person to speak one at a time and summarize what they would say, assuming the others keep speaking in their typical way. You then rotate through the group, letting each person update their summary based on the others’ summaries. Over time, the group summaries become consistent. Mean field variational inference works the same way: you split a complicated joint distribution into independent parts (one per “speaker”), and iteratively refine each part by holding the others fixed and updating it to best explain the data and the probabilistic model. The ELBO acts like a global “score” of how coherent the combined summaries are with the model and data. The independence assumption is the simplifying lens: each part pretends the others only matter through their average behavior (expectations). This keeps each update simple and cheap. The cost of this simplicity is that you can miss important interactions—like two people who always speak in sync (correlation). When the real conversation has strong couplings, independent summaries can be overconfident or systematically biased. Still, for many practical rooms—mixture models, topic models, or regression with many features—this approach gives a remarkably good, fast summary compared to the impossible task of listening to everyone at once.
03Formal Definition
04When to Use
Use the mean field variational family when exact Bayesian inference is intractable but you need scalable approximate posteriors. Typical applications include: (1) Conjugate-exponential models like Gaussian mixtures, latent Dirichlet allocation (topic models), Bayesian linear regression with unknown noise, and probabilistic matrix factorization, where CAVI gives closed-form updates and fast convergence. (2) Large datasets where Markov chain Monte Carlo (MCMC) would be prohibitively slow; mean field VI provides orders-of-magnitude speedups with acceptable accuracy. (3) Online or streaming settings via stochastic variational inference, exploiting the factorization to process mini-batches. (4) As a baseline variational family before trying richer structures (e.g., structured mean field, normalizing flows) if mean field’s independence bias is acceptable. Avoid mean field when posterior dependencies are critical—for example, strong identifiability couplings, narrow curved posteriors, or multi-modality that independence cannot capture. In such cases, consider structured VI, hierarchical couplings, or MCMC for more faithful uncertainty quantification.
⚠️Common Mistakes
- Ignoring the independence bias: Mean field underestimates posterior variance and ignores correlations. Mitigation: validate uncertainty (coverage), compare with MCMC on small subsets, or upgrade to structured families.
- Wrong family/parameterization: Choosing an incompatible q_i (e.g., Gaussian over a variable constrained to be positive) leads to poor fits or numerical issues. Mitigation: match supports (e.g., Gamma for precisions, Dirichlet for probabilities) and prefer natural parameterizations.
- Mixing up Gamma rate vs scale: Many updates use the Gamma rate b (density ∝ b^a/Γ(a) τ^{a-1} e^{-b τ}). Using scale instead silently breaks expectations like E[τ] = a/b. Always check conventions.
- Not monitoring convergence: Relying on a fixed iteration count may stop too early or waste time. Use ELBO increases or parameter change thresholds; beware of oscillations due to poor initialization.
- Numerical instability: Failing to use log-sum-exp for categorical responsibilities, inverting ill-conditioned matrices directly, or allowing variances to become negative. Use stabilized computations (e.g., Cholesky, jitter), clip variances, and check positive definiteness.
- Overlooking data preprocessing: Unscaled features in regression can cause poor conditioning and slow convergence. Standardize features and center responses.
- Confusing CAVI with EM: They look similar, but EM maximizes likelihood over parameters, while VI approximates a full posterior over latents and parameters. Do not interpret variational parameters as MLEs.
- Overfitting priors: Setting hyperparameters too tight can dominate the posterior under mean field. Use weakly informative priors and sensitivity analysis.
Key Formulas
Mean field factorization
Explanation: The variational posterior is assumed to split into independent factors, one per variable or block. This simplification makes updates tractable at the cost of ignoring correlations.
ELBO
Explanation: This is the objective function variational inference maximizes. Increasing the ELBO improves the approximation to the true posterior.
ELBO–evidence identity
Explanation: The evidence (log marginal likelihood) equals the ELBO plus a nonnegative KL term. Therefore maximizing the ELBO is equivalent to minimizing KL(q || p).
CAVI update
Explanation: The optimal update for a variational factor is proportional to the exponential of the expected log-joint over the other factors. This guarantees ELBO ascent.
Gamma expectations (rate parameterization)
Explanation: For (a,b) with rate b, these expectations enter many CAVI updates. Here is the digamma function, the derivative of .
Mean field update for q(\mu) in Normal–Gamma model
Explanation: Given q(), the optimal q() is Gaussian with the shown mean and variance. It uses the expected precision [].
Mean field update for q(\tau) in Normal–Gamma model
Explanation: Given q(), the optimal q() is Gamma with updated shape and rate. Expectations replace the quadratic terms, incorporating both mean and variance of q().
q(w) update for Bayesian linear regression (MFVI)
Explanation: Given q(), the variational posterior over weights is Gaussian with covariance S and mean m. It resembles a ridge regression solution with data precision weighted by [].
q(\tau) update for Bayesian linear regression (MFVI)
Explanation: The Gamma factor over the noise precision uses the residual sum of squares at the current mean and an uncertainty correction via the trace term.
Jensen’s inequality for ELBO
Explanation: Applying concavity of log gives a lower bound on log evidence. This is the standard derivation of the ELBO.
Complexity Analysis
Code Examples
1 #include <iostream> 2 #include <vector> 3 #include <random> 4 #include <cmath> 5 #include <limits> 6 #include <algorithm> 7 8 // This program performs Coordinate Ascent Variational Inference (CAVI) 9 // for a univariate Gaussian with unknown mean mu and precision tau. 10 // Model: 11 // x_i | mu, tau ~ Normal(mu, 1/tau) 12 // mu | tau ~ Normal(mu0, 1/(lambda0 * tau)) 13 // tau ~ Gamma(a0, b0) (rate parameterization) 14 // Mean-field family: q(mu) q(tau) 15 // Updates (see formulas section): 16 // E[tau] = a / b 17 // q(mu) = Normal(m_mu, v_mu), 18 // m_mu = (lambda0 * mu0 + n * xbar) / (lambda0 + n) 19 // v_mu = 1 / (E[tau] * (lambda0 + n)) 20 // q(tau) = Gamma(a, b), 21 // a = a0 + (n + 1)/2 22 // b = b0 + 0.5 * [ lambda0 * E[(mu - mu0)^2] + sum_i E[(x_i - mu)^2] ] 23 // E[(mu - mu0)^2] = (m_mu - mu0)^2 + v_mu 24 // E[(x_i - mu)^2] = (x_i - m_mu)^2 + v_mu 25 26 struct VariationalNG { 27 double m_mu; // mean of q(mu) 28 double v_mu; // variance of q(mu) 29 double a; // shape of q(tau) 30 double b; // rate of q(tau) 31 }; 32 33 int main() { 34 // 1) Generate synthetic data 35 std::mt19937 rng(42); 36 double mu_true = 2.0; 37 double tau_true = 4.0; // precision = 4 => variance = 0.25 38 std::normal_distribution<double> noise(0.0, std::sqrt(1.0 / tau_true)); 39 40 int n = 200; 41 std::vector<double> x(n); 42 for (int i = 0; i < n; ++i) x[i] = mu_true + noise(rng); 43 44 // Precompute sample mean and sum of squares for speed 45 auto mean = [&](const std::vector<double>& v){ 46 double s = 0.0; for(double z : v) s += z; return s / v.size(); 47 }; 48 auto sumsq = [&](const std::vector<double>& v){ 49 double s = 0.0; for(double z : v) s += z*z; return s; 50 }; 51 52 double xbar = mean(x); 53 double sx2 = sumsq(x); 54 double sx = xbar * n; 55 56 // 2) Hyperparameters of Normal–Gamma prior 57 double mu0 = 0.0; 58 double lambda0 = 1.0; // strength of prior on mu (scaled by tau) 59 double a0 = 1.0; // shape 60 double b0 = 1.0; // rate 61 62 // 3) Initialize variational parameters 63 VariationalNG q{ /*m_mu*/ xbar, /*v_mu*/ 1.0, /*a*/ a0 + (n+1)*0.5, /*b*/ b0 + 1.0 }; 64 65 // 4) CAVI iterations 66 const int max_iter = 1000; 67 const double tol = 1e-9; 68 double prev_m_mu = q.m_mu, prev_v_mu = q.v_mu, prev_a = q.a, prev_b = q.b; 69 70 for (int it = 0; it < max_iter; ++it) { 71 // E[tau] 72 double Etau = q.a / q.b; 73 74 // Update q(mu): Gaussian parameters 75 double lambda = lambda0 + n; 76 q.m_mu = (lambda0 * mu0 + n * xbar) / lambda; 77 q.v_mu = 1.0 / (Etau * lambda); 78 79 // Expectations needed for q(tau) 80 double Emu_minus_mu0_sq = (q.m_mu - mu0)*(q.m_mu - mu0) + q.v_mu; 81 // sum_i (x_i - m_mu)^2 = sum x_i^2 - 2 m_mu sum x_i + n m_mu^2 82 double sum_xi_minus_m_sq = sx2 - 2.0 * q.m_mu * sx + n * q.m_mu * q.m_mu; 83 // Add n * Var(mu) to get E[(x_i - mu)^2] sum 84 double sum_E_xi_minus_mu_sq = sum_xi_minus_m_sq + n * q.v_mu; 85 86 // Update q(tau): Gamma parameters 87 q.a = a0 + 0.5 * (n + 1); 88 q.b = b0 + 0.5 * (lambda0 * Emu_minus_mu0_sq + sum_E_xi_minus_mu_sq); 89 90 // Check convergence 91 double diff = std::fabs(q.m_mu - prev_m_mu) + std::fabs(q.v_mu - prev_v_mu) 92 + std::fabs(q.a - prev_a) + std::fabs(q.b - prev_b); 93 prev_m_mu = q.m_mu; prev_v_mu = q.v_mu; prev_a = q.a; prev_b = q.b; 94 if (diff < tol) break; 95 } 96 97 // Report results 98 double Etau = q.a / q.b; 99 std::cout << "True mu=" << mu_true << ", tau=" << tau_true << "\n"; 100 std::cout << "q(mu): mean=" << q.m_mu << ", var=" << q.v_mu << "\n"; 101 std::cout << "q(tau): E[tau]=" << Etau << " (a=" << q.a << ", b=" << q.b << ")\n"; 102 103 return 0; 104 } 105
We approximate the posterior over the mean μ and precision τ of a univariate Gaussian using a factorized variational family q(μ) q(τ). Thanks to conjugacy, CAVI updates are closed-form. The code precomputes simple data statistics, then alternates between updating q(μ) (a Gaussian that uses E[τ]) and q(τ) (a Gamma that uses expectations under q(μ)). We stop when parameter changes are tiny. The output reports posterior means/variances, which should be close to the true generating parameters for enough data and reasonable priors.
1 #include <iostream> 2 #include <vector> 3 #include <random> 4 #include <cmath> 5 #include <limits> 6 #include <cassert> 7 8 // Simple matrix utilities (naive, for teaching; not optimized) 9 using Matrix = std::vector<std::vector<double>>; 10 using Vector = std::vector<double>; 11 12 Matrix zeros(int r, int c){ return Matrix(r, std::vector<double>(c, 0.0)); } 13 Matrix eye(int d){ Matrix I = zeros(d,d); for(int i=0;i<d;++i) I[i][i]=1.0; return I; } 14 Matrix transpose(const Matrix& A){ int r=A.size(), c=A[0].size(); Matrix AT=zeros(c,r); for(int i=0;i<r;++i) for(int j=0;j<c;++j) AT[j][i]=A[i][j]; return AT; } 15 Matrix add(const Matrix& A, const Matrix& B){ int r=A.size(), c=A[0].size(); Matrix C=zeros(r,c); for(int i=0;i<r;++i) for(int j=0;j<c;++j) C[i][j]=A[i][j]+B[i][j]; return C; } 16 Matrix scale(const Matrix& A, double s){ int r=A.size(), c=A[0].size(); Matrix B=zeros(r,c); for(int i=0;i<r;++i) for(int j=0;j<c;++j) B[i][j]=A[i][j]*s; return B; } 17 Matrix add_scaled_identity(const Matrix& A, double alpha){ int d=A.size(); Matrix B=A; for(int i=0;i<d;++i) B[i][i]+=alpha; return B; } 18 Matrix matmul(const Matrix& A, const Matrix& B){ int r=A.size(), m=A[0].size(), c=B[0].size(); Matrix C=zeros(r,c); for(int i=0;i<r;++i) for(int k=0;k<m;++k){ double aik=A[i][k]; for(int j=0;j<c;++j) C[i][j]+=aik*B[k][j]; } return C; } 19 Vector matvec(const Matrix& A, const Vector& x){ int r=A.size(), c=A[0].size(); Vector y(r,0.0); for(int i=0;i<r;++i) for(int j=0;j<c;++j) y[i]+=A[i][j]*x[j]; return y; } 20 Matrix outer(const Vector& a, const Vector& b){ int r=a.size(), c=b.size(); Matrix C=zeros(r,c); for(int i=0;i<r;++i) for(int j=0;j<c;++j) C[i][j]=a[i]*b[j]; return C; } 21 22 // Gauss-Jordan inversion (naive, O(d^3)) 23 Matrix inverse(Matrix A){ 24 int n=A.size(); 25 Matrix I=eye(n); 26 for(int i=0;i<n;++i){ 27 // Pivot 28 int piv=i; double mx=std::fabs(A[i][i]); 29 for(int r=i+1;r<n;++r) if(std::fabs(A[r][i])>mx){ mx=std::fabs(A[r][i]); piv=r; } 30 if(mx < 1e-15) { std::cerr << "Matrix nearly singular; add regularization.\n"; } 31 if(piv!=i){ std::swap(A[piv],A[i]); std::swap(I[piv],I[i]); } 32 // Normalize row 33 double diag=A[i][i]; 34 for(int j=0;j<n;++j){ A[i][j]/=diag; I[i][j]/=diag; } 35 // Eliminate others 36 for(int r=0;r<n;++r){ if(r==i) continue; double f=A[r][i]; if(f==0) continue; for(int j=0;j<n;++j){ A[r][j]-=f*A[i][j]; I[r][j]-=f*I[i][j]; } } 37 } 38 return I; 39 } 40 41 double dot(const Vector& a, const Vector& b){ double s=0.0; for(size_t i=0;i<a.size();++i) s+=a[i]*b[i]; return s; } 42 Vector axpy(double a, const Vector& x, const Vector& y){ Vector z=y; for(size_t i=0;i<x.size();++i) z[i]+=a*x[i]; return z; } 43 44 double trace(const Matrix& A){ double s=0.0; for(size_t i=0;i<A.size();++i) s+=A[i][i]; return s; } 45 Matrix xtx(const Matrix& X){ Matrix XT=transpose(X); return matmul(XT, X); } 46 Vector xty(const Matrix& X, const Vector& y){ Matrix XT=transpose(X); return matvec(XT, y); } 47 48 int main(){ 49 std::mt19937 rng(123); 50 int n=150; int d=3; 51 52 // True parameters 53 Vector w_true = {1.0, -2.0, 0.5}; 54 double tau_true = 5.0; // precision 55 56 // Generate X 57 std::normal_distribution<double> N01(0.0, 1.0); 58 Matrix X(n, Vector(d)); 59 for(int i=0;i<n;++i) for(int j=0;j<d;++j) X[i][j]=N01(rng); 60 61 // Generate y = X w_true + eps 62 std::normal_distribution<double> noise(0.0, std::sqrt(1.0/tau_true)); 63 Vector y(n, 0.0); 64 for(int i=0;i<n;++i){ double yi=0.0; for(int j=0;j<d;++j) yi+=X[i][j]*w_true[j]; yi+=noise(rng); y[i]=yi; } 65 66 // Prior: w ~ N(0, alpha^{-1} I), tau ~ Gamma(a0, b0) (rate) 67 double alpha = 1.0; // weight prior precision 68 double a0 = 1.0, b0 = 1.0; 69 70 // Precompute G = X^T X and g = X^T y 71 Matrix G = xtx(X); // d x d 72 Vector g = xty(X, y); // d 73 74 // Initialize variational parameters 75 Vector m(d, 0.0); // mean of q(w) 76 Matrix S = inverse(add_scaled_identity(zeros(d,d), alpha)); // initial S = (alpha I)^{-1} 77 double a = a0 + 0.5 * n; // shape of q(tau) 78 double b = b0 + 1.0; // rate of q(tau) 79 80 const int max_iter = 500; 81 const double tol = 1e-8; 82 83 for(int it=0; it<max_iter; ++it){ 84 double Etau = a / b; 85 86 // Update q(w): S = (alpha I + Etau * G)^{-1}; m = S * (Etau * g) 87 Matrix A = add_scaled_identity(scale(G, Etau), alpha); 88 Matrix S_new = inverse(A); 89 Vector rhs(d, 0.0); for(int j=0;j<d;++j) rhs[j] = Etau * g[j]; 90 // m = S_new * rhs 91 Vector m_new(d, 0.0); 92 for(int i=0;i<d;++i){ for(int j=0;j<d;++j) m_new[i] += S_new[i][j] * rhs[j]; } 93 94 // Update q(tau): a = a0 + n/2; b = b0 + 0.5 (||y - X m||^2 + tr(S G)) 95 // Compute residuals r = y - X m_new 96 Vector Xm = matvec(X, m_new); 97 double rss = 0.0; for(int i=0;i<n;++i){ double r = y[i] - Xm[i]; rss += r*r; } 98 // tr(X S X^T) = tr(S X^T X) = tr(S G) 99 double trSG = 0.0; for(int i=0;i<d;++i) for(int j=0;j<d;++j) trSG += S_new[i][j] * G[j][i]; 100 double b_new = b0 + 0.5 * (rss + trSG); 101 double a_new = a0 + 0.5 * n; 102 103 // Convergence check (parameters change) 104 double diff = std::fabs(b_new - b) + std::fabs(a_new - a); 105 for(int i=0;i<d;++i) diff += std::fabs(m_new[i] - m[i]); 106 for(int i=0;i<d;++i) for(int j=0;j<d;++j) diff += std::fabs(S_new[i][j] - S[i][j]) * 1e-3; // scaled 107 108 // Assign 109 S = std::move(S_new); m = std::move(m_new); a = a_new; b = b_new; 110 111 if(diff < tol) break; 112 } 113 114 double Etau = a / b; 115 std::cout << "True w: "; for(double wi: w_true) std::cout << wi << ' '; std::cout << "\n"; 116 std::cout << "q(w) mean: "; for(double wi: m) std::cout << wi << ' '; std::cout << "\n"; 117 std::cout << "E[tau] (noise precision): " << Etau << "\n"; 118 119 return 0; 120 } 121
This program implements CAVI for Bayesian linear regression with a Gaussian prior on weights and a Gamma prior on noise precision. The mean-field family factorizes as q(w) q(τ). Given q(τ), q(w) is Gaussian with covariance S = (α I + E[τ] X^T X)^{-1} and mean m = S E[τ] X^T y. Given q(w), q(τ) is Gamma with a = a0 + n/2 and b = b0 + 0.5(||y − X m||^2 + tr(X S X^T)). We generate synthetic data, precompute X^T X and X^T y, and iterate until parameter changes are small. The solution resembles ridge regression, corrected for posterior uncertainty via the trace term.