🎓How I Study AIHISA
📖Read
📄Papers📰Blogs🎬Courses
💡Learn
🛤️Paths📚Topics💡Concepts🎴Shorts
🎯Practice
⏱️Coach🧩Problems🧠Thinking🎯Prompts🧠Review
SearchSettings
How I Study AI - Learn AI Papers & Lectures the Easy Way
📚TheoryIntermediate

Knowledge Distillation Loss

Key Points

  • •
    Knowledge distillation loss blends standard hard-label cross-entropy with a soft distribution match from a teacher using a temperature parameter.
  • •
    The soft part is typically a KL divergence between teacher and student softmaxes at temperature τ, scaled by τ² for proper gradient magnitude.
  • •
    Temperature τ > 1 softens logits so the teacher reveals relative similarities (“dark knowledge”) among classes.
  • •
    The full loss is L = α Lh​ard + (1−α) τ² Ls​oft, where α balances the two terms.
  • •
    Computing the loss stably requires log-sum-exp tricks to avoid overflow in softmax and cross-entropy.
  • •
    The KL divergence must be taken in the correct direction: DK​L(teacher || student), not the reverse.
  • •
    Implementation is O(B⋅C) per batch (B = batch size, C = number of classes) with small memory overhead.
  • •
    In C++, you can implement temperature softmax, cross-entropy, and KL divergence directly and train a simple linear classifier with SGD.

Prerequisites

  • →Softmax and Cross-Entropy — KD builds on softmax probabilities and the cross-entropy loss for hard labels.
  • →KL Divergence and Basic Information Theory — The soft loss uses KL divergence to match teacher and student distributions.
  • →Gradient Descent and Chain Rule — Training requires computing gradients with respect to logits and parameters.
  • →Numerical Stability (Log-Sum-Exp) — Stable softmax/log-softmax are essential to avoid overflow/underflow at higher temperatures.
  • →Linear Classifiers and Matrix Multiplication — The training example backpropagates gradients through a linear model.

Detailed Explanation

Tap terms for definitions

01Overview

Knowledge distillation (KD) is a training technique where a smaller, faster “student” model learns not only from ground-truth labels but also from the output probabilities of a larger, more accurate “teacher” model. The key idea is that the teacher’s output distribution over classes contains richer information than a one-hot label. For instance, even when the correct class is “cat,” a good teacher might assign non-zero probabilities to “lynx” or “tiger,” revealing semantic similarities among classes. KD captures this by defining a loss that combines two parts: (1) the standard cross-entropy with true labels (hard labels), and (2) a divergence term that encourages the student’s predictions to match the teacher’s softened probabilities (soft labels) at a higher temperature τ. The temperature smooths the distributions, making it easier for the student to learn relative class preferences. The overall training objective is a weighted sum of these two components, controlled by α. KD is widely used to compress models for deployment on limited hardware, improve calibration, and sometimes even to regularize models for better generalization. The method is simple to implement and can be applied to classification tasks, sequence models, and beyond, making it a practical and powerful tool in modern machine learning.

02Intuition & Analogies

Imagine taking a multiple-choice test with a mentor. If the mentor only tells you which option is correct, you learn the final answer but not the reasoning. If instead the mentor also says, “The correct answer is A, but B is close, C is plausible in some contexts, and D is very unlikely,” you gain a deeper sense of how similar each option is to the truth. Knowledge distillation works the same way. The teacher’s prediction vector is like the mentor’s nuanced commentary: it encodes not just the winner but how the model ranks all options. The temperature τ is the dial that controls how revealing this commentary is. With τ = 1, the probabilities may be very peaky, hiding relationships among classes. Turning τ up (e.g., τ = 2, 4) softens the peaks, making secondary choices more visible to the student. The student then learns not only to get the right answer but also to understand the landscape of near-misses, which often leads to better generalization. The α parameter is like deciding how much to trust the official answer sheet (hard labels) versus the mentor’s hints (soft labels). If you set α too high, you might ignore valuable hints; too low, and you might drift from the actual answer key. A good balance helps the student become both accurate and well-calibrated, often achieving performance close to the teacher while being much smaller and faster.

03Formal Definition

