🎓How I Study AIHISA
📖Read
📄Papers📰Blogs🎬Courses
💡Learn
🛤️Paths📚Topics💡Concepts🎴Shorts
🎯Practice
⏱️Coach🧩Problems🧠Thinking🎯Prompts🧠Review
SearchSettings
How I Study AI - Learn AI Papers & Lectures the Easy Way
📚TheoryIntermediate

Self-Attention as Graph Neural Network

Key Points

  • •
    Self-attention can be viewed as message passing on a fully connected graph where each token (node) sends a weighted message to every other token.
  • •
    Queries and keys compute edge weights (attention scores), while values are the messages that get aggregated to update each node’s representation.
  • •
    In matrix form, attention equals softmax of a learned weighted adjacency matrix times the value matrix, exactly like a GNN layer with learned, input-dependent edges.
  • •
    Multi-head attention corresponds to running several message-passing channels in parallel and concatenating their outputs.
  • •
    Masks (causal, padding, or top-k sparsity) are equivalent to removing or down-weighting edges in the graph.
  • •
    The computational bottleneck is the dense n-by-n attention matrix, giving O(n2 d) time and O(n2) memory; sparse attention reduces this to O(m d) with m edges.
  • •
    Self-attention without positional information is permutation equivariant over nodes, just like many GNNs on sets.
  • •
    Viewing attention as a GNN clarifies design choices: normalization by softmax, residual connections, and neighborhood selection all map cleanly to graph operations.

Prerequisites

  • →Linear algebra (vectors, matrices, dot products) — Attention uses matrix multiplications and dot products to compute compatibility and aggregation.
  • →Probability and softmax — Attention weights are normalized with softmax and interpreted as probabilities over neighbors.
  • →Graph basics (nodes, edges, adjacency) — Understanding attention as message passing relies on graph terminology and structures.
  • →Neural network layers and training — Q/K/V are learned projections; residuals and normalization stabilize deep models.
  • →Transformer architecture overview — Places self-attention within the larger block used in practice (multi-head, MLP, residuals).
  • →Time/space complexity analysis — Dense attention is quadratic; sparse strategies trade accuracy for efficiency.
  • →Numerical stability techniques — Stable softmax and masking are essential to avoid overflow/underflow.

Detailed Explanation

Tap terms for definitions

01Overview

Self-attention is the core operation behind Transformers. Conceptually, each item in a sequence (a word, token, or feature vector) looks at all other items and decides how much to pay attention to them. This decision is learned via two projections called queries and keys, whose dot product produces a compatibility score between any pair of items. After normalizing these scores with a softmax, we take a weighted sum of value vectors to produce the updated representation for each item. A powerful perspective is to see self-attention as a graph neural network (GNN) on a fully connected graph. Each token is a node, every pair of tokens shares a directed edge, and the attention weights act as edge strengths. The updated representation of a node is the aggregation of messages arriving from its neighbors, where a message is a transformed value and the edge weight is the attention coefficient. This graph view unifies Transformers and GNNs: masking corresponds to removing edges; multi-head attention corresponds to multiple parallel message functions; and sparsity corresponds to restricting the neighborhood. It also makes complexity trade-offs clearer: dense attention is O(n^2 d), motivating sparse or local neighborhoods for long contexts. Overall, the equivalence provides both intuition and practical tools for modifying attention with graph ideas (top-k edges, edge dropout, normalization) or adapting GNNs to sequence tasks.

02Intuition & Analogies

Imagine a group discussion where each participant (node) listens to everyone else but chooses who to trust more based on how relevant their statements are. A participant crafts a question (query) describing what they want to know and compares it to how others present their information (keys). If someone’s presentation closely matches the question, the listener assigns them a high attention score. Then the listener collects everyone’s statements (values), but averages them with weights proportional to those attention scores. The result is a tailored summary for that listener. In graph terms, the group of people forms a fully connected network. The attention scores are like adjustable volume knobs on the edges: if you trust someone, you turn up the volume of their message; if not, you turn it down. The final understanding each person gets is a weighted blend of the others’ contributions. If a rule forbids peeking into the future (causal mask), it’s like saying you can only listen to people who spoke earlier—edges to future nodes are muted. If you only have time for the top few opinions (top-k), you open just a few highest-volume channels and ignore the rest, making the conversation faster. Multi-head attention is like listening through multiple ears, each tuned to different aspects (syntax, long-range dependency, local context). You then combine these parallel impressions into a richer understanding. This analogy mirrors GNN message passing: nodes exchange information across edges with weights that depend on their features, and then aggregate to update their states.

