🎓How I Study AIHISA
📖Read
📄Papers📰Blogs🎬Courses
💡Learn
🛤️Paths📚Topics💡Concepts🎴Shorts
🎯Practice
⏱️Coach🧩Problems🧠Thinking🎯Prompts🧠Review
SearchSettings
How I Study AI - Learn AI Papers & Lectures the Easy Way
📚TheoryIntermediate

Message Passing Framework

Key Points

  • •
    Message Passing Neural Networks (MPNNs) learn on graphs by letting nodes repeatedly exchange and aggregate messages from their neighbors.
  • •
    Each layer computes neighbor messages with a message function M, aggregates them with AGG (like sum/mean/max/attention), and updates node states with an update function U.
  • •
    The core equations are m_v = AGG_{u in N(v)} M(hu​, hv​, euv​) and hv​' = U(hv​, mv​).
  • •
    The framework unifies many GNNs, including Graph Convolutional Networks (GCN), Graph Attention Networks (GAT), and Graph Isomorphism Networks (GIN).
  • •
    Sum aggregation is powerful and preserves counts; mean smooths features; max picks strongest signals; attention learns importance weights.
  • •
    Time complexity is linear in the number of edges for most practical layers, typically O(∣E∣ d2) or O(∣E∣ d) depending on parameterization.
  • •
    Oversmoothing can occur if you average too many times; use residuals, normalization, or limited depth to avoid it.
  • •
    MPNNs can produce node-level, edge-level, or whole-graph embeddings by adding appropriate readout functions.

Prerequisites

  • →Linear algebra (vectors, matrices, matrix-vector multiplication) — Message and update functions are typically linear maps followed by nonlinearities; understanding dimensions and matrix multiplication is essential.
  • →Graph theory basics — You need to know nodes, edges, adjacency, and neighborhoods to understand message passing over graphs.
  • →Neural networks and activations — MPNNs are neural networks; concepts like layers, activations (ReLU), and parameters generalize to graphs.
  • →Softmax and numerical stability — Attention-based MPNNs use softmax; stable implementation avoids overflow by subtracting the maximum logit.
  • →Time and space complexity analysis — Design choices trade off accuracy and speed; understanding O(|E|) vs O(|V|) terms helps scale to large graphs.

Detailed Explanation

Tap terms for definitions

01Overview

A Message Passing Neural Network (MPNN) is a general recipe for building neural networks that operate on graphs. In an MPNN, each node carries a feature vector (its "hidden state"). A layer of computation works by letting each node gather information from its neighbors through messages. First, for every edge (u → v), a message M(h_u, h_v, e_{uv}) is computed using the sender’s features, the receiver’s features, and optionally the edge’s features. Next, all incoming messages at v are aggregated with a permutation-invariant operator AGG (commonly sum, mean, max, or a learned attention-weighted sum). Finally, the node uses an update function U to combine its previous state with the aggregated message, producing a new state. Repeating this process for several layers lets information flow across multiple hops in the graph. The MPNN framework captures a wide range of known graph neural networks as special cases. For example, Graph Convolutional Networks are message passing with normalized mean aggregation and a linear update, while Graph Attention Networks use learned attention weights inside AGG. Because MPNNs are invariant to node order and can incorporate edge features, they are a natural fit for molecules, social networks, knowledge graphs, and more. The framework is modular, separating design choices (message, aggregation, update) so you can adapt it to your domain while keeping time complexity roughly linear in edges.

02Intuition & Analogies

Imagine a group chat where each person (node) regularly summarizes what their friends (neighbors) are saying. In each round, everyone reads messages from their friends, blends them into a digest, and updates their opinion. After a few rounds, individuals have absorbed information that originated several hops away. That’s an MPNN: gossip with structure.

  • Messages: Think of each friend crafting a short note for you. They might tailor it depending on who you are (receiver-aware) or include relationship context (edge features like friendship strength or type).
  • Aggregation: When your inbox has many notes, you need a fair way to combine them regardless of order. Summing is like counting how many friends mention each topic. Averaging is like taking the consensus. Taking a max is like only keeping the strongest argument. Attention is like listening more carefully to trusted friends.
  • Update: After reading, you revise your opinion: combine your prior belief with what you learned. A simple linear update plus a nonlinearity is like taking a weighted compromise; a recurrent-style update (e.g., GRU) is like remembering long-term context while integrating fresh news. Two or three rounds spread information across 2–3 hops, often enough for local tasks. More rounds can blur everyone’s opinions (oversmoothing), making all nodes sound alike. To avoid that, you can keep some individuality (residual connections), normalize inputs (batch/graph norm), or constrain depth. The beauty of MPNNs is that they respect graph structure: you only hear from neighbors, and the combination doesn’t depend on how nodes are indexed, only on who is connected to whom.

