最近看論文看到了圖卷積神經網絡的內容,之前整理過圖神經網絡的內容,這里再補充一下,方便以后查閱。
圖卷積神經網絡(Graph Convolutional Network, GCN)
- 圖卷積神經網絡
- 1. 什么是圖卷積神經網絡(GCN)?
- 2. GCN的原理
- 2.1 圖的表示
- 2.2 譜圖卷積
- 2.3 GCN的層級傳播規則
- 2.4 消息傳遞框架
- 3. GCN的結構
- 4. GCN的應用
- 5. GCN的優點與局限性
- 優點:
- 局限性:
- 6. GCN代碼示例
- 代碼說明:
- 7. 如何擴展GCN?
- 8. 總結
- 譜圖理論(Spectral Graph Theory)
- 1. 背景:為什么需要譜圖理論?
- 2. 什么是譜圖理論?
- 3. 圖的拉普拉斯矩陣
- 3.1 定義
- 3.2 歸一化拉普拉斯矩陣
- 3.3 拉普拉斯矩陣的性質
- 4. 通過拉普拉斯矩陣定義卷積操作
- 4.1 圖傅里葉變換
- 4.2 譜卷積
- 4.3 GCN的簡化
- 5. 為什么用拉普拉斯矩陣定義卷積?
- 6. 局限性與改進
- 圖注意力網絡(GAT)
- 1. 背景:為什么需要GAT?
- 2. GAT的原理
- 3. GAT的數學公式
- 3.1 注意力系數
- 3.2 歸一化注意力系數
- 3.3 加權聚合
- 3.4 多頭注意力
- 3.5 完整傳播規則
- 4. GAT的結構
- 5. GAT與GCN的對比
- 6. GAT的優勢
- 7. GAT的局限性
- 8. GAT的應用
- 9. GAT的實現與代碼
- 代碼說明:
- 10. GAT的擴展與改進
- 11. 總結
先驗知識:圖神經網絡 GNN
圖卷積神經網絡
1. 什么是圖卷積神經網絡(GCN)?
GCN是卷積神經網絡(CNN)的擴展,適用于非歐幾里得空間的數據(如圖結構數據)。傳統CNN處理規則的網格數據(如圖像或時間序列),而GCN處理節點和邊構成的圖結構數據。圖由節點(vertices)和邊(edges)組成,節點表示實體,邊表示實體間的關系。
GCN的核心思想是通過消息傳遞機制,利用圖的拓撲結構,將節點特征與其鄰居的特征聚合,從而學習節點的表示(embedding)。這些表示可用于節點分類、鏈接預測或圖分類等任務。
2. GCN的原理
GCN基于譜圖理論(Spectral Graph Theory)和消息傳遞框架。以下是其核心原理:
2.1 圖的表示
一個圖 G = ( V , E ) G = (V, E) G=(V,E) 由以下部分組成:
- 節點集 V V V,節點數為 n n n。
- 邊集 E E E,表示節點之間的連接。
- 鄰接矩陣 A A A,大小為 n × n n \times n n×n,其中 A i j = 1 A_{ij} = 1 Aij?=1 表示節點 i i i 和 j j j 之間有邊,否則為 0 0 0。
- 節點特征矩陣 X X X,大小為 n × d n \times d n×d,其中每行是節點 i i i 的 d d d 維特征向量。
- 度矩陣 D D D,對角矩陣,其中 D i i = ∑ j A i j D_{ii} = \sum_j A_{ij} Dii?=∑j?Aij? 表示節點 i i i 的度。
2.2 譜圖卷積
GCN最初基于譜圖理論,通過圖的拉普拉斯矩陣定義卷積操作。圖拉普拉斯矩陣定義為:
L = D ? A L = D - A L=D?A
歸一化拉普拉斯矩陣為:
L n o r m = I ? D ? 1 / 2 A D ? 1 / 2 L_{norm} = I - D^{-1/2} A D^{-1/2} Lnorm?=I?D?1/2AD?1/2
其中 I I I 是單位矩陣, D ? 1 / 2 D^{-1/2} D?1/2 是度矩陣的對角元素的倒數平方根。
譜圖卷積通過拉普拉斯矩陣的特征分解,將卷積操作定義在圖的頻域上。然而,計算特征分解的復雜度較高( O ( n 3 ) O(n^3) O(n3)),因此實際中常用近似方法。
2.3 GCN的層級傳播規則
現代GCN(如Kipf & Welling, 2017)使用簡化的消息傳遞機制。一層GCN的傳播規則為:
H ( l + 1 ) = σ ( D ~ ? 1 / 2 A ~ D ~ ? 1 / 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma \left( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)} \right) H(l+1)=σ(D~?1/2A~D~?1/2H(l)W(l))
其中:
- H ( l ) H^{(l)} H(l):第 l l l 層的節點特征矩陣, H ( 0 ) = X H^{(0)} = X H(0)=X。
- A ~ = A + I \tilde{A} = A + I A~=A+I:添加自環的鄰接矩陣(每個節點與自身連接)。
- D ~ \tilde{D} D~: A ~ \tilde{A} A~ 對應的度矩陣, D ~ i i = ∑ j A ~ i j \tilde{D}_{ii} = \sum_j \tilde{A}_{ij} D~ii?=∑j?A~ij?。
- W ( l ) W^{(l)} W(l):第 l l l 層的可學習權重矩陣。
- σ \sigma σ:激活函數(如ReLU)。
- D ~ ? 1 / 2 A ~ D ~ ? 1 / 2 \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} D~?1/2A~D~?1/2:歸一化的鄰接矩陣,用于平衡不同度節點的影響。
直觀解釋:
- 每層GCN聚合節點的鄰居特征(包括自身),通過 A ~ \tilde{A} A~ 實現。
- 歸一化( D ~ ? 1 / 2 \tilde{D}^{-1/2} D~?1/2) 防止高階節點主導聚合。
- 權重矩陣 W ( l ) W^{(l)} W(l) 進行特征變換,激活函數引入非線性。
2.4 消息傳遞框架
GCN可以看作消息傳遞神經網絡(Message Passing Neural Network, MPNN)的一種:
- 聚合:收集鄰居節點的特征(通過 A ~ \tilde{A} A~)。
- 更新:結合自身特征和聚合特征,更新節點表示(通過 W ( l ) W^{(l)} W(l) 和 σ \sigma σ)。
3. GCN的結構
一個典型的GCN模型包含以下部分:
- 輸入層:接受圖的鄰接矩陣 A A A 和節點特征矩陣 X X X。
- 多層GCN:堆疊若干GCN層,每層執行特征聚合和變換。
- 輸出層:
- 對于節點分類,輸出每個節點的類別概率(通過Softmax)。
- 對于圖分類,需要池化層(如全局平均池化)將節點特征匯總為圖特征。
- 損失函數:
- 節點分類:交叉熵損失。
- 圖分類:交叉熵或回歸損失,取決于任務。
4. GCN的應用
GCN在許多領域有廣泛應用:
- 社交網絡:
- 節點分類:預測用戶興趣或社區歸屬。
- 鏈接預測:推薦好友或合作關系。
- 推薦系統:
- 使用用戶-物品交互圖,預測用戶偏好。
- 化學分子分析:
- 圖表示分子結構,預測分子性質(如毒性或溶解度)。
- 知識圖譜:
- 實體分類或關系預測。
- 生物信息學:
- 分析蛋白質相互作用網絡。
5. GCN的優點與局限性
優點:
- 適應圖結構:能有效處理非規則的圖數據。
- 局部性:通過鄰居聚合,捕獲局部拓撲信息。
- 可擴展性:可以堆疊多層,學習復雜模式。
局限性:
- 過平滑問題:
- 堆疊過多GCN層會導致節點特征趨于相似,丟失區分度。
- 固定拓撲:
- GCN依賴靜態圖結構,無法直接處理動態圖。
- 計算復雜度:
- 對于大規模圖,矩陣運算(如 A ~ H \tilde{A} H A~H)可能耗時。
- 邊信息:
- 基本GCN不考慮邊的權重或類型,后續變體(如GAT)改進此問題。
6. GCN代碼示例
以下是一個基于PyTorch Geometric的GCN實現,用于節點分類任務。我們使用Cora數據集(一個常用的學術引用網絡數據集),其中節點是論文,邊是引用關系,目標是預測論文的類別。
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv# 加載Cora數據集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]# 定義GCN模型
class GCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GCN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, x, edge_index):# 第一層GCNx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, p=0.5, training=self.training)# 第二層GCNx = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)# 設置設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)# 初始化模型
model = GCN(in_channels=dataset.num_features,hidden_channels=16,out_channels=dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)# 訓練模型
def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss.item()# 測試模型
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()return acc# 訓練循環
for epoch in range(200):loss = train()if epoch % 10 == 0:acc = test()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {acc:.4f}')# 最終測試
final_acc = test()
print(f'Final Test Accuracy: {final_acc:.4f}')
代碼說明:
- 數據集:Cora數據集包含2708個節點(論文),每個節點有1433維特征(詞袋表示),7個類別,圖有10556條邊。
- 模型:兩層GCN,第一層將特征從1433維降到16維,第二層輸出7維(對應類別)。
- 訓練:使用Adam優化器,交叉熵損失,僅對訓練掩碼(
train_mask
)的節點計算損失。 - 測試:在測試掩碼(
test_mask
)上計算分類準確率。 - 依賴:需要安裝
torch
和torch_geometric
:pip install torch torch-geometric
7. 如何擴展GCN?
GCN是圖神經網絡(GNN)的基礎,許多改進模型基于GCN:
- 圖注意力網絡(GAT):引入注意力機制,動態分配鄰居權重。
- GraphSAGE:通過采樣鄰居,適應大規模圖。
- JK-Net:通過跳躍連接緩解過平滑問題。
- APPNP:結合個性化PageRank,增強傳播效果。
8. 總結
GCN是一種強大的圖神經網絡,通過消息傳遞機制聚合鄰居特征,學習圖中節點的表示。其基于譜圖理論,結構簡單,適合節點分類、鏈接預測等任務。盡管GCN在許多領域表現優異,但過平滑和計算復雜度問題需要通過變體或優化解決。
譜圖理論(Spectral Graph Theory)
1. 背景:為什么需要譜圖理論?
在傳統卷積神經網絡(CNN)中,卷積操作適用于規則的網格數據(如圖像的像素網格),通過滑動窗口提取局部特征。然而,圖結構數據(如社交網絡、分子結構)是非規則的,非歐幾里得空間的數據,節點之間的連接(邊)沒有固定模式。因此,直接應用傳統卷積不可行。
譜圖理論提供了一種數學框架,通過分析圖的拓撲結構(鄰接關系),將圖上的操作(如卷積)定義在頻域(類似于傅里葉變換)。GCN的早期工作(例如Bruna等人,2013)利用譜圖理論,將圖上的卷積定義為基于圖拉普拉斯矩陣的操作。
2. 什么是譜圖理論?
譜圖理論是圖論的一個分支,研究圖的性質通過其矩陣表示(如鄰接矩陣或拉普拉斯矩陣)的特征值(eigenvalues)和特征向量(eigenvectors)。這些特征值和特征向量描述了圖的拓撲結構,例如連通性、聚類特性等。
在GCN中,譜圖理論的核心思想是將圖上的信號(節點特征)投影到圖的頻域(由拉普拉斯矩陣的特征向量定義),進行類似傅里葉變換的操作,再轉換回空間域。這種方法允許我們在圖上定義卷積,類似于圖像上的卷積。
3. 圖的拉普拉斯矩陣
拉普拉斯矩陣(Laplacian Matrix)是圖的矩陣表示,用于捕捉圖的拓撲結構。以下是其定義和性質:
3.1 定義
對于一個無向圖 G = ( V , E ) G = (V, E) G=(V,E),有 n n n 個節點,拉普拉斯矩陣 L L L 定義為:
L = D ? A L = D - A L=D?A
- A A A:鄰接矩陣,大小 n × n n \times n n×n,其中 A i j = 1 A_{ij} = 1 Aij?=1 如果節點 i i i 和 j j j 之間有邊,否則為 0 0 0。
- D D D:度矩陣,對角矩陣,大小 n × n n \times n n×n,其中 D i i = ∑ j A i j D_{ii} = \sum_j A_{ij} Dii?=∑j?Aij? 表示節點 i i i 的度(連接的邊數),非對角元素為 0 0 0。
例如,對于一個簡單圖:
- 鄰接矩陣 A A A:
A = [ 0 1 1 1 1 0 0 0 1 0 0 1 1 0 1 0 ] A = \begin{bmatrix} 0 & 1 & 1 & 1 \\ 1 & 0 & 0 & 0 \\ 1 & 0 & 0 & 1 \\ 1 & 0 & 1 & 0 \end{bmatrix} A= ?0111?1000?1001?1010? ? - 度矩陣 D D D:
D = [ 3 0 0 0 0 1 0 0 0 0 2 0 0 0 0 2 ] D = \begin{bmatrix} 3 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 2 & 0 \\ 0 & 0 & 0 & 2 \end{bmatrix} D= ?3000?0100?0020?0002? ? - 拉普拉斯矩陣 L L L:
L = D ? A = [ 3 ? 1 ? 1 ? 1 ? 1 1 0 0 ? 1 0 2 ? 1 ? 1 0 ? 1 2 ] L = D - A = \begin{bmatrix} 3 & -1 & -1 & -1 \\ -1 & 1 & 0 & 0 \\ -1 & 0 & 2 & -1 \\ -1 & 0 & -1 & 2 \end{bmatrix} L=D?A= ?3?1?1?1??1100??102?1??10?12? ?
3.2 歸一化拉普拉斯矩陣
為了平衡不同度節點的影響,常用歸一化拉普拉斯矩陣:
L n o r m = I ? D ? 1 / 2 A D ? 1 / 2 L_{norm} = I - D^{-1/2} A D^{-1/2} Lnorm?=I?D?1/2AD?1/2
- D ? 1 / 2 D^{-1/2} D?1/2:對角矩陣,其對角元素為 D i i ? 1 / 2 = 1 / 度 i D_{ii}^{-1/2} = 1 / \sqrt{\text{度}_i} Dii?1/2?=1/度i??。
- I I I:單位矩陣。
歸一化拉普拉斯矩陣的特征值在 [ 0 , 2 ] [0, 2] [0,2] 范圍內,適合數值計算。
3.3 拉普拉斯矩陣的性質
- 對稱性:對于無向圖, L L L 是對稱矩陣,因此有實特征值和正交特征向量。
- 正半定性: L L L 的特征值非負,反映圖的連通性(例如,特征值 0 0 0 的重數等于連通分量的個數)。
- 頻域解釋:拉普拉斯矩陣的特征向量形成圖的“頻域基”,類似傅里葉變換中的正弦和余弦函數。特征值表示“頻率”,低頻率對應平滑信號,高頻率對應快速變化的信號。
4. 通過拉普拉斯矩陣定義卷積操作
在譜圖理論中,圖上的卷積操作通過拉普拉斯矩陣的特征分解定義,類似于信號處理中的傅里葉變換。以下是具體步驟:
4.1 圖傅里葉變換
拉普拉斯矩陣 L L L 可以分解為:
L = U Λ U T L = U \Lambda U^T L=UΛUT
- U U U: 特征向量矩陣,列是 L L L 的特征向量 u 1 , u 2 , … , u n u_1, u_2, \ldots, u_n u1?,u2?,…,un?,表示圖的“頻域基”。
- Λ \Lambda Λ: 對角矩陣,對角元素是特征值 λ 1 , λ 2 , … , λ n \lambda_1, \lambda_2, \ldots, \lambda_n λ1?,λ2?,…,λn?,表示“頻率”。
對于節點特征向量 x ∈ R n x \in \mathbb{R}^n x∈Rn(每個節點一個標量特征),其圖傅里葉變換定義為:
x ^ = U T x \hat{x} = U^T x x^=UTx
- x ^ \hat{x} x^ 是頻域中的系數,表示 x x x 在特征向量基上的投影。
- 逆變換為:
x = U x ^ x = U \hat{x} x=Ux^
4.2 譜卷積
在頻域中,卷積等價于逐頻率相乘。假設有一個濾波器(卷積核) g g g,其頻域表示為 g ( Λ ) g(\Lambda) g(Λ)(對角矩陣,元素為 g ( λ i ) g(\lambda_i) g(λi?))。圖上的卷積定義為:
x ? g = U g ( Λ ) U T x x * g = U g(\Lambda) U^T x x?g=Ug(Λ)UTx
- g ( Λ ) g(\Lambda) g(Λ): 濾波器在頻域的響應,控制如何放大或抑制不同頻率的信號。
- U T x U^T x UTx: 將信號 x x x 轉換為頻域。
- U g ( Λ ) U T x U g(\Lambda) U^T x Ug(Λ)UTx: 應用濾波器后轉換回空間域。
因此,圖卷積是通過拉普拉斯矩陣的特征分解,將節點特征在頻域中與濾波器結合,再轉換回節點特征。
4.3 GCN的簡化
早期譜GCN(如Bruna等人)直接使用上述卷積,但計算 U U U 和 U T U^T UT 的復雜度為 O ( n 3 ) O(n^3) O(n3),對大圖不可行。Kipf & Welling (2017) 提出了簡化版本:
- 使用多項式濾波器:假設 g ( Λ ) g(\Lambda) g(Λ) 是拉普拉斯矩陣特征值的多項式,例如 g ( Λ ) = ∑ k θ k Λ k g(\Lambda) = \sum_k \theta_k \Lambda^k g(Λ)=∑k?θk?Λk。
- 近似為低階多項式(如一階),避免特征分解。
- 最終傳播規則為:
H ( l + 1 ) = σ ( D ~ ? 1 / 2 A ~ D ~ ? 1 / 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma \left( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)} \right) H(l+1)=σ(D~?1/2A~D~?1/2H(l)W(l))
其中 A ~ = A + I \tilde{A} = A + I A~=A+I, D ~ \tilde{D} D~ 是 A ~ \tilde{A} A~ 的度矩陣。這種形式避免了昂貴的矩陣分解,直接操作鄰接矩陣。
5. 為什么用拉普拉斯矩陣定義卷積?
拉普拉斯矩陣在圖卷積中有以下優勢:
- 捕捉拓撲結構:拉普拉斯矩陣編碼了圖的連接性(鄰接關系和節點度),適合定義局部聚合操作。
- 頻域解釋:通過特征分解,拉普拉斯矩陣提供了一種頻域視角,類似圖像上的傅里葉變換,便于定義卷積。
- 平滑性:拉普拉斯矩陣與圖的平滑性相關(例如, x T L x x^T L x xTLx 測量信號 x x x 在圖上的變化),卷積操作可以平滑節點特征,聚合鄰居信息。
- 數學優雅:譜圖理論提供了統一的框架,將圖上的操作與傳統信號處理連接起來。
6. 局限性與改進
基于譜圖理論的GCN有以下局限性:
- 計算復雜度:特征分解對大圖不可行( O ( n 3 ) O(n^3) O(n3))。
- 泛化性:譜方法依賴圖的固定拉普拉斯矩陣,難以直接應用于不同結構的圖。
- 局部性:譜卷積本質上是全局操作,可能忽略局部特征。
因此,后續工作(如GraphSAGE、GAT)轉向空間域方法,直接在圖的拓撲上定義卷積(如鄰居聚合),避免譜分解,提高效率和靈活性。
圖注意力網絡(GAT)
圖注意力網絡(Graph Attention Network, GAT)是一種圖神經網絡(Graph Neural Network, GNN)的變體,由Veli?kovi?等人于2017年提出(論文:《Graph Attention Networks》)。它通過引入注意力機制(Attention Mechanism)改進了傳統的圖卷積神經網絡(GCN),能夠動態地為不同鄰居節點分配權重,從而更好地捕捉圖結構中的異質性關系
1. 背景:為什么需要GAT?
圖神經網絡(如GCN)通過聚合節點鄰居的特征來學習節點表示,適用于處理圖結構數據(如社交網絡、分子結構)。然而,GCN存在以下局限性:
- 等權重聚合:GCN假設所有鄰居對節點的貢獻相同(通過歸一化的鄰接矩陣),無法區分鄰居的重要性。例如,在社交網絡中,某些好友的影響可能更大。
- 固定拓撲依賴:GCN的聚合權重完全由圖的拓撲結構(鄰接矩陣和節點度)決定,缺乏靈活性。
- 無法捕捉異質性:對于高度異質的圖(節點或邊的關系差異顯著),GCN的表現可能受限。
GAT通過引入注意力機制解決了這些問題,允許模型動態學習每個鄰居的貢獻權重,類似于自然語言處理中的Transformer模型。這種機制使GAT能夠聚焦于對任務更重要的鄰居,提高表示能力和泛化性。
2. GAT的原理
GAT的核心思想是通過注意力機制為每個節點的鄰居分配不同的權重,然后基于這些權重聚合鄰居特征。其操作可以概括為以下步驟:
- 計算注意力系數:為每條邊(或鄰居對)計算一個注意力分數,表示鄰居的重要性。
- 歸一化注意力系數:使用Softmax將注意力分數歸一化為權重。
- 加權聚合:根據歸一化的注意力權重,聚合鄰居的特征。
- 多頭注意力:可選地使用多組注意力機制(類似Transformer),增強模型表達能力。
GAT仍然基于消息傳遞框架(Message Passing Neural Network, MPNN),但其聚合方式比GCN更靈活。
3. GAT的數學公式
假設有一個圖 G = ( V , E ) G = (V, E) G=(V,E),包含 n n n 個節點,節點特征矩陣為 X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d,其中每個節點有 d d d 維特征。以下是GAT一層的主要數學表達。
3.1 注意力系數
對于節點 i i i 和其鄰居 j ∈ N i j \in \mathcal{N}_i j∈Ni?(包括節點 i i i 自身,若考慮自環),GAT首先將節點特征通過線性變換映射到新的特征空間:
h i = W x i , h j = W x j h_i = W x_i, \quad h_j = W x_j hi?=Wxi?,hj?=Wxj?
- W ∈ R d ′ × d W \in \mathbb{R}^{d' \times d} W∈Rd′×d:可學習的權重矩陣,將特征從 d d d 維映射到 d ′ d' d′ 維。
- x i , x j x_i, x_j xi?,xj?:節點 i i i 和 j j j 的輸入特征。
- h i , h j h_i, h_j hi?,hj?:映射后的特征。
然后,計算節點 i i i 和 j j j 之間的注意力系數 e i j e_{ij} eij?:
e i j = a ( W x i , W x j ) e_{ij} = a(W x_i, W x_j) eij?=a(Wxi?,Wxj?)
- a ( ? , ? ) a(\cdot, \cdot) a(?,?):注意力函數,通常是一個前饋神經網絡。例如,Veli?kovi?等人使用單層感知器:
e i j = LeakyReLU ( a T [ W x i ∥ W x j ] ) e_{ij} = \text{LeakyReLU} \left( a^T [W x_i \parallel W x_j] \right) eij?=LeakyReLU(aT[Wxi?∥Wxj?])- a ∈ R 2 d ′ a \in \mathbb{R}^{2d'} a∈R2d′:可學習的注意力向量。
- [ W x i ∥ W x j ] [W x_i \parallel W x_j] [Wxi?∥Wxj?]:將 W x i W x_i Wxi? 和 W x j W x_j Wxj? 拼接,得到 2 d ′ 2d' 2d′ 維向量。
- LeakyReLU:激活函數,增加非線性。
3.2 歸一化注意力系數
為了使注意力系數可比較(類似概率分布),對節點 i i i 的所有鄰居 j ∈ N i j \in \mathcal{N}_i j∈Ni? 應用Softmax歸一化:
α i j = exp ? ( e i j ) ∑ k ∈ N i exp ? ( e i k ) \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}_i} \exp(e_{ik})} αij?=∑k∈Ni??exp(eik?)exp(eij?)?
- α i j \alpha_{ij} αij?:歸一化的注意力權重,表示鄰居 j j j 對節點 i i i 的相對重要性。
3.3 加權聚合
使用歸一化的注意力權重,聚合鄰居的特征,更新節點 i i i 的表示:
h i ′ = σ ( ∑ j ∈ N i α i j W x j ) h_i' = \sigma \left( \sum_{j \in \mathcal{N}_i} \alpha_{ij} W x_j \right) hi′?=σ ?j∈Ni?∑?αij?Wxj? ?
- h i ′ h_i' hi′?:節點 i i i 的更新特征。
- σ \sigma σ:激活函數(如ELU或ReLU)。
- W x j W x_j Wxj?:鄰居 j j j 的變換特征。
3.4 多頭注意力
為了增強模型的表達能力和穩定性,GAT通常使用多頭注意力(Multi-Head Attention)。運行 K K K 個獨立的注意力機制,得到 K K K 組特征,然后拼接或平均:
- 中間層:拼接多頭輸出:
h i ′ = ∥ k = 1 K σ ( ∑ j ∈ N i α i j k W k x j ) h_i' = \parallel_{k=1}^K \sigma \left( \sum_{j \in \mathcal{N}_i} \alpha_{ij}^k W^k x_j \right) hi′?=∥k=1K?σ ?j∈Ni?∑?αijk?Wkxj? ?- W k W^k Wk:第 k k k 頭的權重矩陣。
- α i j k \alpha_{ij}^k αijk?:第 k k k 頭的注意力權重。
- 輸出維度為 K × d ′ K \times d' K×d′。
- 輸出層:平均多頭輸出:
h i ′ = σ ( 1 K ∑ k = 1 K ∑ j ∈ N i α i j k W k x j ) h_i' = \sigma \left( \frac{1}{K} \sum_{k=1}^K \sum_{j \in \mathcal{N}_i} \alpha_{ij}^k W^k x_j \right) hi′?=σ ?K1?k=1∑K?j∈Ni?∑?αijk?Wkxj? ?- 平均操作減少參數量,適合分類任務。
3.5 完整傳播規則
一層GAT的傳播規則可以總結為:
H ′ = σ ( ∑ j ∈ N i α i j W X ) H' = \sigma \left( \sum_{j \in \mathcal{N}_i} \alpha_{ij} W X \right) H′=σ ?j∈Ni?∑?αij?WX ?
其中 H ′ H' H′ 是更新后的特征矩陣, X X X 是輸入特征矩陣, α i j \alpha_{ij} αij? 通過注意力機制計算。
4. GAT的結構
一個典型的GAT模型包含以下部分:
- 輸入層:
- 輸入:圖的鄰接矩陣(或邊索引列表)和節點特征矩陣 X X X。
- 多層GAT:
- 堆疊若干GAT層,每層執行注意力機制和特征聚合。
- 中間層通常使用多頭注意力(拼接),輸出層可能使用單頭或平均。
- 每層后可添加激活函數(如ELU)和Dropout(防止過擬合)。
- 輸出層:
- 節點分類:通過Softmax輸出每個節點的類別概率。
- 圖分類:通過池化(如全局平均池化)將節點特征匯總為圖特征。
- 損失函數:
- 節點分類:交叉熵損失。
- 圖分類:交叉熵或回歸損失,取決于任務。
5. GAT與GCN的對比
特性 | GCN | GAT |
---|---|---|
鄰居聚合方式 | 等權重(基于歸一化鄰接矩陣) | 動態權重(通過注意力機制) |
權重計算 | 固定(由圖結構決定) | 可學習(注意力系數動態調整) |
表達能力 | 較弱(無法區分鄰居重要性) | 較強(捕捉異質性關系) |
計算復雜度 | 較低(矩陣乘法) | 較高(需計算注意力系數) |
過平滑問題 | 顯著(多層后特征趨同) | 較輕(注意力機制保留差異) |
直觀解釋:
- GCN像“平均池化”,對所有鄰居一視同仁。
- GAT像“加權池化”,根據任務動態選擇重要鄰居,類似“聰明地聽意見”。
6. GAT的優勢
- 動態權重分配:
- 注意力機制允許模型根據任務自動學習鄰居的重要性。例如,在社交網絡中,某些好友的影響可能更大。
- 捕捉異質性:
- GAT能處理節點或邊關系差異顯著的圖,適合復雜網絡。
- 多頭注意力:
- 類似Transformer,增強模型表達能力,捕捉多種關系模式。
- 可解釋性:
- 注意力系數 α i j \alpha_{ij} αij? 可視化,揭示哪些鄰居對預測更重要。
- 緩解過平滑:
- 相比GCN,GAT通過選擇性聚合減少多層后特征趨同的問題。
7. GAT的局限性
- 計算復雜度:
- 計算注意力系數需要為每條邊執行操作,復雜度為 O ( ∣ E ∣ ? d ) O(|E| \cdot d) O(∣E∣?d),對稠密圖或大圖計算成本高。
- 多頭注意力進一步增加計算量。
- 內存需求:
- 存儲注意力系數和多頭特征需要更多內存。
- 穩定性問題:
- 注意力機制可能導致訓練不穩定,尤其在深層網絡中,需小心調整超參數(如Dropout率、學習率)。
- 邊信息限制:
- 基本GAT不直接利用邊特征(如權重或類型),需擴展模型(如EGAT)。
- 過擬合風險:
- 在小圖或稀疏圖上,注意力機制可能過擬合,需正則化(如Dropout)。
8. GAT的應用
GAT在許多圖結構數據的任務中表現出色,包括:
- 社交網絡:
- 節點分類:預測用戶興趣、社區歸屬。
- 鏈接預測:推薦好友或合作關系。
- 推薦系統:
- 使用用戶-物品交互圖,預測用戶偏好。
- 化學分子分析:
- 圖表示分子結構,預測分子性質(如毒性、溶解度)。
- 知識圖譜:
- 實體分類或關系預測。
- 生物信息學:
- 分析蛋白質相互作用網絡,預測蛋白質功能。
- 交通網絡:
- 預測交通流量或路徑優化。
9. GAT的實現與代碼
以下是一個基于PyTorch Geometric的GAT實現,用于節點分類任務。我們使用Cora數據集(一個常用的學術引用網絡數據集),其中節點是論文,邊是引用關系,目標是預測論文的類別。
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv# 加載Cora數據集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]# 定義GAT模型
class GAT(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, heads=8):super(GAT, self).__init__()# 第一層GAT,使用多頭注意力self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)# 第二層GAT,輸出類別self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)def forward(self, x, edge_index):# 第一層GATx = self.conv1(x, edge_index)x = F.elu(x)x = F.dropout(x, p=0.6, training=self.training)# 第二層GATx = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)# 設置設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)# 初始化模型
model = GAT(in_channels=dataset.num_features,hidden_channels=8,out_channels=dataset.num_classes,heads=8).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)# 訓練模型
def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss.item()# 測試模型
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()return acc# 訓練循環
for epoch in range(200):loss = train()if epoch % 10 == 0:acc = test()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {acc:.4f}')# 最終測試
final_acc = test()
print(f'Final Test Accuracy: {final_acc:.4f}')
代碼說明:
- 數據集:Cora數據集,包含2708個節點(論文),每個節點有1433維特征(詞袋表示),7個類別,圖有10556條邊。
- 模型:兩層GAT,第一層將特征從1433維降到8維(使用多頭注意力),第二層輸出7維(對應類別)。
- 注意力機制:GAT使用注意力系數動態加權鄰居特征,增強模型對重要鄰居的關注。
- 訓練:使用Adam優化器,交叉熵損失,僅對訓練掩碼(
train_mask
)的節點計算損失。 - 測試:在測試掩碼(
test_mask
)上計算分類準確率。 - 依賴:需要安裝
torch
和torch_geometric
:pip install torch torch-geometric
10. GAT的擴展與改進
GAT是GNN領域的重要進展,許多后續工作在其基礎上改進:
- GATv2(2021):改進注意力機制,解決原始GAT的表達能力瓶頸,增強性能。
- EGAT:引入邊特征,擴展到加權或有類型的圖。
- HGT(Heterogeneous Graph Transformer):結合GAT和Transformer,處理異構圖。
- 采樣優化:如GraphSAGE的采樣策略,結合GAT,適應大規模圖。
- 動態圖:擴展GAT到時序圖,處理動態拓撲。
11. 總結
圖注意力網絡(GAT)通過引入注意力機制,改進了GCN的局限性,能夠動態學習鄰居的重要性,增強對異質圖的建模能力。其核心是計算注意力系數、歸一化加權和多頭注意力,數學上基于消息傳遞框架。GAT在社交網絡、推薦系統、化學等領域的節點分類、鏈接預測等任務中表現優異,但計算復雜度和內存需求是其挑戰。相比GCN,GAT更靈活、可解釋,但在實際應用中需權衡性能和成本。