Expectation Maximization (EM)
Key Points
- •Expectation Maximization (EM) is an iterative algorithm to estimate parameters when some variables are hidden or unobserved.
- •EM alternates between estimating hidden variables (E-step) and maximizing parameters given those estimates (M-step).
- •The E-step computes Q( = [log p(x, z | using the current parameters to infer the posterior over latent variables.
- •The M-step finds new parameters that maximize Q, guaranteeing the observed-data likelihood does not decrease.
- •EM is widely used for mixture models (e.g., Gaussian Mixture Models), clustering, missing data problems, and hidden-state models.
- •EM can converge to local optima, so initialization, regularization, and monitoring the log-likelihood are critical.
- •Numerically stable implementations use log-sum-exp, enforce constraints (like mixture weights summing to 1), and avoid degenerate solutions (e.g., zero variance).
Prerequisites
- →Basic probability and Bayes’ rule — Understanding joint, marginal, and conditional probabilities is essential to define p(x, z | θ) and p(z | x, θ).
- →Maximum likelihood estimation (MLE) — EM aims to maximize the observed-data likelihood indirectly, so familiarity with MLE helps interpret updates.
- →Log-likelihood and concavity/convexity — EM relies on properties of the log function (concave) and Jensen’s inequality to build a lower bound.
- →Gaussian and Binomial distributions — Common EM applications (GMMs and coin mixtures) use these distributions in E- and M-steps.
- →Linear algebra (means, variances, covariance) — Parameter updates in mixture models require computing weighted means and variances.
- →Numerical stability (log-sum-exp) — Avoids underflow/overflow when summing probabilities in the E-step.
- →Optimization basics — Understanding gradients and constrained updates clarifies M-step solutions and Generalized EM.
- →Clustering concepts — Relates EM responsibilities to soft clustering and component assignments.
Detailed Explanation
Tap terms for definitions01Overview
Expectation Maximization (EM) is a powerful strategy to estimate model parameters when your data is incomplete or contains hidden (latent) parts. Instead of directly maximizing the likelihood of the observed data—which is often intractable because it involves summing or integrating over unknown variables—EM turns the problem into two alternating, simpler subproblems. In the E-step, you compute expectations of the complete-data log-likelihood with respect to the current posterior distribution of the latent variables. In the M-step, you maximize that expected complete-data log-likelihood with respect to the parameters. These steps are repeated until convergence. A key property is monotonic improvement: each full E-then-M update does not decrease the observed-data log-likelihood. This makes EM a go-to workhorse for mixture models (like Gaussian Mixture Models), clustering, imputation for missing data, and a range of latent variable models such as factor analysis and hidden Markov models (via Baum–Welch). While EM is conceptually simple and often easy to implement, it optimizes a generally non-convex objective, so it may converge to local rather than global maxima. Practical success hinges on good initialization, careful numerical handling, and suitable model choices.
02Intuition & Analogies
Imagine sorting coins into different jars when the labels have worn off. You get a pile of coin toss results, but you don’t know which jar each result came from. If you somehow knew which toss belonged to which jar (the hidden assignment), estimating each jar’s bias would be easy: just count heads and tails. Conversely, if you knew the jars’ biases, you could probabilistically guess which jar produced each toss. EM alternates between these two kinds of guessing. First, with current parameter guesses, you softly assign each observation to jars (E-step: expected assignments). Next, treating those soft assignments like weighted data, you update the jar biases (M-step). Repeat this, and your jars’ labels gradually make sense, and the fit to the data improves. A second analogy: think of EM as resolving a blurry photo. The hidden variables are the missing pixels. With your current best picture (parameters), you infer what the missing pixels likely are (E-step). Then, given that filled-in photo, you refine the camera settings to best explain it (M-step). As you iterate, the picture sharpens. The magic comes from using expectations of the complete-data log-likelihood, which turns the otherwise hard observed-only problem into one that alternates between probabilistic filling-in and straightforward optimization. This process guarantees you won’t make the picture worse at each step, though you might still get stuck at a pretty good—but not perfect—version.
03Formal Definition
04When to Use
Use EM when your model involves hidden structure that makes direct maximum likelihood hard but the complete-data problem easy. Classic scenarios include: (1) Mixture models such as Gaussian Mixture Models for soft clustering, where latent variables indicate component membership. (2) Missing data problems, where z represents missing entries; EM imputes them in expectation, enabling parameter updates as if the data were complete. (3) Hidden-state models like Hidden Markov Models (Baum–Welch algorithm), where z are hidden states; EM efficiently handles the temporal structure via dynamic programming in the E-step. (4) Factor analysis and probabilistic PCA, where latent factors z explain correlations in x. (5) Semi-supervised learning, where some labels are missing and treated as latent variables. EM is particularly attractive when the complete-data likelihood is in the exponential family, yielding closed-form M-steps, and when the E-step posterior can be computed exactly or efficiently (sometimes approximately). When posteriors are intractable, consider Variational EM or Monte Carlo EM as extensions.
⚠️Common Mistakes
- Poor initialization: Randomly initializing means and very small variances in GMMs can cause singularities (variances collapsing to zero). Use k-means or k-means++ to initialize means, add variance floors, and spread initial components.
- Ignoring numerical stability: Directly computing responsibilities with tiny probabilities causes underflow. Use log-domain computations and log-sum-exp, normalize carefully, and clamp variances away from zero.
- Violating constraints: Mixture weights must be nonnegative and sum to one. Re-normalize after M-step, and ensure covariance matrices stay positive definite (add small diagonal jitter).
- Not monitoring convergence: Stopping after a fixed number of iterations may halt too soon or waste time. Track the observed-data log-likelihood and stop when its improvement falls below a tolerance.
- Overfitting with too many components: EM maximizes likelihood, which can grow without bound (e.g., GMM components collapsing on single points). Use cross-validation, BIC/AIC, priors (MAP-EM), or variance floors to regularize.
- Misinterpreting local optima: EM can converge to different solutions from different starts. Run multiple restarts and pick the best log-likelihood.
- Data scaling issues: For GMMs, features with different scales distort responsibilities. Standardize features or use full covariances appropriately.
- Treating soft assignments as hard too early: Premature hard-clustering can freeze poor solutions. Keep soft responsibilities for correct M-step updates.
Key Formulas
Q-function
Explanation: This is the expected complete-data log-likelihood under the current posterior of the latent variables. EM maximizes this function in the M-step.
ELBO
Explanation: The ELBO lower-bounds the observed-data log-likelihood. EM alternates between setting q to the true posterior (E-step) and maximizing with respect to θ (M-step).
Decomposition
Explanation: The log evidence equals the ELBO plus a nonnegative KL divergence. Maximizing the ELBO tightens the bound; it is tight when q equals the true posterior.
E-step update
Explanation: Given current parameters, set the variational distribution over latent variables to the true posterior. This makes the ELBO equal the Q-function plus a constant.
M-step update
Explanation: Update parameters by maximizing the expected complete-data log-likelihood. In many models, this yields closed-form updates.
Monotonic improvement
Explanation: Each EM iteration does not decrease the observed-data log-likelihood. This follows from the ELBO construction and the E/M updates.
GMM responsibilities (1D)
Explanation: In a 1D Gaussian Mixture Model, responsibilities are the posterior probabilities that component k generated . They are used as soft weights in the M-step.
GMM M-step (1D)
Explanation: Mixture weights are normalized responsibilities, means are responsibility-weighted averages, and variances are responsibility-weighted second moments.
Two-coin responsibilities
Explanation: For the two-coin mixture, responsibilities are the posterior probabilities a sequence came from coin A versus coin B.
Binomial likelihood
Explanation: Given coin A with bias and tosses with heads, the probability of the sequence is binomial. The combinatorial term cancels in responsibility ratios for fixed .
Two-coin M-step
Explanation: Mixing weight is the average responsibility for coin A. The bias is the responsibility-weighted fraction of heads; similar for coin B.
Mixture log-likelihood
Explanation: The observed-data log-likelihood for a mixture sums the log of weighted component likelihoods over data points. EM maximizes this indirectly via Q.
Log-sum-exp
Explanation: A numerically stable way to compute the log of a sum of exponentials, avoiding underflow/overflow.
Complexity Analysis
Code Examples
1 #include <iostream> 2 #include <vector> 3 #include <cmath> 4 #include <random> 5 #include <algorithm> 6 #include <numeric> 7 #include <limits> 8 9 // Numerically stable log-sum-exp for a vector of log-values 10 double logsumexp(const std::vector<double>& a) { 11 double m = *std::max_element(a.begin(), a.end()); 12 double sum = 0.0; 13 for (double v : a) sum += std::exp(v - m); 14 return m + std::log(sum); 15 } 16 17 // Log of 1D Gaussian density N(x | mu, var) 18 double log_gaussian_1d(double x, double mu, double var) { 19 static const double LOG_SQRT_2PI = 0.5 * std::log(2.0 * M_PI); 20 double diff = x - mu; 21 double log_det_term = 0.5 * std::log(var); 22 double quad = 0.5 * (diff * diff) / var; 23 return -(LOG_SQRT_2PI + log_det_term + quad); 24 } 25 26 struct GMM1D { 27 int K; // number of components 28 std::vector<double> pi; // mixing weights (size K), sum to 1 29 std::vector<double> mu; // means (size K) 30 std::vector<double> var; // variances (size K), positive 31 double var_floor = 1e-6; // variance floor to avoid singularities 32 33 GMM1D(int K_) : K(K_), pi(K_, 1.0 / K_), mu(K_, 0.0), var(K_, 1.0) {} 34 35 // Initialize parameters using simple heuristics 36 void initialize(const std::vector<double>& x, unsigned seed = 42) { 37 std::mt19937 rng(seed); 38 std::uniform_int_distribution<int> uid(0, (int)x.size() - 1); 39 // Pick K random points as initial means 40 for (int k = 0; k < K; ++k) mu[k] = x[uid(rng)]; 41 // Set variances to data variance 42 double mean = std::accumulate(x.begin(), x.end(), 0.0) / x.size(); 43 double v = 0.0; for (double xi : x) { double d = xi - mean; v += d * d; } 44 v = std::max(v / x.size(), var_floor); 45 std::fill(var.begin(), var.end(), v); 46 // Uniform mixing weights 47 std::fill(pi.begin(), pi.end(), 1.0 / K); 48 } 49 50 // Compute log-likelihood of observed data under current parameters 51 double log_likelihood(const std::vector<double>& x) const { 52 double ll = 0.0; 53 std::vector<double> logs(K); 54 for (double xi : x) { 55 for (int k = 0; k < K; ++k) logs[k] = std::log(pi[k]) + log_gaussian_1d(xi, mu[k], var[k]); 56 ll += logsumexp(logs); 57 } 58 return ll; 59 } 60 61 // One EM iteration: E-step (responsibilities) and M-step (update parameters) 62 void em_step(const std::vector<double>& x) { 63 const int N = (int)x.size(); 64 std::vector<std::vector<double>> log_r(N, std::vector<double>(K)); 65 std::vector<std::vector<double>> r(N, std::vector<double>(K)); 66 67 // E-step: compute responsibilities in log-domain for stability 68 for (int n = 0; n < N; ++n) { 69 std::vector<double> logs(K); 70 for (int k = 0; k < K; ++k) { 71 logs[k] = std::log(pi[k]) + log_gaussian_1d(x[n], mu[k], var[k]); 72 } 73 double lse = logsumexp(logs); 74 for (int k = 0; k < K; ++k) { 75 log_r[n][k] = logs[k] - lse; // log responsibility 76 r[n][k] = std::exp(log_r[n][k]); 77 } 78 } 79 80 // M-step: update pi, mu, var using responsibility-weighted sums 81 std::vector<double> Nk(K, 0.0); 82 for (int k = 0; k < K; ++k) { 83 for (int n = 0; n < N; ++n) Nk[k] += r[n][k]; 84 } 85 86 // Update mixing weights 87 for (int k = 0; k < K; ++k) pi[k] = std::max(Nk[k] / N, 1e-15); 88 // Re-normalize to sum to 1 (avoid rounding drift) 89 double pisum = std::accumulate(pi.begin(), pi.end(), 0.0); 90 for (int k = 0; k < K; ++k) pi[k] /= pisum; 91 92 // Update means 93 for (int k = 0; k < K; ++k) { 94 double num = 0.0; 95 for (int n = 0; n < N; ++n) num += r[n][k] * x[n]; 96 mu[k] = num / std::max(Nk[k], 1e-15); 97 } 98 99 // Update variances (with floor for stability) 100 for (int k = 0; k < K; ++k) { 101 double num = 0.0; 102 for (int n = 0; n < N; ++n) { 103 double diff = x[n] - mu[k]; 104 num += r[n][k] * diff * diff; 105 } 106 var[k] = std::max(num / std::max(Nk[k], 1e-15), var_floor); 107 } 108 } 109 }; 110 111 int main() { 112 // Generate simple 1D data from a 2-component mixture for demonstration 113 std::mt19937 rng(7); 114 std::normal_distribution<double> n1(-2.0, 0.7); // mean -2, std 0.7 115 std::normal_distribution<double> n2(3.0, 0.5); // mean 3, std 0.5 116 std::bernoulli_distribution choose(0.4); // mixing weight ~ 0.4 for component 1 117 118 std::vector<double> x; x.reserve(600); 119 for (int i = 0; i < 600; ++i) { 120 bool c = choose(rng); 121 x.push_back(c ? n1(rng) : n2(rng)); 122 } 123 124 int K = 2; 125 GMM1D model(K); 126 model.initialize(x, 123); 127 128 double prev_ll = -std::numeric_limits<double>::infinity(); 129 const int max_iter = 200; 130 const double tol = 1e-6; 131 132 for (int it = 0; it < max_iter; ++it) { 133 model.em_step(x); 134 double ll = model.log_likelihood(x); 135 std::cout << "Iter " << it << ": log-likelihood = " << ll << "\n"; 136 if (std::abs(ll - prev_ll) < tol) break; 137 prev_ll = ll; 138 } 139 140 std::cout << "Estimated parameters:\n"; 141 for (int k = 0; k < K; ++k) { 142 std::cout << " k=" << k 143 << ", pi=" << model.pi[k] 144 << ", mu=" << model.mu[k] 145 << ", var=" << model.var[k] << "\n"; 146 } 147 return 0; 148 } 149
This program implements EM for a 1D Gaussian Mixture Model. The E-step computes responsibilities in log-domain to avoid underflow via log-sum-exp. The M-step updates mixing weights, means, and variances using responsibility-weighted formulas with a small variance floor to prevent singularities. The loop tracks the observed-data log-likelihood and stops when improvements become negligible. The example generates synthetic data from two Gaussians and recovers parameters close to the ground truth.
1 #include <iostream> 2 #include <vector> 3 #include <cmath> 4 #include <limits> 5 #include <algorithm> 6 7 struct Trial { int heads; int tails; }; // one experiment: m=heads+tails tosses 8 9 // Compute log Binomial likelihood (up to a constant C(m,h) that cancels in responsibilities) 10 double log_binom_likelihood(int heads, int tails, double p) { 11 // We omit log C(m, h) since it's the same for both coins in a trial 12 return heads * std::log(p) + tails * std::log(1.0 - p); 13 } 14 15 int main() { 16 // Observed data: sequences of coin toss outcomes from an unknown mixture of two coins 17 std::vector<Trial> data = { 18 {5,5}, {9,1}, {8,2}, {4,6}, {7,3}, {2,8}, {6,4}, {1,9}, {5,5}, {3,7} 19 }; 20 21 // Parameters: mixing weight piA, coin biases pA and pB 22 double piA = 0.5; // initial mixing weight for coin A 23 double pA = 0.6; // initial bias for coin A (probability of heads) 24 double pB = 0.4; // initial bias for coin B 25 26 const int max_iter = 200; 27 const double tol = 1e-8; 28 double prev_ll = -std::numeric_limits<double>::infinity(); 29 30 for (int it = 0; it < max_iter; ++it) { 31 // E-step: responsibilities r_nA and r_nB for each trial 32 std::vector<double> rA(data.size()); 33 double ll = 0.0; // observed-data log-likelihood 34 for (size_t n = 0; n < data.size(); ++n) { 35 int h = data[n].heads, t = data[n].tails; 36 double a = std::log(piA) + log_binom_likelihood(h, t, pA); 37 double b = std::log(1.0 - piA) + log_binom_likelihood(h, t, pB); 38 // log-sum-exp for two terms 39 double m = std::max(a, b); 40 double lse = m + std::log(std::exp(a - m) + std::exp(b - m)); 41 ll += lse; 42 rA[n] = std::exp(a - lse); 43 } 44 45 // Check convergence of log-likelihood 46 if (std::abs(ll - prev_ll) < tol) { 47 std::cout << "Converged at iter " << it << ", log-likelihood = " << ll << "\n"; 48 break; 49 } 50 prev_ll = ll; 51 52 // M-step: update piA, pA, pB using responsibility-weighted counts 53 double sum_rA = 0.0, sum_rB = 0.0; 54 double numA_heads = 0.0, denA = 0.0; 55 double numB_heads = 0.0, denB = 0.0; 56 for (size_t n = 0; n < data.size(); ++n) { 57 int h = data[n].heads, t = data[n].tails; 58 int m = h + t; 59 double r = rA[n]; 60 double s = 1.0 - r; 61 sum_rA += r; sum_rB += s; 62 numA_heads += r * h; denA += r * m; 63 numB_heads += s * h; denB += s * m; 64 } 65 piA = std::min(std::max(sum_rA / data.size(), 1e-12), 1.0 - 1e-12); 66 pA = std::min(std::max(numA_heads / std::max(denA, 1e-12), 1e-12), 1.0 - 1e-12); 67 pB = std::min(std::max(numB_heads / std::max(denB, 1e-12), 1e-12), 1.0 - 1e-12); 68 69 std::cout << "Iter " << it 70 << ": ll=" << prev_ll 71 << ", piA=" << piA 72 << ", pA=" << pA 73 << ", pB=" << pB << "\n"; 74 } 75 76 std::cout << "Final estimates: piA=" << piA << ", pA=" << pA << ", pB=" << pB << "\n"; 77 return 0; 78 } 79
This example fits a two-coin mixture via EM. Each trial consists of m tosses with h heads, generated by either coin A or B. The E-step computes responsibilities r_nA using log-domain probabilities. The M-step updates the mixing weight and coin biases as responsibility-weighted averages. Small clamps keep parameters in valid ranges. The code reports log-likelihood and parameter estimates across iterations.