03Formal Definition

Let X ∈ R^{n × d_{model}} be node features (tokens). Define learned projections WQ​ ∈ R^{d_{model} × dk​}, WK​ ∈ R^{d_{model} × dk​}, WV​ ∈ R^{d_{model} × dv​}. Compute Q=X WQ​, K=X WK​, V=X WV​. Scaled dot-product attention produces weights Aij​ = softmaxj​\!\left(dk​​qi​kj⊤​​ + M_{ij}\right), where Mij​ ∈ R ∪ \{-∞\} is an additive mask (e.g., causal or padding). The output is H' = A V, H' ∈ Rn×dv​. Interpret this as a directed fully connected graph G with nodes 1..n and edge weight from j to i equal to Aij​. Then attention realizes a GNN layer with message function m(hi​, hj​, eij​) = Aij​ Vj​ and sum aggregation: hi​' = ∑j=1n​ Aij​ Vj​. In general GNN notation, for neighborhood N(i) and parameters θ, hi​' = AGGj∈N(i)​ m_θ(hi​, hj​, eij​). Self-attention sets N(i) = \{1,…,n\} (subject to masking), uses a softmax-normalized, input-dependent adjacency A, and linear value projection V. Multi-head attention computes multiple A(h), V(h) in parallel and concatenates results, often followed by an output projection and residual + normalization.

04When to Use

  • Sequences and sets: When every element might depend on every other (language modeling, protein sequences, time series), attention’s fully connected message passing captures long-range interactions.
  • Permutation-sensitive vs. invariant tasks: Without positional encodings, self-attention is permutation equivariant, useful for set modeling; with positional signals, it models ordered sequences.
  • Graph problems with dense interactions: If your domain naturally forms a complete graph (e.g., all-to-all particle interactions, routing between all cities), attention is a natural GNN.
  • When you need adaptive neighborhoods: Attention learns edge weights from data. If you expect the relevant neighbors to change by context, attention is advantageous over fixed kernels.
  • Long-context efficiency trade-offs: For very long n, use sparse variants (local windows, top-k, block-sparse) to reduce O(n^2) cost. This aligns with graph techniques that prune edges or use locality.
  • Multimodal fusion: Viewing modalities as nodes in a complete graph, cross/self-attention aggregates evidence across modalities similarly to heterogeneous GNNs.

⚠️Common Mistakes

  • Forgetting the 1/\sqrt{d_k} scaling: Without it, dot products grow with dimension, causing softmax saturation and poor gradients. Always scale logits by \sqrt{d_k}.
  • Softmax over the wrong axis: Attention weights for node i must sum to 1 across j (row-wise softmax). Column-wise normalization is incorrect for the standard formulation.
  • Masking incorrectly: Add a large negative number (approx. -\infty) to disallowed logits before softmax, not after. Post-softmax zeroing breaks normalization and can leak probability mass.
  • Numerical instability: Compute softmax with logit-max subtraction to avoid overflow. Extremely negative masks should be applied before the subtraction step to preserve correctness.
  • Confusing Q/K/V roles: Keys are compared with queries to produce weights; values are what you aggregate. Mixing these leads to dimensional or conceptual errors.
  • Ignoring permutation properties: Self-attention without positional encodings is permutation equivariant; if your task needs order awareness, add positional signals.
  • Overlooking complexity: Dense attention is O(n^2 d) time and O(n^2) memory. For large n, use sparse/top-k masks or chunked/flash attention implementations.
  • Missing residuals/normalization: In deep stacks, omit at your peril; they stabilize training and are standard in both Transformers and many GNN architectures.

