🎓How I Study AIHISA
📖Read
📄Papers📰Blogs🎬Courses
💡Learn
🛤️Paths📚Topics💡Concepts🎴Shorts
🎯Practice
⏱️Coach🧩Problems🧠Thinking🎯Prompts🧠Review
SearchSettings
How I Study AI - Learn AI Papers & Lectures the Easy Way
📚TheoryIntermediate

Metric Learning

Key Points

  • •
    Metric learning is about automatically learning a distance function so that similar items are close and dissimilar items are far in a feature space.
  • •
    A common approach is to learn a Mahalanobis metric dM​(x, y) = sqrt((x - y)^T M (x - y)) where M is positive semidefinite (PSD).
  • •
    Loss functions like contrastive loss and triplet loss convert similarity constraints into optimization objectives for the metric.
  • •
    Enforcing M to be PSD is essential; using a Cholesky-like factorization M=LT L or a nonnegative diagonal keeps the metric valid.
  • •
    Pair- and triplet-based training scales quadratically or cubically with data, so smart sampling and batching are important.
  • •
    Learned metrics can significantly improve k-NN, clustering, retrieval, face recognition, and anomaly detection.
  • •
    Regularization (e.g., Frobenius norm) and feature scaling prevent overfitting and stabilize learning.
  • •
    In C++, a practical starting point is learning a diagonal Mahalanobis metric with gradient descent and using it in k-NN.

Prerequisites

  • →Linear algebra (vectors, matrices, eigenvalues) — Understanding Mahalanobis metrics, PSD matrices, and factorizations like M = L^T L requires linear algebra.
  • →Calculus and basic optimization — Gradient-based learning of metric parameters relies on derivatives and optimization concepts.
  • →Supervised learning and loss functions — Pairwise/triplet losses and regularization are core to metric learning objectives.
  • →k-Nearest Neighbors (k-NN) — Metric learning is often applied to improve neighbor-based classification and retrieval.
  • →Data preprocessing (scaling/normalization) — Feature scaling interacts strongly with distance computations and stability of training.
  • →Algorithmic complexity — Pair/triplet enumeration is expensive; knowing complexities guides sampling and batching.

Detailed Explanation

Tap terms for definitions

01Overview

Hook: Have you ever noticed that the default Euclidean distance sometimes groups unlike things together? For example, two faces under different lighting may look far apart in raw pixels even if they are the same person. Concept: Metric learning aims to fix this automatically by learning a distance function tuned to your task, so relevant examples are pulled together and irrelevant ones are pushed apart. Instead of handcrafting features or accepting Euclidean distance, we learn parameters of a metric directly from data with supervision in the form of labels, similar/dissimilar pairs, or triplets. The most popular family is Mahalanobis distances, parameterized by a positive semidefinite (PSD) matrix M, which re-weights and correlates feature dimensions. Example: In a product recommendation system, if buyers care more about price than color, metric learning can learn a matrix that makes price differences count more in distance computations, improving nearest-neighbor retrieval of similar products.

02Intuition & Analogies

Hook: Imagine packing a suitcase with items of different fragility. You’d want fragile items closer together and protected, and sturdy items can be separated more. If you pack naively, you might crush your glasses under a heavy book. Concept: Default Euclidean distance is like packing without thinking—it treats all directions (features) equally and independently. Metric learning is like customizing the suitcase padding: you decide which directions are important (heavier weights), which are correlated (tilted padding), and which can be ignored (near-zero weights). The Mahalanobis metric does exactly this—by applying a linear transform to the space before measuring ordinary Euclidean distance, it stretches, shrinks, and rotates the space so that meaningful neighbors become closer. Example: Suppose your data has height in centimeters and income in dollars. In raw scale, income dwarfs height, so Euclidean distance mostly measures income. A learned metric can down-weight income or up-weight height depending on which better predicts similarity, effectively normalizing and reorienting the space to better reflect your task.

03Formal Definition