03Formal Definition

Given a graph G = (V, E) with node features h_v(t) ∈ Rdt​ at layer t and (optional) edge features euv​ ∈ Rde​, an MPNN layer is defined by: 1) Message computation per edge: m_{u \to v}^{(t)} = M_t\big(h_u^{(t)}, h_v(t), e_{uv}\big). 2) Node-wise aggregation: m_v(t) = \operatorname{AGG}\big(\{ mu→v(t)​ : u ∈ N(v) \}\big), where AGG is permutation-invariant (sum, mean, max, attention-weighted sum, etc.). 3) Node update: h_v^{(t+1)} = U_t\big(h_v^{(t)}, m_v^{(t)}\big). A common linear instantiation is: Mt​(hu​, hv​, euv​) = \phi\big(W_m(t) [hu​ ∥ euv​] + b_m^{(t)}\big) and Ut​(hv​, mv​) = \phi\big(W_u(t) [hv​ ∥ mv​] + b_u^{(t)}\big), where [⋅ ∥ ⋅] denotes concatenation and ϕ is a nonlinearity (e.g., ReLU). Attention uses learned coefficients αuv(t)​ with ∑u∈N(v)​ αuv(t)​ = 1. Permutation invariance of AGG ensures node updates are independent of neighbor ordering. Stacking T layers yields T-hop receptive fields. A readout function R(\{h_v(T)\}) produces graph-level embeddings, typically via a sum/mean/max or an attention readout.

04When to Use

Use MPNNs whenever your data is naturally a graph and you need to respect connectivity and permutation invariance.

  • Molecules and materials: Predict properties (energy, solubility) from atoms (nodes) and bonds (edges). Edge features encode bond types; message functions can incorporate chemistry-aware transforms.
  • Social and communication networks: Infer user attributes or community structures from interaction graphs. Attention can downweight noisy neighbors.
  • Knowledge graphs and recommender systems: Nodes are entities/items; edges are relations/interactions. Edge features or relation-specific transformations fit well into MPNN message functions.
  • Program analysis and compilers: Control/data-flow graphs where messages propagate semantics along edges.
  • Traffic and sensor networks: Nodes as sensors/locations; edges as roads/flows. Temporal extensions pass messages over time and space. Prefer simple sum/mean aggregators for speed and stability when neighbor importance is roughly uniform. Use attention when neighbors contribute unequally or the graph is noisy. Limit depth (2–4 layers) for local tasks; consider residuals and normalization for deeper models. For whole-graph predictions, follow with a readout pooling. If you need expressivity close to the Weisfeiler–Lehman test, consider GIN-style sum + MLP updates.

⚠️Common Mistakes

  • Ignoring permutation invariance: Using an order-sensitive aggregator (e.g., concatenating neighbors in arbitrary order) breaks graph symmetry and leads to unstable behavior. Always use sum/mean/max/attention or another invariant reducer.
  • Oversmoothing from too many layers: Repeated averaging drives node features toward a common subspace, erasing distinctions. Mitigate with residual connections, normalization, careful depth, and potentially jumping knowledge connections.
  • Degree bias with mean aggregation: High-degree nodes’ messages get diluted. Consider sum with normalization, degree-aware scaling, or attention to rebalance.
  • Mismatched dimensions and silent broadcasting bugs: Concatenations and matrix multiplications must align exactly; add explicit assertions for shape checks in C++.
  • Neglecting edge directionality: Treating a directed graph as undirected (or vice versa) changes semantics. Create both directions explicitly if needed and set distinct edge features.
  • Not handling isolated nodes: Nodes with no neighbors should get a well-defined aggregated message (e.g., zeros) to avoid NaNs.
  • Numerical instability in attention: Large logits cause overflow in exp. Use max-subtraction trick in softmax and bounded activations (e.g., LeakyReLU).

Key Formulas

MPNN Aggregation

mv(t)​=AGGu∈N(v)​Mt​(hu(t)​,hv(t)​,euv​)

