🎓How I Study AIHISA
📖Read
📄Papers📰Blogs🎬Courses
💡Learn
🛤️Paths📚Topics💡Concepts🎴Shorts
🎯Practice
⏱️Coach🧩Problems🧠Thinking🎯Prompts🧠Review
SearchSettings
How I Study AI - Learn AI Papers & Lectures the Easy Way
📚TheoryAdvanced

Neural Collapse

Key Points

  • •
    Neural Collapse describes what happens at the end of training: the penultimate-layer features of each class concentrate tightly around a class mean.
  • •
    Those class means arrange themselves as the vertices of a perfectly symmetric simplex, forming an Equiangular Tight Frame (ETF).
  • •
    The last-layer classifier weights align with these class means, making linear classification equivalent to nearest-class-mean classification.
  • •
    Off-diagonal cosine similarities between centered class means converge to -1/(K-1), where K is the number of classes.
  • •
    Within-class variability shrinks toward zero, while between-class structure becomes highly symmetric and balanced.
  • •
    This behavior explains why simple prototype-based rules can match softmax at the terminal phase of training.
  • •
    You can verify Neural Collapse numerically by computing class means, cosine similarity matrices, and weight–mean alignment angles.
  • •
    Even simple C++ simulations with synthetic data and softmax regression show the core patterns of Neural Collapse.

Prerequisites

  • →Linear algebra (vectors, inner products, norms, Gram matrices) — Neural Collapse is defined geometrically using inner products, norms, and cosine similarities.
  • →Probability and statistics (means, covariance) — Class means and within-/between-class covariances are central to NC1–NC2.
  • →Softmax regression and cross-entropy — NC3–NC4 concern the last-layer linear classifier trained with cross-entropy.
  • →Gradient descent and regularization — Terminal-phase behavior is observed after optimizing with weight decay; understanding updates helps interpret alignment.
  • →C++ programming (vectors, loops, numerics) — To implement simulations and compute NC metrics efficiently.

Detailed Explanation

Tap terms for definitions

01Overview

Imagine a classroom competition with K teams. As practice goes on, each team’s members start performing more and more like their strongest representative. By the final round, every member of a team behaves almost the same way, and the representatives of different teams are spaced out as evenly as possible. Neural Collapse is the analogous phenomenon for deep networks near the end of training: feature vectors (the outputs of the penultimate layer) from the same class cluster tightly, and the centers of these clusters arrange themselves in a maximally symmetric way, like equally spaced points on a simplex. At the same time, the last-layer classifier weights align with these class centers, so making predictions by the linear classifier becomes nearly the same as picking the nearest class center. Conceptually, Neural Collapse gives a geometric picture of how discriminative representations look when the training loss is driven close to zero under standard settings (e.g., cross-entropy with weight decay). Four hallmark properties (often called NC1–NC4) capture this: within-class collapse (NC1), simplex ETF geometry of class means (NC2), alignment of classifier weights with class means (NC3), and equivalence of the linear classifier with a nearest-class-mean classifier (NC4). This picture helps explain why simple prototype-based methods can work surprisingly well at the end of training and provides guidance for designing classifier heads and analyzing generalization. Example: If you train a softmax classifier on a synthetic dataset where each class is a Gaussian around a symmetric set of points, you will find that the learned weight vectors point in nearly the same directions as the class means, and the pairwise angles between centered class means are almost equal, matching the ETF pattern.

02Intuition & Analogies

Think of each class as a sports team and each data point as a player’s performance vector. Early in the season, performances vary widely within each team. By the final games, coaching and practice reduce variability: players on the same team execute a very similar playbook. That’s the within-class collapse. Now look at the team captains—the average player of each team. If you’re trying to make the teams as easily distinguishable as possible, you’d spread those captains out evenly around a circle (in 2D), a sphere (in 3D), or, more generally, on a regular simplex in higher dimensions. That even spacing ensures that no pair of teams is too close and every team is equally separated from all others. That’s the ETF geometry. The final piece is the referee’s rule for deciding who wins a play: a linear scoring rule (the last layer). Near the end of training, these scoring rules become almost perfectly aligned with the team captains. In other words, scoring a player is nearly equivalent to measuring how close they are to their team’s captain. Thus, the sophisticated softmax decision and the simple nearest-captain rule mostly agree. A concrete everyday analogy: arranging K loudspeakers so that sound power is spread as evenly as possible. You wouldn’t cluster some speakers together—you’d place them uniformly. Similarly, Neural Collapse arranges class centers evenly, like speaker positions on a sphere. Within each speaker’s coverage area (a class), individual sounds (examples) become very similar, so a single representative suffices to describe them.