Key Formulas

Linear Projections

Q=XWQ​,K=XWK​,V=XWV​

Explanation: Inputs X are projected to queries, keys, and values using learned matrices. These define how nodes attend to and message each other.

Scaled Dot-Product Attention Weights

Aij​=softmaxj​(dk​​qi​kj⊤​​+Mij​)

Explanation: Attention weights are a row-wise softmax of scaled dot products plus an optional mask. Each row sums to 1 so weights form a probability distribution over neighbors.

Attention Output

H′=AV

Explanation: The new node features are the attention-weighted average of values. This is identical to multiplying the learned adjacency A by the value matrix V.

Node-wise Update

hi′​=j=1∑n​Aij​Vj​

Explanation: Each node aggregates messages from all nodes using attention weights as edge strengths. It is the standard GNN message passing with sum aggregation.

General GNN Layer

hi′​=AGGj∈N(i)​mθ​(hi​,hj​,eij​)

Explanation: A GNN updates node i by aggregating messages from its neighborhood using a learned message function. Self-attention sets N(i) to all nodes (subject to masking).

Row-Stochastic Weights

j=1∑n​Aij​=1,Aij​≥0

Explanation: After softmax, each row of the attention matrix is a probability distribution. This ensures normalized aggregation and stable scaling.

Matrix Form of Self-Attention

Attn(X)=softmax(dk​​XWQ​(XWK​)⊤​+M)XWV​

Explanation: Attention can be written compactly using matrix multiplication for efficient computation. The mask M encodes disabled or biased edges.

Multi-Head Combination

MultiHead(X)=Concat(H′(1),…,H′(H))WO​

Explanation: Multiple attention heads are computed in parallel and concatenated, then mixed by an output projection. This diversifies relational patterns captured.

Dense Attention Complexity

T(n,d)=O(n2d),S(n)=O(n2)

Explanation: Computing all n-by-n pairwise scores and normalizing them is quadratic in n. Memory is also quadratic due to storing logits or weights.

Top-k Sparse Complexity

T(n,k,d)=O(nkd),S(n,k)=O(nk)

Explanation: Restricting to k neighbors per node reduces both time and memory from quadratic to near-linear in n, assuming efficient sparse operations.

Permutation Equivariance

f(PX)=Pf(X)

Explanation: Without positional information, reordering inputs only reorders outputs. This mirrors GNNs on sets where node labels are arbitrary.

Complexity Analysis

Let n be the number of nodes/tokens, dm​odel the input dimension, and dk​, dv​ the key/value dimensions per head. For a single head, computing Q=X WQ​, K=X WK​, V=X WV​ costs O(n dm​odel dk​ + n dm​odel dk​ + n dm​odel dv​). For typical settings with dm​odel ≈ dk​ ≈ dv​=d, this is O(n d2). The dominant cost is forming pairwise logits S=Q KT / sqrt(dk​), which requires O(n2 d). Row-wise softmax over n rows is O(n2), and the final product H' = A V is O(n2 dv​) ≈ O(n2 d). Thus, dense self-attention time is T(n, d) = O(n2 d), and space is dominated by storing S or A, giving S(n) = O(n2). For H heads, cost scales linearly: O(H n2 dh​ead), where d_head=dm​odel/H, so the overall complexity remains O(n2 dm​odel). Batching B sequences multiplies time and memory by B. Masks do not change asymptotic cost unless they enforce sparsity that is also exploited computationally. Sparse/top-k attention reduces the number of active edges from n2 to m=O(n k). Computing logits only for selected neighbors and aggregating over them yields O(m d) time and O(m) memory. If selecting top-k per node requires a full sort, the selection step is O(n2 log n) naively; practical systems use approximate top-k or locality constraints to achieve O(n k) selection. Local/windowed attention similarly achieves O(n w d) with window size w. Memory-efficient implementations (e.g., flash attention) reduce memory to O(n d) while preserving O(n2 d) time by recomputing softmax blocks on the fly. In graph terms, attention behaves like a GNN on a dense graph with m=n2 edges; moving to sparse neighborhoods directly improves both time and space while keeping the same message-passing semantics.