Explanation: Node v collects messages from its neighbors using a message function M and combines them with a permutation-invariant aggregator. This summarizes neighbor information into a single vector.

MPNN Update

hv(t+1)​=Ut​(hv(t)​,mv(t)​)

Explanation: The node's new state is computed from its previous state and the aggregated neighbor message. This is analogous to taking a step of information integration.

Linear Message with Edge Features

Mt​(hu​,hv​,euv​)=ϕ(Wm(t)​[hu​∥euv​]+bm(t)​)

Explanation: A common message parameterization linearly transforms the sender and edge features followed by a nonlinearity. It is simple, fast, and effective in many domains.

Linear Update

Ut​(hv​,mv​)=ϕ(Wu(t)​[hv​∥mv​]+bu(t)​)

Explanation: The update concatenates the old node state and the aggregated message, applies a linear map, and then a nonlinearity. This mixes self-information with neighborhood context.

Attention Coefficients

αuv(t)​=softmaxu∈N(v)​(a(t)⊤[W(t)hu(t)​∥W(t)hv(t)​])

Explanation: Attention assigns an importance weight to each neighbor’s contribution, normalized to sum to one across neighbors. It helps focus on the most relevant neighbors.

Attention-weighted Message

mv(t)​=u∈N(v)∑​αuv(t)​W(t)hu(t)​

Explanation: The aggregated message for attention-based layers is a weighted sum of transformed neighbor features. Weights are the attention coefficients.

Softmax

softmax(zi​)=∑j​exp(zj​)exp(zi​)​

Explanation: Converts raw scores into probabilities that sum to one. In attention, it normalizes neighbor importance.

GCN Layer

H(t+1)=σ(A~H(t)W),A~=D~−1/2(A+I)D~−1/2

Explanation: This is a specific MPNN where messages are mean-aggregated with symmetric degree normalization and self-loops. It stabilizes training by balancing node degrees.

GIN Update

hv(t+1)​=MLP(t)​(1+ϵ(t))hv(t)​+u∈N(v)∑​hu(t)​​

Explanation: The Graph Isomorphism Network uses sum aggregation and an MLP to reach 1-WL expressivity. The trainable \(ϵ\) controls self-contribution.

Per-layer Time Complexity

T=O(∣E∣dm​(dx​+de​)+∣V∣do​(dx​+dm​))

Explanation: Computing messages scales with the number of edges and dimensions; updates scale with the number of nodes. This highlights near-linear complexity in graph size.

Oversmoothing Limit (Intuition)

t→∞lim​A~tH(0)=1π⊤H(0)

Explanation: Repeated normalized averaging approaches a rank-1 projection for connected, aperiodic graphs, making node features similar. This explains oversmoothing with many layers.

Graph-level Readout

hG​=READOUT({hv(T)​:v∈V})=v∈V∑​hv(T)​

Explanation: A simple way to turn node embeddings into a single graph embedding is to sum them. Other options are mean, max, or attention-based pooling.

Complexity Analysis

Let ∣V∣ be the number of nodes, ∣E∣ the number of (directed) edges processed per layer, dx​ the node feature size, de​ the edge feature size, dm​ the message size, and do​ the output node size. For a linear message M(hu​, hv​, euv​) = φ(Wm​ [hu​ |∣euv​]+bm​)thatignoreshv​inpractice(commonforefficiency),eachedgerequiresO(dm​(dx​+de​))multiplications.Summedoveredges,messagecomputationcostsO(∣E∣dm​(dx​+de​)).Aggregationitselfisanelement−wisereductionandcostsO(∣E∣dm​).FortheupdateU(hv​,mv​)=φ(Wu​[hv​∣∣mv​]+bu​),eachnoderequiresO(do​(dx​+dm​)),totalingO(∣V∣do​(dx​+dm​)).Therefore,thetotalper−layertimeisroughlyO(∣E∣dm​(dx​+de​)+∣V∣do​(dx​+dm​)),whichisnear−linearingraphsizewhendimensionsarefixed.Attention−basedlayersaddthecostofcomputingattentionlogitsperedgeandaper−nodesoftmax.Iflogitsuseavectordotproductonprojectedfeatures,thisisO(∣E∣do​)extra;ifamatrixisused,itcanbeO(∣E∣do2​).Numericallystablesoftmaxalsorequiresonepasstofindthemaxandonetonormalize,butremainslinearinneighbors.SpacecomplexityperlayerisdominatedbystoringnodefeaturesO(∣V∣dx​),(optionally)edgefeaturesO(∣E∣de​),temporarymessagesO(∣E∣dm​)ifmaterialized,andnewnodefeaturesO(∣V∣do​).Manyimplementationsavoidstoringallmessagesbystreamingedges,reducingpeakmemorytoO(∣V∣do​+∣E|) besides parameters. Parameters scale as O(dm​ (dx​ + de​) + do​ (dx​ + dm​)) per layer for linear maps.