Consider a classification problem with C classes. For an input x, let z(s) ∈ RC and z(t) ∈ RC be the student and teacher logits, respectively. Define the temperature-softened probabilities: p(s,τ) = softmax(z(s)/τ) and p(t,τ) = softmax(z(t)/τ), where τ > 1 typically. Let y be the ground-truth class index and ey​ be its one-hot vector. The hard-label loss is the standard cross-entropy Lhard​ = -log py(s,1)​. The soft-label loss is a Kullback–Leibler divergence from teacher to student at temperature τ: L_{soft} = D_{KL}(p^{(t,τ)} ∥ p(s,τ)) = ∑c=1C​ pc(t,τ)​\,log pc(s,τ)​pc(t,τ)​​. The full knowledge distillation objective for a single sample is L = α Lhard​ + (1-α)\,τ2 Lsoft​, with α ∈ [0,1]. The τ2 factor maintains comparable gradient magnitudes across temperatures. For a batch of size B, losses are averaged over samples. The common training practice is to stop gradients through the teacher (treating p(t,τ) as constants) and to use numerically stable softmax/log-softmax computations (e.g., via log-sum-exp).

04When to Use

Use knowledge distillation when you need to compress a large, accurate teacher into a smaller, faster student for deployment on resource-limited devices (mobile, edge, real-time systems). It is effective when latency, memory, or energy constraints prevent using the teacher directly. KD can also help when labeled data are scarce or noisy: the teacher’s soft labels act as a regularizer, guiding the student away from overfitting and toward better-calibrated probabilities. In multiclass problems with many classes or class confusion (e.g., fine-grained recognition), the teacher’s softened distribution conveys informative structure that one-hot labels miss. KD is also useful for self-distillation (teacher and student share the same architecture), continuing training with better calibration, or distilling across modalities (e.g., from an ensemble or a vision-language model to a unimodal student). Additionally, if you already have a strong model serving offline, you can train a lighter student to run online, syncing periodically to keep quality high while controlling costs. Choose KD especially when you observe overconfident outputs, desire smoother decision boundaries, or want to blend knowledge from ensembles into a single deployable model.

⚠️Common Mistakes

  • Using the wrong KL direction. The soft loss should be D_{KL}(teacher \Vert student), not D_{KL}(student \Vert teacher). Reversing it changes gradients and can harm learning.
  • Forgetting \tau^{2} scaling. If you omit the \tau^{2} factor, gradients shrink with larger \tau, weakening the soft-target learning signal.
  • Mixing logits and probabilities. Apply softmax to logits at the correct temperature before computing cross-entropy or KL; do not feed raw logits directly into KL.
  • Inconsistent temperatures. Both teacher and student soft distributions in the soft loss must use the same \tau; do not softmax the teacher at \tau and the student at 1.
  • Numerical instability. Compute softmax/log-softmax with max subtraction and log-sum-exp to avoid overflow/underflow, especially with large |logits| or large \tau.
  • Not averaging over batch. Always average the loss (and gradients) over the batch to keep learning-rate behavior consistent across batch sizes.
  • Backpropagating through the teacher. Typically, the teacher is fixed; ensure its outputs are treated as constants during student training.
  • Over- or under-weighting α. Extreme α values can ignore either ground truth (too small) or teacher guidance (too large). Tune α and \tau jointly.

Key Formulas

Temperature-Scaled Softmax

pi(τ)​=softmax(τzi​​)=∑j=1C​exp(τzj​​)exp(τzi​​)​

Explanation: Softmax with temperature τ. Larger τ makes the distribution flatter; τ=1 recovers the standard softmax.

Hard-Label Cross-Entropy

Lhard​=−logpy(1)​

Explanation: The cross-entropy loss for a one-hot target equals the negative log probability assigned to the true class.

KL Divergence

DKL​(q∥p)=i=1∑C​qi​logpi​qi​​

Explanation: Measures how distribution q differs from p. It is asymmetric; swapping q and p changes the value and gradients.

Knowledge Distillation Loss

L=αLhard​+(1−α)τ2DKL​(p(t,τ)∥p(s,τ))

Explanation: The total loss combines hard-label cross-entropy with a temperature-scaled KL divergence from teacher to student. α balances the two parts and τ² preserves gradient scale.

Gradient of Hard CE w.r.t. Student Logits

∇z(s)​Lhard​=p(s,1)−ey​

