Self-Supervised Learning Theory
Key Points
- •Self-supervised learning (SSL) teaches models to learn useful representations from unlabeled data by solving proxy tasks created directly from the data.
- •The core idea is to enforce agreement between different views or augmentations of the same input while distinguishing from other inputs.
- •Contrastive methods (e.g., InfoNCE) use positives and negatives, while non-contrastive methods (e.g., autoencoders, BYOL, Barlow Twins, VICReg) avoid explicit negatives but include collapse-preventing terms.
- •Theoretical lenses include mutual information bounds, alignment–uniformity trade-offs, redundancy reduction, and invariance/equivariance to transformations.
- •Good augmentations define what invariances the representation should learn and are as important as the model or loss.
- •Batch size, temperature, and normalization strongly affect optimization stability and downstream performance.
- •Representations learned by SSL transfer well to many tasks (classification, retrieval, detection) with little or no labeled data.
- •C++ implementations from scratch are feasible for toy SSL objectives (autoencoders, linear contrastive models), but practical systems use optimized libraries.
Prerequisites
- →Linear Algebra — Embedding spaces, projections, matrix–vector multiplication, and normalization are central to SSL.
- →Probability and Information Theory — Understanding expectations, softmax, and mutual information bounds clarifies SSL objectives.
- →Optimization and Gradient Descent — Training SSL models requires computing gradients and updating parameters stably.
- →Vector Calculus / Matrix Calculus — Backpropagation through normalization and similarity functions depends on Jacobians.
- →C++ Programming (STL, numeric) — Implementing from scratch needs control over arrays, loops, and random number generation.
- →Neural Network Basics — Encoders, projection heads, losses, and activations benefit from foundational NN concepts.
- →Data Augmentation — Augmentations define which invariances SSL learns and must be chosen carefully.
Detailed Explanation
Tap terms for definitions01Overview
Self-supervised learning (SSL) is a family of techniques for learning data representations without human-provided labels. Instead of predicting external annotations, SSL creates internal supervision signals by transforming the data and asking the model to solve a proxy task. Typical examples include predicting masked parts of the input, reconstructing an input from a compressed code, or making different augmented views of the same input map to similar embeddings. The learned encoder captures structure and invariances present in the data that often transfer to many downstream tasks with minimal fine-tuning. SSL has been particularly successful in vision (SimCLR, BYOL, DINO), language (masked language modeling), and multimodal learning (CLIP), where labeling is expensive or ambiguous. Theoretically, SSL objectives can be framed using ideas like mutual information, alignment and uniformity of embeddings on the unit sphere, and redundancy reduction across feature dimensions. Practically, implementations fluctuate between contrastive formulations that use negative examples via softmax normalization and non-contrastive objectives that rely on architectural or regularization tricks to prevent trivial solutions. Key design choices include the family of augmentations, the embedding normalization, the temperature in softmax, and the batch size.
02Intuition & Analogies
Imagine learning to recognize a song by hearing it played on different instruments and at different tempos. Even though the sound changes, you still know it is the same melody. SSL builds this intuition into algorithms: it shows a model different “views” of the same object—like photos under different lighting or texts with some words masked—and asks the model to agree that they are the same underlying thing. At the same time, it should not confuse one song with another; contrastive SSL enforces this by pushing different songs apart. Another analogy is assembling a jigsaw puzzle. You learn the picture by figuring out how pieces fit together without anyone telling you what the final image is. Autoencoders do something similar: they try to compress and then reconstruct an input, discovering structure that makes reconstruction possible. Non-contrastive methods are like practicing with a partner who mirrors your moves; the goal is to match your partner’s motion across different viewpoints while avoiding the trivial solution of doing nothing. To avoid this, you might keep track of both the motion and its diversity—ensuring each move (feature) brings something unique and that not all moves are identical. In SSL, that corresponds to variance and covariance regularization, or using separate but related networks whose parameters evolve slowly. These intuitions guide the design: make same-object views align, different objects spread out, and features remain informative and non-redundant.
03Formal Definition
04When to Use
Use self-supervised learning when labeled data is scarce, expensive, or noisy, but unlabeled data is abundant. It is especially effective in pretraining encoders that will be fine-tuned for downstream tasks like classification, detection, retrieval, or segmentation. In computer vision, apply SSL when augmentations (cropping, color jitter, blur) capture semantic-preserving transformations. In NLP, use masking or next-token prediction to learn strong language representations. For multimodal tasks (e.g., image–text), contrastive objectives align modalities without dense labels. SSL is also useful for anomaly detection: learning a compact representation of normal data lets you flag unusual patterns. If you need representations invariant to specific nuisance factors (lighting, viewpoint) but sensitive to relevant changes, craft augmentations to encode these invariances. When compute or batch size is limited, prefer non-contrastive losses that do not require many negatives, or use memory banks/queues. For small-scale problems or education, simple linear autoencoders or linear contrastive models illustrate the principles and can be implemented from scratch in C++.
⚠️Common Mistakes
• Poor augmentations: If augmentations change semantics (e.g., random erasing too aggressively), the model is forced to match different objects, hurting representation quality. Design transformations that preserve labels you care about downstream. • Collapse: Non-contrastive objectives can converge to constant embeddings. Prevent this with variance or decorrelation penalties, stop-gradient and EMA targets, or architectural asymmetry (prediction head). • Insufficient negatives or small batch: Contrastive methods often need enough negatives to approximate the partition function. If batch size is small, raise temperature, use a memory bank, or switch to non-contrastive methods. • Ignoring normalization: Many SSL losses assume unit-norm embeddings. Without normalization, scale can dominate similarity and destabilize training. • Overfitting to augmentations: If augmentations are too weak, the model memorizes superficial cues; if too strong, the task becomes impossible. Tune augmentation strength and diversity. • Mis-tuned temperature and learning rate: Temperature controls softness of the softmax; too low causes vanishing gradients for most pairs, too high weakens discrimination. Learning rate and weight decay also interact with normalization. • Forgetting projection heads: Often, a projection head g improves the SSL loss landscape while the encoder f retains transferable features. Skipping g can degrade transfer performance. • Evaluating only by pretext loss: Low pretext loss does not guarantee good downstream performance. Always validate with linear probing or fine-tuning on labeled subsets.
Key Formulas
Generic SSL Objective
Explanation: The loss is computed between two transformed views of the same input after passing through an encoder and optional heads. This template covers contrastive and non-contrastive SSL.
InfoNCE / NT-Xent
Explanation: Each positive pair competes against all other pairs in the batch via a softmax over similarities. Lower temperature sharpens the distribution and emphasizes the top match.
Mutual Information Lower Bound
Explanation: The InfoNCE loss lower-bounds the mutual information between two views. As the loss decreases, the estimated mutual information increases.
Cosine Similarity
Explanation: Cosine similarity measures the angle between two vectors and is insensitive to scale. Normalizing embeddings ensures similarity lies in [-1, 1].
Softmax
Explanation: Softmax converts logits into a probability distribution. In contrastive learning, logits are scaled similarities divided by temperature.
Autoencoder Loss
Explanation: An autoencoder minimizes reconstruction error, encouraging the latent code to capture structure sufficient to rebuild the input.
Negative Cosine with Stop-Gradient
Explanation: BYOL minimizes the negative cosine similarity between an online prediction and a target embedding, with gradients stopped through the target.
Barlow Twins
Explanation: The cross-correlation between two views is pushed towards the identity matrix, encouraging invariance (diagonal near 1) and decorrelation (off-diagonals near 0).
VICReg
Explanation: Combines an invariance term (match views), a variance term (keep per-dimension variance above a threshold), and a covariance penalty (reduce redundancy).
Normalization Jacobian
Explanation: When normalizing embeddings, gradients through the normalization step are projected onto the tangent plane of the unit sphere. This is needed to backpropagate InfoNCE with normalized vectors.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Utility functions for small linear algebra 5 vector<double> matvec(const vector<vector<double>>& W, const vector<double>& x) { 6 int r = (int)W.size(); 7 int c = (int)W[0].size(); 8 vector<double> y(r, 0.0); 9 for (int i = 0; i < r; ++i) { 10 double s = 0.0; 11 for (int j = 0; j < c; ++j) s += W[i][j] * x[j]; 12 y[i] = s; 13 } 14 return y; 15 } 16 17 void axpy(vector<double>& y, const vector<double>& x, double a) { 18 for (size_t i = 0; i < y.size(); ++i) y[i] += a * x[i]; 19 } 20 21 // Outer product accumulation: A += a * (u v^T) 22 void outer_acc(vector<vector<double>>& A, const vector<double>& u, const vector<double>& v, double a=1.0) { 23 int r = (int)A.size(); 24 int c = (int)A[0].size(); 25 for (int i = 0; i < r; ++i) 26 for (int j = 0; j < c; ++j) 27 A[i][j] += a * u[i] * v[j]; 28 } 29 30 vector<double> sub(const vector<double>& a, const vector<double>& b) { 31 vector<double> c(a.size()); 32 for (size_t i = 0; i < a.size(); ++i) c[i] = a[i] - b[i]; 33 return c; 34 } 35 36 double l2(const vector<double>& x) { 37 double s = 0.0; for (double v : x) s += v*v; return sqrt(max(1e-12, s)); 38 } 39 40 int main() { 41 ios::sync_with_stdio(false); 42 cin.tie(nullptr); 43 44 // Synthetic data: 3D points lying near a 2D plane (intrinsic dim=2) 45 const int d = 3; // input dimension 46 const int k = 2; // latent dimension 47 const int N = 2000; // number of samples 48 const int epochs = 200; 49 const int batch = 64; 50 const double lr = 1e-2; // learning rate 51 const double wd = 1e-4; // weight decay (L2 regularization) 52 53 mt19937 rng(42); 54 normal_distribution<double> noise(0.0, 0.05); 55 uniform_real_distribution<double> uni(-2.0, 2.0); 56 57 vector<vector<double>> X(N, vector<double>(d)); 58 for (int i = 0; i < N; ++i) { 59 double u = uni(rng), v = uni(rng); 60 // Define a plane: x3 = 0.5*u + 0.3*v + small noise 61 X[i][0] = u + noise(rng); 62 X[i][1] = v + noise(rng); 63 X[i][2] = 0.5*u + 0.3*v + noise(rng); 64 } 65 66 // Parameters: encoder Wenc (k x d), decoder Wdec (d x k) 67 vector<vector<double>> Wenc(k, vector<double>(d)); 68 vector<vector<double>> Wdec(d, vector<double>(k)); 69 normal_distribution<double> init(0.0, 0.1); 70 for (int i = 0; i < k; ++i) for (int j = 0; j < d; ++j) Wenc[i][j] = init(rng); 71 for (int i = 0; i < d; ++i) for (int j = 0; j < k; ++j) Wdec[i][j] = init(rng); 72 73 vector<int> idx(N); iota(idx.begin(), idx.end(), 0); 74 75 for (int e = 1; e <= epochs; ++e) { 76 shuffle(idx.begin(), idx.end(), rng); 77 double epoch_loss = 0.0; 78 79 for (int s = 0; s < N; s += batch) { 80 int b = min(batch, N - s); 81 // Gradient accumulators (same shape as params) 82 vector<vector<double>> gWenc(k, vector<double>(d, 0.0)); 83 vector<vector<double>> gWdec(d, vector<double>(k, 0.0)); 84 double batch_loss = 0.0; 85 86 for (int t = 0; t < b; ++t) { 87 const vector<double>& x = X[idx[s + t]]; 88 // Forward: z = Wenc x; x_hat = Wdec z 89 vector<double> z = matvec(Wenc, x); // k 90 vector<double> xhat = matvec(Wdec, z); // d 91 vector<double> err = sub(xhat, x); // d 92 // Loss: 0.5 * ||err||^2 93 double l = 0.0; for (double v : err) l += 0.5 * v * v; batch_loss += l; 94 95 // Backprop 96 // dL/dWdec += err * z^T 97 outer_acc(gWdec, err, z, 1.0); 98 // dL/dz = Wdec^T err 99 vector<double> dL_dz(k, 0.0); 100 for (int i = 0; i < k; ++i) { 101 double ssum = 0.0; 102 for (int j = 0; j < d; ++j) ssum += Wdec[j][i] * err[j]; 103 dL_dz[i] = ssum; 104 } 105 // dL/dWenc += (dL/dz) * x^T 106 outer_acc(gWenc, dL_dz, x, 1.0); 107 } 108 109 // Average gradients over batch and add weight decay 110 double scale = 1.0 / b; 111 for (int i = 0; i < k; ++i) 112 for (int j = 0; j < d; ++j) 113 gWenc[i][j] = scale * gWenc[i][j] + wd * Wenc[i][j]; 114 for (int i = 0; i < d; ++i) 115 for (int j = 0; j < k; ++j) 116 gWdec[i][j] = scale * gWdec[i][j] + wd * Wdec[i][j]; 117 118 // Update parameters (SGD) 119 for (int i = 0; i < k; ++i) 120 for (int j = 0; j < d; ++j) 121 Wenc[i][j] -= lr * gWenc[i][j]; 122 for (int i = 0; i < d; ++i) 123 for (int j = 0; j < k; ++j) 124 Wdec[i][j] -= lr * gWdec[i][j]; 125 126 epoch_loss += batch_loss / b; 127 } 128 129 epoch_loss /= (N / batch); 130 if (e % 20 == 0 || e == 1) 131 cout << "Epoch " << e << ": avg recon loss = " << epoch_loss << "\n"; 132 } 133 134 // After training, print reconstruction error on a few samples 135 cout << "Sample reconstructions (x vs x_hat):\n"; 136 for (int n = 0; n < 3; ++n) { 137 const vector<double>& x = X[n]; 138 vector<double> z = matvec(Wenc, x); 139 vector<double> xhat = matvec(Wdec, z); 140 cout << fixed << setprecision(3); 141 cout << "x: "; for (double v : x) cout << v << ' '; cout << '\n'; 142 cout << "x_hat: "; for (double v : xhat) cout << v << ' '; cout << "\n\n"; 143 } 144 145 return 0; 146 } 147
This program implements a linear autoencoder in C++ without external libraries. The encoder compresses 3D inputs to 2D, and the decoder reconstructs them. We generate synthetic data that lies near a 2D plane, so an effective autoencoder should learn a near-lossless 2D representation. The code performs forward and backward passes with explicit gradients for a squared-error loss and updates parameters using SGD with weight decay. Linear autoencoders on centered data approximate PCA; thus the learned 2D subspace captures the principal directions of variation.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Small linear algebra helpers 5 vector<double> matvec(const vector<vector<double>>& W, const vector<double>& x) { 6 int r = (int)W.size(); int c = (int)W[0].size(); 7 vector<double> y(r, 0.0); 8 for (int i = 0; i < r; ++i) { 9 double s = 0.0; for (int j = 0; j < c; ++j) s += W[i][j] * x[j]; 10 y[i] = s; 11 } 12 return y; 13 } 14 15 double norm2(const vector<double>& v) { 16 double s = 0.0; for (double x : v) s += x * x; return sqrt(max(1e-12, s)); 17 } 18 19 vector<double> normalize(const vector<double>& v) { 20 double r = norm2(v); vector<double> y = v; for (double &x : y) x /= r; return y; 21 } 22 23 void outer_acc(vector<vector<double>>& A, const vector<double>& u, const vector<double>& v, double a=1.0) { 24 int r = (int)A.size(); int c = (int)A[0].size(); 25 for (int i = 0; i < r; ++i) 26 for (int j = 0; j < c; ++j) 27 A[i][j] += a * u[i] * v[j]; 28 } 29 30 vector<double> add(const vector<double>& a, const vector<double>& b, double alpha=1.0) { 31 vector<double> c(a.size()); 32 for (size_t i = 0; i < a.size(); ++i) c[i] = a[i] + alpha * b[i]; 33 return c; 34 } 35 36 // Apply a small random rotation and Gaussian noise to 2D inputs (augmentations) 37 vector<double> augment(const vector<double>& x, mt19937& rng) { 38 normal_distribution<double> gauss(0.0, 0.05); 39 uniform_real_distribution<double> angle(-0.25, 0.25); // radians 40 double th = angle(rng); 41 double c = cos(th), s = sin(th); 42 vector<double> y(2); 43 y[0] = c * x[0] - s * x[1]; 44 y[1] = s * x[0] + c * x[1]; 45 y[0] += gauss(rng); y[1] += gauss(rng); 46 return y; 47 } 48 49 int main() { 50 ios::sync_with_stdio(false); 51 cin.tie(nullptr); 52 53 // Synthetic 2D dataset on a ring (no labels used) 54 const int d = 2; // input dim 55 const int k = 2; // embedding dim (kept small for demo) 56 const int N = 1024; // samples 57 const int epochs = 200; 58 const int B = 64; // batch size 59 const double lr = 5e-2; // learning rate 60 const double wd = 1e-4; // weight decay 61 const double tau = 0.2; // temperature for InfoNCE 62 63 mt19937 rng(123); 64 vector<vector<double>> X(N, vector<double>(d)); 65 // Points on a noisy circle of radius ~1.0 66 uniform_real_distribution<double> ang(0.0, 2.0 * M_PI); 67 normal_distribution<double> rad_noise(0.0, 0.05); 68 for (int i = 0; i < N; ++i) { 69 double a = ang(rng); 70 double r = 1.0 + rad_noise(rng); 71 X[i][0] = r * cos(a); 72 X[i][1] = r * sin(a); 73 } 74 75 // Shared linear encoder W (k x d) 76 vector<vector<double>> W(k, vector<double>(d)); 77 normal_distribution<double> init(0.0, 0.1); 78 for (int i = 0; i < k; ++i) for (int j = 0; j < d; ++j) W[i][j] = init(rng); 79 80 vector<int> idx(N); iota(idx.begin(), idx.end(), 0); 81 82 for (int e = 1; e <= epochs; ++e) { 83 shuffle(idx.begin(), idx.end(), rng); 84 double epoch_loss = 0.0; int batches = 0; 85 86 for (int s = 0; s < N; s += B) { 87 int b = min(B, N - s); 88 vector<vector<double>> v1(b, vector<double>(d)), v2(b, vector<double>(d)); 89 for (int i = 0; i < b; ++i) { 90 v1[i] = augment(X[idx[s + i]], rng); 91 v2[i] = augment(X[idx[s + i]], rng); 92 } 93 94 // Forward: embeddings z1, z2 (normalized) 95 vector<vector<double>> u1(b, vector<double>(k)), u2(b, vector<double>(k)); 96 vector<vector<double>> z1(b, vector<double>(k)), z2(b, vector<double>(k)); 97 for (int i = 0; i < b; ++i) { 98 u1[i] = matvec(W, v1[i]); 99 u2[i] = matvec(W, v2[i]); 100 z1[i] = normalize(u1[i]); 101 z2[i] = normalize(u2[i]); 102 } 103 104 // Similarity logits l_{i,j} = (z1_i^T z2_j) / tau 105 vector<vector<double>> logits(b, vector<double>(b)); 106 for (int i = 0; i < b; ++i) 107 for (int j = 0; j < b; ++j) { 108 double dot = 0.0; for (int t = 0; t < k; ++t) dot += z1[i][t] * z2[j][t]; 109 logits[i][j] = dot / tau; 110 } 111 112 // Row-wise softmax and loss 113 vector<vector<double>> probs(b, vector<double>(b)); 114 double batch_loss = 0.0; 115 for (int i = 0; i < b; ++i) { 116 double m = *max_element(logits[i].begin(), logits[i].end()); 117 double denom = 0.0; 118 for (int j = 0; j < b; ++j) denom += exp(logits[i][j] - m); 119 for (int j = 0; j < b; ++j) probs[i][j] = exp(logits[i][j] - m) / denom; 120 batch_loss += -log(max(1e-12, probs[i][i])); 121 } 122 batch_loss /= b; 123 124 // Backprop through softmax cross-entropy: dL/dlogit = probs - one_hot 125 vector<vector<double>> dL_dlogit = probs; 126 for (int i = 0; i < b; ++i) dL_dlogit[i][i] -= 1.0; 127 for (int i = 0; i < b; ++i) 128 for (int j = 0; j < b; ++j) 129 dL_dlogit[i][j] /= b; // average over batch 130 131 // Gradients w.r.t. z1 and z2: logit = (z1_i^T z2_j)/tau 132 vector<vector<double>> dL_dz1(b, vector<double>(k, 0.0)); 133 vector<vector<double>> dL_dz2(b, vector<double>(k, 0.0)); 134 for (int i = 0; i < b; ++i) { 135 for (int j = 0; j < b; ++j) { 136 double g = dL_dlogit[i][j] / tau; // derivative through scaling 137 for (int t = 0; t < k; ++t) { 138 dL_dz1[i][t] += g * z2[j][t]; 139 dL_dz2[j][t] += g * z1[i][t]; 140 } 141 } 142 } 143 144 // Backprop through normalization: z = u / ||u|| => du = (I - z z^T) d z / ||u|| 145 vector<vector<double>> gW(k, vector<double>(d, 0.0)); 146 for (int i = 0; i < b; ++i) { 147 // branch 1 148 double r1 = norm2(u1[i]); 149 // (I - z1 z1^T) * dL_dz1 150 vector<double> proj1(k); 151 double dot1 = 0.0; for (int t = 0; t < k; ++t) dot1 += z1[i][t] * dL_dz1[i][t]; 152 for (int t = 0; t < k; ++t) proj1[t] = (dL_dz1[i][t] - dot1 * z1[i][t]) / r1; 153 outer_acc(gW, proj1, v1[i], 1.0); 154 // branch 2 155 double r2 = norm2(u2[i]); 156 vector<double> proj2(k); 157 double dot2 = 0.0; for (int t = 0; t < k; ++t) dot2 += z2[i][t] * dL_dz2[i][t]; 158 for (int t = 0; t < k; ++t) proj2[t] = (dL_dz2[i][t] - dot2 * z2[i][t]) / r2; 159 outer_acc(gW, proj2, v2[i], 1.0); 160 } 161 162 // Average gradients and add weight decay 163 for (int i = 0; i < k; ++i) 164 for (int j = 0; j < d; ++j) 165 gW[i][j] = gW[i][j] / 1.0 + wd * W[i][j]; // already averaged via /b in dlogit 166 167 // SGD update 168 for (int i = 0; i < k; ++i) 169 for (int j = 0; j < d; ++j) 170 W[i][j] -= lr * gW[i][j]; 171 172 epoch_loss += batch_loss; ++batches; 173 } 174 175 if (e % 20 == 0 || e == 1) 176 cout << "Epoch " << e << ": InfoNCE loss = " << (epoch_loss / batches) << "\n"; 177 } 178 179 // Print a few normalized embeddings 180 cout << fixed << setprecision(3); 181 for (int i = 0; i < 5; ++i) { 182 vector<double> u = matvec(W, X[i]); 183 vector<double> z = normalize(u); 184 cout << "Embedding[" << i << "]: (" << z[0] << ", " << z[1] << ")\n"; 185 } 186 187 return 0; 188 } 189
This program trains a shared linear encoder with an InfoNCE loss on augmented 2D points. Each input produces two augmented views; embeddings are normalized and their cosine similarities form logits for a row-wise softmax against all positives in the batch. We derive gradients analytically: first through the softmax cross-entropy, then through the cosine-similarity logits, and finally through the L2 normalization Jacobian. Even with a linear encoder, the model learns to align augmented views while spreading representations on the unit circle, illustrating alignment and uniformity.