Hook: What makes a distance function legitimate, and how do we parametrize one we can learn efficiently? Concept: A metric d on a set X satisfies non-negativity, identity of indiscernibles, symmetry, and triangle inequality. In vector spaces, a widely used parametric family is the Mahalanobis metric: dM​(x, y) = sqrt((x - y)^T M (x - y)), where M is symmetric positive semidefinite (PSD). The PSD constraint ensures non-negativity and symmetry, and via M=LT L we see dM​(x, y) equals Euclidean distance after linear mapping x -> Lx. Learning M is cast as an optimization problem using supervision: we minimize an empirical loss that penalizes large distances for similar pairs and small distances for dissimilar pairs, often with margin-based hinge terms and regularization (e.g., Frobenius norm of M). Example: Contrastive loss for pairs (xi​, xj​, yi​j) with yi​j in {1(similar), 0(dissimilar)}: L=yi​j dM​(xi​, xj​)^2 + (1 - yi​j) [m - dM​(xi​, xj​)]_+^2, plus a regularizer on M, subject to M ⪰ 0.

04When to Use

Hook: If your nearest neighbors don’t look like the right neighbors, your distance is probably wrong. Concept: Use metric learning when the downstream method depends on distances or similarities: k-nearest neighbors, k-means clustering, information retrieval, re-identification, and verification tasks. It’s particularly helpful when raw features have different scales, correlations matter, or you have weak supervision such as pairwise constraints. It can also act as a dimensionality reduction when you factor M = L^T L with L of reduced rank. Example: In face verification, triplet loss encourages an anchor to be closer to a positive (same identity) than to a negative (different identity) by a margin. In e-commerce, product retrieval benefits from a metric that weighs price and brand more than color. In anomaly detection, a learned metric that tightens normal clusters makes outliers stand out more clearly.

⚠️Common Mistakes

Hook: Why does my learned distance make things worse? Concept: Several pitfalls are common. (1) Not enforcing PSD: updating an unconstrained symmetric M can break metric properties; use M = L^T L or restrict to nonnegative diagonal. (2) Ignoring feature scaling: without normalization, some features dominate; standardize features before learning. (3) Overfitting with too many pairs/triplets or too flexible M: use regularization and validation, and consider diagonal or low-rank M for high dimensions. (4) Poor sampling: naively using all pairs is O(n^2); instead, mine informative (hard) positives/negatives and balance classes. (5) Bad margins or learning rates: margins too large make optimization infeasible; too small yields weak separation; tune hyperparameters. (6) Data leakage in evaluation: evaluating on the same pairs/triplets used for training inflates performance; separate train/validation/test. Example: A practitioner updates a full M without projection and ends up with negative eigenvalues; distances become invalid and nearest neighbors flip unpredictably. Projecting onto the PSD cone or parameterizing by L prevents this.

Key Formulas

Metric Axioms

d(x,y)≥0, d(x,x)=0, d(x,y)=d(y,x), d(x,z)≤d(x,y)+d(y,z)

Explanation: These four properties define a valid metric: non-negativity, identity of indiscernibles, symmetry, and triangle inequality. Any learned distance should satisfy them to be mathematically consistent.

Mahalanobis Distance

dM​(x,y)=(x−y)TM(x−y)​

Explanation: This parameterizes distances with a symmetric PSD matrix M. It equals Euclidean distance after a linear transform determined by M.

PSD Equivalences

M⪰0⟺xTMx≥0 ∀x⟺M=LTL

Explanation: A matrix is PSD iff all quadratic forms are nonnegative, which is equivalent to being representable as LT L. This guarantees valid squared distances.

Contrastive Loss

Lcontrast​=ydM​(xi​,xj​)2+(1−y)[m−dM​(xi​,xj​)]+2​

Explanation: Similar pairs (y=1) are pulled together by minimizing squared distance; dissimilar pairs (y=0) are pushed apart to be at least margin m away.

Triplet Loss

Ltriplet​=[dM​(a,p)−dM​(a,n)+α]+​

Explanation: Enforces that an anchor a is closer to a positive p than to a negative n by at least margin alpha. Only violated triplets contribute to the loss.

LMNN Objective

LLMNN​=i∑​j∈N(i)∑​dM​(xi​,xj​)2+ci∑​j∈N(i)∑​l:yl​=yi​∑​[1+dM​(xi​,xj​)2−dM​(xi​,xl​)2]+​

Explanation: Pulls target neighbors close while pushing impostors away by a margin. This convex objective can be minimized under PSD constraints on M.

Frobenius Regularization

∣∣M∣∣F2​=i=1∑d​j=1∑d​Mij2​

Explanation: Penalizes large entries in M to reduce overfitting. Often multiplied by a coefficient λ and added to the training loss.

PSD Projection

M+​=Qdiag(max(0,λ1​),…,max(0,λd​))QT

