šŸŽ“How I Study AIHISA
šŸ“–Read
šŸ“„PapersšŸ“°BlogsšŸŽ¬Courses
šŸ’”Learn
šŸ›¤ļøPathsšŸ“šTopicsšŸ’”ConceptsšŸŽ“Shorts
šŸŽÆPractice
ā±ļøCoach🧩Problems🧠ThinkingšŸŽÆPrompts🧠Review
SearchSettings
How I Study AI - Learn AI Papers & Lectures the Easy Way
šŸ“šTheoryIntermediate

Key-Value Memory Systems

Key Points

  • •
    Key-Value memory systems store information as pairs where keys are used to look up values by similarity rather than exact match.
  • •
    Attention implements a differentiable, soft lookup by turning similarities between a query and keys into a probability distribution over memory slots.
  • •
    Scaled dot-product attention is the standard formulation used in Transformers, with softmax weights over QKT multiplying the value matrix V.
  • •
    Temperature scaling and proper normalization prevent softmax saturation and improve numerical stability.
  • •
    The computational bottleneck of attention is the O(nq​ nk​) pairwise similarity matrix, which limits very long-context usage.
  • •
    Masks allow selective reading from memory by disallowing certain keys (e.g., future tokens in causal models).
  • •
    Cosine similarity and additive (Bahdanau) scoring are common alternatives to dot-product for computing query–key scores.
  • •
    Differentiable write operations (erase–add) enable models to update memory contents with gradient-based learning.

Prerequisites

  • →Linear Algebra (vectors, matrices, dot products) — Attention is expressed as matrix multiplications QK^T and AV; understanding vectors and norms is essential.
  • →Probability and Softmax — Attention weights are probabilities derived from softmax, which requires understanding normalization and distributions.
  • →Calculus and Automatic Differentiation — Differentiable memory relies on gradients through similarity, softmax, and writes.
  • →Numerical Stability Techniques — Stable softmax and safe normalization prevent overflow/underflow in practice.
  • →Neural Networks and Representations — Q, K, V are typically learned projections within neural architectures like Transformers.

Detailed Explanation

Tap terms for definitions

01Overview

Key-Value memory systems are a way for models to store and retrieve information using pairs of vectors: a key that represents how to find the information, and a value that is the information itself. When the model has a question (a query vector), it measures how similar that query is to each key. Instead of picking just one best key, attention turns these similarities into a soft probability distribution and takes a weighted sum of the values. This makes the lookup differentiable, so the whole process can be learned end-to-end with gradient descent. In modern deep learning, this idea appears most famously as attention in Transformers. There, the query (Q), key (K), and value (V) are learned linear projections of the same or different sequences. The attention weights are computed from Q and K, and the resulting distribution mixes the values V to produce outputs. Because this process is differentiable, the network can learn representations that make relevant keys highly similar to the right queries. Beyond Transformers, key-value memory shows up in memory networks, Neural Turing Machines, retrieval-augmented generation, and key-value caches used to speed up autoregressive inference.

02Intuition & Analogies

Imagine a giant, well-organized toolbox. Each tool has a label (the key) and the tool itself (the value). When you need to fix something, you don’t just randomly pick a tool; you read the labels to find what’s most appropriate. In human memory, when someone asks you a question, your brain doesn’t search every memory exactly; it activates related memories based on association strength—this is content-based addressing. Key-Value memory systems do something similar: a query activates keys in proportion to how related they are, and you blend the corresponding values. Why a soft blend instead of picking just one? Think of a music recommendation: you may like songs that are similar to multiple genres you enjoy. A soft choice can combine multiple influences. Softness also makes it easy to learn: because your selection isn’t a hard yes/no, small changes in the model’s parameters smoothly change the output, letting gradients flow. Temperature is like your pickiness. If you’re very picky (low temperature), you almost choose one key (sharp distribution). If you’re easygoing (high temperature), you consider many keys (broad distribution). Masks are like sticky notes saying ā€œdon’t use these tools right now,ā€ such as not looking into the future when predicting the next word. Finally, writing to memory can be gentle: rather than replacing a value outright, you can partially erase components and add new information, allowing incremental updates that remain differentiable and trainable.

