Graph Neural Networks
Graph Neural Networks extend deep learning to non-Euclidean, relational data: citation graphs, molecular structures, social networks, and knowledge graphs. The key insight is that a node's representation should incorporate information from its neighborhood — repeated over multiple layers to capture multi-hop context.
Theory
Node A selected
features: [1, 0.2]
degree: 3
neighbors: B, E, F
click a node to inspect its neighborhood
A GNN lets each node in a graph learn from its neighbors — like how a word's meaning in a sentence depends on surrounding context. Each node sends messages to its neighbors, which aggregate them and update their own state; after L rounds, each node's embedding encodes its L-hop neighborhood. The diagram above shows this message passing process for a single layer: each node receives messages from its immediate neighbors and updates its representation.
The Message Passing Framework
All major GNN variants fit the message passing neural network (MPNN) framework. At layer l, each node v updates its hidden state by aggregating messages from its neighbors N(v):
The AGGREGATE function must be permutation-invariant — nodes have no natural ordering in a graph, so the aggregation must produce the same result regardless of the order neighbors are processed. Sum, mean, and max satisfy this; operations that depend on neighbor ordering do not. This constraint is what restricts GNN aggregators to symmetric functions: any learned aggregation that depended on neighbor ordering would fail when a node's neighbors are permuted, which happens whenever the same graph is processed with a different node indexing.
The initial hidden state h_v^(0) = x_v (the raw node feature vector). After L layers, h_v^(L) is the node's final embedding, encoding its L-hop neighborhood.
Graph Convolutional Network (GCN)
Kipf & Welling (2017) define the layer-wise propagation:
where à = A + I is the adjacency matrix with added self-loops, D̃ is its degree matrix, and W^(l) is a learnable weight matrix. The symmetric normalization by D̃^(−1/2) prevents scale sensitivity to high-degree nodes. The per-node update for node v:
GraphSAGE: Inductive Learning
Hamilton et al. (2017) introduce sampling and aggregation for inductive generalization to unseen nodes:
where N_s(v) is a fixed-size sample of neighbors and ‖ is concatenation. AGGREGATE can be mean, LSTM (on shuffled neighbors), or max-pool. Crucially, GraphSAGE can embed nodes not seen during training.
Graph Attention Network (GAT)
Veličković et al. (2018) learn attention weights over neighbors:
Multi-head attention averages (or concatenates) K independent attention heads, improving stability.
Over-smoothing
A key limitation: stacking too many GNN layers causes all node embeddings to converge to similar values (over-smoothing). Representations lose distinguishability after ~4–8 hops. Solutions include residual connections (similar to ResNets), jumping knowledge networks (JK-Net, which concatenates hidden states from all layers), and DropEdge.
Walkthrough
Task: Node classification on a citation graph. Each node = paper, edges = citations, features = TF-IDF of abstract, labels = research area.
Step 1 — Graph Construction
Build adjacency list from citation data. Add self-loops (each paper cites itself so it keeps its own features). Normalize adjacency: Ã = A + I, D̃_ii = Σ_j Ã_ij.
Step 2 — Layer 0 (Initial Features)
Each node v starts with its feature vector x_v in ℝ^1433 (Cora dataset has 1433-dim bag-of-words features).
Step 3 — Layer 1 Aggregation
For each node, collect neighbor features and apply GCN normalization:
After layer 1, each node's embedding reflects its 1-hop neighborhood.
Step 4 — Layer 2 + Classification
A second GCN layer aggregates 2-hop neighborhoods. A final linear layer + softmax produces class probabilities. Trained with cross-entropy on labeled nodes only (semi-supervised).
On Cora dataset (2708 nodes, 7 classes, 140 training labels):
- 2-layer GCN: ~81% test accuracy
- GAT: ~83% test accuracy
- Baseline (logistic regression on raw features): ~72%
GCN implementation with PyTorch Geometric:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
# Cora: 2708 nodes, 5429 edges, 1433 input features, 7 classes
data = Planetoid(root='/tmp/Cora', name='Cora')[0]
class GCN(torch.nn.Module):
def __init__(self, in_channels: int, hidden: int, out_channels: int):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden)
self.conv2 = GCNConv(hidden, out_channels)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index)) # 1-hop aggregation
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index) # 2-hop aggregation
return F.log_softmax(x, dim=1)
model = GCN(data.num_features, hidden=16, out_channels=7)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
for epoch in range(200):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
# Semi-supervised: loss only on the 140 labeled training nodes
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=1)
acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean()
print(f"Test accuracy: {acc:.4f}") # expect ~0.81Manual message-passing step (NumPy) — shows the math directly:
import numpy as np
def gcn_layer(A: np.ndarray, H: np.ndarray, W: np.ndarray) -> np.ndarray:
"""One GCN layer: H_new = ReLU(D̃^{-1/2} Ã D̃^{-1/2} H W)."""
n = A.shape[0]
A_hat = A + np.eye(n) # add self-loops
deg = A_hat.sum(axis=1)
D_inv = np.diag(1.0 / np.sqrt(deg)) # D̃^{-1/2}
A_sym = D_inv @ A_hat @ D_inv # symmetric normalization
return np.maximum(0, A_sym @ H @ W) # ReLU
# 4-node chain: 0—1—2—3
A = np.array([[0,1,0,0],[1,0,1,0],[0,1,0,1],[0,0,1,0]], dtype=float)
H0 = np.eye(4) # each node's initial feature = one-hot ID
W0 = np.random.randn(4, 8)
H1 = gcn_layer(A, H0, W0) # shape (4, 8)
# H1[0] encodes info from nodes {0, 1}; H1[2] encodes {1, 2, 3}
# A second layer would propagate info two hops awayAnalysis & Evaluation
Where Your Intuition Breaks
Deeper GNNs (more layers) learn richer graph representations. Adding layers in GNNs causes over-smoothing: after many rounds of neighborhood aggregation, all node embeddings converge toward similar values (the graph's dominant eigenvectors), destroying local structural information. A 10-layer GNN often performs worse than a 2–3 layer GNN on node classification because deep stacking erases the very local features that distinguish nodes. The fix is not to avoid depth entirely, but to use residual connections, graph normalization, or attention mechanisms that selectively weight neighbor contributions — or to apply jumping knowledge connections that combine embeddings from all layers.
GNN Variant Comparison
| Model | Aggregation | Inductive | Attention | Best For |
|---|---|---|---|---|
| GCN | Normalized sum | No (transductive) | No | Citation/social graphs |
| GraphSAGE | Sampled mean/pool | Yes | No | Large-scale, new nodes |
| GAT | Attention-weighted sum | Yes | Yes | Heterogeneous neighborhoods |
| GIN | Sum + MLP | Yes | No | Graph-level tasks, isomorphism |
Common Pitfalls
Over-smoothing: Use at most 2–3 layers for most tasks. Add residual connections if deeper networks are needed.
Scalability: Full-graph operations don't scale to 100M+ node graphs. Use mini-batch neighbor sampling (GraphSAGE) or cluster-based batching (Cluster-GCN).
Heterogeneous graphs: When nodes/edges have different types (user, item, category), use type-specific weight matrices or a heterogeneous GNN (HAN, HGT).
Leakage in temporal graphs: If edges represent future interactions, ensure your training subgraph only uses edges before the evaluation timestamp.
Applications Beyond Classification
- Link prediction: predict missing edges; used in knowledge graph completion and drug interaction prediction
- Graph classification: embed entire graphs for molecular property prediction (GIN works well)
- Recommendation: graph-based collaborative filtering (LightGCN removes nonlinear transforms for efficiency)
- Fraud detection: heterogeneous transaction graphs where suspicious behavior forms detectable subgraph patterns
Enjoying these notes?
Get new lessons delivered to your inbox. No spam.