Key-Value Memory Systems
Key Points
- ā¢Key-Value memory systems store information as pairs where keys are used to look up values by similarity rather than exact match.
- ā¢Attention implements a differentiable, soft lookup by turning similarities between a query and keys into a probability distribution over memory slots.
- ā¢Scaled dot-product attention is the standard formulation used in Transformers, with softmax weights over Q multiplying the value matrix V.
- ā¢Temperature scaling and proper normalization prevent softmax saturation and improve numerical stability.
- ā¢The computational bottleneck of attention is the O( ) pairwise similarity matrix, which limits very long-context usage.
- ā¢Masks allow selective reading from memory by disallowing certain keys (e.g., future tokens in causal models).
- ā¢Cosine similarity and additive (Bahdanau) scoring are common alternatives to dot-product for computing queryākey scores.
- ā¢Differentiable write operations (eraseāadd) enable models to update memory contents with gradient-based learning.
Prerequisites
- āLinear Algebra (vectors, matrices, dot products) ā Attention is expressed as matrix multiplications QK^T and AV; understanding vectors and norms is essential.
- āProbability and Softmax ā Attention weights are probabilities derived from softmax, which requires understanding normalization and distributions.
- āCalculus and Automatic Differentiation ā Differentiable memory relies on gradients through similarity, softmax, and writes.
- āNumerical Stability Techniques ā Stable softmax and safe normalization prevent overflow/underflow in practice.
- āNeural Networks and Representations ā Q, K, V are typically learned projections within neural architectures like Transformers.
Detailed Explanation
Tap terms for definitions01Overview
Key-Value memory systems are a way for models to store and retrieve information using pairs of vectors: a key that represents how to find the information, and a value that is the information itself. When the model has a question (a query vector), it measures how similar that query is to each key. Instead of picking just one best key, attention turns these similarities into a soft probability distribution and takes a weighted sum of the values. This makes the lookup differentiable, so the whole process can be learned end-to-end with gradient descent. In modern deep learning, this idea appears most famously as attention in Transformers. There, the query (Q), key (K), and value (V) are learned linear projections of the same or different sequences. The attention weights are computed from Q and K, and the resulting distribution mixes the values V to produce outputs. Because this process is differentiable, the network can learn representations that make relevant keys highly similar to the right queries. Beyond Transformers, key-value memory shows up in memory networks, Neural Turing Machines, retrieval-augmented generation, and key-value caches used to speed up autoregressive inference.
02Intuition & Analogies
Imagine a giant, well-organized toolbox. Each tool has a label (the key) and the tool itself (the value). When you need to fix something, you donāt just randomly pick a tool; you read the labels to find whatās most appropriate. In human memory, when someone asks you a question, your brain doesnāt search every memory exactly; it activates related memories based on association strengthāthis is content-based addressing. Key-Value memory systems do something similar: a query activates keys in proportion to how related they are, and you blend the corresponding values. Why a soft blend instead of picking just one? Think of a music recommendation: you may like songs that are similar to multiple genres you enjoy. A soft choice can combine multiple influences. Softness also makes it easy to learn: because your selection isnāt a hard yes/no, small changes in the modelās parameters smoothly change the output, letting gradients flow. Temperature is like your pickiness. If youāre very picky (low temperature), you almost choose one key (sharp distribution). If youāre easygoing (high temperature), you consider many keys (broad distribution). Masks are like sticky notes saying ādonāt use these tools right now,ā such as not looking into the future when predicting the next word. Finally, writing to memory can be gentle: rather than replacing a value outright, you can partially erase components and add new information, allowing incremental updates that remain differentiable and trainable.
03Formal Definition
04When to Use
Use key-value memory and attention when relationships depend on content similarity rather than fixed positions. Classic examples include machine translation (aligning target words to relevant source words), document question answering (finding relevant passages), and summarization (selecting salient tokens). In language models, self-attention lets each token attend to previous tokens to capture long-range dependencies and compositional structure. Beyond sequences, use key-value memories for retrieval-augmented systems: encode a knowledge base into keys/values and let queries softly retrieve facts. In reinforcement learning or program induction, differentiable external memory allows algorithms to store and recall intermediate results. If you need structured access patterns (e.g., differentiable stacks or tapes), content-based addressing can be combined with location-based addressing. Choose dot-product attention for efficiency on GPUs/TPUs and compatibility with multi-head variants. Prefer cosine similarity when scale invariance is desired. Apply masks when you must restrict attention (causal decoding, padding). If your context is very long and O(n^2) is prohibitive, consider approximate attention (local windows, sparsity, low-rank kernels) or retrieval that narrows candidate keys before soft attention.
ā ļøCommon Mistakes
⢠Missing scaling in dot-product attention: Without dividing by \sqrt{d_k}, logits grow with dimensionality, making softmax overly peaky and gradients unstable. ⢠Unstable softmax: Computing \operatorname{softmax}(x) via \exp(x)/\sum \exp(x) without subtracting \max(x) risks overflow; always use the log-sum-exp trick. ⢠Shape confusion between Q, K, V: Mixing up row-major conventions and dimensions often leads to silent logic errors. Clearly document shapes and check them at runtime. ⢠Forgetting masks or using the wrong mask polarity: Adding 0 for masked positions does nothing; you must add large negative numbers (or -\infty) to logits so softmax produces zeros. ⢠Ignoring normalization in cosine similarity: If you compute a raw dot and then divide by norms with near-zero vectors, you can get NaNs. Add \epsilon to denominators and validate inputs. ⢠Memory blow-up: Materializing the full attention matrix A of size n_q \times n_k can be infeasible for long sequences. Consider chunking, flash-attention style kernels, or streaming KV caches. ⢠Over-sparse temperatures: Setting temperature too low effectively performs argmax, killing gradients and slowing learning; tune \tau or learn it. ⢠Writing without bounds: Eraseāadd writes require values in [0,1] for erase gates; failing to clamp or parameterize can destabilize training.
Key Formulas
Scaled Dot-Product Attention
Explanation: Compute pairwise similarities between queries and keys, scale by the key dimension, turn them into probabilities with softmax, and mix values accordingly. This is the core operation in Transformer layers.
Softmax with Temperature
Explanation: Given a row of scores , softmax produces nonnegative weights that sum to 1. The temperature controls sharpness: small yields peaky distributions; large yields smoother ones.
Cosine Similarity (Stabilized)
Explanation: Cosine similarity measures angle-based similarity and is scale-invariant. Adding avoids division by zero in practice.
Additive (Bahdanau) Score
Explanation: An MLP-based similarity that can capture complex relations between q and k. It is often used with smaller dimensions and recurrent models.
Masked Attention
Explanation: Adding a mask M with - entries forces softmax to assign zero probability to forbidden positions. This is essential for causal and padded attention.
EraseāAdd Memory Write
Explanation: Differentiable update of a value matrix: first erase selected components via gate e and weights w, then add new content a. Broadcasting applies over slots and value dimensions.
Attention Complexity
Explanation: Forming Q costs O( ) and multiplying by V costs O( ). The attention matrix dominates memory.
Multi-Head Composition
Explanation: Multiple attention heads are computed in parallel on projected Q, K, V and then concatenated and linearly mixed. This allows modeling diverse relations.
Stable Softmax
Explanation: Subtracting the maximum before exponentiation prevents overflow and keeps numerical values in a safe range.
Cross-Entropy Loss
Explanation: When training attention to select correct items, cross-entropy between target distribution y and predicted p encourages the model to assign high probability to correct keys.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Simple row-major matrix wrapper 5 struct Matrix { 6 int rows, cols; 7 vector<double> data; // size = rows * cols 8 Matrix(int r=0, int c=0, double v=0.0): rows(r), cols(c), data(r*c, v) {} 9 inline double& operator()(int r, int c) { return data[r*cols + c]; } 10 inline double operator()(int r, int c) const { return data[r*cols + c]; } 11 }; 12 13 // Compute stable softmax over a vector (in-place) and return sum to verify 1.0 14 double softmax_inplace(vector<double>& x) { 15 double mx = *max_element(x.begin(), x.end()); 16 double sum = 0.0; 17 for (double &v : x) { v = exp(v - mx); sum += v; } 18 for (double &v : x) v /= (sum + 1e-12); 19 return sum; 20 } 21 22 // Compute C = A * B^T (A: m x d, B: n x d) -> C: m x n 23 Matrix matmul_ABt(const Matrix& A, const Matrix& B) { 24 assert(A.cols == B.cols); 25 int m = A.rows, n = B.rows, d = A.cols; 26 Matrix C(m, n, 0.0); 27 for (int i = 0; i < m; ++i) { 28 for (int j = 0; j < n; ++j) { 29 double s = 0.0; 30 const int aoff = i*A.cols; 31 const int boff = j*B.cols; 32 for (int k = 0; k < d; ++k) s += A.data[aoff + k] * B.data[boff + k]; 33 C(i,j) = s; 34 } 35 } 36 return C; 37 } 38 39 // Compute C = A * B (A: m x n, B: n x p) -> C: m x p 40 Matrix matmul(const Matrix& A, const Matrix& B) { 41 assert(A.cols == B.rows); 42 int m = A.rows, n = A.cols, p = B.cols; 43 Matrix C(m, p, 0.0); 44 for (int i = 0; i < m; ++i) { 45 for (int k = 0; k < n; ++k) { 46 double aik = A(i,k); 47 for (int j = 0; j < p; ++j) C(i,j) += aik * B(k,j); 48 } 49 } 50 return C; 51 } 52 53 // Apply mask: add mask to logits (mask entries are 0 for keep, very negative for block) 54 void add_mask(Matrix& S, const Matrix* mask) { 55 if (!mask) return; 56 assert(mask->rows == S.rows && mask->cols == S.cols); 57 for (int i = 0; i < S.rows; ++i) 58 for (int j = 0; j < S.cols; ++j) 59 S(i,j) += (*mask)(i,j); 60 } 61 62 // Row-wise softmax over S, in place 63 void rowwise_softmax(Matrix& S) { 64 vector<double> row(S.cols); 65 for (int i = 0; i < S.rows; ++i) { 66 for (int j = 0; j < S.cols; ++j) row[j] = S(i,j); 67 softmax_inplace(row); 68 for (int j = 0; j < S.cols; ++j) S(i,j) = row[j]; 69 } 70 } 71 72 // Scaled dot-product attention: O = softmax((Q K^T)/sqrt(dk) + mask) V 73 Matrix scaled_dot_product_attention(const Matrix& Q, const Matrix& K, const Matrix& V, 74 const Matrix* mask = nullptr, double temperature = 1.0) { 75 assert(Q.cols == K.cols); 76 assert(K.rows == V.rows); 77 int dk = Q.cols; 78 double scale = 1.0 / sqrt((double)dk); 79 Matrix S = matmul_ABt(Q, K); // (nq x nk) 80 // scale and temperature 81 for (int i = 0; i < S.rows; ++i) 82 for (int j = 0; j < S.cols; ++j) 83 S(i,j) = (S(i,j) * scale) / max(1e-12, temperature); 84 add_mask(S, mask); // optional mask 85 rowwise_softmax(S); // attention weights A 86 Matrix O = matmul(S, V); // (nq x dv) 87 return O; 88 } 89 90 // Utility to print matrix 91 void print_matrix(const Matrix& M, const string& name) { 92 cout << name << " (" << M.rows << "x" << M.cols << ")\n"; 93 cout.setf(ios::fixed); cout << setprecision(4); 94 for (int i = 0; i < M.rows; ++i) { 95 for (int j = 0; j < M.cols; ++j) cout << setw(8) << M(i,j) << ' '; 96 cout << '\n'; 97 } 98 } 99 100 int main() { 101 // Example: 2 queries, 4 keys/values, dk=3, dv=2 102 int nq = 2, nk = 4, dk = 3, dv = 2; 103 Matrix Q(nq, dk), K(nk, dk), V(nk, dv); 104 105 // Initialize deterministic small numbers for demonstration 106 // Q 107 Q(0,0)=0.2; Q(0,1)=0.1; Q(0,2)=0.7; 108 Q(1,0)=0.9; Q(1,1)=0.0; Q(1,2)=0.1; 109 // K 110 K(0,0)=0.1; K(0,1)=0.2; K(0,2)=0.6; 111 K(1,0)=0.9; K(1,1)=0.1; K(1,2)=0.0; 112 K(2,0)=0.0; K(2,1)=0.9; K(2,2)=0.1; 113 K(3,0)=0.3; K(3,1)=0.3; K(3,2)=0.4; 114 // V 115 V(0,0)=1.0; V(0,1)=0.0; 116 V(1,0)=0.0; V(1,1)=1.0; 117 V(2,0)=0.5; V(2,1)=0.5; 118 V(3,0)=0.2; V(3,1)=0.8; 119 120 // Optional causal-like mask: forbid attending to last two keys for the second query 121 Matrix mask(nq, nk, 0.0); 122 mask(1,2) = -1e9; mask(1,3) = -1e9; // large negative approximates -inf 123 124 Matrix O = scaled_dot_product_attention(Q, K, V, &mask, /*temperature=*/1.0); 125 126 print_matrix(O, "Output O = Attention(Q,K,V)"); 127 return 0; 128 } 129
This program implements scaled dot-product attention on CPU with row-major matrices. It computes S = QK^T, scales by 1/sqrt(d_k), applies an optional mask (using large negative numbers), performs a numerically stable row-wise softmax to get attention weights, and finally multiplies by V to produce outputs. The example shows two queries attending over four keys/values, with a mask restricting the second query.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 struct MemoryKV { 5 int slots, d_key, d_val; 6 vector<double> K; // slots x d_key 7 vector<double> V; // slots x d_val 8 MemoryKV(int n, int dk, int dv): slots(n), d_key(dk), d_val(dv), K(n*dk,0.0), V(n*dv,0.0) {} 9 10 // Access helpers 11 inline double& key(int i, int j){ return K[i*d_key + j]; } 12 inline double& val(int i, int j){ return V[i*d_val + j]; } 13 14 // Normalize a vector (L2) with epsilon for stability 15 static void l2_normalize(vector<double>& x) { 16 double s=0; for(double v:x) s+=v*v; s = sqrt(s)+1e-12; for(double &v:x) v/=s; 17 } 18 19 // Cosine similarity between query q (size d_key) and key i 20 double cos_sim_slot(const vector<double>& q, int i) const { 21 double num=0, nq=0, nk=0; 22 for (int j=0;j<d_key;++j){ double kj = K[i*d_key+j]; num += q[j]*kj; nq += q[j]*q[j]; nk += kj*kj; } 23 return num / (sqrt(nq)*sqrt(nk) + 1e-12); 24 } 25 26 // Read: soft attention over slots using cosine similarity and temperature 27 vector<double> read(const vector<double>& q, double temperature=1.0) const { 28 vector<double> logits(slots); 29 for (int i=0;i<slots;++i) logits[i] = cos_sim_slot(q,i) / max(1e-12, temperature); 30 // softmax 31 double mx = *max_element(logits.begin(), logits.end()); 32 double sum = 0.0; for (double &z: logits){ z = exp(z - mx); sum += z; } 33 for (double &z: logits) z /= (sum + 1e-12); 34 // weighted sum of values 35 vector<double> out(d_val, 0.0); 36 for (int i=0;i<slots;++i){ 37 double w = logits[i]; 38 for (int j=0;j<d_val;++j) out[j] += w * V[i*d_val + j]; 39 } 40 return out; 41 } 42 43 // Differentiable erase-add write using weights w over slots, erase e in [0,1]^d_val, add a in R^{d_val} 44 void write(const vector<double>& w, const vector<double>& e, const vector<double>& a) { 45 assert((int)w.size()==slots && (int)e.size()==d_val && (int)a.size()==d_val); 46 for (int i=0;i<slots;++i){ 47 double wi = std::clamp(w[i], 0.0, 1.0); 48 for (int j=0;j<d_val;++j){ 49 double erase_gate = 1.0 - wi * std::clamp(e[j], 0.0, 1.0); 50 V[i*d_val + j] = V[i*d_val + j] * erase_gate + wi * a[j]; 51 } 52 } 53 } 54 }; 55 56 // Utility to print a vector 57 void print_vec(const vector<double>& x, const string& name){ 58 cout.setf(ios::fixed); cout << setprecision(4); 59 cout << name << ": "; 60 for(double v:x) cout << v << ' '; 61 cout << '\n'; 62 } 63 64 int main(){ 65 // Build a small memory with 4 slots, key dim 3, value dim 4 66 MemoryKV mem(4, 3, 4); 67 68 // Initialize keys to be roughly orthogonal 69 double Kinit[4][3] = { {1,0,0}, {0,1,0}, {0,0,1}, {1,1,1} }; 70 for(int i=0;i<4;++i) for(int j=0;j<3;++j) mem.key(i,j) = Kinit[i][j]; 71 72 // Initialize values (e.g., one-hot categories) 73 double Vinit[4][4] = { {1,0,0,0}, {0,1,0,0}, {0,0,1,0}, {0,0,0,1} }; 74 for(int i=0;i<4;++i) for(int j=0;j<4;++j) mem.val(i,j) = Vinit[i][j]; 75 76 // Query close to key #2 (index 1): add small noise 77 vector<double> q = {0.02, 0.98, 0.01}; 78 79 // Read with moderate temperature 80 vector<double> out = mem.read(q, /*temperature=*/0.5); 81 print_vec(out, "Readout before writes"); // should be close to value[1] = [0,1,0,0] 82 83 // Now perform a differentiable write: slightly move memory towards a new value 84 // Suppose we want slot 1 to also encode [0.2, 0.8, 0, 0] 85 vector<double> w = {0.0, 0.7, 0.0, 0.0}; // focus write on slot 1 86 vector<double> e = {0.5, 0.5, 0.5, 0.5}; // erase half of old content where written 87 vector<double> a = {0.2, 0.8, 0.0, 0.0}; // add new content 88 mem.write(w, e, a); 89 90 // Read again with the same query; output should shift towards a 91 vector<double> out2 = mem.read(q, 0.5); 92 print_vec(out2, "Readout after writes"); 93 94 // Demonstrate near-argmax behavior with low temperature 95 vector<double> out3 = mem.read(q, /*temperature=*/0.05); 96 print_vec(out3, "Readout with low temperature (near-hard)"); 97 98 return 0; 99 } 100
This program implements a small key-value memory with cosine-similarity attention for reading and an eraseāadd rule for differentiable writing. The read function turns cosine similarities into a softmax distribution (with temperature) over slots and returns the weighted sum of values. The write function gates erasure and addition per slot, demonstrating how memory can be updated smoothly. The example initializes nearly orthogonal keys and one-hot values, reads a target slot, performs a partial write to adjust a slotās content, and shows how outputs shift.