03Formal Definition

Let h(x) ∈ Rd denote the penultimate-layer feature for input x, and let y ∈ \{1,…,K\} be the class label. Define the class mean (feature centroid) μk​ = E[\,h(x)∣ y=k\,] and the global mean μˉ​ = K1​∑k=1K​ μk​. Neural Collapse at terminal training is characterized by four properties: NC1 (Within-class collapse): The within-class covariance vanishes, ΣW​ = K1​ ∑k=1K​ Cov(h(x)∣ y=k) → 0. NC2 (Simplex ETF of centered means): The centered means μ~​k​ = μk​ - μˉ​ have equal norms, and for i = j, the cosine similarity is constant and equals -K−11​. Equivalently, after normalization, their Gram matrix equals the ETF Gram: G_{ETF} = IK​ - K1​11⊤ scaled by K−1K​. NC3 (Alignment): The last-layer classifier weights wk​ (for softmax) align with the class means: wk​ ∥ μ~​k​ (proportional up to a common scaling) and share equal norms. NC4 (Nearest-class-mean equivalence): Because of NC2–NC3 and centering, the linear classifier f(x) = \operatorname*{argmax}_k w_k⊤ h(x) agrees with the nearest-class-mean rule f_{NCM}(x) = \operatorname*{argmin}_k ∥ h(x) - μk​ ∥2​. These properties are typically observed empirically when minimizing cross-entropy with weight decay in balanced, overparameterized regimes, as training loss approaches zero.

04When to Use

Use Neural Collapse as a lens when analyzing trained classifiers near zero training error, especially with cross-entropy and weight decay. It is valuable for: (1) Diagnosing training saturation: if within-class scatter stops decreasing or the cosine matrix of class means deviates strongly from an ETF, your model may not be in the terminal phase or the data may be unbalanced. (2) Designing classifier heads: fixed-ETF classifiers or prototype-based heads can be justified by NC2–NC3 and often reduce parameters without hurting accuracy late in training. (3) Few-shot or incremental learning: nearest-class-mean rules become compelling if features already satisfy NC1–NC2, so adding a new class requires only computing its mean. (4) Understanding robustness and calibration: the symmetric geometry hints at how margins and confidence distribute across classes. You should also apply these ideas to synthetic experiments to build intuition: create K Gaussian clusters around a simplex and train softmax regression; watch weight vectors align with class means and check that off-diagonal cosines approach -1/(K-1). In practical deep networks, extract penultimate features on a balanced validation set and compute NC metrics to decide whether a simpler decision rule (e.g., NCM) would suffice or whether further feature learning is necessary.

⚠️Common Mistakes

• Confusing logits with features: Neural Collapse is primarily about penultimate-layer features h(x), not the raw logits before softmax. Always compute class means in feature space, not logit space. • Ignoring centering: NC2 concerns centered class means (subtract the global mean). Failing to center can hide the ETF pattern. • Expecting ETF early in training: Neural Collapse is a terminal-phase phenomenon, typically appearing after training loss is very small. Checking too soon may give misleading results. • Unbalanced data: Class imbalance distorts the geometry (different mean norms and covariances). Use balanced subsets or weight corrections before drawing NC conclusions. • Numerical issues: When estimating cosines and covariances, normalize vectors and guard against division by tiny norms. Use stable softmax (subtract max logit) in training. • Overinterpreting noise: Real datasets will not perfectly satisfy ETF; look for trends (e.g., average off-diagonal cosine near -1/(K-1)), not exact equalities. • Dimension confusion: The K class means span at most K-1 dimensions after centering. Do not expect them to be linearly independent in \mathbb{R}^{K}; they lie in the 1^{\top}x=0 subspace. • Forgetting regularization’s role: Weight decay helps enforce equal-norm weights and alignment; without it, patterns may be weaker or slower to emerge.

Key Formulas

Class and Global Means

μk​=E[h(x)∣y=k],μˉ​=K1​k=1∑K​μk​

Explanation: These define the average feature for each class and the overall average across classes. Centering class means uses μk​ - μˉ​.

NC1: Within-Class Collapse

ΣW​=K1​k=1∑K​Cov(h(x)∣y=k)→0

