Wake-Sleep Algorithm
Key Points
- •The Wake–Sleep algorithm trains a pair of models: a generative model that explains how data are produced and a recognition model that guesses hidden causes from observed data.
- •Learning alternates in two phases: the wake phase updates the generative model using real data, and the sleep phase updates the recognition model using simulated (fantasy) data from the generative model.
- •The two phases optimize Kullback–Leibler (KL) divergences in opposite directions, which makes training stable and efficient for discrete latent variables.
- •Wake updates maximize the joint log-likelihood of observed data and sampled latent variables drawn from the recognition model.
- •Sleep updates make the recognition model approximate the true posterior by fitting it to latent–data pairs sampled from the generative model.
- •This algorithm is especially useful when exact inference is intractable but ancestral sampling from the generative model is easy.
- •A classic application is the Helmholtz machine with binary latent units and Bernoulli outputs, trained via simple stochastic gradient steps.
- •Although Wake–Sleep predates modern variational autoencoders, its ideas connect closely to the ELBO and variational inference.
Prerequisites
- →Basic probability and conditional distributions — Wake–Sleep relies on priors, likelihoods, and conditional probabilities such as p(x|h) and p(h|x).
- →Logistic regression and Bernoulli likelihoods — Both generator and recognizer use sigmoid activations with Bernoulli log-likelihood gradients.
- →KL divergence and variational inference — Understanding the two KL directions clarifies the wake and sleep objectives and their biases.
- →Stochastic gradient descent — Parameter updates in both phases are simple SGD steps on sampled objectives.
- →Directed graphical models — Ancestral sampling and factorization p(x,h)=p(x|h)p(h) are core to the algorithm.
Detailed Explanation
Tap terms for definitions01Overview
The Wake–Sleep algorithm is a two-phase learning procedure for training latent-variable generative models alongside an auxiliary recognition (inference) network. The generative model, parameterized by (\theta), specifies how hidden variables (causes) produce observable data, while the recognition model, parameterized by (\phi), maps data back to plausible hidden states. Because computing the true posterior (p_\theta(h\mid x)) is typically intractable, Wake–Sleep uses a practical alternative: it teaches the generative model using posterior samples from the recognition model (wake) and teaches the recognition model using synthetic data produced by the generative model (sleep). In the wake phase, the algorithm observes real data (x), samples latent variables (h) from the recognition model (q_\phi(h\mid x)), and updates (\theta) to make those sampled pairings more likely under (p_\theta(x,h)). In the sleep phase, the algorithm draws latent variables from the generative prior, generates synthetic observations, and then updates (\phi) so that (q_\phi(h\mid x)) better recovers the sampled latent variables. This alternation side-steps expensive exact inference while keeping both models synchronized. The method was originally proposed to train Helmholtz machines with discrete latent variables, but the core idea reappears in modern amortized variational inference. Wake–Sleep provides an intuitive, sampling-based route to joint learning of generative and inference mechanisms and remains a foundational concept in probabilistic machine learning.
02Intuition & Analogies
Imagine learning to lip-read in a dark theater with a friend. Your friend (the generative model) knows how words produce lip movements; you (the recognition model) try to guess the spoken words from what you see. You practice in two complementary ways. Wake phase: you watch real actors (true data) and make your best guess of the words. Your friend then adapts their understanding of how those guessed words would have produced the observed lips so that next time, those word–lip pairs feel more plausible. In other words, the explanation mechanism (generative model) is tuned to explain what your perception believes is happening. Sleep phase: later at home, your friend imagines random words and plays them out with their lips (fantasy data). Now you practice reading those lips and adjust your perception so that your guesses match the words your friend actually imagined. In other words, your perception learns to invert your friend’s imagination. Alternating these phases does two things: it ensures the explainer can account for real-world observations and it ensures the perceiver can decode the explainer’s internal causes. This back-and-forth is powerful when directly computing the exact mapping from lips to words is too hard, but simulating lips from words is easy. The same dynamic applies in machine learning: it’s often easy to sample from a generative model (ancestral sampling) but hard to compute exact posteriors. Wake–Sleep exploits this asymmetry to co-train a generator and an inference network using realistic and synthetic experiences.
03Formal Definition
04When to Use
Use the Wake–Sleep algorithm when you have a directed latent-variable model for which (1) sampling from the prior and likelihood is straightforward, (2) exact posterior inference is intractable, and (3) you can parametrize an efficient recognition model to approximate the posterior. Classic scenarios include Helmholtz machines with discrete latent units, hierarchical generative models with multiple stochastic layers, and situations where reparameterization tricks are unavailable (e.g., non-differentiable discrete latents). It is well-suited when you want amortized inference: a single recognition network that can quickly infer latent variables for many datapoints at test time. Wake–Sleep is also a practical pretraining or initialization strategy: it can produce a reasonable generator and recognizer that can later be refined by tighter objectives such as variational EM, the ELBO, or importance-weighted bounds. In small to medium-scale problems where coding simplicity and sampling efficiency matter more than achieving state-of-the-art likelihoods, Wake–Sleep offers a conceptually clean, computationally light approach with intuitive update rules. Choose it when ancestral sampling is cheap and you are comfortable with slightly looser learning signals than modern gradient-estimators provide for continuous latents.
⚠️Common Mistakes
Common pitfalls include: (1) Ignoring the KL direction. The wake update aligns (p_\theta(h\mid x)) toward (q_\phi(h\mid x)) (reverse KL), while sleep aligns (q_\phi(h\mid x)) toward the true posterior under (p_\theta) (forward KL). Swapping or mixing these without care can lead to mode-seeking or mode-covering biases that harm learning. (2) Poor calibration of the generative prior. If the prior is too sharp or too diffuse early on, sleep samples may be uninformative, producing a recognition model that learns the wrong regions of (x)-space. Use conservative learning rates and consider entropy-encouraging initializations. (3) Updating only one model. Failing to alternate sufficiently (e.g., many wake steps without sleep) breaks the mutual improvement loop; balance steps or tune ratios empirically. (4) Sampling noise with tiny minibatches. Since both phases rely on samples, gradients are noisy; use adequate batch sizes or average over multiple samples per datapoint when feasible. (5) Mis-specified likelihoods. For Bernoulli data, forget to clip probabilities away from 0 or 1 before taking logs, causing numerical issues. (6) Forgetting that classical Wake–Sleep does not optimize a single tight bound. Comparing likelihoods to ELBO-trained models may be unfair; to improve, switch to reweighted Wake–Sleep or variational fine-tuning. (7) Implementation bugs: mismatched shapes in matrix multiplications, not detaching samples correctly conceptually, or reusing RNG states inadvertently. Careful unit tests on synthetic data can catch these issues early.
Key Formulas
ELBO
Explanation: This is the evidence lower bound. Maximizing it encourages the generative model to explain the data while keeping the recognition posterior close to the true posterior. It lower-bounds the average log-likelihood of the data.
Wake gradient
Explanation: In the wake phase, we update the generative parameters by sampling latent variables from the recognition model and taking a stochastic gradient step on the joint log-probability. This avoids computing the intractable true posterior.
Sleep objective
Explanation: In the sleep phase, we sample from the generative model and fit the recognition model to those samples. This minimizes the forward KL from the true posterior under the current generator to the recognition model.
Bernoulli likelihood
Explanation: For binary data, the conditional log-likelihood decomposes into a sum of Bernoulli log-probabilities. Its gradients yield simple residual forms used in wake updates.
Bernoulli gradient (generator)
Explanation: The gradient of the Bernoulli log-likelihood with respect to the generator's weights and biases is the prediction error times inputs. This leads to efficient outer-product updates.
Bernoulli gradient (recognizer)
Explanation: Similarly, the recognition model’s gradient is an error term between sampled latents and predicted posterior probabilities times the observed input.
KL divergence
Explanation: KL measures dissimilarity between two distributions and is nonnegative. Its asymmetry explains why wake and sleep emphasize different approximation behaviors.
Per-epoch time complexity
Explanation: With N datapoints, D visible units, and H latent units, each phase requires matrix–vector products and outer products scaling with D×H per example. Thus total per-epoch runtime is linear in N and bilinear in D and H.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 struct RNG { 5 mt19937_64 gen; 6 uniform_real_distribution<double> unif; 7 RNG(uint64_t seed=1234567ULL): gen(seed), unif(0.0,1.0) {} 8 double randu() { return unif(gen); } 9 bool bernoulli(double p) { 10 p = min(1.0, max(0.0, p)); 11 return randu() < p; 12 } 13 }; 14 15 static inline double sigmoid(double z) { 16 if (z >= 0) { 17 double ez = exp(-z); 18 return 1.0 / (1.0 + ez); 19 } else { 20 double ez = exp(z); 21 return ez / (1.0 + ez); 22 } 23 } 24 25 struct Helmholtz { 26 int D, H; 27 // Generative: x|h ~ Bernoulli(sigmoid(Wg * h + bg)), h ~ Bernoulli(sigmoid(aprior)) 28 vector<vector<double>> Wg; // D x H 29 vector<double> bg; // D 30 vector<double> aprior; // H (logits) 31 // Recognition: h|x ~ Bernoulli(sigmoid(Wr * x + br)) 32 vector<vector<double>> Wr; // H x D 33 vector<double> br; // H 34 RNG rng; 35 36 Helmholtz(int D_, int H_, uint64_t seed=42): D(D_), H(H_), Wg(D_, vector<double>(H_)), bg(D_), aprior(H_), Wr(H_, vector<double>(D_)), br(H_), rng(seed) { 37 // Xavier/Glorot-like small init for stability 38 double scaleWg = sqrt(2.0/(D+H)); 39 double scaleWr = sqrt(2.0/(D+H)); 40 normal_distribution<double> nd(0.0, 0.1); 41 for (int i=0;i<D;i++) for (int j=0;j<H;j++) Wg[i][j] = nd(rng.gen) * scaleWg; 42 for (int i=0;i<D;i++) bg[i] = 0.0; 43 for (int j=0;j<H;j++) aprior[j] = 0.0; // prior probs ~ 0.5 44 for (int j=0;j<H;j++) for (int i=0;i<D;i++) Wr[j][i] = nd(rng.gen) * scaleWr; 45 for (int j=0;j<H;j++) br[j] = 0.0; 46 } 47 48 vector<double> gen_probs_from_h(const vector<int>& hbin) const { 49 vector<double> z(D,0.0); 50 for (int i=0;i<D;i++) { 51 double s = bg[i]; 52 for (int j=0;j<H;j++) s += Wg[i][j] * (double)hbin[j]; 53 z[i] = sigmoid(s); 54 } 55 return z; 56 } 57 58 vector<double> rec_probs_from_x(const vector<int>& xbin) const { 59 vector<double> z(H,0.0); 60 for (int j=0;j<H;j++) { 61 double s = br[j]; 62 for (int i=0;i<D;i++) s += Wr[j][i] * (double)xbin[i]; 63 z[j] = sigmoid(s); 64 } 65 return z; 66 } 67 68 vector<int> sample_prior_h() { 69 vector<int> h(H,0); 70 for (int j=0;j<H;j++) h[j] = rng.bernoulli(sigmoid(aprior[j])) ? 1 : 0; 71 return h; 72 } 73 74 vector<int> sample_from_probs(const vector<double>& probs) { 75 vector<int> out(probs.size()); 76 for (size_t i=0;i<probs.size();++i) out[i] = rng.bernoulli(probs[i]) ? 1 : 0; 77 return out; 78 } 79 80 // Proxy reconstruction probs: use expected h = q(h=1|x) for speed 81 vector<double> recon_probs_expectation(const vector<int>& xbin) const { 82 // h_bar = q(h=1|x) as probabilities 83 vector<double> hbar = rec_probs_from_x(xbin); 84 vector<double> z(D,0.0); 85 for (int i=0;i<D;i++) { 86 double s = bg[i]; 87 for (int j=0;j<H;j++) s += Wg[i][j] * hbar[j]; 88 z[i] = sigmoid(s); 89 } 90 return z; 91 } 92 93 // One epoch of Wake–Sleep over dataset X (binary vectors) 94 void train_epoch(const vector<vector<int>>& X, double lr_gen, double lr_rec, int batch_size) { 95 int N = (int)X.size(); 96 vector<int> idx(N); iota(idx.begin(), idx.end(), 0); 97 shuffle(idx.begin(), idx.end(), rng.gen); 98 for (int start=0; start<N; start+=batch_size) { 99 int end = min(N, start + batch_size); 100 int B = end - start; 101 // Accumulate gradients 102 vector<vector<double>> dWg(D, vector<double>(H,0.0)); 103 vector<double> dbg(D, 0.0); 104 vector<double> da(H, 0.0); 105 vector<vector<double>> dWr(H, vector<double>(D,0.0)); 106 vector<double> dbr(H, 0.0); 107 108 // Wake phase: update generator using real data and h ~ q(h|x) 109 for (int t=start; t<end; ++t) { 110 const vector<int>& x = X[idx[t]]; 111 // Sample h from recognition 112 vector<double> qh = rec_probs_from_x(x); 113 vector<int> h(H,0); 114 for (int j=0;j<H;j++) h[j] = rng.bernoulli(qh[j]) ? 1 : 0; 115 // Compute generator probs and residuals 116 vector<double> px = gen_probs_from_h(h); 117 for (int i=0;i<D;i++) { 118 double err = (double)x[i] - px[i]; // gradient for Bernoulli logits 119 dbg[i] += err; 120 for (int j=0;j<H;j++) dWg[i][j] += err * (double)h[j]; 121 } 122 // Prior logits gradient: h - sigma(a) 123 for (int j=0;j<H;j++) da[j] += (double)h[j] - sigmoid(aprior[j]); 124 } 125 // Apply generator updates 126 double scale_g = lr_gen / (double)B; 127 for (int i=0;i<D;i++) { 128 bg[i] += scale_g * dbg[i]; 129 for (int j=0;j<H;j++) Wg[i][j] += scale_g * dWg[i][j]; 130 } 131 for (int j=0;j<H;j++) aprior[j] += scale_g * da[j]; 132 133 // Sleep phase: update recognition using fantasy pairs (h,x) ~ p(h)p(x|h) 134 for (int t=0; t<B; ++t) { 135 vector<int> h = sample_prior_h(); 136 vector<double> px = gen_probs_from_h(h); 137 vector<int> x = sample_from_probs(px); 138 vector<double> qh = rec_probs_from_x(x); 139 for (int j=0;j<H;j++) { 140 double err = (double)h[j] - qh[j]; 141 dbr[j] += err; 142 for (int i=0;i<D;i++) dWr[j][i] += err * (double)x[i]; 143 } 144 } 145 double scale_r = lr_rec / (double)B; 146 for (int j=0;j<H;j++) { 147 br[j] += scale_r * dbr[j]; 148 for (int i=0;i<D;i++) Wr[j][i] += scale_r * dWr[j][i]; 149 } 150 } 151 } 152 153 // Evaluate average reconstruction cross-entropy using expectation hbar = q(h=1|x) 154 double avg_recon_ce(const vector<vector<int>>& X) const { 155 double ce = 0.0; int N = (int)X.size(); 156 const double eps = 1e-7; 157 for (const auto& x : X) { 158 vector<double> y = recon_probs_expectation(x); 159 for (int i=0;i<D;i++) { 160 double p = min(1.0 - eps, max(eps, y[i])); 161 ce += - (x[i] * log(p) + (1 - x[i]) * log(1 - p)); 162 } 163 } 164 return ce / (double)N; 165 } 166 }; 167 168 // Create a synthetic dataset from a hidden ground-truth generator 169 vector<vector<int>> make_synthetic(int N, int D, int H, uint64_t seed=2024) { 170 RNG rng(seed); 171 // Ground-truth parameters 172 vector<vector<double>> Wg(D, vector<double>(H)); 173 vector<double> bg(D, 0.0); 174 vector<double> a(H, 0.0); 175 normal_distribution<double> nd(0.0, 1.0); 176 for (int i=0;i<D;i++) for (int j=0;j<H;j++) Wg[i][j] = nd(rng.gen) * ( (j%2==0) ? 1.0 : -1.0 ) * 0.8 / sqrt((double)H); 177 for (int j=0;j<H;j++) a[j] = (j%2==0 ? 0.5 : -0.5); // prior probs ~ 0.62 or 0.38 178 179 auto sigmoid_local = [](double z){ return 1.0/(1.0+exp(-z)); }; 180 181 vector<vector<int>> X; X.reserve(N); 182 for (int n=0;n<N;n++) { 183 // sample h 184 vector<int> h(H,0); 185 for (int j=0;j<H;j++) h[j] = (rng.randu() < sigmoid_local(a[j])) ? 1 : 0; 186 // sample x 187 vector<double> p(D,0.0); 188 for (int i=0;i<D;i++){ 189 double s = bg[i]; 190 for (int j=0;j<H;j++) s += Wg[i][j] * (double)h[j]; 191 p[i] = sigmoid_local(s); 192 } 193 vector<int> x(D,0); 194 for (int i=0;i<D;i++) x[i] = rng.randu() < p[i] ? 1 : 0; 195 X.push_back(move(x)); 196 } 197 return X; 198 } 199 200 int main(){ 201 ios::sync_with_stdio(false); 202 cin.tie(nullptr); 203 204 int D = 16; // visible dimension 205 int H = 4; // latent dimension 206 int N = 2000; // dataset size 207 int epochs = 50; 208 int batch = 64; 209 double lr_gen = 0.2; 210 double lr_rec = 0.2; 211 212 auto X = make_synthetic(N, D, H); 213 Helmholtz model(D, H, 1337); 214 215 for (int e=1; e<=epochs; ++e) { 216 model.train_epoch(X, lr_gen, lr_rec, batch); 217 if (e % 5 == 0) { 218 double ce = model.avg_recon_ce(X); 219 cout << "Epoch " << e << ": recon-CE = " << ce << "\n"; 220 } 221 } 222 223 // Show a few fantasy samples after training 224 cout << "\nFantasy samples (first 5):\n"; 225 for (int k=0;k<5;k++) { 226 vector<int> h = model.sample_prior_h(); 227 vector<double> px = model.gen_probs_from_h(h); 228 vector<int> x = model.sample_from_probs(px); 229 cout << "x: "; 230 for (int i=0;i<D;i++) cout << x[i]; 231 cout << "\n"; 232 } 233 return 0; 234 } 235
This program implements a single-layer Helmholtz machine with Bernoulli latent and visible units and trains it using the classical Wake–Sleep algorithm. The wake phase samples h from the recognition model given real data and updates the generator via Bernoulli gradients. The sleep phase samples (h, x) from the generator and updates the recognition model to better invert the generator. We train on a synthetic dataset produced by a hidden ground-truth generator and periodically report reconstruction cross-entropy using the recognition network’s mean latents.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 static inline double sigmoid(double z){ if(z>=0){ double e=exp(-z); return 1.0/(1.0+e);} else { double e=exp(z); return e/(1.0+e);} } 5 6 struct SimpleHM { 7 int D,H; vector<vector<double>> Wg; vector<double> bg; vector<double> aprior; vector<vector<double>> Wr; vector<double> br; 8 SimpleHM(int D_, int H_): D(D_), H(H_), Wg(D_, vector<double>(H_,0.0)), bg(D_,0.0), aprior(H_,0.0), Wr(H_, vector<double>(D_,0.0)), br(H_,0.0) {} 9 vector<double> rec_probs(const vector<int>& x) const { vector<double> q(H,0.0); for(int j=0;j<H;j++){ double s=br[j]; for(int i=0;i<D;i++) s+=Wr[j][i]*x[i]; q[j]=sigmoid(s);} return q; } 10 vector<double> gen_probs(const vector<double>& hbar) const { vector<double> p(D,0.0); for(int i=0;i<D;i++){ double s=bg[i]; for(int j=0;j<H;j++) s+=Wg[i][j]*hbar[j]; p[i]=sigmoid(s);} return p; } 11 }; 12 13 int main(){ 14 int D=8, H=3; 15 SimpleHM m(D,H); 16 // Pretend these were learned; here we set small arbitrary values for demo 17 m.Wg = {{1.0,-0.5,0.2},{-0.3,0.8,-0.1},{0.6,0.1,0.0},{-0.7,0.2,0.3},{0.4,-0.9,0.5},{0.1,0.2,-0.4},{-0.2,0.3,0.7},{0.5,0.0,-0.6}}; 18 m.bg = {0.0,0.1,-0.2,0.0,0.2,0.0,-0.1,0.0}; 19 m.aprior = {0.2,-0.2,0.0}; 20 m.Wr = {{0.5,-0.3,0.7,-0.1,0.2,0.0,-0.4,0.1},{-0.2,0.6,-0.1,0.3,-0.5,0.1,0.2,-0.3},{0.1,0.2,0.0,-0.4,0.3,-0.2,0.5,0.1}}; 21 m.br = {0.0,0.1,-0.1}; 22 23 // Inference: estimate posterior means hbar = q(h=1|x) 24 vector<int> x = {1,0,1,0,1,0,0,1}; 25 vector<double> hbar = m.rec_probs(x); 26 cout << "Posterior means hbar: "; for(double v: hbar) cout << fixed << setprecision(3) << v << " "; cout << "\n"; 27 28 // Reconstruction: use expected h to predict x probabilities 29 vector<double> xhat = m.gen_probs(hbar); 30 cout << "Reconstruction probs xhat: "; for(double v: xhat) cout << fixed << setprecision(3) << v << " "; cout << "\n"; 31 32 // Prior probabilities for latents 33 cout << "Prior probs for h: "; for(double a: m.aprior) cout << fixed << setprecision(3) << 1.0/(1.0+exp(-a)) << " "; cout << "\n"; 34 35 return 0; 36 } 37
This example shows how, after training, you can use the recognition model to infer posterior mean probabilities for the latent variables and then reconstruct the input via the generator using those means. It also prints the prior probabilities implied by the prior logits. In a real workflow you would load the learned parameters from a checkpoint instead of hardcoding them.