Stochastic Variational Inference
Key Points
- ā¢Stochastic Variational Inference (SVI) scales variational inference to large datasets by taking noisy but unbiased gradient steps using minibatches.
- ā¢SVI maximizes the Evidence Lower Bound (ELBO) with stochastic optimization, often using reparameterization gradients or natural gradients.
- ā¢It separates model variables into global parameters shared by all data and local variables tied to individual data points, updating globals with subsampled statistics.
- ā¢The key trick is to multiply minibatch contributions by N/|B| so the gradient is an unbiased estimate of the full-data gradient.
- ā¢SVI works best for exponential-family models (with natural gradients) but also handles nonconjugate models via black-box gradients.
- ā¢Learning-rate schedules (RobbinsāMonro or Adam) are crucial to ensure convergence and stability.
- ā¢Reparameterization (e.g., w = μ + Ļ ā reduces gradient variance and simplifies the entropy gradient.
- ā¢Practical SVI needs good initialization, feature scaling, and variance-control techniques to avoid slow or unstable training.
Prerequisites
- āProbability and Random Variables ā Understanding distributions, expectations, and conditional independence is essential for Bayesian models and ELBO definitions.
- āBayesian Inference Basics ā Knowledge of priors, likelihoods, and posteriors is needed to interpret variational objectives and gradients.
- āExponential Families ā SVI often uses natural parameters and sufficient statistics that arise from exponential-family structure.
- āOptimization and SGD ā SVI relies on stochastic optimization; step-size schedules and convergence properties come from SGD theory.
- āMultivariate Calculus ā Gradients of log densities and chain rule through reparameterized samples require calculus.
- āLinear Algebra ā Vectorized operations, inner products, and covariance/precision representations are ubiquitous in VI.
- āC++ Programming Basics ā Implementing SVI requires data structures, random number generation, and numerical stability in C++.
- āNumerical Stability Techniques ā Log-sum-exp, gradient clipping, and safe sigmoid computations prevent overflow/underflow in SVI.
Detailed Explanation
Tap terms for definitions01Overview
Stochastic Variational Inference (SVI) is a method for fitting probabilistic models to massive datasets by combining variational inference (VI) with stochastic optimization. Rather than processing all data at once, SVI looks at small random minibatches and takes gradient steps that, in expectation, improve a global objective called the Evidence Lower Bound (ELBO). This makes SVI especially effective for models with both global latent variables (shared across all data points) and local latent variables (specific to each observation). Classic examples include topic models like Latent Dirichlet Allocation, matrix factorization, and Bayesian regression models. VI turns Bayesian inference into an optimization problem: approximate the intractable posterior with a simpler family q_Ī»(z) and adjust Ī» to make q close to the true posterior. SVI does this at scale by using unbiased gradient estimates of the ELBO obtained from small random data subsets. When models are in the conjugate exponential family, SVI can use natural gradient updates with streamed sufficient statistics; otherwise, it can use black-box gradients via the reparameterization trick. The result is a general, scalable inference engine that converges under mild conditions while using memory and compute that depend on the minibatch size rather than the whole dataset.
02Intuition & Analogies
Imagine trying to estimate the average height of people in a huge city. You could measure everyone (exact inference), but thatās expensive. Or you could sample a small random group each day, update your estimate a bit, and keep going. Over time your estimate converges, even though each daily update is noisy. That is the essence of SVI: learn from small, random bites of the data. In Bayesian modeling, instead of a single number like height, we have distributions over hidden quantities (parameters and latent variables). Variational inference pretends these distributions come from a simpler, manageable family and adjusts their parameters to best match the truth. The ELBO plays the role of a score we want to maximizeāthe higher, the closer our approximation is to the real posterior. But calculating the exact ELBO over billions of data points is as hard as measuring everyone in the city. SVI says: sample a minibatch (say 256 people), compute how this batch suggests we should move our belief (the gradient), scale up by how much of the city it represents, and take a small step. Repeat with a new batch. Some days the batch is unrepresentative, so the step is a bit off, but on average the steps push in the right direction. If steps are not too big and gradually calm down (learning-rate schedule), the process stabilizes near the best possible approximation. When we can rewrite the model so the updates use summary statistics (exponential-family conjugacy), each batch provides the right kind of summaries and we can make especially efficient "natural" steps. When thatās not possible, we simulate small noise and use calculus through a differentiable sampling path (reparameterization) to keep the updates stable and low variance.
03Formal Definition
04When to Use
Use SVI when your dataset is too large to fit into memory or to process in a single pass, but you still want Bayesian uncertainty estimates or latent structures. It shines in models with a clear separation between global and local variables, where each datapoint contributes similar sufficient statisticsāexamples include topic models (LDA), Bayesian matrix factorization, mixture models, and hierarchical regressions. SVI is also appropriate when you have streaming data and want to update beliefs online without revisiting all past observations. If your model is in the conjugate exponential family, SVI provides very efficient natural-gradient updates based on minibatch sufficient statistics. If it isnāt, SVI can still be used with black-box, reparameterization-based gradients, especially for differentiable likelihoods (e.g., logistic regression) and Gaussian variational families. Choose SVI over full-batch VI when the dataset is large (N in millions or more) or when training time per epoch is a constraint. Prefer SVI over Markov Chain Monte Carlo when you need faster approximate inference, are comfortable with an optimization-based approximation, and can tolerate some bias for large speed gains. Avoid SVI when datasets are tiny (full-batch VI or exact methods may be simpler) or when the model involves complex discrete latents without a good gradient estimator (unless you use specialized control variates or relaxations).
ā ļøCommon Mistakes
⢠Forgetting the N/|B| scaling on minibatch contributions, which biases the gradient and can prevent convergence to the correct solution. ⢠Using a constant, too-large learning rate without decay, causing oscillations or divergence. Follow RobbinsāMonro schedules or use adaptive optimizers like Adam with care. ⢠Ignoring the entropy term in the ELBO when computing gradients for nonconjugate models. In reparameterized Gaussian q, the entropy contributes a +1 term to the gradient with respect to log-standard-deviation (rho) per dimension. ⢠Not standardizing features or targets in regression-like models, which leads to ill-conditioned optimization and very slow or unstable SVI steps. ⢠Drawing too few Monte Carlo samples for gradient estimation, yielding high-variance updates. One sample is common but can be noisy; consider multiple samples or control variates if unstable. ⢠Misusing natural gradients: updating mean parameters with Euclidean gradients instead of natural parameters (or forgetting the Fisher correction) breaks the theory and can slow convergence dramatically. ⢠Poor initialization of variational variances (too small), which causes vanishing exploration and local optima; or too large, which causes noisy gradients. Initialize log-std (rho) around log(0.1) to log(1.0) as a reasonable default. ⢠For conjugate SVI, forgetting to convert between natural parameters and mean/variance correctly, leading to negative variances or numerical errors. Always keep precision (inverse variance) positive and check bounds.
Key Formulas
ELBO Definition
Explanation: The ELBO is the objective maximized in variational inference. Increasing the ELBO improves the fit of the variational posterior to the true posterior.
Global-Local Factorization
Explanation: Many hierarchical models decompose into per-data local terms plus a global prior. SVI exploits this to use minibatches for unbiased gradient estimates.
Stochastic ELBO Gradient
Explanation: Gradients computed on a minibatch B are scaled by N/|B| to form an unbiased estimate of the full-data ELBO gradient. This is the core of SVI.
Reparameterization Trick
Explanation: Sampling via a differentiable transform allows low-variance gradients by backpropagating through the sample. For Gaussian q, g is affine.
Gaussian Reparameterization
Explanation: A diagonal Gaussian variational posterior is sampled by shifting and scaling standard normal noise. This enables pathwise derivatives.
Mean-field Gaussian ELBO Gradients
Explanation: With w = + , the entropy term contributes +1 per dimension to . The rest comes from the joint log-density via the chain rule.
RobbinsāMonro Step Size
Explanation: A standard decaying learning-rate schedule that satisfies convergence conditions for stochastic approximation. t0 delays decay to stabilize early steps.
Natural Gradient
Explanation: Preconditioning by the inverse Fisher information aligns steps with the geometry of probability distributions, often yielding faster convergence.
SVI Natural-Parameter Update
Explanation: For conjugate exponential-family models, natural parameters are moved toward the minibatch-implied posterior natural parameters with a step size .
Logistic Function
Explanation: Maps real numbers to probabilities. In logistic regression, gradients involve y - sigmoid( w).
Adam Optimizer
Explanation: Adaptive moment estimation rescales gradients using running averages of first and second moments, improving stability for noisy gradients like those in SVI.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Utility: sigmoid function 5 static inline double sigmoid(double z) { 6 if (z >= 0) { 7 double ez = exp(-z); 8 return 1.0 / (1.0 + ez); 9 } else { 10 double ez = exp(z); 11 return ez / (1.0 + ez); 12 } 13 } 14 15 // Simple Adam optimizer for vectors 16 struct Adam { 17 vector<double> m, v; 18 double beta1, beta2, alpha, eps; 19 int t; 20 Adam(int d, double alpha=0.05, double beta1=0.9, double beta2=0.999, double eps=1e-8) 21 : m(d,0.0), v(d,0.0), beta1(beta1), beta2(beta2), alpha(alpha), eps(eps), t(0) {} 22 void step(vector<double>& params, const vector<double>& grad) { 23 t++; 24 double b1t = 1.0 - pow(beta1, t); 25 double b2t = 1.0 - pow(beta2, t); 26 for (size_t i = 0; i < params.size(); ++i) { 27 m[i] = beta1 * m[i] + (1.0 - beta1) * grad[i]; 28 v[i] = beta2 * v[i] + (1.0 - beta2) * grad[i] * grad[i]; 29 double mhat = m[i] / b1t; 30 double vhat = v[i] / b2t; 31 params[i] -= alpha * mhat / (sqrt(vhat) + eps); 32 } 33 } 34 }; 35 36 // Generate synthetic logistic regression data 37 struct Dataset { 38 vector<vector<double>> X; // N x d 39 vector<int> y; // N 40 }; 41 42 Dataset make_data(int N, int d, unsigned seed=42) { 43 mt19937 rng(seed); 44 normal_distribution<double> nd(0.0, 1.0); 45 vector<double> w_true(d, 0.0); 46 for (int j = 0; j < d; ++j) w_true[j] = nd(rng) * 0.5; // true weights 47 48 Dataset D; 49 D.X.assign(N, vector<double>(d, 0.0)); 50 D.y.assign(N, 0); 51 uniform_real_distribution<double> unif(0.0, 1.0); 52 53 for (int i = 0; i < N; ++i) { 54 for (int j = 0; j < d; ++j) { 55 D.X[i][j] = nd(rng); // standard normal features 56 } 57 double z = 0.0; 58 for (int j = 0; j < d; ++j) z += D.X[i][j] * w_true[j]; 59 double p = sigmoid(z); 60 D.y[i] = (unif(rng) < p) ? 1 : 0; 61 } 62 return D; 63 } 64 65 // Compute gradient wrt w for prior + minibatch logistic likelihood 66 // prior: w ~ N(0, alpha^2 I) => grad log p(w) = - w / alpha^2 67 // likelihood: grad sum log p(y|x,w) = sum_i (y_i - sigmoid(w^T x_i)) x_i 68 vector<double> grad_log_joint_w(const vector<vector<double>>& Xb, 69 const vector<int>& yb, 70 const vector<double>& w, 71 double alpha, 72 double scale) { 73 int b = (int)Xb.size(); 74 int d = (int)w.size(); 75 vector<double> g(d, 0.0); 76 // Prior gradient 77 for (int j = 0; j < d; ++j) g[j] += - w[j] / (alpha * alpha); 78 // Likelihood gradient (scaled by N / b) 79 for (int i = 0; i < b; ++i) { 80 double z = 0.0; 81 for (int j = 0; j < d; ++j) z += Xb[i][j] * w[j]; 82 double p = sigmoid(z); 83 double r = (double)yb[i] - p; // residual 84 for (int j = 0; j < d; ++j) g[j] += scale * r * Xb[i][j]; 85 } 86 return g; 87 } 88 89 int main() { 90 ios::sync_with_stdio(false); 91 cin.tie(nullptr); 92 93 // Hyperparameters 94 int N = 50000; // dataset size 95 int d = 20; // feature dimension 96 int iters = 2000; // SVI iterations 97 int batch = 256; // minibatch size 98 double prior_alpha = 1.0; // prior std for w 99 unsigned seed = 123; 100 101 // Generate data 102 Dataset D = make_data(N, d, seed); 103 104 // Variational parameters for q(w) = N(mu, diag(sigma^2)) with sigma = exp(rho) 105 vector<double> mu(d, 0.0); 106 vector<double> rho(d, -2.0); // log-std init ~ exp(-2) ā 0.135 107 108 // Optimizers for mu and rho 109 Adam opt_mu(d, 0.05), opt_rho(d, 0.01); 110 111 mt19937 rng(seed + 1); 112 normal_distribution<double> stdn(0.0, 1.0); 113 uniform_int_distribution<int> uid(0, N - 1); 114 115 vector<double> sigma(d, 0.0), eps(d, 0.0), w(d, 0.0); 116 117 for (int t = 1; t <= iters; ++t) { 118 // Sample a minibatch 119 vector<vector<double>> Xb; 120 vector<int> yb; 121 Xb.reserve(batch); yb.reserve(batch); 122 for (int i = 0; i < batch; ++i) { 123 int idx = uid(rng); 124 Xb.push_back(D.X[idx]); 125 yb.push_back(D.y[idx]); 126 } 127 double scale = (double)N / (double)batch; // N/|B| 128 129 // Sample epsilon and form w = mu + sigma * epsilon 130 for (int j = 0; j < d; ++j) { 131 sigma[j] = exp(rho[j]); 132 eps[j] = stdn(rng); 133 w[j] = mu[j] + sigma[j] * eps[j]; 134 } 135 136 // Gradient wrt w of log p(y,w) 137 vector<double> gw = grad_log_joint_w(Xb, yb, w, prior_alpha, scale); 138 139 // Reparameterization gradients for ELBO 140 // dL/dmu = gw 141 // dL/drho = (sigma * eps) * gw + 1 (elementwise), entropy gives +1 142 vector<double> g_mu(d, 0.0), g_rho(d, 0.0); 143 for (int j = 0; j < d; ++j) { 144 g_mu[j] = gw[j]; 145 g_rho[j] = (sigma[j] * eps[j]) * gw[j] + 1.0; 146 } 147 148 // Optional: gradient clipping for stability 149 auto clip = [](vector<double>& g, double c){ 150 double n2 = 0.0; for (double v: g) n2 += v*v; n2 = sqrt(n2); 151 if (n2 > c && n2 > 0) { double s = c / n2; for (double& v: g) v *= s; } 152 }; 153 clip(g_mu, 10.0); 154 clip(g_rho, 10.0); 155 156 // Adam steps 157 opt_mu.step(mu, g_mu); 158 opt_rho.step(rho, g_rho); 159 160 if (t % 200 == 0) { 161 // Report average predictive log-likelihood on a small validation subsample 162 int val = 1000; 163 double ll = 0.0; 164 for (int i = 0; i < val; ++i) { 165 int idx = uid(rng); 166 double z = 0.0; 167 for (int j = 0; j < d; ++j) z += D.X[idx][j] * mu[j]; // use mu as a point estimate 168 double p = sigmoid(z); 169 ll += (D.y[idx] ? log(max(p, 1e-12)) : log(max(1.0 - p, 1e-12))); 170 } 171 ll /= val; 172 cerr << "Iter " << t << ": avg loglik ~ " << ll << "\n"; 173 } 174 } 175 176 // Print learned mean parameters 177 cout << fixed << setprecision(4); 178 cout << "Learned mu (first 10): "; 179 for (int j = 0; j < min(d,10); ++j) cout << mu[j] << (j+1<min(d,10)?" ":"\n"); 180 cout << "Learned sigma (first 10): "; 181 for (int j = 0; j < min(d,10); ++j) cout << exp(rho[j]) << (j+1<min(d,10)?" ":"\n"); 182 return 0; 183 } 184
This program fits a Bayesian logistic regression model using SVI with a diagonal Gaussian variational posterior. It uses the reparameterization trick w = μ + Ļ ā ε to form low-variance gradients. The gradient with respect to μ equals the gradient of the joint log-density evaluated at the sampled w. The gradient with respect to Ļ = log Ļ includes a chain-rule term (Ļ ā ε) ā ā_w log p plus a +1 entropy contribution. Minibatches are scaled by N/|B| for unbiasedness, and Adam stabilizes updates.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Model: x_n ~ N(mu, sigma2), prior mu ~ N(mu0, tau02). We use SVI on q(mu) in the same family. 5 // Natural parameters for Normal over mu with known variance form: 6 // log q(mu) = eta1 * mu + eta2 * mu^2 + const, where eta2 < 0. 7 // Convert: tau2 = -1 / (2 * eta2), mu = tau2 * eta1. 8 9 struct Data { vector<double> x; }; 10 11 Data make_data(int N, double mu_true, double sigma2, unsigned seed=7) { 12 mt19937 rng(seed); 13 normal_distribution<double> nd(mu_true, sqrt(sigma2)); 14 Data D; D.x.resize(N); 15 for (int i = 0; i < N; ++i) D.x[i] = nd(rng); 16 return D; 17 } 18 19 int main(){ 20 ios::sync_with_stdio(false); 21 cin.tie(nullptr); 22 23 // Hyperparameters 24 int N = 100000; 25 double mu_true = 2.5; 26 double sigma2 = 1.0; // known observation variance 27 double mu0 = 0.0; // prior mean 28 double tau02 = 25.0; // prior variance (weak prior) 29 30 // SVI schedule: rho_t = (t0 + t)^{-kappa} 31 double t0 = 10.0, kappa = 0.6; 32 int iters = 1000; 33 int batch = 256; 34 35 Data D = make_data(N, mu_true, sigma2); 36 mt19937 rng(123); 37 uniform_int_distribution<int> uid(0, N-1); 38 39 // Natural parameters (initialize to prior) 40 double eta1 = mu0 / tau02; 41 double eta2 = -1.0 / (2.0 * tau02); 42 43 for (int t = 1; t <= iters; ++t) { 44 // Minibatch sum 45 double sumx = 0.0; 46 for (int i = 0; i < batch; ++i) sumx += D.x[ uid(rng) ]; 47 double scale = (double)N / (double)batch; // N/|B| 48 49 // Posterior natural params implied by this minibatch (conjugate update) 50 // Each data point contributes: eta1 += x / sigma2, eta2 += -1/(2*sigma2) 51 double eta1_post_hat = (mu0 / tau02) + scale * (sumx / sigma2); 52 double eta2_post_hat = (-1.0 / (2.0 * tau02)) + scale * (- (double)batch / (2.0 * sigma2)); 53 54 // RobbinsāMonro step 55 double rho = pow(t0 + t, -kappa); 56 eta1 = (1.0 - rho) * eta1 + rho * eta1_post_hat; 57 eta2 = (1.0 - rho) * eta2 + rho * eta2_post_hat; 58 59 // Occasionally report mean/var of q(mu) 60 if (t % 100 == 0) { 61 double tau2 = -1.0 / (2.0 * eta2); 62 double muq = tau2 * eta1; 63 cerr << "Iter " << t << ": mu_q=" << muq << ", var_q=" << tau2 << "\n"; 64 } 65 } 66 67 double tau2 = -1.0 / (2.0 * eta2); 68 double muq = tau2 * eta1; 69 cout << fixed << setprecision(6); 70 cout << "Posterior mean mu_q ~ " << muq << "\n"; 71 cout << "Posterior var tau2 ~ " << tau2 << "\n"; 72 return 0; 73 } 74
This example performs SVI for a conjugate NormalāNormal model with unknown mean and known variance. The posterior over μ is Normal with natural parameters Ī·1 and Ī·2. Each datapoint contributes x/Ļ^2 to Ī·1 and ā1/(2Ļ^2) to Ī·2. Using a minibatch, we form an unbiased estimate of the full posterior natural parameters and take a RobbinsāMonro step toward them. Converting back to mean/variance shows μ_q and Var_q converging to the exact posterior.