Code Examples

Dense single-head self-attention as fully connected message passing
1#include <bits/stdc++.h>
2using namespace std;
3
4using Matrix = vector<vector<double>>;
5
6Matrix matmul(const Matrix& A, const Matrix& B) {
7 size_t n = A.size(), m = A[0].size(), p = B[0].size();
8 Matrix C(n, vector<double>(p, 0.0));
9 for (size_t i = 0; i < n; ++i)
10 for (size_t k = 0; k < m; ++k) {
11 double aik = A[i][k];
12 for (size_t j = 0; j < p; ++j)
13 C[i][j] += aik * B[k][j];
14 }
15 return C;
16}
17
18Matrix transpose(const Matrix& A) {
19 size_t n = A.size(), m = A[0].size();
20 Matrix AT(m, vector<double>(n));
21 for (size_t i = 0; i < n; ++i)
22 for (size_t j = 0; j < m; ++j)
23 AT[j][i] = A[i][j];
24 return AT;
25}
26
27vector<double> softmax(const vector<double>& v) {
28 double m = *max_element(v.begin(), v.end());
29 double sum = 0.0;
30 vector<double> e(v.size());
31 for (size_t i = 0; i < v.size(); ++i) {
32 e[i] = exp(v[i] - m); // numerical stability
33 sum += e[i];
34 }
35 for (size_t i = 0; i < v.size(); ++i) e[i] /= (sum + 1e-12);
36 return e;
37}
38
39Matrix rowwise_softmax(const Matrix& A) {
40 Matrix S(A.size(), vector<double>(A[0].size()));
41 for (size_t i = 0; i < A.size(); ++i) {
42 S[i] = softmax(A[i]);
43 }
44 return S;
45}
46
47Matrix add(const Matrix& A, const Matrix& B) {
48 Matrix C = A;
49 for (size_t i = 0; i < A.size(); ++i)
50 for (size_t j = 0; j < A[0].size(); ++j)
51 C[i][j] += B[i][j];
52 return C;
53}
54
55Matrix scale(const Matrix& A, double s) {
56 Matrix B = A;
57 for (auto& r : B) for (auto& x : r) x *= s;
58 return B;
59}
60
61void printMatrix(const string& name, const Matrix& M, int prec=4) {
62 cout << name << " (" << M.size() << "x" << M[0].size() << ")\n";
63 cout.setf(std::ios::fixed); cout << setprecision(prec);
64 for (const auto& row : M) {
65 for (double x : row) cout << setw(9) << x << ' ';
66 cout << '\n';
67 }
68 cout.unsetf(std::ios::fixed);
69 cout << "\n";
70}
71
72int main() {
73 // Toy data: n nodes, d_model input dim, d_k key/query dim, d_v value dim
74 int n = 4, d_model = 6, d_k = 3, d_v = 4;
75 std::mt19937 rng(42);
76 std::normal_distribution<double> N(0.0, 1.0);
77
78 auto randMat = [&](int r, int c) {
79 Matrix M(r, vector<double>(c));
80 for (int i = 0; i < r; ++i)
81 for (int j = 0; j < c; ++j)
82 M[i][j] = N(rng) / sqrt((double)c);
83 return M;
84 };
85
86 Matrix X = randMat(n, d_model);
87 Matrix WQ = randMat(d_model, d_k);
88 Matrix WK = randMat(d_model, d_k);
89 Matrix WV = randMat(d_model, d_v);
90
91 // Q, K, V projections
92 Matrix Q = matmul(X, WQ); // (n x d_k)
93 Matrix K = matmul(X, WK); // (n x d_k)
94 Matrix V = matmul(X, WV); // (n x d_v)
95
96 // Scaled dot-product logits S = Q K^T / sqrt(d_k)
97 Matrix KT = transpose(K);
98 Matrix S = matmul(Q, KT);
99 S = scale(S, 1.0 / sqrt((double)d_k));
100
101 // (Optional) mask M could be added here as S = S + M
102
103 // Attention weights A = softmax_row(S)
104 Matrix A = rowwise_softmax(S);
105
106 // Output H' = A V
107 Matrix H = matmul(A, V);
108
109 printMatrix("Attention weights A", A);
110 printMatrix("Output H'", H);
111
112 // Sanity: each row of A should sum to ~1
113 cout.setf(std::ios::fixed); cout << setprecision(6);
114 cout << "Row sums of A: ";
115 for (int i = 0; i < n; ++i) {
116 double sum = accumulate(A[i].begin(), A[i].end(), 0.0);
117 cout << sum << ' ';
118 }
119 cout << "\n";
120 return 0;
121}
122