Explanation: The average covariance of features within each class tends to zero at terminal phase. It means all samples of a class concentrate around a single point.

NC2: Simplex ETF Cosines

μ~​k​=μk​−μˉ​,∥μ~​i​∥∥μ~​j​∥μ~​i⊤​μ~​j​​=−K−11​(i=j)

Explanation: After centering and normalizing, all off-diagonal cosine similarities between class means are identical at −1/(K−1). This is the signature of a regular simplex arrangement.

ETF Gram Matrix

GETF​=K−1K​(IK​−K1​11⊤)

Explanation: This is the Gram matrix of K unit-norm vectors forming a centered simplex ETF. Its diagonal entries are 1 and off-diagonals are −1/(K−1).

NC3: Alignment and Equal Norms

wk​∥μ~​k​,∥wk​∥=∥wj​∥∀j,k

Explanation: Classifier weights become parallel to the centered class means, and all have the same norm (often encouraged by weight decay).

NC4: NCM Equivalence

f(x)=argkmax​wk⊤​h(x)=argkmin​∥h(x)−μk​∥22​

Explanation: Under centering and alignment, the linear classifier’s decision is the same as choosing the nearest class mean. The quadratic terms become constant across classes and cancel.

Softmax Probabilities

pk​(h)=∑j=1K​exp(wj⊤​h+bj​)exp(wk⊤​h+bk​)​

Explanation: This converts linear scores into class probabilities. It is the standard output layer for multiclass classification.

Cross-Entropy with Weight Decay

L(W,b)=−N1​i=1∑N​logpyi​​(h(xi​))+2λ​∥W∥F2​

Explanation: The training objective combines negative log-likelihood with L2 regularization. Weight decay supports equal-norm, well-conditioned solutions.

Cosine Similarity

cos(θ)=∥a∥2​∥b∥2​a⊤b​

Explanation: Measures alignment between two vectors. Values near 1 indicate strong alignment; values near −1/(K−1) between class means indicate ETF structure.

ETF Deviation (Frobenius)

EETF​=∥CETF​∥F​∥C−CETF​∥F​​

Explanation: A practical metric: compare the empirical cosine matrix C of centered class means to the ideal ETF cosine matrix. Smaller values indicate stronger Neural Collapse structure.

Complexity Analysis

In the provided simulations, let N be the number of samples, K the number of classes, and d the feature dimension. • Data generation and ETF construction: Building K simplex-ETF centers in RK (or RK−1) is O(K2) due to normalization and pairwise checks, though the constructive formula is O(K). Sampling N points with Gaussian noise is O(Nd). • Metric computation: Class means require a single pass, O(Nd). Computing the cosine similarity matrix across K means is O(K2 d). Within-class and between-class scatter traces are O(Nd) and O(Kd), respectively. These are all linear or quadratic in K and linear in N and d, so they are inexpensive for moderate K. • Softmax regression (full-batch gradient descent): Each epoch computes scores and gradients in O(NKd). Specifically, for each sample, we compute K linear scores (O(Kd)) and update K class-specific gradients (O(Kd)). The memory footprint is O(Nd) for the dataset and O(Kd) for the parameters (W) plus O(K) for biases, making the approach scalable to thousands of samples and moderate K and d. • Verification steps after training: Alignment angles between K weight vectors and K means cost O(Kd). Nearest-class-mean classification on N samples is O(NKd) if done naively by checking all K means per sample. Agreement evaluation between NCM and softmax has the same complexity. In practice, the dominant term is the training loop, O(E N K d), where E is the number of epochs. Using mini-batches reduces per-epoch cost proportionally to the batch size but requires more iterations; asymptotically, the total work remains O(N K d) per pass over the data. Space complexity is dominated by storing X (O(Nd)) and the model (O(Kd)), which are modest for the synthetic examples here.

Code Examples