Explanation: Given eigen-decomposition M=Q diag(λ) QT, projecting onto the PSD cone zeroes negative eigenvalues to restore metric validity.

Pairs and Triplets Count

(2n​)=2n(n−1)​,(3n​)=6n(n−1)(n−2)​

Explanation: The number of possible pairs and triplets grows quadratically and cubically with n, motivating sampling strategies for scalable training.

Diagonal Mahalanobis

ddiag​(x,y)2=k=1∑d​wk​(xk​−yk​)2,wk​≥0

Explanation: A simple, PSD-constrained metric that reweights each feature independently. Useful as a fast baseline and easy to learn with gradient descent.

Complexity Analysis

Training complexity in metric learning depends on the parameterization and supervision granularity. With pairwise supervision, the number of potential pairs scales as O(n2) (specifically, n(n−1)/2), and with triplets it scales as O(n3). Enumerating all pairs or triplets quickly becomes infeasible, so practical systems sample mini-batches, mine hard examples, or restrict neighborhoods. For a diagonal Mahalanobis metric with parameters w in Rd, one gradient step over B pairs costs O(B d): computing squared distances and gradients is linear in dimension and batch size, with O(d) memory for parameters. For a full d×d Mahalanobis M, storing parameters costs O(d2), and each gradient evaluation typically costs O(B d2) due to matrix–vector operations in (x−y)^T M (x−y); projecting M onto the PSD cone via eigen-decomposition costs O(d3) per projection. If factorizing M=LT L with L in Rr×d (r << d), a forward pass costs O(B r d) and reduces memory to O(r d). In inference, computing one Mahalanobis distance naively is O(d2) for full M and O(d) for diagonal M; for k-NN queries against n points, a brute-force search is O(n d) (diagonal) or O(n d2) (full). Precomputing pairwise distances among n items costs O(n2 d) (diagonal) or O(n2 d2) (full), with O(n2) memory if you store the full matrix. To scale retrieval, one may combine learned metrics with approximate nearest neighbor indexes, reducing query time to sublinear at the cost of approximation.

Code Examples

Compute Mahalanobis distances and a pairwise distance matrix
1#include <bits/stdc++.h>
2using namespace std;
3
4// Compute y = M * v for symmetric M (d x d) and vector v (d)
5vector<double> matVec(const vector<vector<double>>& M, const vector<double>& v) {
6 int d = (int)v.size();
7 vector<double> y(d, 0.0);
8 for (int i = 0; i < d; ++i) {
9 double sum = 0.0;
10 for (int j = 0; j < d; ++j) sum += M[i][j] * v[j];
11 y[i] = sum;
12 }
13 return y;
14}
15
16// Compute squared Mahalanobis distance (x - y)^T M (x - y)
17double mahalanobisSquared(const vector<double>& x, const vector<double>& y, const vector<vector<double>>& M) {
18 int d = (int)x.size();
19 vector<double> diff(d);
20 for (int i = 0; i < d; ++i) diff[i] = x[i] - y[i];
21 vector<double> Md = matVec(M, diff);
22 double val = 0.0;
23 for (int i = 0; i < d; ++i) val += diff[i] * Md[i];
24 return val; // nonnegative if M is PSD
25}
26
27int main() {
28 ios::sync_with_stdio(false);
29 cin.tie(nullptr);
30
31 // Example PSD matrix M (2x2): SPD if leading minors > 0
32 vector<vector<double>> M = {
33 {3.0, 1.0},
34 {1.0, 2.0}
35 };
36
37 // Small dataset: 3 points in R^2
38 vector<vector<double>> X = {
39 {0.0, 0.0},
40 {1.0, 0.5},
41 {3.0, 2.0}
42 };
43
44 int n = (int)X.size();
45 vector<vector<double>> D(n, vector<double>(n, 0.0));
46
47 // Compute pairwise squared Mahalanobis distances
48 for (int i = 0; i < n; ++i) {
49 for (int j = i; j < n; ++j) {
50 double d2 = mahalanobisSquared(X[i], X[j], M);
51 D[i][j] = D[j][i] = d2; // symmetric
52 }
53 }
54
55 cout << fixed << setprecision(4);
56 cout << "Pairwise squared Mahalanobis distances (using M):\n";
57 for (int i = 0; i < n; ++i) {
58 for (int j = 0; j < n; ++j) cout << setw(8) << D[i][j] << ' ';
59 cout << '\n';
60 }
61
62 return 0;
63}
64

