Mean Field Theory of Neural Networks
Key Points
- â˘Mean field theory treats very wide randomly initialized neural networks as averaging machines where each neuron behaves like a sample from a common distribution.
- â˘As width grows, pre-activations in each layer become Gaussian by the Central Limit Theorem, letting us predict variances and correlations layer by layer.
- â˘In the infinite-width limit, the networkâs output distribution becomes a Gaussian process with a kernel computable by a simple recurrence (the NNGP kernel).
- â˘For ReLU networks, variance and correlation propagate via closed-form arc-cosine formulas, enabling exact predictions at initialization.
- â˘Choosing weight variance using mean field (e.g., He or Xavier) prevents signal explosion/decay and keeps variances stable across depth.
- â˘The âedge of chaosâ is characterized by a slope parameter \(\); when \(=1\), signals (and gradients) neither explode nor vanish.
- â˘Mean field predictions are most accurate at large width and at initialization; finite-width and training dynamics introduce deviations.
- â˘These tools guide hyperparameters, diagnose vanishing/exploding behavior, and connect deep nets to kernels for theory and practice.
Prerequisites
- âGaussian random variables â Mean field recurrences assume Gaussian pre-activations; understanding variance, covariance, and bivariate Gaussians is essential.
- âCentral Limit Theorem and Law of Large Numbers â They justify Gaussian pre-activations and self-averaging across wide layers.
- âInner products and norms â The base kernel K^(0) is the normalized inner product; covariance depends on norms and angles.
- âFeedforward neural networks â You need the structure of layers, activations, weights, and biases to apply mean-field ideas.
- âDerivatives of activations â NTK and edge-of-chaos analyses involve expectations over activation derivatives.
- âRandom number generation â Simulations use Gaussian RNGs for weights, biases, and inputs.
- âKernel methods â NNGP and NTK connect deep nets to kernel regression and learning dynamics.
- âTrigonometric functions and angles â ReLU arc-cosine formulas use angles between inputs via arccos and trigonometric identities.
Detailed Explanation
Tap terms for definitions01Overview
Mean field theory (MFT) of neural networks studies randomly initialized, very wide networks by replacing complicated, high-dimensional randomness with simpler aggregate behavior. The key idea is self-averaging: when each neuron receives a sum of many small, independent contributions, the Central Limit Theorem implies the pre-activation becomes approximately Gaussian. This lets us track only low-order statisticsâmeans, variances, and correlationsâthrough layers, rather than the entire weight matrices. Practically, MFT yields recurrences describing how variance and input correlations propagate through depth, providing principled rules for weight scaling (e.g., He or Xavier initialization). In the infinite-width limit, outputs of deep networks converge to a Gaussian process (NNGP), with a kernel defined by these recurrences. Closely related, the Neural Tangent Kernel (NTK) characterizes training dynamics for infinitesimal learning rates or at infinite width. These perspectives explain stability at initialization, predict signal propagation, and clarify when and why gradients vanish or explode. For popular activations like ReLU or tanh, the recurrences admit closed-form or efficiently computable expressions. Mean field tools thus bridge deep learning practice and probabilistic/statistical physics methods, yielding both intuition and concrete formulas usable for model design.
02Intuition & Analogies
Imagine each neuron as a weather station summing many small, independent sensor readings (its inputs), each with tiny random weightings. For a single reading, noise dominates; but average enough sensors and the result becomes reliably bell-shaped (Gaussian) due to the Central Limit Theorem. In a very wide layer, each neuron receives a large number of such inputs, so its pre-activation is like a noisy averageâwell-approximated by a Gaussian with predictable mean and variance. Push this through a nonlinearity (like ReLU), and you get a new distribution with a new variance; feed that into the next layer and repeat. Instead of tracking every stationâs detailed wiring, you track the weather patternâs key summariesâvariance (how turbulent) and correlation (how similar two citiesâ weather is). If two input points are similar (highly correlated), their weather patterns remain similar after passing through layers, governed by a simple correlation update rule. When layers are immensely wide, randomness across different neurons averages out so thoroughly that the entire network acts like sampling from a single Gaussian process: for any set of inputs, the output vector is jointly Gaussian with a kernel thatâs easy to compute. Tuning the weight scale is like adjusting sensor gain: too small and signals die out (calm weather everywhere); too large and everything saturates or becomes chaotic (storms everywhere). The sweet spotâedge of chaosâkeeps distinctions among inputs alive across many layers, making learning easier.
03Formal Definition
04When to Use
- Selecting initialization scales: Use the variance recursion to choose (\sigma_w^2) and (\sigma_b^2) (e.g., He for ReLU, Xavier for tanh) so that (q^{(\ell)}) remains stable across depth.
- Diagnosing vanishing/exploding signals or gradients: If (q^{(\ell)}) shrinks/grows, activations collapse/saturate; if the correlation mapâs slope (\chi) is far from 1, gradients vanish/explode. Adjust (\sigma_w^2), (\sigma_b^2), or activation.
- Comparing architectures/activations: Predict which nonlinearity preserves information (correlation) best over many layers; choose depth accordingly.
- Kernel viewpoints: For infinite width, use NNGP to compute predictive distributions without training, or NTK to approximate gradient descent dynamics. These give baselines and theoretical bounds.
- Sanity checks for inputs: Normalize inputs so (K^{(0)}(x,x)) is (\mathcal{O}(1)) to match assumptions of the recurrences.
- Understanding batchnorm or residual connections: Although classical MFT assumes i.i.d. layers without normalization, extensions analyze how these components affect (q), correlations, and (\chi).
â ď¸Common Mistakes
- Confusing fan-in scaling: Mean field requires (\operatorname{Var}(W_{ij}) = \sigma_w^2/n_{\ell-1}), not a constant variance; otherwise variance blows up with width.
- Ignoring biases in recurrences: (\sigma_b^2) shifts variances and correlations; setting it incorrectly changes fixed points and can induce unwanted drift.
- Overtrusting infinite-width predictions at small width: Finite networks deviate due to higher-order correlations; treat MFT as a guide, not gospel.
- Assuming results hold during training: NNGP characterizes initialization; NTK gives a training approximation in specific regimes (small learning rate, wide layers). General training can drift away from mean-field assumptions.
- Misusing activation formulas: Closed-form arc-cosine kernels apply to ReLU and some homogeneous activations; tanh/sigmoid require numerical expectations.
- Forgetting input scaling: If inputs are not normalized, the base kernel (K^{(0)}) is off, invalidating later layersâ predictions.
- Mixing fan-in/fan-out variants (Xavier vs He): ReLU typically uses He (variance (2/\text{fan-in})); tanh uses Xavier ((1/\text{fan-in})).
- Overlooking correlation slope (\chi): Stability around the fixed point depends on (\chi); setting (\sigma_w^2) without checking (\chi) can cause vanishing/exploding gradients.
Key Formulas
Pre-activation
Explanation: Each neuronâs pre-activation is a sum of many small random contributions plus bias. With i.i.d. weights and large fan-in, this sum becomes approximately Gaussian.
Variance Recurrence
Explanation: Layer-wise variance evolves by pushing a Gaussian through the activation and rescaling by weight variance, then adding bias variance. It predicts signal growth or decay across depth.
NNGP Kernel Recurrence
Explanation: Starting from the input inner product (normalized by input dimension), the kernel at the next layer is the expected product of activations under a joint Gaussian with covariance from the previous layer.
Layer Covariance
Explanation: This matrix captures variances and covariance of pre-activations for a pair of inputs. It parameterizes the bivariate Gaussian used in the kernel expectation.
ReLU Arc-cosine Expectation
Explanation: For ReLU, the expectation of the product of activations has a closed form depending only on the variances and correlation (encoded by the angle \(\)). This yields exact NNGP recurrences.
ReLU Second Moment
Explanation: A zero-mean Gaussian pushed through ReLU has half the variance kept. This simplifies the variance recurrence.
Mean-field Scaling
Explanation: Weights are scaled by fan-in so that pre-activation variance remains \((1)\) as width grows. Bias variance is typically small or zero.
Edge-of-Chaos Slope
Explanation: The slope of the correlation map at the fixed point \(q^*\) controls stability. If \(<1\), signals contract; if \(>1\), they expand; \(=1\) is critical.
ReLU Criticality
Explanation: For ReLU, \(['(z)^2]=1/2\). Thus the edge of chaos occurs at \(=2\).
NNGP Limit
Explanation: At infinite width, the network induces a Gaussian process whose kernel is the depth-L kernel computed by the recurrence. Predictions can be made by kernel regression.
NTK Recurrence (schematic)
Explanation: The NTK across layers evolves using a derivative kernel and the NNGP. At infinite width, \(\) remains constant during training, predicting linearized dynamics.
Central Limit Theorem
Explanation: Sums of many i.i.d. contributions tend to a Gaussian. This underpins Gaussian pre-activations and mean-field recurrences.
Normalized Correlation Map (ReLU, He)
Explanation: With ReLU and variance-preserving weights (\(=2,\ =0\)), the next-layer correlation is a closed-form function of the current correlation.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // ReLU activation 5 inline double relu(double x) { return x > 0.0 ? x : 0.0; } 6 7 // Compute sample mean and variance of a vector 8 pair<double,double> mean_var(const vector<double>& v) { 9 double m = 0.0; for (double x : v) m += x; m /= (double)v.size(); 10 double s2 = 0.0; for (double x : v) { double d = x - m; s2 += d*d; } 11 s2 /= (double)v.size(); 12 return {m, s2}; 13 } 14 15 int main() { 16 ios::sync_with_stdio(false); 17 cin.tie(nullptr); 18 19 // Settings 20 int d0 = 512; // input dimension (fan-in of first layer) 21 int width = 4000; // width for hidden layers (wide) 22 int L = 5; // number of layers to simulate 23 double sigma_w2 = 2.0; // ReLU variance-preserving in mean-field (He): Var(W_ij)=2/fan_in 24 double sigma_b2 = 0.0; // bias variance 25 26 // Random generators 27 std::mt19937_64 rng(12345); 28 std::normal_distribution<double> stdn(0.0, 1.0); 29 30 // Create Gaussian input with unit variance entries 31 vector<double> a_prev(d0); 32 for (int j = 0; j < d0; ++j) a_prev[j] = stdn(rng); // mean 0, var 1 33 34 // Theoretical variance recursion q^{l} 35 // q^0 is input variance per coordinate. For N(0,1) inputs, q^0 = 1. 36 double q_theory = 1.0; 37 38 cout << fixed << setprecision(6); 39 cout << "Layer 0 (input): q_theory= " << q_theory << "\n"; 40 41 for (int ell = 1; ell <= L; ++ell) { 42 int n_in = (ell == 1) ? d0 : width; 43 int n_out = width; // keep width constant 44 double w_scale = sqrt(sigma_w2 / (double)n_in); 45 46 vector<double> z(n_out, 0.0), a(n_out, 0.0); 47 48 // Stream weights: for each output neuron, accumulate dot product 49 for (int i = 0; i < n_out; ++i) { 50 double zi = 0.0; 51 for (int j = 0; j < n_in; ++j) { 52 double wij = w_scale * stdn(rng); // W_ij ~ N(0, sigma_w2 / n_in) 53 zi += wij * a_prev[j]; 54 } 55 double bi = sqrt(sigma_b2) * stdn(rng); 56 zi += bi; 57 z[i] = zi; 58 a[i] = relu(zi); 59 } 60 61 auto mvz = mean_var(z); 62 auto mva = mean_var(a); 63 64 // Theoretical update for ReLU: E[ReLU(z)^2] = q/2 when z~N(0,q) 65 double q_next = sigma_w2 * (q_theory / 2.0) + sigma_b2; 66 67 cout << "Layer " << ell << ": preact mean= " << mvz.first 68 << ", preact var(empirical)= " << mvz.second 69 << ", act var(empirical)= " << mva.second 70 << ", q_theory(next)= " << q_next << "\n"; 71 72 // Prepare next layer 73 a_prev.swap(a); 74 q_theory = q_next; 75 } 76 77 // Note: Empirical pre-activations should have near-zero mean and variance close to q_theory at each layer. 78 return 0; 79 } 80
This program simulates forward propagation through several wide ReLU layers with He scaling (Var(W)=2/fan_in). It streams random weights (no large matrices stored), computes pre-activations and activations, and reports empirical pre-activation variance alongside the mean-field theoretical variance computed by q^{l+1} = sigma_w^2 * (q^l / 2) + sigma_b^2. As width increases, empirical and theoretical variances match closely; pre-activation means approach zero and distributions look Gaussian by aggregation.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 inline double relu(double x) { return x > 0.0 ? x : 0.0; } 5 6 // Arc-cosine expectation for ReLU: E[ReLU(u) ReLU(v)] given variances q1, q2 and covariance c12 7 double relu_expectation(double q1, double q2, double c12) { 8 if (q1 <= 0 || q2 <= 0) return 0.0; 9 double corr = c12 / sqrt(q1 * q2); 10 corr = max(-1.0, min(1.0, corr)); 11 double theta = acos(corr); 12 return sqrt(q1 * q2) * (sin(theta) + (M_PI - theta) * cos(theta)) / (2.0 * M_PI); 13 } 14 15 // Theoretical NNGP kernel recursion for ReLU 16 struct KernelState { double Kxx, Kxpxp, Kxxp; }; 17 KernelState relu_kernel_theory(const vector<double>& x, const vector<double>& xp, 18 int L, double sigma_w2, double sigma_b2) { 19 int d = (int)x.size(); 20 auto dot = [&](const vector<double>& a, const vector<double>& b){ 21 long double s=0; for (int i=0;i<d;++i) s += (long double)a[i]*b[i]; return (double)s; }; 22 double Kxx = dot(x,x) / (double)d; 23 double Kxpxp = dot(xp,xp) / (double)d; 24 double Kxxp = dot(x,xp) / (double)d; 25 for (int ell = 0; ell < L; ++ell) { 26 double q1 = Kxx, q2 = Kxpxp, c12 = Kxxp; 27 double Exx = sigma_w2 * (q1 / 2.0) + sigma_b2; // E[ReLU(u)^2] = q/2 28 double Expxp = sigma_w2 * (q2 / 2.0) + sigma_b2; 29 double Exxp = sigma_w2 * relu_expectation(q1, q2, c12) + sigma_b2; 30 Kxx = Exx; Kxpxp = Expxp; Kxxp = Exxp; 31 } 32 return {Kxx, Kxpxp, Kxxp}; 33 } 34 35 // Empirical kernel via a single wide random network (streamed weights) 36 KernelState relu_kernel_empirical(const vector<double>& x, const vector<double>& xp, 37 int L, int width, double sigma_w2, double sigma_b2, uint64_t seed) { 38 std::mt19937_64 rng(seed); 39 std::normal_distribution<double> stdn(0.0, 1.0); 40 41 vector<double> a = x, ap = xp; 42 43 auto layer = [&](const vector<double>& ain, const vector<double>& apin, int n_in, int n_out) { 44 double w_scale = sqrt(sigma_w2 / (double)n_in); 45 vector<double> aout(n_out), aoutp(n_out); 46 for (int i = 0; i < n_out; ++i) { 47 double zi = 0.0, zip = 0.0; 48 for (int j = 0; j < n_in; ++j) { 49 double wij = w_scale * stdn(rng); 50 zi += wij * ain[j]; 51 zip += wij * apin[j]; // same weights for both inputs 52 } 53 double bi = sqrt(sigma_b2) * stdn(rng); 54 zi += bi; zip += bi; // same bias shared across inputs 55 aout[i] = relu(zi); 56 aoutp[i] = relu(zip); 57 } 58 return pair<vector<double>, vector<double>>(move(aout), move(aoutp)); 59 }; 60 61 int n_in = (int)x.size(); 62 for (int ell = 0; ell < L; ++ell) { 63 auto [anext, apnext] = layer(a, ap, n_in, width); 64 a.swap(anext); ap.swap(apnext); 65 n_in = width; 66 } 67 68 auto dot = [&](const vector<double>& u, const vector<double>& v){ 69 long double s=0; for (int i=0;i<(int)u.size();++i) s += (long double)u[i]*v[i]; return (double)s; }; 70 double Kxx = dot(a,a) / (double)a.size(); 71 double Kxpxp = dot(ap,ap) / (double)ap.size(); 72 double Kxxp = dot(a,ap) / (double)a.size(); 73 return {Kxx, Kxpxp, Kxxp}; 74 } 75 76 int main(){ 77 ios::sync_with_stdio(false); 78 cin.tie(nullptr); 79 80 int d0 = 512; // input dimension 81 int L = 3; // number of ReLU layers 82 int width = 4000; // wide hidden layers 83 double sigma_w2 = 2.0; // He scaling in mean-field 84 double sigma_b2 = 0.0; // no bias 85 86 // Build two inputs with controlled correlation 87 std::mt19937_64 rng(4242); 88 std::normal_distribution<double> stdn(0.0, 1.0); 89 vector<double> x(d0), xp(d0); 90 double target_corr = 0.5; // desired cosine similarity 91 for (int i = 0; i < d0; ++i) { 92 double u = stdn(rng), v = stdn(rng); 93 x[i] = u; 94 xp[i] = target_corr * u + sqrt(1.0 - target_corr*target_corr) * v; 95 } 96 97 auto th = relu_kernel_theory(x, xp, L, sigma_w2, sigma_b2); 98 auto emp = relu_kernel_empirical(x, xp, L, width, sigma_w2, sigma_b2, 2024ULL); 99 100 cout << fixed << setprecision(6); 101 cout << "Theoretical K(x,x)= " << th.Kxx << ", K(x',x')= " << th.Kxpxp << ", K(x,x')= " << th.Kxxp << "\n"; 102 cout << " Empirical K(x,x)= " << emp.Kxx << ", K(x',x')= " << emp.Kxpxp << ", K(x,x')= " << emp.Kxxp << "\n"; 103 104 double c_th = th.Kxxp / sqrt(th.Kxx * th.Kxpxp); 105 double c_emp = emp.Kxxp / sqrt(emp.Kxx * emp.Kxpxp); 106 cout << "Correlation (theory)= " << c_th << ", (empirical)= " << c_emp << "\n"; 107 108 return 0; 109 } 110
This code compares the NNGP kernel predicted by mean-field theory to an empirical estimate computed by forward-propagating two inputs through the same wide random ReLU network. The theoretical kernel uses the arc-cosine expectation to update variances and covariance across layers. The empirical kernel is the average product of activations across the final layer. With sufficiently large width, the two match closely, illustrating the Gaussian process limit.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // ReLU correlation map under mean-field with given sigma_w^2, sigma_b^2. 5 // We propagate (q, c) pairs assuming identical variances for the two inputs. 6 struct QC { double q; double c; }; 7 8 // E[ReLU(u) ReLU(v)] with angle theta (derived from correlation c) 9 static inline double relu_E(double q1, double q2, double c) { 10 if (q1 <= 0 || q2 <= 0) return 0.0; 11 c = max(-1.0, min(1.0, c)); 12 double theta = acos(c); 13 return sqrt(q1*q2) * (sin(theta) + (M_PI - theta) * cos(theta)) / (2.0*M_PI); 14 } 15 16 QC step_relu(double q, double c, double sigma_w2, double sigma_b2) { 17 double q_next = sigma_w2 * (q / 2.0) + sigma_b2; // ReLU second moment q/2 18 double cov_next = sigma_w2 * relu_E(q, q, c) + sigma_b2; 19 double c_next = cov_next / q_next; // normalize to correlation 20 return {q_next, c_next}; 21 } 22 23 int main(){ 24 ios::sync_with_stdio(false); 25 cin.tie(nullptr); 26 27 // Settings 28 double sigma_b2 = 0.0; 29 30 // Case A: subcritical (signals contract) 31 double sigma_w2_A = 1.5; // chi = sigma_w2/2 = 0.75 < 1 32 // Case B: supercritical (signals expand/chaotic) 33 double sigma_w2_B = 3.0; // chi = 1.5 > 1 34 // Case C: edge of chaos 35 double sigma_w2_C = 2.0; // chi = 1 36 37 // Initial variance and correlation 38 double q0 = 1.0; // normalized inputs 39 double c0 = 0.5; // initial cosine similarity 40 41 auto run_case = [&](double sigma_w2, const string& name){ 42 double q = q0, c = c0; 43 cout << "\n" << name << ": sigma_w^2= " << sigma_w2 << ", chi= " << (sigma_w2/2.0) << "\n"; 44 for (int ell = 0; ell < 10; ++ell) { 45 cout << "Layer " << ell << ": q= " << q << ", c= " << c << "\n"; 46 auto qc = step_relu(q, c, sigma_w2, sigma_b2); 47 q = qc.q; c = qc.c; 48 } 49 }; 50 51 run_case(sigma_w2_A, "Subcritical"); 52 run_case(sigma_w2_B, "Supercritical"); 53 run_case(sigma_w2_C, "Edge of Chaos"); 54 55 return 0; 56 } 57
This analytic simulator iterates the mean-field correlation map for ReLU. It reports how variance q and correlation c evolve with depth under three regimes: subcritical (chi < 1), supercritical (chi > 1), and critical (edge of chaos, chi = 1). For ReLU, chi = sigma_w^2 / 2. Subcritical maps push c toward 1 quickly (loss of discriminability), supercritical can destabilize or push c away from fixed points, and the edge preserves distinctions best over many layers.