Metric Learning
Key Points
- •Metric learning is about automatically learning a distance function so that similar items are close and dissimilar items are far in a feature space.
- •A common approach is to learn a Mahalanobis metric (x, y) = sqrt((x - y)^T M (x - y)) where M is positive semidefinite (PSD).
- •Loss functions like contrastive loss and triplet loss convert similarity constraints into optimization objectives for the metric.
- •Enforcing M to be PSD is essential; using a Cholesky-like factorization L or a nonnegative diagonal keeps the metric valid.
- •Pair- and triplet-based training scales quadratically or cubically with data, so smart sampling and batching are important.
- •Learned metrics can significantly improve k-NN, clustering, retrieval, face recognition, and anomaly detection.
- •Regularization (e.g., Frobenius norm) and feature scaling prevent overfitting and stabilize learning.
- •In C++, a practical starting point is learning a diagonal Mahalanobis metric with gradient descent and using it in k-NN.
Prerequisites
- →Linear algebra (vectors, matrices, eigenvalues) — Understanding Mahalanobis metrics, PSD matrices, and factorizations like M = L^T L requires linear algebra.
- →Calculus and basic optimization — Gradient-based learning of metric parameters relies on derivatives and optimization concepts.
- →Supervised learning and loss functions — Pairwise/triplet losses and regularization are core to metric learning objectives.
- →k-Nearest Neighbors (k-NN) — Metric learning is often applied to improve neighbor-based classification and retrieval.
- →Data preprocessing (scaling/normalization) — Feature scaling interacts strongly with distance computations and stability of training.
- →Algorithmic complexity — Pair/triplet enumeration is expensive; knowing complexities guides sampling and batching.
Detailed Explanation
Tap terms for definitions01Overview
Hook: Have you ever noticed that the default Euclidean distance sometimes groups unlike things together? For example, two faces under different lighting may look far apart in raw pixels even if they are the same person. Concept: Metric learning aims to fix this automatically by learning a distance function tuned to your task, so relevant examples are pulled together and irrelevant ones are pushed apart. Instead of handcrafting features or accepting Euclidean distance, we learn parameters of a metric directly from data with supervision in the form of labels, similar/dissimilar pairs, or triplets. The most popular family is Mahalanobis distances, parameterized by a positive semidefinite (PSD) matrix M, which re-weights and correlates feature dimensions. Example: In a product recommendation system, if buyers care more about price than color, metric learning can learn a matrix that makes price differences count more in distance computations, improving nearest-neighbor retrieval of similar products.
02Intuition & Analogies
Hook: Imagine packing a suitcase with items of different fragility. You’d want fragile items closer together and protected, and sturdy items can be separated more. If you pack naively, you might crush your glasses under a heavy book. Concept: Default Euclidean distance is like packing without thinking—it treats all directions (features) equally and independently. Metric learning is like customizing the suitcase padding: you decide which directions are important (heavier weights), which are correlated (tilted padding), and which can be ignored (near-zero weights). The Mahalanobis metric does exactly this—by applying a linear transform to the space before measuring ordinary Euclidean distance, it stretches, shrinks, and rotates the space so that meaningful neighbors become closer. Example: Suppose your data has height in centimeters and income in dollars. In raw scale, income dwarfs height, so Euclidean distance mostly measures income. A learned metric can down-weight income or up-weight height depending on which better predicts similarity, effectively normalizing and reorienting the space to better reflect your task.
03Formal Definition
04When to Use
Hook: If your nearest neighbors don’t look like the right neighbors, your distance is probably wrong. Concept: Use metric learning when the downstream method depends on distances or similarities: k-nearest neighbors, k-means clustering, information retrieval, re-identification, and verification tasks. It’s particularly helpful when raw features have different scales, correlations matter, or you have weak supervision such as pairwise constraints. It can also act as a dimensionality reduction when you factor M = L^T L with L of reduced rank. Example: In face verification, triplet loss encourages an anchor to be closer to a positive (same identity) than to a negative (different identity) by a margin. In e-commerce, product retrieval benefits from a metric that weighs price and brand more than color. In anomaly detection, a learned metric that tightens normal clusters makes outliers stand out more clearly.
⚠️Common Mistakes
Hook: Why does my learned distance make things worse? Concept: Several pitfalls are common. (1) Not enforcing PSD: updating an unconstrained symmetric M can break metric properties; use M = L^T L or restrict to nonnegative diagonal. (2) Ignoring feature scaling: without normalization, some features dominate; standardize features before learning. (3) Overfitting with too many pairs/triplets or too flexible M: use regularization and validation, and consider diagonal or low-rank M for high dimensions. (4) Poor sampling: naively using all pairs is O(n^2); instead, mine informative (hard) positives/negatives and balance classes. (5) Bad margins or learning rates: margins too large make optimization infeasible; too small yields weak separation; tune hyperparameters. (6) Data leakage in evaluation: evaluating on the same pairs/triplets used for training inflates performance; separate train/validation/test. Example: A practitioner updates a full M without projection and ends up with negative eigenvalues; distances become invalid and nearest neighbors flip unpredictably. Projecting onto the PSD cone or parameterizing by L prevents this.
Key Formulas
Metric Axioms
Explanation: These four properties define a valid metric: non-negativity, identity of indiscernibles, symmetry, and triangle inequality. Any learned distance should satisfy them to be mathematically consistent.
Mahalanobis Distance
Explanation: This parameterizes distances with a symmetric PSD matrix M. It equals Euclidean distance after a linear transform determined by M.
PSD Equivalences
Explanation: A matrix is PSD iff all quadratic forms are nonnegative, which is equivalent to being representable as L. This guarantees valid squared distances.
Contrastive Loss
Explanation: Similar pairs (y=1) are pulled together by minimizing squared distance; dissimilar pairs (y=0) are pushed apart to be at least margin m away.
Triplet Loss
Explanation: Enforces that an anchor a is closer to a positive p than to a negative n by at least margin alpha. Only violated triplets contribute to the loss.
LMNN Objective
Explanation: Pulls target neighbors close while pushing impostors away by a margin. This convex objective can be minimized under PSD constraints on M.
Frobenius Regularization
Explanation: Penalizes large entries in M to reduce overfitting. Often multiplied by a coefficient and added to the training loss.
PSD Projection
Explanation: Given eigen-decomposition diag() , projecting onto the PSD cone zeroes negative eigenvalues to restore metric validity.
Pairs and Triplets Count
Explanation: The number of possible pairs and triplets grows quadratically and cubically with n, motivating sampling strategies for scalable training.
Diagonal Mahalanobis
Explanation: A simple, PSD-constrained metric that reweights each feature independently. Useful as a fast baseline and easy to learn with gradient descent.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Compute y = M * v for symmetric M (d x d) and vector v (d) 5 vector<double> matVec(const vector<vector<double>>& M, const vector<double>& v) { 6 int d = (int)v.size(); 7 vector<double> y(d, 0.0); 8 for (int i = 0; i < d; ++i) { 9 double sum = 0.0; 10 for (int j = 0; j < d; ++j) sum += M[i][j] * v[j]; 11 y[i] = sum; 12 } 13 return y; 14 } 15 16 // Compute squared Mahalanobis distance (x - y)^T M (x - y) 17 double mahalanobisSquared(const vector<double>& x, const vector<double>& y, const vector<vector<double>>& M) { 18 int d = (int)x.size(); 19 vector<double> diff(d); 20 for (int i = 0; i < d; ++i) diff[i] = x[i] - y[i]; 21 vector<double> Md = matVec(M, diff); 22 double val = 0.0; 23 for (int i = 0; i < d; ++i) val += diff[i] * Md[i]; 24 return val; // nonnegative if M is PSD 25 } 26 27 int main() { 28 ios::sync_with_stdio(false); 29 cin.tie(nullptr); 30 31 // Example PSD matrix M (2x2): SPD if leading minors > 0 32 vector<vector<double>> M = { 33 {3.0, 1.0}, 34 {1.0, 2.0} 35 }; 36 37 // Small dataset: 3 points in R^2 38 vector<vector<double>> X = { 39 {0.0, 0.0}, 40 {1.0, 0.5}, 41 {3.0, 2.0} 42 }; 43 44 int n = (int)X.size(); 45 vector<vector<double>> D(n, vector<double>(n, 0.0)); 46 47 // Compute pairwise squared Mahalanobis distances 48 for (int i = 0; i < n; ++i) { 49 for (int j = i; j < n; ++j) { 50 double d2 = mahalanobisSquared(X[i], X[j], M); 51 D[i][j] = D[j][i] = d2; // symmetric 52 } 53 } 54 55 cout << fixed << setprecision(4); 56 cout << "Pairwise squared Mahalanobis distances (using M):\n"; 57 for (int i = 0; i < n; ++i) { 58 for (int j = 0; j < n; ++j) cout << setw(8) << D[i][j] << ' '; 59 cout << '\n'; 60 } 61 62 return 0; 63 } 64
This program defines a symmetric matrix M and computes squared Mahalanobis distances between all pairs of points. The squared form avoids an extra sqrt and is sufficient for ranking neighbors. If M is PSD, all distances are nonnegative and symmetric.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 struct Pair { int i, j; int y; }; // y=1 similar, y=0 dissimilar 5 6 // Squared distance under diagonal metric: sum_k w[k]*(x_k - y_k)^2 7 double dsq_diag(const vector<double>& a, const vector<double>& b, const vector<double>& w) { 8 double s = 0.0; 9 for (size_t k = 0; k < w.size(); ++k) { 10 double d = a[k] - b[k]; 11 s += w[k] * d * d; 12 } 13 return s; 14 } 15 16 int main() { 17 ios::sync_with_stdio(false); 18 cin.tie(nullptr); 19 20 // Synthetic 2D data: two clusters 21 vector<vector<double>> X = { 22 {0.0, 0.0}, {0.2, -0.1}, {0.1, 0.1}, // class 0 23 {3.0, 2.0}, {3.2, 1.9}, {2.9, 2.1} // class 1 24 }; 25 vector<int> y = {0,0,0, 1,1,1}; 26 27 int n = (int)X.size(); 28 int d = (int)X[0].size(); 29 30 // Build labeled pairs: a small balanced set 31 vector<Pair> pairs; 32 for (int i = 0; i < n; ++i) { 33 for (int j = i+1; j < n; ++j) { 34 if ((int)pairs.size() > 60) break; // limit 35 pairs.push_back({i, j, y[i] == y[j] ? 1 : 0}); 36 } 37 } 38 39 // Parameters: diagonal weights w[k] >= 0 (PSD) 40 vector<double> w(d, 1.0); // initialize as identity weights 41 42 double margin = 1.0; // contrastive margin m 43 double lr = 0.1; // learning rate 44 double lambda = 1e-3; // L2 regularization strength 45 int epochs = 200; 46 47 std::mt19937 rng(42); 48 49 auto total_loss = [&](const vector<double>& wcur){ 50 double L = 0.0; 51 for (const auto& p: pairs) { 52 double s = dsq_diag(X[p.i], X[p.j], wcur); 53 double dxy = sqrt(max(1e-12, s)); 54 if (p.y == 1) { 55 L += s; // pull similar pairs together 56 } else { 57 double h = max(0.0, margin - dxy); 58 L += h * h; // push dissimilar pairs apart 59 } 60 } 61 // L2 regularization on w 62 double reg = 0.0; 63 for (double wk : wcur) reg += wk * wk; 64 return L + lambda * reg; 65 }; 66 67 // Training: simple full-batch gradient descent 68 for (int e = 0; e < epochs; ++e) { 69 vector<double> grad(d, 0.0); 70 for (const auto& p: pairs) { 71 // compute per-pair gradient wrt w_k 72 double s = dsq_diag(X[p.i], X[p.j], w); // s = sum_k w_k * (dx_k)^2 73 double dxy = sqrt(max(1e-12, s)); 74 for (int k = 0; k < d; ++k) { 75 double dx = X[p.i][k] - X[p.j][k]; 76 double gk_sim = dx * dx; // d/dw_k of s 77 if (p.y == 1) { 78 grad[k] += gk_sim; // derivative of s 79 } else { 80 double h = margin - dxy; 81 if (h > 0) { 82 // d/dw_k ( (m - sqrt(s))^2 ) = - (m - sqrt(s)) / sqrt(s) * (dx^2) 83 grad[k] += - (h / dxy) * gk_sim; 84 } 85 } 86 } 87 } 88 // Add L2 gradient and update with learning rate 89 for (int k = 0; k < d; ++k) { 90 grad[k] += 2.0 * lambda * w[k]; 91 w[k] -= lr * grad[k]; 92 // Enforce PSD (nonnegative diagonal) 93 if (w[k] < 0.0) w[k] = 0.0; 94 } 95 if ((e+1) % 50 == 0) { 96 cout << "Epoch " << (e+1) << ", loss = " << total_loss(w) << ", w = [" << w[0] << ", " << w[1] << "]\n"; 97 } 98 } 99 100 // Show distances after learning 101 cout << fixed << setprecision(4); 102 cout << "\nLearned diagonal weights w: [" << w[0] << ", " << w[1] << "]\n"; 103 cout << "Sample squared distances (within vs across classes):\n"; 104 auto print_pair = [&](int i, int j){ 105 cout << "d^2(x"<<i<<", x"<<j<<") = " << dsq_diag(X[i], X[j], w) 106 << " (label sim? " << (y[i]==y[j] ? "yes" : "no") << ")\n"; 107 }; 108 print_pair(0, 1); // similar 109 print_pair(0, 3); // dissimilar 110 111 return 0; 112 } 113
This program learns nonnegative diagonal weights w for a Mahalanobis metric using contrastive loss on a small synthetic dataset. Similar pairs are pulled together by minimizing squared distance; dissimilar pairs are pushed apart to be at least a margin apart. We enforce PSD by clamping weights to be nonnegative. The result is a task-tuned per-feature reweighting.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Squared distance under diagonal metric 5 double dsq_diag(const vector<double>& a, const vector<double>& b, const vector<double>& w) { 6 double s = 0.0; 7 for (size_t k = 0; k < w.size(); ++k) { 8 double d = a[k] - b[k]; 9 s += w[k] * d * d; 10 } 11 return s; 12 } 13 14 int predictKNN(const vector<vector<double>>& Xtr, const vector<int>& ytr, 15 const vector<double>& xq, int K, const vector<double>& w) { 16 vector<pair<double,int>> dist; 17 dist.reserve(Xtr.size()); 18 for (size_t i = 0; i < Xtr.size(); ++i) { 19 dist.push_back({ dsq_diag(Xtr[i], xq, w), (int)i }); // squared distances suffice for ranking 20 } 21 nth_element(dist.begin(), dist.begin() + K, dist.end()); 22 unordered_map<int, int> vote; 23 for (int i = 0; i < K; ++i) vote[ytr[dist[i].second]]++; 24 // majority vote 25 int bestLabel = -1, bestCnt = -1; 26 for (auto &kv : vote) { 27 if (kv.second > bestCnt) { bestCnt = kv.second; bestLabel = kv.first; } 28 } 29 return bestLabel; 30 } 31 32 int main() { 33 ios::sync_with_stdio(false); 34 cin.tie(nullptr); 35 36 // Tiny train set in 2D 37 vector<vector<double>> Xtr = { 38 {0.0, 0.0}, {0.3, -0.1}, {0.1, 0.2}, // class 0 39 {2.5, 2.0}, {2.9, 2.2}, {3.1, 1.9} // class 1 40 }; 41 vector<int> ytr = {0,0,0, 1,1,1}; 42 43 // Two metrics: Euclidean (w=[1,1]) vs learned emphasis on x-dimension (w=[4,1]) 44 vector<double> w_euclid = {1.0, 1.0}; 45 vector<double> w_learned = {4.0, 1.0}; 46 47 vector<vector<double>> Xtest = {{0.2, 0.1}, {2.8, 2.1}, {1.5, 1.0}}; 48 int K = 3; 49 50 cout << "Comparing k-NN predictions (K=3) under two metrics:\n"; 51 for (size_t i = 0; i < Xtest.size(); ++i) { 52 int pe = predictKNN(Xtr, ytr, Xtest[i], K, w_euclid); 53 int pl = predictKNN(Xtr, ytr, Xtest[i], K, w_learned); 54 cout << "x_test["<<i<<"] => Euclid: " << pe << ", Learned-diag: " << pl << "\n"; 55 } 56 57 return 0; 58 } 59
This example shows how a learned diagonal metric (here manually set to emphasize the first feature) changes k-NN decisions. Squared distances are used for ranking, which is equivalent to using distances. In practice, use the learned weights from training (as in the previous example).