This program computes single-head scaled dot-product self-attention as matrix operations. It forms Q, K, V via learned projections, computes logits S = QK^T / sqrt(d_k), applies a row-wise softmax to obtain attention weights A (a learned, row-stochastic adjacency), and aggregates values via H' = A V. The printed A is exactly the edge-weight matrix of a fully connected, directed graph used for message passing.

Time: O(n^2 d + n d^2) ≈ O(n^2 d) for d_model ≈ d_k ≈ d_v = dSpace: O(n^2) to store logits/weights plus O(n d) for activations
Message passing view: aggregate over edges derived from attention
1#include <bits/stdc++.h>
2using namespace std;
3using Matrix = vector<vector<double>>;
4
5Matrix matmul(const Matrix& A, const Matrix& B) {
6 size_t n = A.size(), m = A[0].size(), p = B[0].size();
7 Matrix C(n, vector<double>(p, 0.0));
8 for (size_t i = 0; i < n; ++i)
9 for (size_t k = 0; k < m; ++k) {
10 double aik = A[i][k];
11 for (size_t j = 0; j < p; ++j)
12 C[i][j] += aik * B[k][j];
13 }
14 return C;
15}
16
17Matrix transpose(const Matrix& A) {
18 size_t n = A.size(), m = A[0].size();
19 Matrix AT(m, vector<double>(n));
20 for (size_t i = 0; i < n; ++i)
21 for (size_t j = 0; j < m; ++j)
22 AT[j][i] = A[i][j];
23 return AT;
24}
25
26vector<double> softmax(const vector<double>& v) {
27 double m = *max_element(v.begin(), v.end());
28 double sum = 0.0;
29 vector<double> e(v.size());
30 for (size_t i = 0; i < v.size(); ++i) { e[i] = exp(v[i] - m); sum += e[i]; }
31 for (size_t i = 0; i < v.size(); ++i) e[i] /= (sum + 1e-12);
32 return e;
33}
34
35Matrix rowwise_softmax(const Matrix& A) {
36 Matrix S(A.size(), vector<double>(A[0].size()));
37 for (size_t i = 0; i < A.size(); ++i) S[i] = softmax(A[i]);
38 return S;
39}
40
41int main(){
42 // Dimensions
43 int n = 3, d_model = 4, d_k = 2, d_v = 3;
44
45 // Small fixed matrices for reproducibility
46 Matrix X = {{0.2, 1.0, -0.5, 0.3},
47 {1.1, -0.2, 0.7, -0.3},
48 {-0.4, 0.5, 0.9, -1.2}};
49 Matrix WQ = {{0.1, -0.2}, {0.3, 0.4}, {-0.5, 0.6}, {0.7, -0.8}};
50 Matrix WK = {{-0.4, 0.1}, {0.2, -0.3}, {0.5, 0.6}, {-0.7, 0.8}};
51 Matrix WV = {{0.2, -0.1, 0.3}, {-0.4, 0.5, 0.1}, {0.6, -0.2, -0.3}, {0.7, 0.9, -0.6}};
52
53 // Compute Q, K, V
54 Matrix Q = matmul(X, WQ);
55 Matrix K = matmul(X, WK);
56 Matrix V = matmul(X, WV);
57
58 // Logits S = Q K^T / sqrt(d_k)
59 Matrix KT = transpose(K);
60 Matrix S(n, vector<double>(n, 0.0));
61 for (int i = 0; i < n; ++i)
62 for (int j = 0; j < n; ++j)
63 for (int t = 0; t < d_k; ++t)
64 S[i][j] += Q[i][t] * KT[t][j];
65 double scale = 1.0 / sqrt((double)d_k);
66 for (int i = 0; i < n; ++i) for (int j = 0; j < n; ++j) S[i][j] *= scale;
67
68 // Attention weights A (adjacency)
69 Matrix A = rowwise_softmax(S);
70
71 // MESSAGE PASSING: h_i' = sum_j A_ij * V_j
72 Matrix H(n, vector<double>(d_v, 0.0));
73 for (int i = 0; i < n; ++i) {
74 for (int j = 0; j < n; ++j) {
75 double w = A[i][j]; // edge weight from j -> i
76 for (int c = 0; c < d_v; ++c) H[i][c] += w * V[j][c];
77 }
78 }
79
80 // Print A and H
81 cout.setf(std::ios::fixed); cout << setprecision(4);
82 cout << "Attention (Adjacency) A:\n";
83 for (int i = 0; i < n; ++i) { for (int j = 0; j < n; ++j) cout << setw(9) << A[i][j] << ' '; cout << '\n'; }
84 cout << "\nAggregated output H':\n";
85 for (int i = 0; i < n; ++i) { for (int c = 0; c < d_v; ++c) cout << setw(9) << H[i][c] << ' '; cout << '\n'; }
86 return 0;
87}
88

