Disentangled Representations
Key Points
- •Disentangled representations aim to encode independent factors of variation (like shape, size, or color) into separate coordinates of a latent vector.
- •They make models more interpretable, controllable, and modular by turning complex variations into simple, independent knobs.
- •In theory, disentanglement is tied to independence, mutual information, and total correlation, and often assumes a latent generative process.
- •Without inductive biases, supervision, or interventions, purely unsupervised disentanglement is generally unidentifiable.
- •Metrics such as total correlation, mutual information gap, and modularity explicitly measure how separated the factors are.
- •Popular approaches include \(\)-VAE (strengthening the KL term), InfoGAN (maximizing mutual information with codes), and ICA for linear mixtures.
- •You can estimate disentanglement quality in C++ by computing covariance, Gaussian total correlation, and simple ICA recovery.
- •Disentanglement helps in controllable generation, robotics and planning, fairness, domain adaptation, and scientific discovery.
Prerequisites
- →Probability and Random Variables — Understanding independence, distributions, and expectations is essential for defining and measuring disentanglement.
- →Information Theory Basics — Mutual information, KL divergence, entropy, and total correlation underpin the objectives and metrics.
- →Linear Algebra — Matrices, eigen-decomposition, and whitening are needed for ICA and covariance-based metrics.
- →Optimization — Training objectives like \(\beta\)-VAE require gradient-based optimization concepts.
- →Gaussian Distributions — Closed-form KL and Gaussian TC approximations rely on properties of multivariate normals.
- →Matrix Factorization (PCA/Whitening) — Preprocessing for ICA and understanding variance directions requires PCA/whitening.
- →Causal Thinking (optional but helpful) — Disentanglement often aligns with independent causal mechanisms and interventions.
Detailed Explanation
Tap terms for definitions01Overview
Hook: Imagine a music studio with many sliders—one for volume, one for bass, one for treble. You can change each sound quality independently because each slider controls exactly one factor. Concept: Disentangled representations bring that same clarity to machine learning. Instead of mixing all variations of an object (like orientation, color, size) into a single opaque vector, a disentangled representation tries to place each independent factor of variation into its own coordinate or small subset of coordinates. This makes the representation easy to interpret and manipulate. Example: In a generative model of faces, changing one latent coordinate might only affect rotation, another might only affect lighting, and another might only affect smile intensity. With this setup, editing a face becomes as easy as moving a single slider. In practice, disentanglement is grounded in probability and information theory: we want latent variables to be statistically independent and to align with meaningful, human-interpretable factors. Many approaches regularize models (e.g., via (\beta)-VAE) so that the learned latent space prefers factorized distributions, pushing the model to isolate independent causes. While useful and intuitive, it is also subtle: without additional structure (inductive biases, weak labels, or interventions), perfect disentanglement is often impossible to guarantee. Nevertheless, aiming for partial disentanglement can yield large gains in interpretability, robustness, and controllability across real-world tasks.
02Intuition & Analogies
Hook: Think of making a smoothie with strawberries, bananas, and yogurt. Once blended, it is hard to taste each ingredient separately or remove just the strawberries. Concept: A raw data vector is like that smoothie—many causes blended together. Disentangled representations try to unblend the smoothie so that each ingredient sits in its own bottle. If you want more strawberry flavor, you just pour more from the strawberry bottle. Analogy 1: Home lighting. Multiple light sources (sunlight, lamp, monitor) affect brightness in a room. If your representation entangles them, increasing one might unpredictably change others. A disentangled representation would keep controls independent: you could turn up the lamp without affecting sunlight or the monitor’s glow. Analogy 2: Camera settings. ISO, shutter speed, and aperture each independently change some aspect of the photo. If your camera UI mixed these controls into a single dial, you’d get frustrating, unpredictable results. A disentangled UI separates them into independent dials. Example: In image generation, suppose we want to vary an object’s rotation without changing its color or shape. In a disentangled latent space, rotation has its own dimension. You can rotate objects by moving just that coordinate, leaving everything else fixed. This property also helps downstream tasks: a classifier can focus on a single relevant coordinate rather than decoding a complicated mixture. In robotics, a few disentangled coordinates might correspond to joint angles, while other coordinates can capture lighting or background, making planning more stable and interpretable. The big idea: when independent causes become independent coordinates, reasoning, editing, and control all get simpler.
03Formal Definition
04When to Use
Hook: If you need reliable knobs to control a system, disentanglement is probably your friend. Concept: Use disentangled representations when interpretability, control, or transfer across environments matters more than squeezing out the last percent of raw predictive accuracy. Example scenarios: (1) Controllable generation: in graphics or content creation, separate codes for shape, pose, and texture let artists or systems edit content predictably. (2) Robotics and planning: separating task-relevant state (e.g., joint angles) from nuisance variables (e.g., lighting) stabilizes policies and model-based control. (3) Fairness: separating protected attributes (e.g., gender) from job-relevant skills can support fairness-aware decision-making and auditing. (4) Scientific discovery: when latent variables correspond to interpretable factors (e.g., physical parameters), scientists can test hypotheses more directly. (5) Domain adaptation and robustness: decoupling invariant content from style or domain shifts leads to models that generalize across new environments. When not to use: if your task is purely predictive with abundant labeled data and interpretability/control are not needed, the extra regularization required for disentanglement can hurt accuracy. Also, if you cannot inject any inductive bias or supervision and your data lacks the structure that makes factors separable, strong disentanglement may be unattainable.
⚠️Common Mistakes
Hook: Not every neat-looking latent axis is a true "knob." Concept: Disentanglement is about independent causal factors, not just sparse or axis-aligned coordinates. Mistake 1: Confusing sparsity with disentanglement. A sparse code can still entangle factors if multiple coordinates co-vary to represent a single change. Fix: Measure independence (e.g., total correlation) and alignment to known factors, not just sparsity. Mistake 2: Believing unsupervised training always finds semantic factors. Theoretical results show this is generally unidentifiable without biases or signals. Fix: Add inductive biases (architecture, prior), weak labels, or interventions (e.g., time, actions). Mistake 3: Overfitting to metrics. A model might game a metric like MIG without genuinely aligning with semantics. Fix: Use multiple metrics, qualitative traversals, and task-based evaluations. Mistake 4: Ignoring permutation and sign ambiguities. Even in ideal settings (e.g., ICA), recovered components are identifiable only up to permutation and scaling/sign. Fix: Accept these equivalences and align components post hoc. Mistake 5: Equating low correlation with independence. Zero correlation does not imply independence (except for Gaussians). Fix: Prefer information-theoretic measures or higher-order statistics. Mistake 6: Forcing too much regularization (e.g., very large (\beta) in (\beta)-VAE), which can collapse useful information and hurt reconstructions. Fix: Tune regularization carefully and monitor reconstruction quality.
Key Formulas
Generative Process
Explanation: Observed data x is produced by a function g of independent latent factors s and possibly noise epsilon. This assumption underpins the idea that a representation can recover the factors.
Factorized Latent Aggregation
Explanation: The aggregated posterior over the learned code z should factorize across dimensions if the latents are independent. Independence implies the joint equals the product of marginals.
Total Correlation
Explanation: Total correlation measures multivariate dependence. It is zero exactly when the coordinates of z are mutually independent.
Mutual Information
Explanation: Mutual information quantifies how much knowing X reduces uncertainty about Y. High MI between a latent and a single factor indicates alignment.
\u03b2-VAE Objective
Explanation: This loss balances reconstruction quality with a penalty that pushes latents toward a factorized prior p(z). Larger beta encourages more independence but can hurt reconstructions.
Gaussian KL to Standard Normal
Explanation: For Gaussian posteriors, the KL term has a closed form. With diagonal \(\), it simplifies to a sum over dimensions.
Gaussian Total Correlation
Explanation: Under a multivariate normal approximation, total correlation equals half the log of the ratio between the product of marginal variances and the joint covariance determinant.
Mutual Information Gap
Explanation: MIG averages, over factors, the gap between the top two mutual informations normalized by the factor entropy. Larger values indicate cleaner one-to-one alignment.
ICA Model
Explanation: In linear ICA, observations are mixtures of independent sources. Disentanglement corresponds to recovering an unmixing matrix W that yields independent components.
Equivariance Relation
Explanation: If data transforms via a group action g, an equivariant encoder f reflects that structure in the latent space via a representation \(\). This can support structured disentanglement.
Reconstruction Loss (MSE)
Explanation: A common reconstruction term in generative models. Lower MSE means reconstructed data is closer to the input.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Generate standard normal using Box-Muller 5 static double randn() { 6 static bool hasSpare = false; 7 static double spare; 8 if (hasSpare) { hasSpare = false; return spare; } 9 hasSpare = true; 10 double u, v, s; 11 do { 12 u = 2.0 * rand() / (double)RAND_MAX - 1.0; 13 v = 2.0 * rand() / (double)RAND_MAX - 1.0; 14 s = u*u + v*v; 15 } while (s >= 1.0 || s == 0.0); 16 s = sqrt(-2.0 * log(s) / s); 17 spare = v * s; 18 return u * s; 19 } 20 21 // Cholesky decomposition for SPD matrix; returns lower-triangular L s.t. A = L L^T 22 bool cholesky(const vector<vector<double>>& A, vector<vector<double>>& L) { 23 int d = (int)A.size(); 24 L.assign(d, vector<double>(d, 0.0)); 25 for (int i = 0; i < d; ++i) { 26 for (int j = 0; j <= i; ++j) { 27 double sum = 0.0; 28 for (int k = 0; k < j; ++k) sum += L[i][k] * L[j][k]; 29 if (i == j) { 30 double val = A[i][i] - sum; 31 if (val <= 0.0) return false; 32 L[i][j] = sqrt(val); 33 } else { 34 L[i][j] = (A[i][j] - sum) / L[j][j]; 35 } 36 } 37 } 38 return true; 39 } 40 41 // Determinant via Gaussian elimination with partial pivoting; returns log|det| and sign 42 pair<double,int> logdet(vector<vector<double>> A) { 43 int n = (int)A.size(); 44 double logabs = 0.0; int sign = 1; 45 for (int i = 0; i < n; ++i) { 46 int piv = i; 47 for (int r = i+1; r < n; ++r) if (fabs(A[r][i]) > fabs(A[piv][i])) piv = r; 48 if (fabs(A[piv][i]) < 1e-15) return { -INFINITY, 0 }; // singular 49 if (piv != i) { swap(A[piv], A[i]); sign *= -1; } 50 double pivot = A[i][i]; 51 logabs += log(fabs(pivot)); 52 for (int r = i+1; r < n; ++r) { 53 double factor = A[r][i] / pivot; 54 for (int c = i; c < n; ++c) A[r][c] -= factor * A[i][c]; 55 } 56 } 57 return { logabs, sign }; 58 } 59 60 // Compute mean of rows (N samples, d dims) 61 vector<double> mean_rows(const vector<vector<double>>& X) { 62 int N = (int)X.size(); 63 int d = (int)X[0].size(); 64 vector<double> m(d, 0.0); 65 for (int i = 0; i < N; ++i) 66 for (int j = 0; j < d; ++j) m[j] += X[i][j]; 67 for (int j = 0; j < d; ++j) m[j] /= max(1, N); 68 return m; 69 } 70 71 // Covariance (unbiased) from samples-as-rows 72 vector<vector<double>> covariance(const vector<vector<double>>& X) { 73 int N = (int)X.size(); 74 int d = (int)X[0].size(); 75 vector<double> m = mean_rows(X); 76 vector<vector<double>> C(d, vector<double>(d, 0.0)); 77 for (int i = 0; i < N; ++i) { 78 for (int a = 0; a < d; ++a) { 79 double va = X[i][a] - m[a]; 80 for (int b = 0; b < d; ++b) { 81 C[a][b] += va * (X[i][b] - m[b]); 82 } 83 } 84 } 85 double denom = (N > 1) ? (N - 1.0) : 1.0; 86 for (int a = 0; a < d; ++a) 87 for (int b = 0; b < d; ++b) 88 C[a][b] /= denom; 89 return C; 90 } 91 92 // Estimate Gaussian TC: 0.5 * (sum log variances - log det Sigma) 93 double gaussian_total_correlation(const vector<vector<double>>& Sigma) { 94 int d = (int)Sigma.size(); 95 double sum_log_var = 0.0; 96 for (int j = 0; j < d; ++j) { 97 double v = Sigma[j][j]; 98 if (v <= 0.0) return NAN; 99 sum_log_var += log(v); 100 } 101 auto [ldet, sign] = logdet(Sigma); 102 if (sign <= 0) return NAN; // not SPD or numerical issue 103 return 0.5 * (sum_log_var - ldet); 104 } 105 106 int main() { 107 srand(7); 108 // Define a 3x3 covariance with correlations (positive off-diagonals) 109 vector<vector<double>> Sigma = { 110 {1.0, 0.6, 0.3}, 111 {0.6, 1.5, 0.2}, 112 {0.3, 0.2, 2.0} 113 }; 114 115 // Cholesky to generate correlated Gaussians: x = L * n, n ~ N(0, I) 116 vector<vector<double>> L; 117 if (!cholesky(Sigma, L)) { 118 cerr << "Covariance not SPD.\n"; 119 return 0; 120 } 121 122 int N = 10000, d = 3; 123 vector<vector<double>> X(N, vector<double>(d, 0.0)); 124 for (int i = 0; i < N; ++i) { 125 vector<double> n(d); 126 for (int j = 0; j < d; ++j) n[j] = randn(); 127 // x = L * n 128 for (int r = 0; r < d; ++r) { 129 double val = 0.0; 130 for (int k = 0; k <= r; ++k) val += L[r][k] * n[k]; 131 X[i][r] = val; 132 } 133 } 134 135 auto S_hat = covariance(X); 136 double tc = gaussian_total_correlation(S_hat); 137 cout.setf(ios::fixed); cout << setprecision(6); 138 cout << "Estimated Gaussian TC: " << tc << "\n"; 139 140 // For comparison, TC should be ~0 if Sigma is diagonal; try diagonal case 141 vector<vector<double>> Sigma_diag = {{1.0,0.0,0.0},{0.0,1.5,0.0},{0.0,0.0,2.0}}; 142 vector<vector<double>> Ld; cholesky(Sigma_diag, Ld); 143 for (int i = 0; i < N; ++i) { 144 vector<double> n(d); 145 for (int j = 0; j < d; ++j) n[j] = randn(); 146 for (int r = 0; r < d; ++r) { 147 double val = 0.0; 148 for (int k = 0; k <= r; ++k) val += Ld[r][k] * n[k]; 149 X[i][r] = val; 150 } 151 } 152 auto S_hat_diag = covariance(X); 153 double tc_diag = gaussian_total_correlation(S_hat_diag); 154 cout << "TC for diagonal covariance (should be near 0): " << tc_diag << "\n"; 155 return 0; 156 } 157
This program estimates total correlation (TC) under a Gaussian approximation using sampled data. It generates correlated Gaussian samples via Cholesky decomposition, computes the empirical covariance, and then evaluates TC = 0.5 * (sum log marginal variances − log det covariance). A diagonal covariance yields TC ≈ 0, indicating independence; off-diagonal correlations produce positive TC, indicating dependence among latent dimensions.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 struct Mat2 { 5 double a11, a12, a21, a22; // row-major 6 }; 7 8 Mat2 matmul(const Mat2& A, const Mat2& B) { 9 return { A.a11*B.a11 + A.a12*B.a21, A.a11*B.a12 + A.a12*B.a22, 10 A.a21*B.a11 + A.a22*B.a21, A.a21*B.a12 + A.a22*B.a22 }; 11 } 12 13 Mat2 transpose(const Mat2& A) { return {A.a11, A.a21, A.a12, A.a22}; } 14 15 Mat2 inv2(const Mat2& A) { 16 double det = A.a11*A.a22 - A.a12*A.a21; 17 return { A.a22/det, -A.a12/det, -A.a21/det, A.a11/det }; 18 } 19 20 // Eigen-decomposition of 2x2 symmetric matrix [[a,b],[b,d]] 21 void eigen2x2(double a, double b, double d, double& l1, double& l2, Mat2& V) { 22 double tr = a + d; 23 double det = a*d - b*b; 24 double disc = sqrt(max(0.0, tr*tr/4.0 - det)); 25 l1 = tr/2.0 + disc; l2 = tr/2.0 - disc; 26 // eigenvector for l1: (b, l1 - a) or (l1 - d, b) 27 double v1x = b, v1y = l1 - a; 28 if (fabs(v1x) + fabs(v1y) < 1e-12) { v1x = l1 - d; v1y = b; } 29 double n1 = sqrt(v1x*v1x + v1y*v1y); v1x /= n1; v1y /= n1; 30 // eigenvector for l2: orthogonal to v1 31 double v2x = -v1y, v2y = v1x; 32 V = {v1x, v2x, v1y, v2y}; // columns are eigenvectors 33 } 34 35 // Center data (N x 2) 36 void center(vector<array<double,2>>& X) { 37 int N = (int)X.size(); 38 double mx=0.0, my=0.0; for (auto& x : X){ mx += x[0]; my += x[1]; } 39 mx/=N; my/=N; for (auto& x : X){ x[0]-=mx; x[1]-=my; } 40 } 41 42 // Whiten data using ZCA: Xw = X * V * D^{-1/2} * V^T (but we will return X * Ww where Ww = V * D^{-1/2} * V^T) 43 Mat2 whiten_matrix(const vector<array<double,2>>& X) { 44 int N = (int)X.size(); 45 // Covariance 46 double c11=0, c12=0, c22=0; 47 for (auto& x : X) { c11 += x[0]*x[0]; c12 += x[0]*x[1]; c22 += x[1]*x[1]; } 48 c11 /= (N-1); c12 /= (N-1); c22 /= (N-1); 49 // Eigen-decomposition 50 double l1,l2; Mat2 V; 51 eigen2x2(c11, c12, c22, l1, l2, V); 52 // D^{-1/2} 53 double s1 = 1.0/sqrt(max(l1,1e-12)), s2 = 1.0/sqrt(max(l2,1e-12)); 54 Mat2 Dm = {s1,0,0,s2}; 55 // Ww = V * D^{-1/2} * V^T 56 Mat2 Ww = matmul(matmul(V, Dm), transpose(V)); 57 return Ww; 58 } 59 60 // Nonlinearity and its expectation updates (FastICA with tanh) 61 array<double,2> fastica_one(const vector<array<double,2>>& Xw, array<double,2> w0, int iters=200) { 62 int N = (int)Xw.size(); 63 auto dot = [](const array<double,2>& a, const array<double,2>& b){ return a[0]*b[0]+a[1]*b[1]; }; 64 auto norm = [&](array<double,2> w){ double n=sqrt(w[0]*w[0]+w[1]*w[1]); w[0]/=n; w[1]/=n; return w; }; 65 array<double,2> w = norm(w0); 66 for (int it=0; it<iters; ++it){ 67 double g1=0, g2=0, gp=0; // accumulate E[x*g(w^T x)] and E[g'(w^T x)] 68 for (auto& x : Xw){ 69 double u = dot(w, x); 70 double gu = tanh(u); 71 g1 += x[0]*gu; g2 += x[1]*gu; gp += (1 - gu*gu); 72 } 73 g1/=N; g2/=N; gp/=N; 74 // w_new = E[x g(w^T x)] - E[g'(w^T x)] w 75 array<double,2> wn = { g1 - gp*w[0], g2 - gp*w[1] }; 76 w = norm(wn); 77 } 78 return w; 79 } 80 81 int main(){ 82 srand(1); 83 int N = 4000; 84 vector<array<double,2>> S(N); // true sources (independent) 85 // s1 ~ Laplace(0,1): inverse CDF sampling from U(-0.5,0.5) 86 auto laplace = [](){ double u = ((rand()/(double)RAND_MAX) - 0.5); return (u>=0?1:-1)*log(1-2*fabs(u)); }; 87 for (int i=0;i<N;++i){ 88 double s1 = laplace(); 89 double s2 = sin(2*M_PI*(i/(double)N)*7.0); // independent periodic source 90 S[i] = {s1, s2}; 91 } 92 // Mix with A 93 Mat2 A = {1.0, 0.5, 0.7, 1.0}; 94 vector<array<double,2>> X(N); 95 for (int i=0;i<N;++i){ 96 X[i] = { A.a11*S[i][0] + A.a12*S[i][1], A.a21*S[i][0] + A.a22*S[i][1] }; 97 } 98 // Center and whiten 99 center(X); 100 Mat2 Ww = whiten_matrix(X); 101 vector<array<double,2>> Xw(N); 102 for (int i=0;i<N;++i){ 103 Xw[i] = { Ww.a11*X[i][0] + Ww.a12*X[i][1], Ww.a21*X[i][0] + Ww.a22*X[i][1] }; 104 } 105 // Run FastICA to get two components (deflation, orthogonalize second against first) 106 array<double,2> w1 = fastica_one(Xw, {1.0, 0.3}); 107 // Make second start orthogonal to first 108 array<double,2> w2 = { -w1[1], w1[0] }; 109 w2 = fastica_one(Xw, w2); 110 111 // Unmixing (on whitened data): rows of W are components 112 // Recover sources y = W * Xw^T (applied per sample) 113 vector<array<double,2>> Y(N); 114 for (int i=0;i<N;++i){ 115 double y1 = w1[0]*Xw[i][0] + w1[1]*Xw[i][1]; 116 double y2 = w2[0]*Xw[i][0] + w2[1]*Xw[i][1]; 117 Y[i] = {y1, y2}; 118 } 119 120 // Evaluate correlation magnitude between recovered and true sources (up to sign/permutation) 121 auto corr = [&](int si, int yi){ 122 double ms=0,my=0; for (int i=0;i<N;++i){ ms+=S[i][si]; my+=Y[i][yi]; } 123 ms/=N; my/=N; double num=0, ds=0, dy=0; 124 for (int i=0;i<N;++i){ double as=S[i][si]-ms, ay=Y[i][yi]-my; num+=as*ay; ds+=as*as; dy+=ay*ay; } 125 return num/sqrt(ds*dy+1e-12); 126 }; 127 double c11=fabs(corr(0,0)), c12=fabs(corr(0,1)), c21=fabs(corr(1,0)), c22=fabs(corr(1,1)); 128 cout.setf(ios::fixed); cout<<setprecision(4); 129 cout << "|corr(s1, y1)|="<<c11<<" |corr(s1, y2)|="<<c12<<"\n"; 130 cout << "|corr(s2, y1)|="<<c21<<" |corr(s2, y2)|="<<c22<<"\n"; 131 cout << "Note: components are identifiable up to sign and permutation.\n"; 132 return 0; 133 } 134
This example performs 2D FastICA to separate two independent sources from their linear mixtures. It centers and whitens the data, then applies a fixed-point iteration with a tanh nonlinearity to estimate independent components. The output correlations show that each recovered component strongly aligns with one true source (up to sign and permutation), illustrating linear disentanglement.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Computes beta-VAE loss for a batch with diagonal Gaussian posterior q(z|x)=N(mu, diag(sigma^2)). 5 // Inputs: recon (predictions), x (targets), mu, logvar, beta. Shapes: batch B x D for latents; recon/x are B x P. 6 // Returns: pair(reconstruction_log_likelihood (as -MSE), total_loss) 7 8 pair<double,double> beta_vae_loss(const vector<vector<double>>& recon, 9 const vector<vector<double>>& x, 10 const vector<vector<double>>& mu, 11 const vector<vector<double>>& logvar, 12 double beta) { 13 int B = (int)x.size(); 14 int P = (int)x[0].size(); 15 int D = (int)mu[0].size(); 16 // Reconstruction term: here we use negative MSE as a proxy for log-likelihood (up to constant) 17 double mse = 0.0; 18 for (int i=0;i<B;++i) 19 for (int j=0;j<P;++j) { 20 double diff = x[i][j] - recon[i][j]; 21 mse += diff*diff; 22 } 23 mse /= (B*P); 24 double recon_term = -mse; // higher is better 25 26 // KL term to standard normal: 0.5 * sum( exp(logvar) + mu^2 - 1 - logvar ) 27 double kl = 0.0; 28 for (int i=0;i<B;++i) 29 for (int j=0;j<D;++j) { 30 double lv = logvar[i][j]; 31 double var = exp(lv); 32 kl += (var + mu[i][j]*mu[i][j] - 1.0 - lv); 33 } 34 kl *= 0.5; kl /= B; // average per example 35 36 double total = recon_term - beta * kl; // ELBO with beta weighting 37 return {recon_term, total}; 38 } 39 40 int main(){ 41 // Toy batch: B=2 samples, P=3-dim recon, D=2-dim latent 42 vector<vector<double>> x = {{1.0, 0.0, -1.0}, {0.5, -0.2, 0.3}}; 43 vector<vector<double>> recon = {{0.8, -0.1, -0.9}, {0.6, -0.1, 0.4}}; 44 vector<vector<double>> mu = {{0.2, -0.1}, {0.0, 0.3}}; 45 vector<vector<double>> logvar = {{-0.2, -0.5}, {-0.3, -0.1}}; // variances < 1 encourage compact latents 46 double beta = 4.0; // stronger pressure for factorized latents 47 48 auto [recon_term, total] = beta_vae_loss(recon, x, mu, logvar, beta); 49 cout.setf(ios::fixed); cout<<setprecision(6); 50 cout << "Reconstruction term (−MSE): " << recon_term << "\n"; 51 cout << "Total beta-VAE objective: " << total << "\n"; 52 cout << "Note: Larger beta encourages independence but may hurt reconstructions.\n"; 53 return 0; 54 } 55
This program computes the beta-VAE loss for a small batch using a diagonal Gaussian posterior. It calculates a reconstruction term (here, negative MSE as a proxy) and the closed-form KL divergence to a standard normal prior, then combines them with a user-specified beta to encourage factorized latents.