03Formal Definition

Let Q ∈ Rnq​×dk​ be a matrix of query vectors, K ∈ Rnk​×dk​ be key vectors, and V ∈ Rnk​×dv​ be value vectors. A similarity function s: Rdk​ Ɨ Rdk​ → R scores how well a query matches a key. In scaled dot-product attention, similarities are S = dk​​1​ Q K⊤. The attention distribution for query i is ai​ = softmax(Si,:​), yielding nonnegative weights that sum to 1. The readout is O=A V, where A ∈ Rnq​×nk​ stacks the ai​ row vectors. Alternative scoring includes cosine similarity s(q,k) = ∄q∄∄k∄q⊤k​ and additive (Bahdanau) attention s(q,k) = w⊤ tanh(Wq​ q + Wk​ k + b). Masks M ∈ \{0,-āˆž\}^{nq​ Ɨ nk​} can be added to S so that softmax zeros out disallowed positions. Temperature Ļ„ > 0 rescales logits S/Ļ„, controlling entropy of A. Differentiable write operations can be defined over a memory matrix. For erase–add (as in Neural Turing Machines), given write weights w ∈ [0,1]^{nk​}, erase vector e ∈ [0,1]^{dv​}, and add vector a ∈ Rdv​, the value matrix updates as V' = V āŠ™ (1 - w e⊤) + w a⊤, applied elementwise, where w e⊤ broadcasts over value dimensions.

04When to Use

Use key-value memory and attention when relationships depend on content similarity rather than fixed positions. Classic examples include machine translation (aligning target words to relevant source words), document question answering (finding relevant passages), and summarization (selecting salient tokens). In language models, self-attention lets each token attend to previous tokens to capture long-range dependencies and compositional structure. Beyond sequences, use key-value memories for retrieval-augmented systems: encode a knowledge base into keys/values and let queries softly retrieve facts. In reinforcement learning or program induction, differentiable external memory allows algorithms to store and recall intermediate results. If you need structured access patterns (e.g., differentiable stacks or tapes), content-based addressing can be combined with location-based addressing. Choose dot-product attention for efficiency on GPUs/TPUs and compatibility with multi-head variants. Prefer cosine similarity when scale invariance is desired. Apply masks when you must restrict attention (causal decoding, padding). If your context is very long and O(n^2) is prohibitive, consider approximate attention (local windows, sparsity, low-rank kernels) or retrieval that narrows candidate keys before soft attention.

āš ļøCommon Mistakes

• Missing scaling in dot-product attention: Without dividing by \sqrt{d_k}, logits grow with dimensionality, making softmax overly peaky and gradients unstable. • Unstable softmax: Computing \operatorname{softmax}(x) via \exp(x)/\sum \exp(x) without subtracting \max(x) risks overflow; always use the log-sum-exp trick. • Shape confusion between Q, K, V: Mixing up row-major conventions and dimensions often leads to silent logic errors. Clearly document shapes and check them at runtime. • Forgetting masks or using the wrong mask polarity: Adding 0 for masked positions does nothing; you must add large negative numbers (or -\infty) to logits so softmax produces zeros. • Ignoring normalization in cosine similarity: If you compute a raw dot and then divide by norms with near-zero vectors, you can get NaNs. Add \epsilon to denominators and validate inputs. • Memory blow-up: Materializing the full attention matrix A of size n_q \times n_k can be infeasible for long sequences. Consider chunking, flash-attention style kernels, or streaming KV caches. • Over-sparse temperatures: Setting temperature too low effectively performs argmax, killing gradients and slowing learning; tune \tau or learn it. • Writing without bounds: Erase–add writes require values in [0,1] for erase gates; failing to clamp or parameterize can destabilize training.

Key Formulas

Scaled Dot-Product Attention

Attention(Q,K,V)=softmax(dk​​QKāŠ¤ā€‹)V

Explanation: Compute pairwise similarities between queries and keys, scale by the key dimension, turn them into probabilities with softmax, and mix values accordingly. This is the core operation in Transformer layers.

