Wasserstein Distance & Optimal Transport
Key Points
- ā¢Wasserstein distance (Earth Moverās Distance) measures how much āworkā is needed to transform one probability distribution into another by moving mass with minimal total cost.
- ā¢For discrete distributions, the problem is a linear program known as optimal transport; the cost matrix encodes pairwise moving costs between source and target bins.
- ā¢In one dimension with equal weights, W1 reduces to the average absolute difference between sorted samples, making it very fast to compute.
- ā¢The exact discrete EMD can be computed via min-cost flow on a bipartite graph where edges carry transported mass at a given cost.
- ā¢Entropic regularization (Sinkhorn) makes OT much faster by turning it into iterative matrix scaling with near-linear algebraic operations.
- ā¢Wasserstein distances are true metrics for , satisfying symmetry, triangle inequality, and identity of indiscernibles.
- ā¢They are robust to small translations and capture geometry of the sample space, unlike many bin-by-bin divergences.
- ā¢Applications include computer vision, NLP, generative models (WGAN), domain adaptation, clustering, and matching histograms.
Prerequisites
- āProbability distributions and histograms ā Wasserstein compares probability measures; understanding mass, normalization, and empirical measures is essential.
- āMetric spaces and norms ā The ground cost uses distances like Euclidean or Manhattan; properties of metrics affect Wasserstein behavior.
- āLinear programming and duality ā Discrete OT is an LP; dual variables and constraints explain structure and algorithms.
- āGraph theory and network flow ā Exact EMD can be solved by min-cost flow on bipartite graphs.
- āNumerical linear algebra ā Sinkhorn requires stable matrix operations and understanding of scaling and convergence.
- āSorting and prefix sums ā The 1D W1 shortcuts rely on sorting and cumulative differences.
- āExponentials and logarithms ā Sinkhorn constructs K = exp(-C/ε); stability often benefits from log-domain reasoning.
Detailed Explanation
Tap terms for definitions01Overview
The Wasserstein distance, also known as the Earth Moverās Distance (EMD), quantifies how different two probability distributions are by asking: how much total effort is required to move mass from one distribution to match the other? The effort is measured as mass times distance moved. This view leads to optimal transport, a mathematical framework where you choose a transport plan that specifies how much mass to move from each source location to each target location to minimize total cost. In continuous settings, the distance depends on the geometry of the space (e.g., Euclidean distance in R^d). In discrete settings, it becomes a linear optimization problem over a cost matrix. A special and highly practical case is one-dimensional distributions, where Wasserstein-1 can be computed very efficiently using sorted samples or cumulative distribution functions. To accelerate computations in higher dimensions, entropic regularization yields the Sinkhorn divergence, which is smooth and computable by simple iterative matrix rescaling. Compared to other divergences (like KL or JS), Wasserstein captures spatial relationshipsāmoving mass across nearby bins is cheaper than across distant binsāmaking it robust and semantically meaningful in applications like color histogram matching, word embeddings, and probability density comparison in machine learning.
02Intuition & Analogies
Imagine two piles of sand on a beach, arranged differently. To make one pile look exactly like the other, you scoop sand from some spots and carry it to others. The total effort is how much sand you carry times how far you carry it, summed over all scoops. If the two piles are already close in shape, little sand moves or it moves only short distances, so the effort is small; if theyāre far apart, the effort is large. That total minimal effort is the Earth Moverās Distance. Now think of data histograms instead of sand piles: bars indicate how much probability mass sits at different locations (colors in an image, words in a sentence embedding, or pixel intensities). A bin-by-bin difference ignores that nearby bins are similar. EMD fixes this: moving probability from a nearby bin costs less than from a distant bin, so if two histograms differ mostly by a small shift, the distance is small. In one dimension, the story is even simpler: line up the sand grains of both piles from left to right, and pair the i-th grain in one pile with the i-th grain in the other. The average distance between paired grains is exactly W1 when weights are equal. In higher dimensions, pairing isnāt so straightforwardāyou must choose a transport plan across many possible routes. This becomes a network flow problem: sources push mass through edges to sinks, each edge charging a fee proportional to distance. The cheapest way to satisfy all demands is optimal transport. When exact solutions are too slow, you can slightly relax the problem by adding a tiny āentropy bonusā for spread-out plans, which enables extremely fast matrix scaling iterations (Sinkhorn).
03Formal Definition
04When to Use
Use Wasserstein distance when the geometry of your sample space matters. If nearby points should be considered more similar than distant ones (e.g., pixel intensities, colors, word embeddings, spatial coordinates), Wasserstein captures this naturally. In image processing, apply EMD to compare color or texture histograms; it recognizes palette shifts better than L1 or L2. In NLP, compare distributions of word embeddings across documents or align topics. In generative modeling, Wasserstein distances stabilize training (e.g., WGAN) by providing meaningful gradients even when supports do not overlap. For 1D signals (audio intensities, greyscale rows), W1 can be computed quickly via sorting or cumulative differences, making it ideal for fast comparisons at scale. If exact OT is too expensive in higher dimensions or large supports, use entropic regularization (Sinkhorn) to get a smooth, fast approximation that is differentiable and GPU-friendly. When distributions are histograms over a grid, min-cost flow with a metric cost (e.g., Manhattan or Euclidean) provides exact solutions for moderate sizes. If your application needs a true metric with triangle inequality and robustness to small translations, prefer Wasserstein over divergences like KL that can be infinite with non-overlapping supports.
ā ļøCommon Mistakes
⢠Ignoring mass normalization: Wasserstein assumes total mass matches. If histograms do not sum to the same value, either normalize them to probability distributions or use unbalanced OT with penalties; otherwise, results are incorrect or infeasible. ⢠Confusing bin-by-bin distances with OT: L1 on histogram bins does not account for geometry and can overstate differences when distributions are shifted; OT explicitly moves mass across bins with costs. ⢠Misusing the 1D shortcut: The āsort and average absolute differencesā formula for W1 only holds for equal total mass and equal-weight samples. For unequal weights or different sample sizes, use quantiles or weighted algorithms, not naive pairing. ⢠Choosing an inappropriate ground metric: Costs must reflect meaningful geometry (e.g., Euclidean for coordinates, cosine distance for normalized embeddings). A poor choice yields misleading distances. ⢠Numerical instability in Sinkhorn: Using too small ε or too many iterations without stabilization can cause underflow/overflow. Use log-domain stabilization or damping, and add small epsilons in divisions. ⢠Overlooking complexity: Exact OT with dense cost matrices is O(n^3) in worst-case LP or O(TĀ·E log V) with min-cost flow; for large datasets prefer approximate methods (Sinkhorn, low-rank approximations, multi-scale). ⢠Forgetting dual feasibility in W1: When using dual potentials or Lipschitz critics (e.g., WGAN), enforce the 1-Lipschitz constraint (gradient penalty or spectral norm), otherwise the estimator is biased. ⢠Misinterpreting units: W_p has units of distance, but costs often use d^p; remember to take the p-th root for p > 1 if comparing across settings.
Key Formulas
Wasserstein-p Distance
Explanation: Defines the p-Wasserstein distance as the p-th root of the minimal expected p-th power transport cost over all couplings. Use this as the core definition for continuous spaces.
Discrete OT Linear Program
Explanation: Optimizes a nonnegative transport matrix Ī to match given marginals a and b while minimizing total cost with cost matrix C. This is the standard discrete formulation used in practice.
KantorovichāRubinstein Dual
Explanation: Expresses W1 as a supremum over all 1-Lipschitz functions. Useful in algorithms like WGAN where the critic approximates the optimal potential.
1D CDF Formula
Explanation: In one dimension, W1 equals the area between the two cumulative distribution functions. This enables fast computation using cumulative sums.
Quantile Representation
Explanation: Pairs quantiles at equal cumulative probabilities. For empirical equal-weight samples, sorting and pairing implements this formula directly.
Entropic Regularized OT
Explanation: Adds an entropy-based KL penalty with strength ε that smooths the problem and enables fast Sinkhorn iterations. As ε ā 0, it approaches classic OT.
Sinkhorn Updates
Explanation: Core iterative scaling steps to enforce row and column sums for the regularized problem. Converges rapidly for moderate ε and well-conditioned K.
Triangle Inequality
Explanation: Shows that Wasserstein distances form a metric for . This guarantees consistency when comparing multiple distributions.
Equal-Weight 1D Empirical
Explanation: For two empirical distributions with n equal-weight samples each, W1 equals the average absolute difference of sorted samples. This provides an O(n log n) algorithm.
1D Histogram Flow
Explanation: For 1D histograms on a uniform grid with bin width Īx, W1 equals the sum of absolute cumulative differences times Īx. Useful for fast histogram comparisons.
Generalized KL
Explanation: The generalized KL divergence used in entropic OT encourages Ī to remain close to the independent coupling a while meeting marginal constraints.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Compute W1 between two sets of equal-weight 1D samples. 5 // Assumes P.size() == Q.size() and both represent empirical distributions 6 // with weights 1/n on the real line. 7 double wasserstein1D_equal_weights(vector<double> P, vector<double> Q) { 8 size_t n = P.size(); 9 if (n != Q.size() || n == 0) { 10 throw invalid_argument("P and Q must have the same nonzero size"); 11 } 12 sort(P.begin(), P.end()); 13 sort(Q.begin(), Q.end()); 14 long double sum_abs = 0.0L; 15 for (size_t i = 0; i < n; ++i) { 16 sum_abs += fabsl((long double)P[i] - (long double)Q[i]); 17 } 18 // Average absolute difference equals W1 for equal weights in 1D 19 return (double)(sum_abs / (long double)n); 20 } 21 22 int main() { 23 // Example: two shifted point sets on the line 24 vector<double> P = {0.0, 1.0, 2.0, 3.0}; 25 vector<double> Q = {0.5, 1.5, 2.5, 3.5}; 26 double w1 = wasserstein1D_equal_weights(P, Q); 27 cout << fixed << setprecision(6); 28 cout << "W1(P, Q) = " << w1 << "\n"; // Expected 0.5 average shift 29 return 0; 30 } 31
This program computes W1 for two 1D empirical distributions with equal weights by sorting both sample arrays and averaging the absolute differences element-wise. In one dimension, pairing sorted samples implements the quantile formula exactly for equal weights.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 struct Edge { 5 int to, rev; 6 double cap, cost; 7 }; 8 9 struct MinCostMaxFlow { 10 int N; 11 vector<vector<Edge>> G; 12 MinCostMaxFlow(int n) : N(n), G(n) {} 13 14 void add_edge(int u, int v, double cap, double cost) { 15 Edge a{v, (int)G[v].size(), cap, cost}; 16 Edge b{u, (int)G[u].size(), 0.0, -cost}; 17 G[u].push_back(a); 18 G[v].push_back(b); 19 } 20 21 pair<double,double> min_cost_max_flow(int s, int t, double maxf = numeric_limits<double>::infinity()) { 22 const double INF = 1e100; 23 double flow = 0.0, cost = 0.0; 24 vector<double> pot(N, 0.0), dist(N); 25 vector<int> pv(N), pe(N); 26 while (flow + 1e-15 < maxf) { 27 // Dijkstra on reduced costs 28 fill(dist.begin(), dist.end(), INF); 29 dist[s] = 0.0; 30 priority_queue<pair<double,int>, vector<pair<double,int>>, greater<pair<double,int>>> pq; 31 pq.push({0.0, s}); 32 while (!pq.empty()) { 33 auto [d, u] = pq.top(); pq.pop(); 34 if (d > dist[u]) continue; 35 for (int i = 0; i < (int)G[u].size(); ++i) { 36 const Edge &e = G[u][i]; 37 if (e.cap <= 1e-15) continue; 38 double rc = e.cost + pot[u] - pot[e.to]; // reduced cost 39 if (dist[e.to] > dist[u] + rc + 1e-18) { 40 dist[e.to] = dist[u] + rc; 41 pv[e.to] = u; pe[e.to] = i; 42 pq.push({dist[e.to], e.to}); 43 } 44 } 45 } 46 if (dist[t] >= INF/2) break; // no more augmenting path 47 for (int v = 0; v < N; ++v) if (dist[v] < INF/2) pot[v] += dist[v]; 48 // Find bottleneck 49 double addf = maxf - flow; 50 int v = t; 51 while (v != s) { 52 int u = pv[v]; int i = pe[v]; 53 addf = min(addf, G[u][i].cap); 54 v = u; 55 } 56 // Augment 57 v = t; 58 while (v != s) { 59 int u = pv[v]; int i = pe[v]; 60 Edge &e = G[u][i]; Edge &er = G[v][e.rev]; 61 e.cap -= addf; er.cap += addf; 62 v = u; 63 } 64 flow += addf; 65 cost += addf * pot[t]; // pot[t] equals shortest path distance in original costs 66 } 67 return {flow, cost}; 68 } 69 }; 70 71 // Compute EMD (W1) between two 1D histograms with positions x (sources) and y (targets) 72 // and masses a, b (sum equal). Cost is |x_i - y_j|. 73 double emd_min_cost_flow(const vector<double>& x, const vector<double>& a, 74 const vector<double>& y, const vector<double>& b) { 75 int n = (int)x.size(); 76 int m = (int)y.size(); 77 if (n == 0 || m == 0) throw invalid_argument("Empty supports"); 78 auto sum_vec = [](const vector<double>& v){ return accumulate(v.begin(), v.end(), 0.0); }; 79 double sa = sum_vec(a), sb = sum_vec(b); 80 if (fabs(sa - sb) > 1e-9) throw invalid_argument("Total mass must match for balanced OT"); 81 82 int S = 0; 83 int offsetX = 1; 84 int offsetY = offsetX + n; 85 int T = offsetY + m; 86 MinCostMaxFlow mcmf(T + 1); 87 88 // Source to X nodes 89 for (int i = 0; i < n; ++i) { 90 if (a[i] < -1e-15) throw invalid_argument("Negative mass in a"); 91 mcmf.add_edge(S, offsetX + i, a[i], 0.0); 92 } 93 // X to Y edges with cost = |x_i - y_j| 94 for (int i = 0; i < n; ++i) { 95 for (int j = 0; j < m; ++j) { 96 double c = fabs(x[i] - y[j]); 97 mcmf.add_edge(offsetX + i, offsetY + j, numeric_limits<double>::infinity(), c); 98 } 99 } 100 // Y nodes to Sink 101 for (int j = 0; j < m; ++j) { 102 if (b[j] < -1e-15) throw invalid_argument("Negative mass in b"); 103 mcmf.add_edge(offsetY + j, T, b[j], 0.0); 104 } 105 106 auto [flow, cost] = mcmf.min_cost_max_flow(S, T, sa); 107 if (fabs(flow - sa) > 1e-6) throw runtime_error("Could not send all mass; graph may be disconnected"); 108 return cost; // For W1 with |.| ground metric, this equals EMD 109 } 110 111 int main() { 112 // Example histograms: mass at positions x and y 113 vector<double> x = {0.0, 2.0}; 114 vector<double> a = {0.6, 0.4}; // total 1.0 115 vector<double> y = {1.0, 3.0}; 116 vector<double> b = {0.5, 0.5}; // total 1.0 117 118 double emd = emd_min_cost_flow(x, a, y, b); 119 cout << fixed << setprecision(6); 120 cout << "EMD (W1) via min-cost flow = " << emd << "\n"; 121 return 0; 122 } 123
This program constructs a bipartite flow network representing transport from source masses a at positions x to target masses b at positions y with edge cost equal to the absolute distance. It then runs a min-cost max-flow solver using Dijkstra with Johnson potentials to find the minimal total cost, which equals the Earth Moverās Distance for the |Ā·| ground metric.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Compute Sinkhorn distance approximation: solves 5 // min_{Gamma} <C, Gamma> + eps * KL(Gamma || a b^T) s.t. Gamma 1 = a, Gamma^T 1 = b 6 // using iterative matrix scaling. Returns (Gamma, cost). 7 8 struct SinkhornResult { 9 vector<vector<double>> Gamma; // transport plan (approx.) 10 double cost; // <C, Gamma> 11 int iters; // iterations performed 12 }; 13 14 SinkhornResult sinkhorn(const vector<vector<double>>& C, 15 const vector<double>& a, 16 const vector<double>& b, 17 double eps = 0.05, 18 int max_iters = 5000, 19 double tol = 1e-9) { 20 int n = (int)C.size(); 21 int m = (int)C[0].size(); 22 auto sum_vec = [](const vector<double>& v){ return accumulate(v.begin(), v.end(), 0.0); }; 23 double sa = sum_vec(a), sb = sum_vec(b); 24 if (fabs(sa - 1.0) > 1e-6 || fabs(sb - 1.0) > 1e-6) 25 cerr << "Warning: a and b are expected to sum to 1; they will be used as given.\n"; 26 27 // Build kernel K = exp(-C / eps) 28 vector<vector<double>> K(n, vector<double>(m)); 29 for (int i = 0; i < n; ++i) 30 for (int j = 0; j < m; ++j) 31 K[i][j] = exp(-C[i][j] / eps); 32 33 vector<double> u(n, 1.0), v(m, 1.0); 34 35 auto K_times_v = [&](const vector<double>& v_, int i){ 36 long double s = 0.0L; 37 for (int j = 0; j < m; ++j) s += (long double)K[i][j] * (long double)v_[j]; 38 return (double)s; 39 }; 40 auto KT_times_u = [&](const vector<double>& u_, int j){ 41 long double s = 0.0L; 42 for (int i = 0; i < n; ++i) s += (long double)K[i][j] * (long double)u_[i]; 43 return (double)s; 44 }; 45 46 int it = 0; 47 for (; it < max_iters; ++it) { 48 // u update: u = a / (K v) 49 double max_rel_err = 0.0; 50 vector<double> Kv(n); 51 for (int i = 0; i < n; ++i) { 52 Kv[i] = K_times_v(v, i) + 1e-300; // numerical safety 53 double new_u = a[i] / Kv[i]; 54 max_rel_err = max(max_rel_err, fabs(new_u - u[i]) / (fabs(u[i]) + 1e-12)); 55 u[i] = new_u; 56 } 57 // v update: v = b / (K^T u) 58 vector<double> KTu(m); 59 for (int j = 0; j < m; ++j) { 60 KTu[j] = KT_times_u(u, j) + 1e-300; 61 double new_v = b[j] / KTu[j]; 62 max_rel_err = max(max_rel_err, fabs(new_v - v[j]) / (fabs(v[j]) + 1e-12)); 63 v[j] = new_v; 64 } 65 if (max_rel_err < tol) break; 66 } 67 68 // Construct Gamma = diag(u) * K * diag(v) 69 vector<vector<double>> G(n, vector<double>(m)); 70 for (int i = 0; i < n; ++i) 71 for (int j = 0; j < m; ++j) 72 G[i][j] = u[i] * K[i][j] * v[j]; 73 74 // Compute transport cost <C, Gamma> 75 long double cost = 0.0L; 76 for (int i = 0; i < n; ++i) 77 for (int j = 0; j < m; ++j) 78 cost += (long double)C[i][j] * (long double)G[i][j]; 79 80 return {G, (double)cost, it}; 81 } 82 83 int main() { 84 // Example: two small 2D supports with Euclidean distances 85 vector<pair<double,double>> X = {{0,0}, {1,0}, {0,1}}; // n=3 86 vector<pair<double,double>> Y = {{1,1}, {2,0}}; // m=2 87 vector<double> a = {0.3, 0.4, 0.3}; // sums to 1 88 vector<double> b = {0.6, 0.4}; // sums to 1 89 90 int n = (int)X.size(), m = (int)Y.size(); 91 vector<vector<double>> C(n, vector<double>(m)); 92 for (int i = 0; i < n; ++i) 93 for (int j = 0; j < m; ++j) { 94 double dx = X[i].first - Y[j].first; 95 double dy = X[i].second - Y[j].second; 96 C[i][j] = sqrt(dx*dx + dy*dy); // ground metric cost (p=1) 97 } 98 99 double eps = 0.1; // regularization strength 100 auto res = sinkhorn(C, a, b, eps, 10000, 1e-10); 101 102 cout << fixed << setprecision(6); 103 cout << "Sinkhorn cost ~= " << res.cost << ", iterations = " << res.iters << "\n"; 104 // Optionally, print a few entries of Gamma 105 cout << "Gamma matrix:\n"; 106 for (int i = 0; i < n; ++i) { 107 for (int j = 0; j < m; ++j) cout << res.Gamma[i][j] << (j+1==m?'\n':' '); 108 } 109 return 0; 110 } 111
This implementation solves the entropic-regularized OT problem using Sinkhorn iterations. It builds the Gibbs kernel K = exp(-C/ε), alternates scaling of rows and columns to match marginals a and b, and outputs the approximate plan Πalong with its transport cost. It is much faster than exact OT for moderate to large problems.