This program defines a symmetric matrix M and computes squared Mahalanobis distances between all pairs of points. The squared form avoids an extra sqrt and is sufficient for ranking neighbors. If M is PSD, all distances are nonnegative and symmetric.

Time: O(n^2 d^2) to fill the pairwise matrix with a full d×d M (O(d^2) per pair).Space: O(n^2) to store the distance matrix plus O(d^2) for M.
Learn a diagonal Mahalanobis metric with contrastive loss (gradient descent)
1#include <bits/stdc++.h>
2using namespace std;
3
4struct Pair { int i, j; int y; }; // y=1 similar, y=0 dissimilar
5
6// Squared distance under diagonal metric: sum_k w[k]*(x_k - y_k)^2
7double dsq_diag(const vector<double>& a, const vector<double>& b, const vector<double>& w) {
8 double s = 0.0;
9 for (size_t k = 0; k < w.size(); ++k) {
10 double d = a[k] - b[k];
11 s += w[k] * d * d;
12 }
13 return s;
14}
15
16int main() {
17 ios::sync_with_stdio(false);
18 cin.tie(nullptr);
19
20 // Synthetic 2D data: two clusters
21 vector<vector<double>> X = {
22 {0.0, 0.0}, {0.2, -0.1}, {0.1, 0.1}, // class 0
23 {3.0, 2.0}, {3.2, 1.9}, {2.9, 2.1} // class 1
24 };
25 vector<int> y = {0,0,0, 1,1,1};
26
27 int n = (int)X.size();
28 int d = (int)X[0].size();
29
30 // Build labeled pairs: a small balanced set
31 vector<Pair> pairs;
32 for (int i = 0; i < n; ++i) {
33 for (int j = i+1; j < n; ++j) {
34 if ((int)pairs.size() > 60) break; // limit
35 pairs.push_back({i, j, y[i] == y[j] ? 1 : 0});
36 }
37 }
38
39 // Parameters: diagonal weights w[k] >= 0 (PSD)
40 vector<double> w(d, 1.0); // initialize as identity weights
41
42 double margin = 1.0; // contrastive margin m
43 double lr = 0.1; // learning rate
44 double lambda = 1e-3; // L2 regularization strength
45 int epochs = 200;
46
47 std::mt19937 rng(42);
48
49 auto total_loss = [&](const vector<double>& wcur){
50 double L = 0.0;
51 for (const auto& p: pairs) {
52 double s = dsq_diag(X[p.i], X[p.j], wcur);
53 double dxy = sqrt(max(1e-12, s));
54 if (p.y == 1) {
55 L += s; // pull similar pairs together
56 } else {
57 double h = max(0.0, margin - dxy);
58 L += h * h; // push dissimilar pairs apart
59 }
60 }
61 // L2 regularization on w
62 double reg = 0.0;
63 for (double wk : wcur) reg += wk * wk;
64 return L + lambda * reg;
65 };
66
67 // Training: simple full-batch gradient descent
68 for (int e = 0; e < epochs; ++e) {
69 vector<double> grad(d, 0.0);
70 for (const auto& p: pairs) {
71 // compute per-pair gradient wrt w_k
72 double s = dsq_diag(X[p.i], X[p.j], w); // s = sum_k w_k * (dx_k)^2
73 double dxy = sqrt(max(1e-12, s));
74 for (int k = 0; k < d; ++k) {
75 double dx = X[p.i][k] - X[p.j][k];
76 double gk_sim = dx * dx; // d/dw_k of s
77 if (p.y == 1) {
78 grad[k] += gk_sim; // derivative of s
79 } else {
80 double h = margin - dxy;
81 if (h > 0) {
82 // d/dw_k ( (m - sqrt(s))^2 ) = - (m - sqrt(s)) / sqrt(s) * (dx^2)
83 grad[k] += - (h / dxy) * gk_sim;
84 }
85 }
86 }
87 }
88 // Add L2 gradient and update with learning rate
89 for (int k = 0; k < d; ++k) {
90 grad[k] += 2.0 * lambda * w[k];
91 w[k] -= lr * grad[k];
92 // Enforce PSD (nonnegative diagonal)
93 if (w[k] < 0.0) w[k] = 0.0;
94 }
95 if ((e+1) % 50 == 0) {
96 cout << "Epoch " << (e+1) << ", loss = " << total_loss(w) << ", w = [" << w[0] << ", " << w[1] << "]\n";
97 }
98 }
99
100 // Show distances after learning
101 cout << fixed << setprecision(4);
102 cout << "\nLearned diagonal weights w: [" << w[0] << ", " << w[1] << "]\n";
103 cout << "Sample squared distances (within vs across classes):\n";
104 auto print_pair = [&](int i, int j){
105 cout << "d^2(x"<<i<<", x"<<j<<") = " << dsq_diag(X[i], X[j], w)
106 << " (label sim? " << (y[i]==y[j] ? "yes" : "no") << ")\n";
107 };
108 print_pair(0, 1); // similar
109 print_pair(0, 3); // dissimilar
110
111 return 0;
112}
113