Softmax with Temperature

ai​=softmax(si​),aij​=āˆ‘t=1nk​​exp(sit​/Ļ„)exp(sij​/Ļ„)​

Explanation: Given a row of scores si​, softmax produces nonnegative weights that sum to 1. The temperature Ļ„ controls sharpness: small Ļ„ yields peaky distributions; large Ļ„ yields smoother ones.

Cosine Similarity (Stabilized)

s(q,k)=∄q∄2ā€‹āˆ„k∄2​+ϵq⊤k​

Explanation: Cosine similarity measures angle-based similarity and is scale-invariant. Adding ϵ avoids division by zero in practice.

Additive (Bahdanau) Score

s(q,k)=w⊤tanh(Wq​q+Wk​k+b)

Explanation: An MLP-based similarity that can capture complex relations between q and k. It is often used with smaller dimensions and recurrent models.

Masked Attention

S′=S+M,Mijā€‹āˆˆ{0,āˆ’āˆž}

Explanation: Adding a mask M with -āˆž entries forces softmax to assign zero probability to forbidden positions. This is essential for causal and padded attention.

Erase–Add Memory Write

V′=VāŠ™(1āˆ’we⊤)+wa⊤

Explanation: Differentiable update of a value matrix: first erase selected components via gate e and weights w, then add new content a. Broadcasting applies over slots and value dimensions.

Attention Complexity

O(nq​nk​dk​+nq​nk​dv​)

Explanation: Forming QKT costs O(nq​ nk​ dk​) and multiplying by V costs O(nq​ nk​ dv​). The nq​ Ɨ nk​ attention matrix dominates memory.

Multi-Head Composition

H(Q,K,V)=concat(head1​,…,headh​)WO

Explanation: Multiple attention heads are computed in parallel on projected Q, K, V and then concatenated and linearly mixed. This allows modeling diverse relations.

Stable Softmax

softmax(x)j​=āˆ‘i​exp(xiā€‹āˆ’maxi​xi​)exp(xjā€‹āˆ’maxi​xi​)​

Explanation: Subtracting the maximum before exponentiation prevents overflow and keeps numerical values in a safe range.

Cross-Entropy Loss

L=āˆ’i=1āˆ‘n​yi​logpi​

Explanation: When training attention to select correct items, cross-entropy between target distribution y and predicted p encourages the model to assign high probability to correct keys.

Complexity Analysis

Let nq​ be the number of queries, nk​ the number of keys/values, dk​ the key dimension, and dv​ the value dimension. Computing the similarity matrix S=QKT costs O(nq​ nk​ dk​) time. Applying a row-wise softmax costs O(nq​ nk​). Multiplying the attention matrix A by V costs O(nq​ nk​ dv​). Thus, the total time complexity is O(nq​ nk​ (dk​ + dv​)) and is often summarized as O(nq​ nk​ d), where d is a representative dimension. The dominant memory cost is storing A of size nq​ Ɨ nk​, i.e., O(nq​ nk​), plus the inputs O(nq​ dk​ + nk​ dk​ + nk​ dv​) and the output O(nq​ dv​). For self-attention with sequences of length n where nq​=nk​=n and typical dimensions dk​ ā‰ˆ dv​ ā‰ˆ d, the computation is O(n2 d) and memory is O(n2). This quadratic scaling is the main limitation for very long sequences. Multi-head attention with h heads multiplies both compute and memory by approximately h (ignoring potential parallelization benefits), yielding O(h n2 d/h) per head with projections so the overall scaling remains O(n2 d). If sparse or local attention restricts each query to k ≪ nk​ keys, the compute reduces to O(nq​ k d), and the memory for attention weights drops to O(nq​ k). Retrieval-augmented systems that first shortlist candidates (e.g., via ANN search) can similarly reduce nk​ before applying soft attention. During inference with KV caches in autoregressive models, forming QKT at step t is O(t dk​) for one query, and memory grows linearly with context length for stored K and V.

Code Examples