Construct and Verify a Simplex ETF; Sample Data and Measure Cosines
1#include <iostream>
2#include <vector>
3#include <random>
4#include <cmath>
5#include <numeric>
6#include <algorithm>
7
8// Utility: dot product
9double dot(const std::vector<double>& a, const std::vector<double>& b) {
10 double s = 0.0; for (size_t i = 0; i < a.size(); ++i) s += a[i]*b[i]; return s;
11}
12
13// Utility: L2 norm
14double norm(const std::vector<double>& a) {
15 return std::sqrt(dot(a,a));
16}
17
18// Create K ETF centers in R^K by projecting standard basis onto the 1^T x = 0 subspace and normalizing.
19// v_i = e_i - (1/K) 1; then normalize to unit norm. Pairwise cosines become -1/(K-1).
20std::vector<std::vector<double>> make_simplex_etf(int K) {
21 std::vector<std::vector<double>> centers(K, std::vector<double>(K, 0.0));
22 double one_over_K = 1.0 / K;
23 for (int i = 0; i < K; ++i) {
24 // Start from e_i
25 for (int j = 0; j < K; ++j) centers[i][j] = (j == i ? 1.0 : 0.0) - one_over_K;
26 // Normalize
27 double nrm = norm(centers[i]);
28 for (int j = 0; j < K; ++j) centers[i][j] /= nrm;
29 }
30 return centers; // dimension K (effective rank K-1)
31}
32
33// Sample N_per_class points around each center with Gaussian noise sigma
34std::vector<std::vector<double>> sample_points(const std::vector<std::vector<double>>& centers,
35 int N_per_class, double sigma,
36 std::vector<int>& labels,
37 std::mt19937& rng) {
38 int K = (int)centers.size();
39 int d = (int)centers[0].size();
40 std::normal_distribution<double> gauss(0.0, sigma);
41 std::vector<std::vector<double>> X; X.reserve(K * N_per_class);
42 labels.clear(); labels.reserve(K * N_per_class);
43 for (int k = 0; k < K; ++k) {
44 for (int n = 0; n < N_per_class; ++n) {
45 std::vector<double> x(d);
46 for (int j = 0; j < d; ++j) x[j] = centers[k][j] + gauss(rng);
47 X.push_back(std::move(x));
48 labels.push_back(k);
49 }
50 }
51 return X;
52}
53
54// Compute class means from data and labels
55std::vector<std::vector<double>> compute_class_means(const std::vector<std::vector<double>>& X,
56 const std::vector<int>& y, int K) {
57 int N = (int)X.size(); int d = (int)X[0].size();
58 std::vector<std::vector<double>> mu(K, std::vector<double>(d, 0.0));
59 std::vector<int> cnt(K, 0);
60 for (int i = 0; i < N; ++i) {
61 int k = y[i]; ++cnt[k];
62 for (int j = 0; j < d; ++j) mu[k][j] += X[i][j];
63 }
64 for (int k = 0; k < K; ++k) {
65 for (int j = 0; j < d; ++j) mu[k][j] /= std::max(1, cnt[k]);
66 }
67 return mu;
68}
69
70// Center class means by subtracting their global mean
71void center_means(std::vector<std::vector<double>>& mu) {
72 int K = (int)mu.size(); int d = (int)mu[0].size();
73 std::vector<double> mbar(d, 0.0);
74 for (int k = 0; k < K; ++k) for (int j = 0; j < d; ++j) mbar[j] += mu[k][j];
75 for (int j = 0; j < d; ++j) mbar[j] /= K;
76 for (int k = 0; k < K; ++k) for (int j = 0; j < d; ++j) mu[k][j] -= mbar[j];
77}
78
79// Compute average off-diagonal cosine similarity among class means
80double average_offdiag_cosine(const std::vector<std::vector<double>>& mu) {
81 int K = (int)mu.size();
82 double s = 0.0; int c = 0;
83 for (int i = 0; i < K; ++i) {
84 for (int j = i+1; j < K; ++j) {
85 double ci = dot(mu[i], mu[j]) / (norm(mu[i]) * norm(mu[j]) + 1e-12);
86 s += ci; ++c;
87 }
88 }
89 return (c > 0 ? s / c : 0.0);
90}
91
92// Compute within-class and between-class scatter traces
93void scatter_traces(const std::vector<std::vector<double>>& X,
94 const std::vector<int>& y,
95 const std::vector<std::vector<double>>& mu,
96 double& tr_Sw, double& tr_Sb) {
97 int N = (int)X.size(); int d = (int)X[0].size(); int K = (int)mu.size();
98 // Global mean of class means
99 std::vector<std::vector<double>> cm = mu; // copy to center
100 center_means(cm);
101 // tr(S_b) = (1/K) * sum_k ||mu_k - mbar||^2, but since cm are centered, that's just (1/K) * sum_k ||cm_k||^2
102 double sb = 0.0;
103 for (int k = 0; k < K; ++k) sb += dot(cm[k], cm[k]);
104 tr_Sb = sb / K;
105 // tr(S_w) = (1/N) * sum_i ||x_i - mu_{y_i}||^2
106 double sw = 0.0;
107 for (int i = 0; i < N; ++i) {
108 int k = y[i]; double s = 0.0;
109 for (int j = 0; j < d; ++j) {
110 double r = X[i][j] - mu[k][j]; s += r*r;
111 }
112 sw += s;
113 }
114 tr_Sw = sw / N;
115}
116
117int main() {
118 int K = 5; // number of classes
119 int N_per_class = 200; // samples per class
120 double sigma = 0.2; // noise level
121 std::mt19937 rng(42);
122
123 // 1) Build ETF centers and verify analytical cosines
124 auto centers = make_simplex_etf(K);
125 double target_cos = -1.0 / (K - 1);
126
127 // 2) Sample data and compute empirical class means
128 std::vector<int> y;
129 auto X = sample_points(centers, N_per_class, sigma, y, rng);
130 auto mu = compute_class_means(X, y, K);
131
132 // 3) Center means and measure off-diagonal cosine
133 center_means(mu);
134 double avg_cos = average_offdiag_cosine(mu);
135
136 // 4) Scatter traces
137 double tr_Sw = 0.0, tr_Sb = 0.0;
138 scatter_traces(X, y, mu, tr_Sw, tr_Sb);
139
140 std::cout << "K = " << K << "\n";
141 std::cout << "Target ETF off-diagonal cosine = " << target_cos << "\n";
142 std::cout << "Empirical average off-diagonal cosine (centered means) = " << avg_cos << "\n";
143 std::cout << "Within-class trace tr(Sw) = " << tr_Sw << ", Between-class trace tr(Sb) = " << tr_Sb << "\n";
144 std::cout << "Ratio tr(Sw)/tr(Sb) = " << (tr_Sw / (tr_Sb + 1e-12)) << " (smaller is stronger collapse)\n";
145
146 return 0;
147}
148