Explanation: For standard cross-entropy with one-hot targets, the gradient equals the difference between predicted probabilities and the one-hot vector.

Gradient of Soft KD Term

∇z(s)​[τ2DKL​(p(t,τ)∥p(s,τ))]=τ(p(s,τ)−p(t,τ))

Explanation: With both teacher and student evaluated at the same temperature τ, the gradient simplifies to τ times the difference of their softened probabilities.

Cross-Entropy–KL Identity

H(q,p)=−i=1∑C​qi​logpi​=H(q)+DKL​(q∥p)

Explanation: Cross-entropy decomposes into the entropy of q (constant if q is fixed) plus the KL divergence from q to p. Useful for understanding CE as a divergence.

Log-Sum-Exp Trick

logi=1∑C​ezi​=m+logi=1∑C​ezi​−m,m=imax​zi​

Explanation: Subtracting the maximum improves numerical stability when computing log-sum-exp for softmax or log-softmax calculations.

Batch-Averaged KD Loss

Lbatch​=B1​b=1∑B​[αLhard(b)​+(1−α)τ2DKL​(pb(t,τ)​∥pb(s,τ)​)]

Explanation: The KD loss is commonly averaged over the batch to keep gradients scale-invariant with respect to batch size.

Complexity Analysis

Let B be the batch size and C the number of classes. Computing temperature-scaled softmax for the student and teacher requires O(B⋅C) operations, as each sample computes exponentials and a sum over C classes. Using numerically stable log-sum-exp adds only constant overhead per class (one subtraction and a max scan that is also O(C)), so the asymptotic time is unchanged. The hard-label cross-entropy requires O(B) once the student probabilities or log-probabilities are available. The KL divergence term requires O(B⋅C) to accumulate qi​ log(qi​/pi​) across classes for each sample. Therefore, the total forward pass for the KD loss is O(B⋅C). If you additionally compute gradients with respect to the student logits, you again perform O(B⋅C) work: computing p(s,1), p(s,τ), p(t,τ), and forming the difference terms used in the gradients. Space-wise, if you stream per-sample computations, you can keep O(C) working memory per sample. In a vectorized batch implementation, you typically store several B×C matrices (student logits, teacher logits, probabilities at τ, probabilities at 1, and possibly intermediate buffers), leading to O(B⋅C) memory. Constants include the parameters α and τ and temporary scalars for max and sums. Precision can matter: using double-precision reduces underflow/overflow risk at high τ but may be slower; single-precision usually suffices with log-sum-exp. Overall, KD loss adds minimal overhead compared to standard cross-entropy, making it practical even for large-class problems. The dominant cost in training is usually the forward/backward passes of the models themselves, not the KD arithmetic, which remains linear in B and C.

Code Examples