Scaled Dot-Product Attention with Optional Mask (Batched, CPU)
1#include <bits/stdc++.h>
2using namespace std;
3
4// Simple row-major matrix wrapper
5struct Matrix {
6 int rows, cols;
7 vector<double> data; // size = rows * cols
8 Matrix(int r=0, int c=0, double v=0.0): rows(r), cols(c), data(r*c, v) {}
9 inline double& operator()(int r, int c) { return data[r*cols + c]; }
10 inline double operator()(int r, int c) const { return data[r*cols + c]; }
11};
12
13// Compute stable softmax over a vector (in-place) and return sum to verify 1.0
14double softmax_inplace(vector<double>& x) {
15 double mx = *max_element(x.begin(), x.end());
16 double sum = 0.0;
17 for (double &v : x) { v = exp(v - mx); sum += v; }
18 for (double &v : x) v /= (sum + 1e-12);
19 return sum;
20}
21
22// Compute C = A * B^T (A: m x d, B: n x d) -> C: m x n
23Matrix matmul_ABt(const Matrix& A, const Matrix& B) {
24 assert(A.cols == B.cols);
25 int m = A.rows, n = B.rows, d = A.cols;
26 Matrix C(m, n, 0.0);
27 for (int i = 0; i < m; ++i) {
28 for (int j = 0; j < n; ++j) {
29 double s = 0.0;
30 const int aoff = i*A.cols;
31 const int boff = j*B.cols;
32 for (int k = 0; k < d; ++k) s += A.data[aoff + k] * B.data[boff + k];
33 C(i,j) = s;
34 }
35 }
36 return C;
37}
38
39// Compute C = A * B (A: m x n, B: n x p) -> C: m x p
40Matrix matmul(const Matrix& A, const Matrix& B) {
41 assert(A.cols == B.rows);
42 int m = A.rows, n = A.cols, p = B.cols;
43 Matrix C(m, p, 0.0);
44 for (int i = 0; i < m; ++i) {
45 for (int k = 0; k < n; ++k) {
46 double aik = A(i,k);
47 for (int j = 0; j < p; ++j) C(i,j) += aik * B(k,j);
48 }
49 }
50 return C;
51}
52
53// Apply mask: add mask to logits (mask entries are 0 for keep, very negative for block)
54void add_mask(Matrix& S, const Matrix* mask) {
55 if (!mask) return;
56 assert(mask->rows == S.rows && mask->cols == S.cols);
57 for (int i = 0; i < S.rows; ++i)
58 for (int j = 0; j < S.cols; ++j)
59 S(i,j) += (*mask)(i,j);
60}
61
62// Row-wise softmax over S, in place
63void rowwise_softmax(Matrix& S) {
64 vector<double> row(S.cols);
65 for (int i = 0; i < S.rows; ++i) {
66 for (int j = 0; j < S.cols; ++j) row[j] = S(i,j);
67 softmax_inplace(row);
68 for (int j = 0; j < S.cols; ++j) S(i,j) = row[j];
69 }
70}
71
72// Scaled dot-product attention: O = softmax((Q K^T)/sqrt(dk) + mask) V
73Matrix scaled_dot_product_attention(const Matrix& Q, const Matrix& K, const Matrix& V,
74 const Matrix* mask = nullptr, double temperature = 1.0) {
75 assert(Q.cols == K.cols);
76 assert(K.rows == V.rows);
77 int dk = Q.cols;
78 double scale = 1.0 / sqrt((double)dk);
79 Matrix S = matmul_ABt(Q, K); // (nq x nk)
80 // scale and temperature
81 for (int i = 0; i < S.rows; ++i)
82 for (int j = 0; j < S.cols; ++j)
83 S(i,j) = (S(i,j) * scale) / max(1e-12, temperature);
84 add_mask(S, mask); // optional mask
85 rowwise_softmax(S); // attention weights A
86 Matrix O = matmul(S, V); // (nq x dv)
87 return O;
88}
89
90// Utility to print matrix
91void print_matrix(const Matrix& M, const string& name) {
92 cout << name << " (" << M.rows << "x" << M.cols << ")\n";
93 cout.setf(ios::fixed); cout << setprecision(4);
94 for (int i = 0; i < M.rows; ++i) {
95 for (int j = 0; j < M.cols; ++j) cout << setw(8) << M(i,j) << ' ';
96 cout << '\n';
97 }
98}
99
100int main() {
101 // Example: 2 queries, 4 keys/values, dk=3, dv=2
102 int nq = 2, nk = 4, dk = 3, dv = 2;
103 Matrix Q(nq, dk), K(nk, dk), V(nk, dv);
104
105 // Initialize deterministic small numbers for demonstration
106 // Q
107 Q(0,0)=0.2; Q(0,1)=0.1; Q(0,2)=0.7;
108 Q(1,0)=0.9; Q(1,1)=0.0; Q(1,2)=0.1;
109 // K
110 K(0,0)=0.1; K(0,1)=0.2; K(0,2)=0.6;
111 K(1,0)=0.9; K(1,1)=0.1; K(1,2)=0.0;
112 K(2,0)=0.0; K(2,1)=0.9; K(2,2)=0.1;
113 K(3,0)=0.3; K(3,1)=0.3; K(3,2)=0.4;
114 // V
115 V(0,0)=1.0; V(0,1)=0.0;
116 V(1,0)=0.0; V(1,1)=1.0;
117 V(2,0)=0.5; V(2,1)=0.5;
118 V(3,0)=0.2; V(3,1)=0.8;
119
120 // Optional causal-like mask: forbid attending to last two keys for the second query
121 Matrix mask(nq, nk, 0.0);
122 mask(1,2) = -1e9; mask(1,3) = -1e9; // large negative approximates -inf
123
124 Matrix O = scaled_dot_product_attention(Q, K, V, &mask, /*temperature=*/1.0);
125
126 print_matrix(O, "Output O = Attention(Q,K,V)");
127 return 0;
128}
129