This program constructs K simplex ETF centers using a closed-form projection of the standard basis onto the 1^T x = 0 subspace and normalizes them. It then samples Gaussian data around each center, computes empirical class means, centers them, and measures the average off-diagonal cosine similarity. It also computes within-class and between-class scatter traces to quantify collapse. For an ideal ETF, off-diagonal cosine equals −1/(K−1). With moderate noise, the empirical value approaches that target as sample size grows.

Time: O(N d + K^2 d) where N = K * N_per_class and d = KSpace: O(N d + K d)
Train Softmax Regression with Weight Decay; Verify NC1–NC4 Metrics
1#include <iostream>
2#include <vector>
3#include <random>
4#include <cmath>
5#include <numeric>
6#include <algorithm>
7
8// Basic vector utilities
9double dot(const std::vector<double>& a, const std::vector<double>& b){ double s=0; for(size_t i=0;i<a.size();++i) s+=a[i]*b[i]; return s; }
10double norm(const std::vector<double>& a){ return std::sqrt(dot(a,a)); }
11
12// Simplex ETF centers in R^K
13std::vector<std::vector<double>> make_simplex_etf(int K){
14 std::vector<std::vector<double>> c(K, std::vector<double>(K,0.0));
15 double one_over_K = 1.0 / K;
16 for(int i=0;i<K;++i){
17 for(int j=0;j<K;++j) c[i][j] = (j==i?1.0:0.0) - one_over_K;
18 double nrm = norm(c[i]);
19 for(int j=0;j<K;++j) c[i][j] /= nrm;
20 }
21 return c;
22}
23
24struct Dataset {
25 std::vector<std::vector<double>> X; // N x d
26 std::vector<int> y; // N labels
27 int N, d, K;
28};
29
30Dataset make_dataset(int K, int N_per_class, double sigma, std::mt19937& rng){
31 auto centers = make_simplex_etf(K);
32 int d = (int)centers[0].size();
33 std::normal_distribution<double> gauss(0.0, sigma);
34 Dataset ds; ds.N = K*N_per_class; ds.d = d; ds.K = K;
35 ds.X.reserve(ds.N); ds.y.reserve(ds.N);
36 for(int k=0;k<K;++k){
37 for(int n=0;n<N_per_class;++n){
38 std::vector<double> x(d);
39 for(int j=0;j<d;++j) x[j] = centers[k][j] + gauss(rng);
40 ds.X.push_back(std::move(x)); ds.y.push_back(k);
41 }
42 }
43 return ds;
44}
45
46// Softmax and stable log-sum-exp
47std::vector<double> softmax(const std::vector<double>& s){
48 double m = *std::max_element(s.begin(), s.end());
49 double Z = 0.0; std::vector<double> p(s.size());
50 for(size_t k=0;k<s.size();++k){ p[k] = std::exp(s[k]-m); Z += p[k]; }
51 for(size_t k=0;k<s.size();++k) p[k] /= (Z + 1e-12);
52 return p;
53}
54
55// Train softmax regression (full-batch GD) with L2 regularization
56struct SoftmaxReg {
57 int d, K;
58 std::vector<std::vector<double>> W; // K x d (class-major)
59 std::vector<double> b; // K
60 SoftmaxReg(int d_, int K_): d(d_), K(K_), W(K_, std::vector<double>(d_,0.0)), b(K_,0.0) {}
61
62 // Forward: scores for one sample
63 void scores(const std::vector<double>& x, std::vector<double>& s) const {
64 s.assign(K, 0.0);
65 for(int k=0;k<K;++k){ s[k] = dot(W[k], x) + b[k]; }
66 }
67
68 double train_epoch(const Dataset& ds, double lr, double lambda){
69 std::vector<std::vector<double>> gW(K, std::vector<double>(d, 0.0));
70 std::vector<double> gb(K, 0.0);
71 double loss = 0.0;
72 std::vector<double> s(K), p(K);
73 for(int i=0;i<ds.N;++i){
74 scores(ds.X[i], s);
75 p = softmax(s);
76 // loss contribution: -log p_{y_i}
77 loss += -std::log(std::max(1e-12, p[ds.y[i]]));
78 // gradient: (p - y_onehot) * x^T and (p - y_onehot) for bias
79 for(int k=0;k<K;++k){
80 double g = p[k] - (k==ds.y[i] ? 1.0 : 0.0);
81 gb[k] += g;
82 for(int j=0;j<d;++j) gW[k][j] += g * ds.X[i][j];
83 }
84 }
85 // Average and add L2 gradients
86 loss = loss / ds.N + 0.5 * lambda * [&](){ double s=0; for(int k=0;k<K;++k) for(int j=0;j<d;++j) s += W[k][j]*W[k][j]; return s; }();
87 for(int k=0;k<K;++k){
88 gb[k] /= ds.N;
89 for(int j=0;j<d;++j){ gW[k][j] = gW[k][j]/ds.N + lambda * W[k][j]; }
90 }
91 // Update
92 for(int k=0;k<K;++k){
93 b[k] -= lr * gb[k];
94 for(int j=0;j<d;++j) W[k][j] -= lr * gW[k][j];
95 }
96 return loss;
97 }
98
99 int predict(const std::vector<double>& x) const {
100 std::vector<double> s(K); scores(x,s);
101 int arg=0; for(int k=1;k<K;++k) if(s[k]>s[arg]) arg=k; return arg;
102 }
103};
104
105// Compute class means from raw inputs (our features here)
106std::vector<std::vector<double>> class_means(const Dataset& ds){
107 std::vector<std::vector<double>> mu(ds.K, std::vector<double>(ds.d, 0.0));
108 std::vector<int> cnt(ds.K, 0);
109 for(int i=0;i<ds.N;++i){ int k = ds.y[i]; ++cnt[k]; for(int j=0;j<ds.d;++j) mu[k][j] += ds.X[i][j]; }
110 for(int k=0;k<ds.K;++k) for(int j=0;j<ds.d;++j) mu[k][j] /= std::max(1,cnt[k]);
111 return mu;
112}
113
114void center_means(std::vector<std::vector<double>>& mu){
115 int K = (int)mu.size(); int d = (int)mu[0].size();
116 std::vector<double> mbar(d,0.0);
117 for(int k=0;k<K;++k) for(int j=0;j<d;++j) mbar[j]+=mu[k][j];
118 for(int j=0;j<d;++j) mbar[j]/=K;
119 for(int k=0;k<K;++k) for(int j=0;j<d;++j) mu[k][j]-=mbar[j];
120}
121
122// NCM prediction using Euclidean distance to class means
123int predict_ncm(const std::vector<double>& x, const std::vector<std::vector<double>>& mu){
124 int K = (int)mu.size(); int best=0; double bestd=1e300;
125 for(int k=0;k<K;++k){ double d2=0.0; for(size_t j=0;j<x.size();++j){ double r = x[j]-mu[k][j]; d2 += r*r; } if(d2<bestd){ bestd=d2; best=k; } }
126 return best;
127}
128
129int main(){
130 // Settings
131 int K = 5; int N_per_class = 300; double sigma = 0.25; std::mt19937 rng(123);
132 Dataset ds = make_dataset(K, N_per_class, sigma, rng);
133
134 // Model and training
135 SoftmaxReg model(ds.d, ds.K);
136 double lr = 0.5, lambda = 1e-2; int epochs = 800;
137 for(int e=1;e<=epochs;++e){
138 double L = model.train_epoch(ds, lr, lambda);
139 if(e%100==0) std::cout << "Epoch "<<e<<" loss = "<<L<<"\n";
140 }
141
142 // Accuracy
143 int correct = 0; for(int i=0;i<ds.N;++i) correct += (model.predict(ds.X[i])==ds.y[i]);
144 std::cout << "Train accuracy = " << (100.0*correct/ds.N) << "%\n";
145
146 // NC metrics
147 auto mu = class_means(ds);
148 // NC1: within-class trace vs between-class trace
149 auto centered = mu; center_means(centered);
150 // tr(S_b) = (1/K) sum ||centered_mu||^2
151 double tr_Sb=0.0; for(int k=0;k<K;++k) tr_Sb += dot(centered[k], centered[k]); tr_Sb/=K;
152 // tr(S_w) = (1/N) sum ||x - mu_y||^2
153 double tr_Sw=0.0; for(int i=0;i<ds.N;++i){ int k=ds.y[i]; double s=0.0; for(int j=0;j<ds.d;++j){ double r=ds.X[i][j]-mu[k][j]; s+=r*r; } tr_Sw+=s; } tr_Sw/=ds.N;
154 std::cout << "NC1: tr(Sw)/tr(Sb) = " << (tr_Sw/(tr_Sb+1e-12)) << " (smaller is better)\n";
155
156 // NC2: off-diagonal cosine among centered class means
157 double off=0.0; int cnt=0; for(int i=0;i<K;++i) for(int j=i+1;j<K;++j){ off += dot(centered[i],centered[j])/(norm(centered[i])*norm(centered[j])+1e-12); ++cnt; }
158 double avg_off = off/std::max(1,cnt);
159 std::cout << "NC2: avg off-diagonal cosine (target = " << (-1.0/(K-1)) << ") = " << avg_off << "\n";
160
161 // NC3: alignment angles between w_k and class means mu_k (in degrees)
162 double sum_angle=0.0, max_angle=0.0; for(int k=0;k<K;++k){ double c = dot(model.W[k], mu[k])/(norm(model.W[k])*norm(mu[k])+1e-12); c = std::max(-1.0,std::min(1.0,c)); double ang = std::acos(c)*180.0/M_PI; sum_angle+=ang; if(ang>max_angle) max_angle=ang; }
163 std::cout << "NC3: mean alignment angle (degrees) = " << (sum_angle/K) << ", max = " << max_angle << "\n";
164
165 // NC4: NCM vs softmax decision agreement
166 int agree=0; for(int i=0;i<ds.N;++i){ int a = model.predict(ds.X[i]); int b = predict_ncm(ds.X[i], mu); if(a==b) ++agree; }
167 std::cout << "NC4: NCM-softmax agreement = " << (100.0*agree/ds.N) << "%\n";
168
169 return 0;
170}
171

This end-to-end example trains a multinomial logistic regression model with L2 weight decay on synthetic data generated from simplex-ETF centers with Gaussian noise. After training to low loss, it computes: (NC1) the ratio of within-class to between-class scatter traces; (NC2) the average off-diagonal cosine of centered class means; (NC3) alignment angles between each class weight vector and its class mean; and (NC4) the agreement rate between softmax predictions and nearest-class-mean predictions. On balanced data and sufficient training, you should observe small tr(Sw)/tr(Sb), off-diagonal cosines near −1/(K−1), small alignment angles, and high NCM–softmax agreement.

Time: O(E N K d) for E epochs; each epoch is O(N K d)Space: O(N d + K d)
#neural collapse#simplex etf#equiangular tight frame#penultimate features#class means#softmax regression#cross-entropy#nearest class mean#cosine similarity#within-class covariance#between-class scatter#alignment#prototype classifier#weight decay#terminal phase