Compute Knowledge Distillation Loss (single and batch) with stable softmax and KL
1#include <iostream>
2#include <vector>
3#include <cmath>
4#include <algorithm>
5#include <numeric>
6#include <cassert>
7
8// Numerically stable log-sum-exp
9double log_sum_exp(const std::vector<double>& z_over_tau) {
10 double m = *std::max_element(z_over_tau.begin(), z_over_tau.end());
11 double sum = 0.0;
12 for (double v : z_over_tau) sum += std::exp(v - m);
13 return m + std::log(sum);
14}
15
16// Compute softmax(z / tau) in a numerically stable way
17std::vector<double> softmax_temp(const std::vector<double>& z, double tau) {
18 std::vector<double> z_over_tau(z.size());
19 for (size_t i = 0; i < z.size(); ++i) z_over_tau[i] = z[i] / tau;
20 double lse = log_sum_exp(z_over_tau);
21 std::vector<double> p(z.size());
22 for (size_t i = 0; i < z.size(); ++i) p[i] = std::exp(z_over_tau[i] - lse);
23 return p; // sums to 1
24}
25
26// Cross-entropy with hard label y using student probabilities at tau=1
27// CE = -log p_y; we compute via log-softmax for stability
28double cross_entropy_hard(const std::vector<double>& logits, int y) {
29 std::vector<double> z1_over_tau(logits.size());
30 for (size_t i = 0; i < logits.size(); ++i) z1_over_tau[i] = logits[i]; // tau=1
31 double lse = log_sum_exp(z1_over_tau);
32 double log_py = logits[y] - lse; // log-softmax_y
33 return -log_py;
34}
35
36// KL divergence KL(q || p) where q and p are probability vectors
37// Assumes q_i >= 0, p_i > 0, sum to 1. Adds small epsilon for safety.
38double kl_divergence(const std::vector<double>& q, const std::vector<double>& p) {
39 const double eps = 1e-12;
40 assert(q.size() == p.size());
41 double kl = 0.0;
42 for (size_t i = 0; i < q.size(); ++i) {
43 double qi = std::max(q[i], 0.0);
44 double pi = std::max(p[i], eps);
45 if (qi > 0.0) kl += qi * (std::log(qi + eps) - std::log(pi));
46 }
47 return kl;
48}
49
50// Knowledge Distillation loss for a single example
51// L = alpha * CE_hard + (1-alpha) * tau^2 * KL(teacher_tau || student_tau)
52double kd_loss_single(const std::vector<double>& student_logits,
53 const std::vector<double>& teacher_logits,
54 int y,
55 double alpha, double tau) {
56 double L_hard = cross_entropy_hard(student_logits, y);
57 std::vector<double> p_s_tau = softmax_temp(student_logits, tau);
58 std::vector<double> p_t_tau = softmax_temp(teacher_logits, tau);
59 double L_soft = kl_divergence(p_t_tau, p_s_tau);
60 return alpha * L_hard + (1.0 - alpha) * (tau * tau) * L_soft;
61}
62
63// Batch KD loss: average over batch
64double kd_loss_batch(const std::vector<std::vector<double>>& S_logits,
65 const std::vector<std::vector<double>>& T_logits,
66 const std::vector<int>& y,
67 double alpha, double tau) {
68 size_t B = S_logits.size();
69 double sumL = 0.0;
70 for (size_t b = 0; b < B; ++b) {
71 sumL += kd_loss_single(S_logits[b], T_logits[b], y[b], alpha, tau);
72 }
73 return sumL / static_cast<double>(B);
74}
75
76int main() {
77 // Example 1: single sample, 3 classes
78 std::vector<double> student_logits = {2.0, 0.5, -1.0};
79 std::vector<double> teacher_logits = {3.0, 1.0, -0.5};
80 int y = 0; // ground-truth class
81 double alpha = 0.5;
82 double tau = 3.0;
83
84 double L_single = kd_loss_single(student_logits, teacher_logits, y, alpha, tau);
85 std::cout << "Single-sample KD loss = " << L_single << "\n";
86
87 // Example 2: batch of 2
88 std::vector<std::vector<double>> S_logits = {
89 {2.0, 0.5, -1.0},
90 {0.2, -0.1, 1.2}
91 };
92 std::vector<std::vector<double>> T_logits = {
93 {3.0, 1.0, -0.5},
94 {-0.2, 0.3, 2.0}
95 };
96 std::vector<int> labels = {0, 2};
97
98 double L_batch = kd_loss_batch(S_logits, T_logits, labels, alpha, tau);
99 std::cout << "Batch KD loss = " << L_batch << "\n";
100 return 0;
101}
102

This program implements numerically stable softmax with temperature, hard-label cross-entropy, KL divergence, and combines them into the standard KD loss L = α L_hard + (1−α) τ² KL(teacher_τ || student_τ). The single-sample and batch functions demonstrate how to compute the loss. Stability is ensured via the log-sum-exp trick for log-softmax computations.