This example constructs attention weights A and then explicitly performs message passing: for each node i, it sums incoming messages from all nodes j, each message being V_j scaled by A_{ij}. The result H' matches the matrix product A V, making the GNN equivalence concrete.

Time: O(n^2 d + n d^2) ≈ O(n^2 d)Space: O(n^2) for A plus O(n d) for Q, K, V
Causal masked and top-k sparse attention (edge pruning)
1#include <bits/stdc++.h>
2using namespace std;
3using Matrix = vector<vector<double>>;
4
5vector<double> softmax(const vector<double>& v) {
6 double m = *max_element(v.begin(), v.end());
7 double sum = 0.0; vector<double> e(v.size());
8 for (size_t i = 0; i < v.size(); ++i) { e[i] = exp(v[i] - m); sum += e[i]; }
9 for (size_t i = 0; i < v.size(); ++i) e[i] /= (sum + 1e-12);
10 return e;
11}
12
13// Select top-k indices by value (ties arbitrary)
14vector<int> topk_indices(const vector<double>& v, int k) {
15 vector<pair<double,int>> vp; vp.reserve(v.size());
16 for (int i = 0; i < (int)v.size(); ++i) vp.push_back({v[i], i});
17 nth_element(vp.begin(), vp.begin() + min(k, (int)vp.size()) - 1, vp.end(),
18 [](auto& a, auto& b){ return a.first > b.first; });
19 int kk = min(k, (int)vp.size());
20 vector<pair<double,int>> top(vp.begin(), vp.begin()+kk);
21 sort(top.begin(), top.end(), [](auto& a, auto& b){ return a.first > b.first; });
22 vector<int> idx; idx.reserve(kk);
23 for (auto &p : top) idx.push_back(p.second);
24 return idx;
25}
26
27int main(){
28 int n = 6, d_k = 4; // focusing on logits computation only (Q, K precomputed)
29 std::mt19937 rng(0);
30 std::normal_distribution<double> N(0.0, 1.0);
31
32 // Random Q, K, V
33 vector<vector<double>> Q(n, vector<double>(d_k)), K(n, vector<double>(d_k));
34 int d_v = 3; vector<vector<double>> V(n, vector<double>(d_v));
35 for (int i = 0; i < n; ++i) {
36 for (int t = 0; t < d_k; ++t) { Q[i][t] = N(rng)/sqrt(d_k); K[i][t] = N(rng)/sqrt(d_k); }
37 for (int t = 0; t < d_v; ++t) V[i][t] = N(rng)/sqrt(d_v);
38 }
39
40 // Compute dense logits S = Q K^T / sqrt(d_k)
41 vector<vector<double>> S(n, vector<double>(n, 0.0));
42 for (int i = 0; i < n; ++i)
43 for (int j = 0; j < n; ++j)
44 for (int t = 0; t < d_k; ++t)
45 S[i][j] += Q[i][t] * K[j][t];
46 double scale = 1.0 / sqrt((double)d_k);
47 for (int i = 0; i < n; ++i) for (int j = 0; j < n; ++j) S[i][j] *= scale;
48
49 // Apply causal mask: disallow edges to future nodes (j > i)
50 const double NEG_INF = -1e9;
51 for (int i = 0; i < n; ++i)
52 for (int j = i+1; j < n; ++j)
53 S[i][j] = NEG_INF;
54
55 // Top-k sparsification per node: keep only k strongest allowed neighbors
56 int k = 2;
57 vector<vector<int>> neighbors(n);
58 vector<vector<double>> sparseWeights(n); // normalized over kept neighbors
59 for (int i = 0; i < n; ++i) {
60 // Only consider allowed j (S[i][j] > NEG_INF/2)
61 vector<double> allowed;
62 vector<int> allowed_idx;
63 for (int j = 0; j < n; ++j) if (S[i][j] > -1e8) { allowed.push_back(S[i][j]); allowed_idx.push_back(j); }
64 // Select top-k among allowed
65 vector<int> top_local = topk_indices(allowed, min(k, (int)allowed.size()));
66 // Build logits for just top-k and softmax normalize
67 vector<double> top_logits; top_logits.reserve(top_local.size());
68 neighbors[i].reserve(top_local.size());
69 for (int id : top_local) { top_logits.push_back(allowed[id]); neighbors[i].push_back(allowed_idx[id]); }
70 vector<double> probs = softmax(top_logits);
71 sparseWeights[i] = probs; // normalized over kept neighbors
72 }
73
74 // Aggregate using sparse edges: h_i' = sum_{j in N_k(i)} alpha_ij * V_j
75 vector<vector<double>> H(n, vector<double>(d_v, 0.0));
76 for (int i = 0; i < n; ++i) {
77 for (size_t e = 0; e < neighbors[i].size(); ++e) {
78 int j = neighbors[i][e]; double w = sparseWeights[i][e];
79 for (int t = 0; t < d_v; ++t) H[i][t] += w * V[j][t];
80 }
81 }
82
83 // Print sparse neighbors and weights
84 cout.setf(std::ios::fixed); cout << setprecision(4);
85 for (int i = 0; i < n; ++i) {
86 cout << "Node " << i << " neighbors: ";
87 for (size_t e = 0; e < neighbors[i].size(); ++e)
88 cout << "(" << neighbors[i][e] << ", w=" << sparseWeights[i][e] << ") ";
89 cout << "\n";
90 }
91 cout << "\nSparse output H':\n";
92 for (int i = 0; i < n; ++i) { for (int t = 0; t < d_v; ++t) cout << setw(9) << H[i][t] << ' '; cout << '\n'; }
93 return 0;
94}
95

This code demonstrates two graph operations on attention: (1) a causal mask removes edges to future nodes (j > i), and (2) top-k sparsification keeps only the k strongest incoming edges per node and renormalizes their weights with a softmax. The result is a sparse message-passing update that approximates dense attention while reducing the number of edges.

Time: O(n^2 d_k) to compute logits plus O(n k log n) for naive top-k selection; aggregation is O(n k d_v)Space: O(n^2) for dense logits if computed fully; the sparse representation stores O(n k) edges
#self-attention#graph neural network#message passing#transformer#queries keys values#softmax#adjacency matrix#multi-head attention#sparse attention#causal mask#top-k neighbors#permutation equivariance#scaled dot-product#complexity analysis