Code Examples

Sum/Mean/Max MPNN Layer with Edge Features (Single Forward Pass)
1#include <iostream>
2#include <vector>
3#include <cmath>
4#include <limits>
5#include <random>
6#include <iomanip>
7
8// Simple, self-contained C++ implementation of a single MPNN layer.
9// - Message: M(h_u, h_v, e_uv) = ReLU(W_m [h_u || e_uv] + b_m)
10// - Aggregation: sum | mean | max (element-wise)
11// - Update: U(h_v, m_v) = ReLU(W_u [h_v || m_v] + b_u)
12// This code demonstrates the mechanics of message passing, not training.
13
14struct Edge {
15 int to;
16 std::vector<double> ef; // edge feature
17};
18
19struct Graph {
20 int n; // number of nodes
21 std::vector<std::vector<Edge>> adj; // adjacency list with edge features
22};
23
24// Utility: element-wise ReLU
25static inline double relu(double x) { return x > 0 ? x : 0.0; }
26
27// Utility: apply ReLU to a vector
28std::vector<double> relu_vec(const std::vector<double>& x) {
29 std::vector<double> y = x;
30 for (double &v : y) v = relu(v);
31 return y;
32}
33
34// Matrix-vector multiply: y = W x (W: rows x cols)
35std::vector<double> matvec(const std::vector<std::vector<double>>& W, const std::vector<double>& x) {
36 int rows = (int)W.size();
37 int cols = (int)W[0].size();
38 std::vector<double> y(rows, 0.0);
39 for (int i = 0; i < rows; ++i) {
40 double s = 0.0;
41 for (int j = 0; j < cols; ++j) s += W[i][j] * x[j];
42 y[i] = s;
43 }
44 return y;
45}
46
47// Vector add: a + b (same size)
48std::vector<double> vec_add(const std::vector<double>& a, const std::vector<double>& b) {
49 int n = (int)a.size();
50 std::vector<double> c(n);
51 for (int i = 0; i < n; ++i) c[i] = a[i] + b[i];
52 return c;
53}
54
55// Concatenate two vectors
56std::vector<double> concat(const std::vector<double>& a, const std::vector<double>& b) {
57 std::vector<double> c;
58 c.reserve(a.size() + b.size());
59 c.insert(c.end(), a.begin(), a.end());
60 c.insert(c.end(), b.begin(), b.end());
61 return c;
62}
63
64// Initialize a matrix with small random numbers for demo
65std::vector<std::vector<double>> rand_matrix(int rows, int cols, std::mt19937& rng, double scale = 0.1) {
66 std::uniform_real_distribution<double> dist(-scale, scale);
67 std::vector<std::vector<double>> W(rows, std::vector<double>(cols));
68 for (int i = 0; i < rows; ++i)
69 for (int j = 0; j < cols; ++j)
70 W[i][j] = dist(rng);
71 return W;
72}
73
74// Initialize a vector with small random numbers for demo
75std::vector<double> rand_vector(int n, std::mt19937& rng, double scale = 0.1) {
76 std::uniform_real_distribution<double> dist(-scale, scale);
77 std::vector<double> b(n);
78 for (int i = 0; i < n; ++i) b[i] = dist(rng);
79 return b;
80}
81
82enum class Aggregator { SUM, MEAN, MAX };
83
84struct MPNNLayer {
85 int d_x; // node feature size in
86 int d_e; // edge feature size
87 int d_m; // message size (also output size here)
88 Aggregator agg;
89
90 // Parameters
91 std::vector<std::vector<double>> W_m; // d_m x (d_x + d_e)
92 std::vector<double> b_m; // d_m
93 std::vector<std::vector<double>> W_u; // d_m x (d_x + d_m)
94 std::vector<double> b_u; // d_m
95
96 MPNNLayer(int d_x_, int d_e_, int d_m_, Aggregator agg_, std::mt19937& rng)
97 : d_x(d_x_), d_e(d_e_), d_m(d_m_), agg(agg_) {
98 W_m = rand_matrix(d_m, d_x + d_e, rng);
99 b_m = rand_vector(d_m, rng);
100 W_u = rand_matrix(d_m, d_x + d_m, rng);
101 b_u = rand_vector(d_m, rng);
102 }
103
104 // Forward pass: returns new node features h'
105 std::vector<std::vector<double>> forward(
106 const Graph& G,
107 const std::vector<std::vector<double>>& node_feat // size: n x d_x
108 ) const {
109 int n = G.n;
110 std::vector<std::vector<double>> messages(n, std::vector<double>(d_m, 0.0));
111 std::vector<int> deg(n, 0);
112
113 // For MAX aggregator, initialize with -inf
114 if (agg == Aggregator::MAX) {
115 for (int v = 0; v < n; ++v) {
116 for (int k = 0; k < d_m; ++k) messages[v][k] = -std::numeric_limits<double>::infinity();
117 }
118 }
119
120 // 1) Compute and aggregate edge messages
121 for (int v = 0; v < n; ++v) {
122 for (const auto& e : G.adj[v]) {
123 int u = e.to;
124 std::vector<double> x = concat(node_feat[u], e.ef); // [h_u || e_uv]
125 std::vector<double> m = matvec(W_m, x); // W_m x
126 for (int i = 0; i < d_m; ++i) m[i] = relu(m[i] + b_m[i]); // ReLU + bias
127
128 if (agg == Aggregator::SUM || agg == Aggregator::MEAN) {
129 for (int i = 0; i < d_m; ++i) messages[v][i] += m[i];
130 } else if (agg == Aggregator::MAX) {
131 for (int i = 0; i < d_m; ++i) messages[v][i] = std::max(messages[v][i], m[i]);
132 }
133 }
134 deg[v] = (int)G.adj[v].size();
135 }
136
137 // If MEAN, divide by degree (handle isolated nodes)
138 if (agg == Aggregator::MEAN) {
139 for (int v = 0; v < n; ++v) {
140 if (deg[v] > 0) {
141 for (int i = 0; i < d_m; ++i) messages[v][i] /= (double)deg[v];
142 }
143 }
144 }
145
146 // If MAX and node is isolated, replace -inf by 0
147 if (agg == Aggregator::MAX) {
148 for (int v = 0; v < n; ++v) {
149 if (deg[v] == 0) {
150 for (int i = 0; i < d_m; ++i) messages[v][i] = 0.0;
151 }
152 }
153 }
154
155 // 2) Update node states: h'_v = ReLU(W_u [h_v || m_v] + b_u)
156 std::vector<std::vector<double>> out(n, std::vector<double>(d_m, 0.0));
157 for (int v = 0; v < n; ++v) {
158 std::vector<double> z = concat(node_feat[v], messages[v]);
159 std::vector<double> y = matvec(W_u, z);
160 for (int i = 0; i < d_m; ++i) out[v][i] = relu(y[i] + b_u[i]);
161 }
162 return out;
163 }
164};
165
166int main() {
167 std::mt19937 rng(42);
168
169 // Build a small undirected graph with 4 nodes.
170 // We'll add both directions for each edge so every message is explicit.
171 Graph G; G.n = 4; G.adj.assign(G.n, {});
172
173 auto add_edge = [&](int u, int v, const std::vector<double>& ef){
174 G.adj[u].push_back({v, ef});
175 };
176
177 // Example edges with 2D edge features (e.g., type and weight)
178 add_edge(0, 1, {1.0, 0.5}); add_edge(1, 0, {1.0, 0.5});
179 add_edge(1, 2, {0.0, 1.0}); add_edge(2, 1, {0.0, 1.0});
180 add_edge(2, 3, {1.0, 0.2}); add_edge(3, 2, {1.0, 0.2});
181 add_edge(0, 2, {0.5, 0.5}); add_edge(2, 0, {0.5, 0.5});
182
183 // Node features (4 nodes, 3D features)
184 std::vector<std::vector<double>> H = {
185 {1.0, 0.0, 0.5}, // node 0
186 {0.3, 0.8, 0.1}, // node 1
187 {0.0, 1.0, 0.2}, // node 2
188 {0.5, 0.4, 0.9} // node 3
189 };
190
191 int d_x = 3; // node feature size
192 int d_e = 2; // edge feature size
193 int d_m = 4; // message/out size
194
195 // Choose aggregator: SUM, MEAN, or MAX
196 MPNNLayer layer_sum(d_x, d_e, d_m, Aggregator::SUM, rng);
197 MPNNLayer layer_mean(d_x, d_e, d_m, Aggregator::MEAN, rng);
198 MPNNLayer layer_max(d_x, d_e, d_m, Aggregator::MAX, rng);
199
200 auto print_mat = [](const std::vector<std::vector<double>>& X, const std::string& title){
201 std::cout << title << "\n";
202 for (const auto& r : X) {
203 for (double v : r) std::cout << std::fixed << std::setprecision(4) << v << "\t";
204 std::cout << "\n";
205 }
206 std::cout << "\n";
207 };
208
209 print_mat(H, "Input node features H:");
210
211 auto H_sum = layer_sum.forward(G, H);
212 print_mat(H_sum, "After 1 MPNN layer (SUM aggregation):");
213
214 auto H_mean = layer_mean.forward(G, H);
215 print_mat(H_mean, "After 1 MPNN layer (MEAN aggregation):");
216
217 auto H_max = layer_max.forward(G, H);
218 print_mat(H_max, "After 1 MPNN layer (MAX aggregation):");
219
220 return 0;
221}
222