This program learns nonnegative diagonal weights w for a Mahalanobis metric using contrastive loss on a small synthetic dataset. Similar pairs are pulled together by minimizing squared distance; dissimilar pairs are pushed apart to be at least a margin apart. We enforce PSD by clamping weights to be nonnegative. The result is a task-tuned per-feature reweighting.

Time: O(E * P * d), where E is epochs, P is number of pairs, and d is the dimension.Space: O(d) for parameters and gradients, plus O(n d) for the data.
k-NN classification with a learned (diagonal) metric
1#include <bits/stdc++.h>
2using namespace std;
3
4// Squared distance under diagonal metric
5double dsq_diag(const vector<double>& a, const vector<double>& b, const vector<double>& w) {
6 double s = 0.0;
7 for (size_t k = 0; k < w.size(); ++k) {
8 double d = a[k] - b[k];
9 s += w[k] * d * d;
10 }
11 return s;
12}
13
14int predictKNN(const vector<vector<double>>& Xtr, const vector<int>& ytr,
15 const vector<double>& xq, int K, const vector<double>& w) {
16 vector<pair<double,int>> dist;
17 dist.reserve(Xtr.size());
18 for (size_t i = 0; i < Xtr.size(); ++i) {
19 dist.push_back({ dsq_diag(Xtr[i], xq, w), (int)i }); // squared distances suffice for ranking
20 }
21 nth_element(dist.begin(), dist.begin() + K, dist.end());
22 unordered_map<int, int> vote;
23 for (int i = 0; i < K; ++i) vote[ytr[dist[i].second]]++;
24 // majority vote
25 int bestLabel = -1, bestCnt = -1;
26 for (auto &kv : vote) {
27 if (kv.second > bestCnt) { bestCnt = kv.second; bestLabel = kv.first; }
28 }
29 return bestLabel;
30}
31
32int main() {
33 ios::sync_with_stdio(false);
34 cin.tie(nullptr);
35
36 // Tiny train set in 2D
37 vector<vector<double>> Xtr = {
38 {0.0, 0.0}, {0.3, -0.1}, {0.1, 0.2}, // class 0
39 {2.5, 2.0}, {2.9, 2.2}, {3.1, 1.9} // class 1
40 };
41 vector<int> ytr = {0,0,0, 1,1,1};
42
43 // Two metrics: Euclidean (w=[1,1]) vs learned emphasis on x-dimension (w=[4,1])
44 vector<double> w_euclid = {1.0, 1.0};
45 vector<double> w_learned = {4.0, 1.0};
46
47 vector<vector<double>> Xtest = {{0.2, 0.1}, {2.8, 2.1}, {1.5, 1.0}};
48 int K = 3;
49
50 cout << "Comparing k-NN predictions (K=3) under two metrics:\n";
51 for (size_t i = 0; i < Xtest.size(); ++i) {
52 int pe = predictKNN(Xtr, ytr, Xtest[i], K, w_euclid);
53 int pl = predictKNN(Xtr, ytr, Xtest[i], K, w_learned);
54 cout << "x_test["<<i<<"] => Euclid: " << pe << ", Learned-diag: " << pl << "\n";
55 }
56
57 return 0;
58}
59

This example shows how a learned diagonal metric (here manually set to emphasize the first feature) changes k-NN decisions. Squared distances are used for ranking, which is equivalent to using distances. In practice, use the learned weights from training (as in the previous example).

Time: O(n d) per query for brute-force k-NN with a diagonal metric; O(n) extra for partial selection (nth_element).Space: O(n d) for the dataset and O(n) for temporary distances.
#metric learning#mahalanobis distance#contrastive loss#triplet loss#knn#psd matrix#regularization#hard negative mining#embedding#similarity learning#lmnn#distance function#pairwise constraints#feature scaling#fisher information