This program implements scaled dot-product attention on CPU with row-major matrices. It computes S = QK^T, scales by 1/sqrt(d_k), applies an optional mask (using large negative numbers), performs a numerically stable row-wise softmax to get attention weights, and finally multiplies by V to produce outputs. The example shows two queries attending over four keys/values, with a mask restricting the second query.

Time: O(n_q n_k d_k + n_q n_k d_v)Space: O(n_q n_k + n_q d_k + n_k d_k + n_k d_v + n_q d_v)
Cosine-Similarity Key-Value Memory with Differentiable Erase–Add Write
1#include <bits/stdc++.h>
2using namespace std;
3
4struct MemoryKV {
5 int slots, d_key, d_val;
6 vector<double> K; // slots x d_key
7 vector<double> V; // slots x d_val
8 MemoryKV(int n, int dk, int dv): slots(n), d_key(dk), d_val(dv), K(n*dk,0.0), V(n*dv,0.0) {}
9
10 // Access helpers
11 inline double& key(int i, int j){ return K[i*d_key + j]; }
12 inline double& val(int i, int j){ return V[i*d_val + j]; }
13
14 // Normalize a vector (L2) with epsilon for stability
15 static void l2_normalize(vector<double>& x) {
16 double s=0; for(double v:x) s+=v*v; s = sqrt(s)+1e-12; for(double &v:x) v/=s;
17 }
18
19 // Cosine similarity between query q (size d_key) and key i
20 double cos_sim_slot(const vector<double>& q, int i) const {
21 double num=0, nq=0, nk=0;
22 for (int j=0;j<d_key;++j){ double kj = K[i*d_key+j]; num += q[j]*kj; nq += q[j]*q[j]; nk += kj*kj; }
23 return num / (sqrt(nq)*sqrt(nk) + 1e-12);
24 }
25
26 // Read: soft attention over slots using cosine similarity and temperature
27 vector<double> read(const vector<double>& q, double temperature=1.0) const {
28 vector<double> logits(slots);
29 for (int i=0;i<slots;++i) logits[i] = cos_sim_slot(q,i) / max(1e-12, temperature);
30 // softmax
31 double mx = *max_element(logits.begin(), logits.end());
32 double sum = 0.0; for (double &z: logits){ z = exp(z - mx); sum += z; }
33 for (double &z: logits) z /= (sum + 1e-12);
34 // weighted sum of values
35 vector<double> out(d_val, 0.0);
36 for (int i=0;i<slots;++i){
37 double w = logits[i];
38 for (int j=0;j<d_val;++j) out[j] += w * V[i*d_val + j];
39 }
40 return out;
41 }
42
43 // Differentiable erase-add write using weights w over slots, erase e in [0,1]^d_val, add a in R^{d_val}
44 void write(const vector<double>& w, const vector<double>& e, const vector<double>& a) {
45 assert((int)w.size()==slots && (int)e.size()==d_val && (int)a.size()==d_val);
46 for (int i=0;i<slots;++i){
47 double wi = std::clamp(w[i], 0.0, 1.0);
48 for (int j=0;j<d_val;++j){
49 double erase_gate = 1.0 - wi * std::clamp(e[j], 0.0, 1.0);
50 V[i*d_val + j] = V[i*d_val + j] * erase_gate + wi * a[j];
51 }
52 }
53 }
54};
55
56// Utility to print a vector
57void print_vec(const vector<double>& x, const string& name){
58 cout.setf(ios::fixed); cout << setprecision(4);
59 cout << name << ": ";
60 for(double v:x) cout << v << ' ';
61 cout << '\n';
62}
63
64int main(){
65 // Build a small memory with 4 slots, key dim 3, value dim 4
66 MemoryKV mem(4, 3, 4);
67
68 // Initialize keys to be roughly orthogonal
69 double Kinit[4][3] = { {1,0,0}, {0,1,0}, {0,0,1}, {1,1,1} };
70 for(int i=0;i<4;++i) for(int j=0;j<3;++j) mem.key(i,j) = Kinit[i][j];
71
72 // Initialize values (e.g., one-hot categories)
73 double Vinit[4][4] = { {1,0,0,0}, {0,1,0,0}, {0,0,1,0}, {0,0,0,1} };
74 for(int i=0;i<4;++i) for(int j=0;j<4;++j) mem.val(i,j) = Vinit[i][j];
75
76 // Query close to key #2 (index 1): add small noise
77 vector<double> q = {0.02, 0.98, 0.01};
78
79 // Read with moderate temperature
80 vector<double> out = mem.read(q, /*temperature=*/0.5);
81 print_vec(out, "Readout before writes"); // should be close to value[1] = [0,1,0,0]
82
83 // Now perform a differentiable write: slightly move memory towards a new value
84 // Suppose we want slot 1 to also encode [0.2, 0.8, 0, 0]
85 vector<double> w = {0.0, 0.7, 0.0, 0.0}; // focus write on slot 1
86 vector<double> e = {0.5, 0.5, 0.5, 0.5}; // erase half of old content where written
87 vector<double> a = {0.2, 0.8, 0.0, 0.0}; // add new content
88 mem.write(w, e, a);
89
90 // Read again with the same query; output should shift towards a
91 vector<double> out2 = mem.read(q, 0.5);
92 print_vec(out2, "Readout after writes");
93
94 // Demonstrate near-argmax behavior with low temperature
95 vector<double> out3 = mem.read(q, /*temperature=*/0.05);
96 print_vec(out3, "Readout with low temperature (near-hard)");
97
98 return 0;
99}
100

This program implements a small key-value memory with cosine-similarity attention for reading and an erase–add rule for differentiable writing. The read function turns cosine similarities into a softmax distribution (with temperature) over slots and returns the weighted sum of values. The write function gates erasure and addition per slot, demonstrating how memory can be updated smoothly. The example initializes nearly orthogonal keys and one-hot values, reads a target slot, performs a partial write to adjust a slot’s content, and shows how outputs shift.

Time: Read: O(slots * d_key + slots * d_val). Write: O(slots * d_val).Space: O(slots * d_key + slots * d_val) for stored memory.
#key-value memory#attention#scaled dot-product#cosine similarity#softmax#masking#temperature#multi-head attention#content-based addressing#neural turing machine#memory networks#kv cache#sparse attention#differentiable write#transformer