This program builds a small undirected graph with 4 nodes and 2D edge features and runs a single MPNN layer three times, each with a different aggregator (sum/mean/max). Messages are computed from sender node features concatenated with edge features, aggregated at each receiver, and then combined with the receiver’s own features to produce updated node embeddings. The purpose is to illustrate how the choice of AGG affects the result while keeping message and update functions identical.

Time: O(|E| d_m (d_x + d_e) + |V| d_m (d_x + d_m))Space: O(|V| d_x + |E| d_e + |V| d_m)
Graph Attention Message Passing (GAT-style, No Edge Features)
1#include <iostream>
2#include <vector>
3#include <cmath>
4#include <limits>
5#include <random>
6#include <iomanip>
7
8// A minimal Graph Attention (single-head) layer as an MPNN instantiation.
9// h'_v = ReLU( sum_{u in N(v)} alpha_{uv} * (W h_u) )
10// alpha_{uv} = softmax_u( LeakyReLU( a^T [W h_u || W h_v] ) )
11
12struct Edge { int to; };
13struct Graph { int n; std::vector<std::vector<Edge>> adj; };
14
15static inline double relu(double x) { return x > 0 ? x : 0.0; }
16static inline double leaky_relu(double x, double neg_slope = 0.2) { return x >= 0 ? x : neg_slope * x; }
17
18std::vector<std::vector<double>> rand_matrix(int rows, int cols, std::mt19937& rng, double scale = 0.1) {
19 std::uniform_real_distribution<double> dist(-scale, scale);
20 std::vector<std::vector<double>> W(rows, std::vector<double>(cols));
21 for (int i = 0; i < rows; ++i)
22 for (int j = 0; j < cols; ++j)
23 W[i][j] = dist(rng);
24 return W;
25}
26
27std::vector<double> rand_vector(int n, std::mt19937& rng, double scale = 0.1) {
28 std::uniform_real_distribution<double> dist(-scale, scale);
29 std::vector<double> b(n);
30 for (int i = 0; i < n; ++i) b[i] = dist(rng);
31 return b;
32}
33
34std::vector<double> matvec(const std::vector<std::vector<double>>& W, const std::vector<double>& x) {
35 int rows = (int)W.size();
36 int cols = (int)W[0].size();
37 std::vector<double> y(rows, 0.0);
38 for (int i = 0; i < rows; ++i) {
39 double s = 0.0;
40 for (int j = 0; j < cols; ++j) s += W[i][j] * x[j];
41 y[i] = s;
42 }
43 return y;
44}
45
46std::vector<double> concat(const std::vector<double>& a, const std::vector<double>& b) {
47 std::vector<double> c; c.reserve(a.size() + b.size());
48 c.insert(c.end(), a.begin(), a.end());
49 c.insert(c.end(), b.begin(), b.end());
50 return c;
51}
52
53struct GATLayer {
54 int d_in; // input feature size
55 int d_out; // output (head) size
56 std::vector<std::vector<double>> W; // d_out x d_in
57 std::vector<double> a; // 2*d_out vector for attention
58
59 GATLayer(int d_in_, int d_out_, std::mt19937& rng) : d_in(d_in_), d_out(d_out_) {
60 W = rand_matrix(d_out, d_in, rng);
61 a = rand_vector(2 * d_out, rng);
62 }
63
64 std::vector<std::vector<double>> forward(const Graph& G, const std::vector<std::vector<double>>& H) const {
65 int n = G.n;
66 // Precompute Wh for all nodes
67 std::vector<std::vector<double>> Wh(n, std::vector<double>(d_out, 0.0));
68 for (int v = 0; v < n; ++v) Wh[v] = matvec(W, H[v]);
69
70 std::vector<std::vector<double>> out(n, std::vector<double>(d_out, 0.0));
71
72 for (int v = 0; v < n; ++v) {
73 // 1) Compute attention logits for neighbors of v
74 std::vector<double> logits; logits.reserve(G.adj[v].size());
75 double max_logit = -std::numeric_limits<double>::infinity();
76 for (const auto& e : G.adj[v]) {
77 int u = e.to;
78 std::vector<double> cat = concat(Wh[u], Wh[v]); // [Wh_u || Wh_v]
79 double s = 0.0;
80 for (int i = 0; i < (int)cat.size(); ++i) s += a[i] * cat[i];
81 double l = leaky_relu(s, 0.2);
82 logits.push_back(l);
83 if (l > max_logit) max_logit = l; // for numerical stability
84 }
85
86 // 2) Softmax over neighbors of v
87 std::vector<double> alpha(logits.size(), 0.0);
88 double denom = 0.0;
89 for (double l : logits) denom += std::exp(l - max_logit);
90 for (size_t i = 0; i < logits.size(); ++i) alpha[i] = std::exp(logits[i] - max_logit) / (denom > 0 ? denom : 1.0);
91
92 // 3) Aggregate weighted sum and apply nonlinearity
93 std::vector<double> m(d_out, 0.0);
94 for (size_t i = 0; i < G.adj[v].size(); ++i) {
95 int u = G.adj[v][i].to;
96 for (int k = 0; k < d_out; ++k) m[k] += alpha[i] * Wh[u][k];
97 }
98 for (int k = 0; k < d_out; ++k) out[v][k] = relu(m[k]);
99 }
100 return out;
101 }
102};
103
104int main() {
105 std::mt19937 rng(123);
106
107 // Small directed graph (edges set both ways if you want undirected behavior)
108 Graph G; G.n = 4; G.adj.assign(G.n, {});
109 auto add_edge = [&](int u, int v){ G.adj[u].push_back({v}); };
110 add_edge(0, 1); add_edge(1, 0);
111 add_edge(1, 2); add_edge(2, 1);
112 add_edge(2, 3); add_edge(3, 2);
113 add_edge(0, 2); add_edge(2, 0);
114
115 // Node features (4 nodes, 3D)
116 std::vector<std::vector<double>> H = {
117 {1.0, 0.0, 0.5},
118 {0.3, 0.8, 0.1},
119 {0.0, 1.0, 0.2},
120 {0.5, 0.4, 0.9}
121 };
122
123 GATLayer gat(3, 4, rng); // d_in=3, d_out=4
124
125 auto print_mat = [](const std::vector<std::vector<double>>& X, const std::string& title){
126 std::cout << title << "\n";
127 for (const auto& r : X) {
128 for (double v : r) std::cout << std::fixed << std::setprecision(4) << v << "\t";
129 std::cout << "\n";
130 }
131 std::cout << "\n";
132 };
133
134 print_mat(H, "Input node features H:");
135 auto H_out = gat.forward(G, H);
136 print_mat(H_out, "After 1 GAT-style attention layer:");
137
138 return 0;
139}
140

This example implements a single-head Graph Attention layer as an MPNN: neighbors are projected, attention logits are computed from concatenated sender and receiver projections, softmax-normalized, and used to weight neighbor contributions. The max-subtraction trick ensures numerical stability of the softmax. It demonstrates how attention fits naturally into the message passing framework and how it adapts contributions per neighbor.

Time: O(|E| d_out + |V| d_in d_out) for dot-product attention (shown); with heavier parameterizations it can be O(|E| d_out^2).Space: O(|V| d_in + |V| d_out) plus adjacency storage
#message passing neural network#mpnn#graph neural network#gnn#aggregation#attention#gcn#gin#readout#oversmoothing#permutation invariance#edge features#neighborhood#softmax#graph embedding