圖神經網絡(GNN)基本概念與核心原理
圖神經網絡(GNN)是一類專門處理圖結構數據的神經網絡模型 (GTAT: empowering graph neural networks with cross attention | Scientific Reports)。圖結構數據由節點(表示實體)和邊(表示實體間關系)構成,每個節點和邊都可以帶有特征信息。GNN的核心思想是通過多輪**消息傳遞(message passing)**來迭代更新節點的表示:每層GNN會讓每個節點收集并聚合其鄰居節點的特征,然后通過一個神經網絡變換這些聚合信息,更新自身的表示 (Graph neural network - Wikipedia) (GTAT: empowering graph neural networks with cross attention | Scientific Reports)。這樣,多層堆疊的GNN可以讓信息在圖中從一個節點傳遞到遠處的節點,從而學習到圖的全局結構特征。
- 圖結構和特征:圖由節點和邊組成,節點可對應機器、任務、地理位置等實體,節點特征描述實體屬性(如機器人狀態、任務需求等),邊可表示實體間的聯系或拓撲結構。
- 消息傳遞與聚合:在每一層GNN中,每個節點會收集所有鄰居節點的特征(如將鄰居特征求和或求平均),并結合自身特征輸入一個神經網絡進行變換。這樣,節點能“看到”局部鄰域的信息,形成新的表示。
- 迭代更新與表達:通過多層GNN的迭代,每個節點的信息融合來自更遠節點的影響,最終輸出的節點表示(或全圖表示)可用于后續任務,如節點分類、圖分類或回歸等。經過訓練后的GNN能夠自動提取圖結構中的有效信息,無需手工設計特征。
經過若干層GNN后,我們可以得到每個節點或整個圖的高維嵌入(embedding),并據此完成分類、回歸等任務。這種基于圖結構的神經網絡具有很強的表達能力,能夠捕捉節點間的復雜關系 (GTAT: empowering graph neural networks with cross attention | Scientific Reports) (Graph neural network - Wikipedia)。
GNN示例:基于GCN的簡單實現
下面以PyTorch Geometric為例,演示一個簡單的兩層圖卷積網絡(GCN)實現,用于對圖中節點進行分類。代碼中對每行添加了中文注釋說明。
import torch
import torch.nn.functional as F