Time: O(B⋅C) for a batch with B samples and C classes.Space: O(C) working memory per sample; O(B⋅C) if storing full batch matrices.
Train a linear student with KD against a fixed linear teacher (SGD)
1#include <iostream>
2#include <vector>
3#include <random>
4#include <cmath>
5#include <numeric>
6#include <algorithm>
7#include <cassert>
8
9// Utility: stable log-sum-exp for a vector
10double log_sum_exp(const std::vector<double>& v) {
11 double m = *std::max_element(v.begin(), v.end());
12 double sum = 0.0; for (double x : v) sum += std::exp(x - m);
13 return m + std::log(sum);
14}
15
16std::vector<double> softmax_tau(const std::vector<double>& z, double tau) {
17 std::vector<double> zt(z.size());
18 for (size_t i = 0; i < z.size(); ++i) zt[i] = z[i] / tau;
19 double lse = log_sum_exp(zt);
20 std::vector<double> p(z.size());
21 for (size_t i = 0; i < z.size(); ++i) p[i] = std::exp(zt[i] - lse);
22 return p;
23}
24
25// Compute logits: z = x * W, where x is (D), W is (D x C), result is (C)
26std::vector<double> logits_row(const std::vector<double>& x, const std::vector<std::vector<double>>& W) {
27 size_t D = x.size(), C = W[0].size();
28 std::vector<double> z(C, 0.0);
29 for (size_t c = 0; c < C; ++c) {
30 double s = 0.0;
31 for (size_t d = 0; d < D; ++d) s += x[d] * W[d][c];
32 z[c] = s;
33 }
34 return z;
35}
36
37// Build one-hot vector of size C for label y
38std::vector<double> one_hot(size_t C, int y) {
39 std::vector<double> v(C, 0.0); v[y] = 1.0; return v;
40}
41
42// Compute KD loss (batch-averaged) and gradient w.r.t. student logits
43// grad_z = alpha*(p1 - y_onehot) + (1-alpha)*tau*(p_tau - t_tau)
44struct KDResult { double loss; std::vector<std::vector<double>> grad_logits; };
45
46KDResult kd_loss_and_grad(const std::vector<std::vector<double>>& Zs, // B x C student logits
47 const std::vector<std::vector<double>>& Zt, // B x C teacher logits
48 const std::vector<int>& y, double alpha, double tau) {
49 size_t B = Zs.size(), C = Zs[0].size();
50 double loss_sum = 0.0;
51 std::vector<std::vector<double>> grad(B, std::vector<double>(C, 0.0));
52
53 for (size_t b = 0; b < B; ++b) {
54 // Hard CE via log-softmax
55 double lse1 = log_sum_exp(Zs[b]);
56 double log_py = Zs[b][y[b]] - lse1;
57 double Lhard = -log_py;
58
59 // Probabilities
60 std::vector<double> p1(C), ptau_s(C), ptau_t(C), oh = one_hot(C, y[b]);
61 for (size_t c = 0; c < C; ++c) p1[c] = std::exp(Zs[b][c] - lse1); // softmax tau=1
62 ptau_s = softmax_tau(Zs[b], tau);
63 ptau_t = softmax_tau(Zt[b], tau);
64
65 // KL(teacher||student) at tau
66 double Lsoft = 0.0;
67 const double eps = 1e-12;
68 for (size_t c = 0; c < C; ++c) {
69 double q = std::max(ptau_t[c], 0.0);
70 double p = std::max(ptau_s[c], eps);
71 if (q > 0.0) Lsoft += q * (std::log(q + eps) - std::log(p));
72 }
73
74 // Total per-sample loss
75 double L = alpha * Lhard + (1.0 - alpha) * (tau * tau) * Lsoft;
76 loss_sum += L;
77
78 // Gradient w.r.t. student logits
79 for (size_t c = 0; c < C; ++c) {
80 double g_hard = p1[c] - oh[c];
81 double g_soft = (ptau_s[c] - ptau_t[c]) * tau; // from tau^2 * KL
82 grad[b][c] = alpha * g_hard + (1.0 - alpha) * g_soft;
83 }
84 }
85
86 // Average over batch
87 for (size_t b = 0; b < B; ++b)
88 for (size_t c = 0; c < C; ++c)
89 grad[b][c] /= static_cast<double>(B);
90
91 KDResult res; res.loss = loss_sum / static_cast<double>(B); res.grad_logits = std::move(grad); return res;
92}
93
94int main() {
95 // Synthetic data: D=4 features, C=3 classes
96 size_t D = 4, C = 3, N = 200; // samples
97 std::mt19937 rng(42);
98 std::normal_distribution<double> nd(0.0, 1.0);
99
100 // Generate features X (N x D) and labels y by a hidden linear rule
101 std::vector<std::vector<double>> X(N, std::vector<double>(D));
102 std::vector<int> y(N);
103 std::vector<std::vector<double>> W_hidden(D, std::vector<double>(C));
104 for (size_t d = 0; d < D; ++d)
105 for (size_t c = 0; c < C; ++c)
106 W_hidden[d][c] = nd(rng);
107 for (size_t i = 0; i < N; ++i) {
108 for (size_t d = 0; d < D; ++d) X[i][d] = nd(rng);
109 // Label from argmax of X * W_hidden (noisy linear separability)
110 std::vector<double> z = logits_row(X[i], W_hidden);
111 y[i] = int(std::max_element(z.begin(), z.end()) - z.begin());
112 }
113
114 // Teacher: linear model with weights W_t close to hidden rule
115 std::vector<std::vector<double>> W_t = W_hidden; // idealized teacher
116
117 // Student: linear model with smaller random weights
118 std::vector<std::vector<double>> W_s(D, std::vector<double>(C));
119 std::uniform_real_distribution<double> ur(-0.1, 0.1);
120 for (size_t d = 0; d < D; ++d)
121 for (size_t c = 0; c < C; ++c)
122 W_s[d][c] = ur(rng);
123
124 // Hyperparameters
125 double alpha = 0.5, tau = 4.0, lr = 0.1;
126 size_t epochs = 50, B = 20; // batch size
127
128 for (size_t epoch = 0; epoch < epochs; ++epoch) {
129 // Simple SGD over mini-batches
130 double running_loss = 0.0; size_t batches = 0;
131 for (size_t start = 0; start < N; start += B) {
132 size_t end = std::min(N, start + B);
133 size_t bsz = end - start;
134 std::vector<std::vector<double>> Zs(bsz, std::vector<double>(C));
135 std::vector<std::vector<double>> Zt(bsz, std::vector<double>(C));
136
137 for (size_t i = 0; i < bsz; ++i) {
138 Zs[i] = logits_row(X[start + i], W_s);
139 Zt[i] = logits_row(X[start + i], W_t); // teacher is fixed
140 }
141
142 // KD loss and grad w.r.t. student logits
143 std::vector<int> y_batch(y.begin() + start, y.begin() + end);
144 KDResult res = kd_loss_and_grad(Zs, Zt, y_batch, alpha, tau);
145 running_loss += res.loss; ++batches;
146
147 // Backprop to weights: grad_W = X^T * grad_logits / B
148 std::vector<std::vector<double>> gradW(D, std::vector<double>(C, 0.0));
149 for (size_t i = 0; i < bsz; ++i) {
150 for (size_t d = 0; d < D; ++d) {
151 for (size_t c = 0; c < C; ++c) {
152 gradW[d][c] += X[start + i][d] * res.grad_logits[i][c];
153 }
154 }
155 }
156 // SGD update
157 for (size_t d = 0; d < D; ++d)
158 for (size_t c = 0; c < C; ++c)
159 W_s[d][c] -= lr * gradW[d][c];
160 }
161 std::cout << "Epoch " << epoch+1 << ": avg KD loss = " << (running_loss / batches) << "\n";
162 }
163
164 // Quick evaluation: accuracy with student
165 size_t correct = 0;
166 for (size_t i = 0; i < N; ++i) {
167 std::vector<double> z = logits_row(X[i], W_s);
168 int pred = int(std::max_element(z.begin(), z.end()) - z.begin());
169 if (pred == y[i]) ++correct;
170 }
171 std::cout << "Training-set accuracy (student) = " << (100.0 * correct / N) << "%\n";
172 return 0;
173}
174

This example trains a linear student classifier to mimic a fixed linear teacher using the KD objective. It computes both the forward loss and the analytic gradient with respect to student logits. The gradient formula used is grad_z = α (p_s − one_hot(y)) + (1−α) τ (p_s^τ − p_t^τ). The gradient is backpropagated to weights via grad_W = X^T ⋅ grad_logits, and SGD updates the student. The code illustrates how τ and α affect the balance between matching ground truth and following the teacher.

Time: Per batch: O(B⋅(D⋅C + C)) to form logits and probabilities; end-to-end per epoch: O(N⋅D⋅C).Space: O(B⋅C) for logits and gradients, plus O(D⋅C) for model weights.
#knowledge distillation#kd loss#temperature scaling#kl divergence#cross-entropy#soft labels#teacher student model#model compression#dark knowledge#label smoothing#probability calibration#distillation training#log-sum-exp#softmax#regularization