Multi-Head Attention
Key Points
- •Multi-Head Attention runs several attention mechanisms in parallel so each head can focus on different relationships in the data.
- •Each head uses its own learned projections of queries, keys, and values, then their outputs are concatenated and linearly mixed.
- •The core operation is scaled dot-product attention: softmax of Q scaled by 1/sqrt(), multiplied by V.
- •Scaling by 1/sqrt() keeps gradients stable and prevents extremely peaked or flat softmax distributions.
- •Masks (e.g., causal or padding masks) are added to attention scores to block illegal or irrelevant positions.
- •Multi-Head Attention enables models like Transformers to capture long-range dependencies efficiently.
- •Time complexity is dominated by forming the attention matrix, which is O(h ) for sequence length T and h heads.
- •Numerical stability requires subtracting the row-wise max before softmax and careful treatment of masks.
Prerequisites
- →Linear Algebra (vectors, matrices, transpose, matrix multiplication) — Attention is expressed with matrix multiplications, transposes, and concatenations.
- →Dot Product and Similarity — Attention scores are computed via dot products between queries and keys.
- →Softmax and Numerical Stability — Converting scores to probabilities requires softmax; stability tricks are essential.
- →Linear (Fully Connected) Layers — Q, K, V, and output projections are linear layers applied to token embeddings.
- →Basic Sequence Modeling Concepts — Understanding tokens, positions, and masks clarifies how attention operates over sequences.
Detailed Explanation
Tap terms for definitions01Overview
Multi-Head Attention (MHA) is a neural network component that learns how different parts of a sequence should pay attention to one another. Instead of computing a single attention distribution, MHA runs several attention "heads" in parallel. Each head has its own learned linear projections that map inputs into query (Q), key (K), and value (V) spaces. Within each head, attention scores are computed from Q and K, normalized via a softmax, and used to weight the values V. The multiple head outputs are then concatenated and passed through a final linear layer. This design allows the model to capture different types of relationships simultaneously—such as positional patterns, semantic similarity, or syntax-like structure—because different heads can specialize.
MHA is a central building block of the Transformer architecture, powering advances in language modeling, translation, vision transformers, and more. It replaces recurrence and convolutions with a flexible all-to-all interaction pattern controlled by learned weights and attention distributions. The key computational trick is the scaled dot-product attention, which ensures stable training by dividing raw dot products by the square root of the key dimension. MHA supports masking to enforce constraints (like causality in autoregressive models) and to ignore padded tokens. Overall, MHA balances expressiveness and parallelism: it is highly parallelizable on modern hardware and rich enough to capture long-range dependencies.
02Intuition & Analogies
Imagine you are reading a paragraph and trying to understand the meaning of a particular word. You might look back at some earlier words for context (like subjects or adjectives), glance ahead for clarifications, or even re-check names or places mentioned before. Now imagine several friends each doing this reading task but with different strategies: one focuses on matching names, another on verb tense, another on nearby words for context, and another on punctuation or structure. After they each bring back their findings, you combine everyone’s perspective to form a more complete understanding. That’s the spirit of Multi-Head Attention.
In this analogy, each "friend" is an attention head. The head first translates the current word into a question (query) about what it needs to know. It then compares the query to all other words, each represented as a key, to see which are most relevant to answering that question. The similarities form attention scores. The head then gathers information (values) mostly from the words it considers important and produces a weighted combination—a summary focusing on what it judged relevant.
Different heads have different perspectives because they learn different projections: they transform the same input into different Q, K, and V spaces. One head might learn to track coreference (who is doing what), another might track position or rhythm, and another might capture topical similarity. Finally, we stitch all these focused summaries together (concatenate) and mix them with a final linear layer. The result is a rich, multi-faceted representation of each token informed by the whole sequence.
03Formal Definition
04When to Use
Use Multi-Head Attention when you need a model to capture relationships among elements of a sequence (or set) without relying on recurrence or convolution. It shines in tasks where long-range dependencies matter and parallel processing is desired. Examples include:
- Natural language processing: machine translation, language modeling, summarization, question answering, where tokens must attend to semantically related tokens across long distances.
- Vision: Vision Transformers split images into patches and use MHA to relate patches across the whole image, capturing global context that can be hard for small convolutions.
- Speech and audio: relate segments across time for recognition, enhancement, or classification.
- Multimodal tasks: align information across modalities (text–image attention in image captioning or VQA).
- Retrieval and memory: attend over a set of candidate items, memory slots, or database entries.
Choose MHA when you need flexible, content-based interactions and the ability to learn multiple complementary patterns simultaneously. When sequence lengths are very large (T is big), consider efficient attention variants (sparse, low-rank, locality-sensitive hashing) to reduce O(T^2) costs. For strictly autoregressive generation, ensure causal masks are applied; for batched data with padding, include padding masks to avoid attending to meaningless tokens.
⚠️Common Mistakes
- Forgetting the scaling factor 1/\sqrt{d_k}: Without scaling, dot products grow with d_k, leading to saturated softmax (very peaked) and vanishing gradients. Always divide scores by \sqrt{d_k}.
- Incorrect masking: Using multiplicative masks with zeros post-softmax, or mismatched mask shapes, can leak information. Use additive masks before softmax with 0 for allowed and a large negative number (approximating -\infty) for disallowed positions.
- Unstable softmax: Computing \exp(scores) directly can overflow. Subtract the row-wise maximum score before exponentiation for numerical stability.
- Shape confusion: Mixing up (T \times d) vs. (d \times T), or concatenating along the wrong axis, leads to silent errors. Keep a clear convention and assert shapes.
- Weight sharing errors: Each head must have its own W^Q, W^K, W^V (unless using a fused projection). Accidentally reusing matrices breaks head diversity.
- Causal vs. bidirectional: Forgetting to switch masks between training (bidirectional in encoders) and generation (causal in decoders) can cause information leakage.
- Ignoring padding: Not masking out padded tokens allows attention to consider garbage positions, harming performance and sometimes causing NaNs.
- Mismatched dimensions: Ensure that h d_v equals the input dimension to W^O and that d_k, d_v are consistent across heads to simplify implementation.
Key Formulas
Scaled Dot-Product Attention
Explanation: Compute similarity scores via dot products between queries and keys, scale by 1/sqrt(), normalize with softmax, and form a weighted sum of values. This is the core operation inside each head.
Per-Head Output
Explanation: Each head first projects inputs into separate Q, K, V spaces, then applies scaled dot-product attention to produce its own representation.
Multi-Head Aggregation
Explanation: Concatenate the outputs from all h heads along the feature dimension and apply a final linear projection to return to the model dimension.
Parameter Count (without biases)
Explanation: Total parameters for separate per-head projections: , , for each head, plus the output projection . Bias terms, if used, add linear terms.
Masked Scores
Explanation: Add an attention mask M (0 for allowed, −∞ for disallowed) to the scaled scores before softmax. This enforces constraints like causality or padding.
Softmax
Explanation: Converts a vector of real numbers into a probability distribution that sums to 1. Used to turn attention scores into attention weights.
Time Complexity (single batch, self-attention)
Explanation: The T×T attention score matrix per head leads to the dominant h term; projections and the output layer add linear terms in T with matrix multiplications.
Memory Complexity (activations)
Explanation: Storing attention probabilities across h heads requires O(h ) memory, often the limiting factor for long sequences.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 using Matrix = vector<vector<double>>; 5 6 Matrix zeros(int r, int c){ return Matrix(r, vector<double>(c, 0.0)); } 7 Matrix transpose(const Matrix &A){ int r=A.size(), c=A[0].size(); Matrix T(c, vector<double>(r)); for(int i=0;i<r;++i) for(int j=0;j<c;++j) T[j][i]=A[i][j]; return T; } 8 Matrix matmul(const Matrix &A, const Matrix &B){ int r=A.size(), k=A[0].size(), c=B[0].size(); Matrix C(r, vector<double>(c, 0.0)); for(int i=0;i<r;++i) for(int j=0;j<c;++j){ double s=0; for(int t=0;t<k;++t) s += A[i][t]*B[t][j]; C[i][j]=s;} return C; } 9 Matrix add(const Matrix &A, const Matrix &B){ int r=A.size(), c=A[0].size(); Matrix C=A; for(int i=0;i<r;++i) for(int j=0;j<c;++j) C[i][j]+=B[i][j]; return C; } 10 Matrix scale(const Matrix &A, double s){ int r=A.size(), c=A[0].size(); Matrix C=A; for(int i=0;i<r;++i) for(int j=0;j<c;++j) C[i][j]*=s; return C; } 11 12 // Row-wise stable softmax 13 Matrix softmax_rows(const Matrix &S){ int r=S.size(), c=S[0].size(); Matrix P(r, vector<double>(c)); for(int i=0;i<r;++i){ double m = S[i][0]; for(int j=1;j<c;++j) m = max(m, S[i][j]); double sum=0.0; for(int j=0;j<c;++j){ P[i][j] = exp(S[i][j] - m); sum += P[i][j]; } for(int j=0;j<c;++j) P[i][j] /= (sum + 1e-12); } return P; } 14 15 // Build a causal mask of size T x T: 0 on/below diagonal, -1e9 above (disallowed) 16 Matrix causal_mask(int T){ Matrix M(T, vector<double>(T, 0.0)); const double NEG_INF = -1e9; for(int i=0;i<T;++i) for(int j=i+1;j<T;++j) M[i][j] = NEG_INF; return M; } 17 18 void print_matrix(const string &name, const Matrix &A, int prec=3){ cout << name << " (" << A.size() << "x" << A[0].size() << ")\n"; cout.setf(std::ios::fixed); cout << setprecision(prec); for(const auto &row: A){ for(double x: row) cout << setw(8) << x << ' '; cout << '\n'; } } 19 20 // Scaled Dot-Product Attention: softmax((QK^T)/sqrt(dk) + M) V 21 Matrix scaled_dot_product_attention(const Matrix &Q, const Matrix &K, const Matrix &V, const Matrix *mask){ 22 int Tq = Q.size(); 23 int dk = Q[0].size(); 24 int Tk = K.size(); 25 int dv = V[0].size(); 26 // Scores = Q K^T / sqrt(dk) 27 Matrix Kt = transpose(K); 28 Matrix S = matmul(Q, Kt); 29 double scale_factor = 1.0 / sqrt((double)dk); 30 S = scale(S, scale_factor); 31 // Add mask if provided 32 if(mask){ 33 const Matrix &M = *mask; 34 // M must be Tq x Tk; add elementwise (0 for allowed, -inf for disallowed) 35 for(int i=0;i<Tq;++i) for(int j=0;j<Tk;++j) S[i][j] += M[i][j]; 36 } 37 // Probabilities 38 Matrix P = softmax_rows(S); 39 // Output = P V 40 Matrix O = matmul(P, V); 41 return O; 42 } 43 44 int main(){ 45 // Example dimensions 46 int T = 4; // sequence length 47 int dk = 3; // key/query dim 48 int dv = 2; // value dim 49 50 // Toy Q, K, V 51 Matrix Q = { 52 {0.2, 0.1, 0.4}, 53 {0.0, 0.5, 0.3}, 54 {0.1, 0.0, 0.2}, 55 {0.3, 0.2, 0.1} 56 }; // T x dk 57 Matrix K = { 58 {0.2, 0.0, 0.1}, 59 {0.1, 0.4, 0.3}, 60 {0.3, 0.1, 0.2}, 61 {0.0, 0.2, 0.2} 62 }; // T x dk 63 Matrix V = { 64 { 0.5, 0.0}, 65 {-0.2, 0.1}, 66 { 0.3, -0.1}, 67 { 0.0, 0.2} 68 }; // T x dv 69 70 Matrix M = causal_mask(T); // prevent looking ahead 71 72 Matrix O = scaled_dot_product_attention(Q, K, V, &M); 73 74 print_matrix("Attention output O", O); 75 return 0; 76 } 77
This program implements the core scaled dot-product attention for a single head. It constructs scores S = QK^T, scales by 1/sqrt(d_k), adds a causal mask (upper triangle set to -inf), applies a numerically stable row-wise softmax, and multiplies by V to get the output. The example uses small toy matrices so you can inspect the numbers easily.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 using Matrix = vector<vector<double>>; 5 6 Matrix zeros(int r, int c){ return Matrix(r, vector<double>(c, 0.0)); } 7 Matrix transpose(const Matrix &A){ int r=A.size(), c=A[0].size(); Matrix T(c, vector<double>(r)); for(int i=0;i<r;++i) for(int j=0;j<c;++j) T[j][i]=A[i][j]; return T; } 8 Matrix matmul(const Matrix &A, const Matrix &B){ int r=A.size(), k=A[0].size(), c=B[0].size(); Matrix C(r, vector<double>(c, 0.0)); for(int i=0;i<r;++i) for(int j=0;j<c;++j){ double s=0; for(int t=0;t<k;++t) s += A[i][t]*B[t][j]; C[i][j]=s;} return C; } 9 Matrix add(const Matrix &A, const Matrix &B){ int r=A.size(), c=A[0].size(); Matrix C=A; for(int i=0;i<r;++i) for(int j=0;j<c;++j) C[i][j]+=B[i][j]; return C; } 10 Matrix scale(const Matrix &A, double s){ int r=A.size(), c=A[0].size(); Matrix C=A; for(int i=0;i<r;++i) for(int j=0;j<c;++j) C[i][j]*=s; return C; } 11 12 Matrix softmax_rows(const Matrix &S){ int r=S.size(), c=S[0].size(); Matrix P(r, vector<double>(c)); for(int i=0;i<r;++i){ double m = S[i][0]; for(int j=1;j<c;++j) m = max(m, S[i][j]); double sum=0.0; for(int j=0;j<c;++j){ P[i][j] = exp(S[i][j] - m); sum += P[i][j]; } for(int j=0;j<c;++j) P[i][j] /= (sum + 1e-12); } return P; } 13 14 Matrix causal_mask(int T){ Matrix M(T, vector<double>(T, 0.0)); const double NEG_INF = -1e9; for(int i=0;i<T;++i) for(int j=i+1;j<T;++j) M[i][j] = NEG_INF; return M; } 15 16 Matrix scaled_dot_product_attention(const Matrix &Q, const Matrix &K, const Matrix &V, const Matrix *mask){ 17 int dk = Q[0].size(); 18 Matrix S = matmul(Q, transpose(K)); 19 S = scale(S, 1.0 / sqrt((double)dk)); 20 if(mask){ 21 for(size_t i=0;i<S.size();++i) for(size_t j=0;j<S[0].size();++j) S[i][j] += (*mask)[i][j]; 22 } 23 Matrix P = softmax_rows(S); 24 return matmul(P, V); 25 } 26 27 // Initialize a matrix with small random values for demonstration 28 Matrix randn(int r, int c, std::mt19937 &gen, double scale=0.1){ 29 std::normal_distribution<double> dist(0.0, scale); 30 Matrix A(r, vector<double>(c)); 31 for(int i=0;i<r;++i) for(int j=0;j<c;++j) A[i][j] = dist(gen); 32 return A; 33 } 34 35 void print_matrix(const string &name, const Matrix &A, int prec=3, int max_rows=6){ 36 cout << name << " (" << A.size() << "x" << A[0].size() << ")\n"; 37 cout.setf(std::ios::fixed); cout << setprecision(prec); 38 int r=A.size(); int c=A[0].size(); 39 for(int i=0;i<min(r,max_rows);++i){ for(int j=0;j<c;++j) cout << setw(8) << A[i][j] << ' '; cout << '\n'; } 40 if(r>max_rows) cout << "...\n"; 41 } 42 43 int main(){ 44 // Dimensions 45 int T = 5; // sequence length 46 int d_model = 8; // model dimension 47 int h = 2; // number of heads 48 int d_k = 4; // per-head key/query dim (often d_model / h) 49 int d_v = 4; // per-head value dim (often d_model / h) 50 51 // Input sequence X: T x d_model (toy values) 52 Matrix X(T, vector<double>(d_model)); 53 for(int i=0;i<T;++i) for(int j=0;j<d_model;++j) X[i][j] = 0.01 * (i+1) * (j+1); 54 55 std::mt19937 gen(42); 56 57 // Per-head projection matrices 58 vector<Matrix> WQ(h), WK(h), WV(h); 59 for(int i=0;i<h;++i){ 60 WQ[i] = randn(d_model, d_k, gen); 61 WK[i] = randn(d_model, d_k, gen); 62 WV[i] = randn(d_model, d_v, gen); 63 } 64 // Output projection W^O: (h d_v) x d_model 65 Matrix WO = randn(h * d_v, d_model, gen); 66 67 // Optional mask (causal for autoregressive) 68 Matrix M = causal_mask(T); 69 70 // Compute per-head Q, K, V and attention outputs 71 vector<Matrix> heads(h); 72 for(int i=0;i<h;++i){ 73 Matrix Qi = matmul(X, WQ[i]); // T x d_k 74 Matrix Ki = matmul(X, WK[i]); // T x d_k 75 Matrix Vi = matmul(X, WV[i]); // T x d_v 76 heads[i] = scaled_dot_product_attention(Qi, Ki, Vi, &M); // T x d_v 77 } 78 79 // Concatenate heads along feature dimension: T x (h d_v) 80 Matrix Hcat(T, vector<double>(h * d_v)); 81 for(int t=0;t<T;++t){ 82 for(int i=0;i<h;++i){ 83 for(int j=0;j<d_v;++j){ 84 Hcat[t][i*d_v + j] = heads[i][t][j]; 85 } 86 } 87 } 88 89 // Final output projection: Y = Hcat * WO (T x d_model) 90 Matrix Y = matmul(Hcat, WO); 91 92 print_matrix("Input X", X); 93 print_matrix("Concatenated heads Hcat", Hcat); 94 print_matrix("Output Y", Y); 95 return 0; 96 } 97
This program performs a full Multi-Head Attention forward pass for a single sequence (no batching). It creates per-head projection matrices W^Q, W^K, W^V, applies them to X, runs scaled dot-product attention with a causal mask in each head, concatenates head outputs, and applies the final projection W^O. The dimensions are small and randomly initialized for demonstration.