圖卷積神經網絡(Graph Convolutional Network, GCN)

最近看論文看到了圖卷積神經網絡的內容,之前整理過圖神經網絡的內容,這里再補充一下,方便以后查閱。

圖卷積神經網絡(Graph Convolutional Network, GCN)

  • 圖卷積神經網絡
    • 1. 什么是圖卷積神經網絡(GCN)?
    • 2. GCN的原理
      • 2.1 圖的表示
      • 2.2 譜圖卷積
      • 2.3 GCN的層級傳播規則
      • 2.4 消息傳遞框架
    • 3. GCN的結構
    • 4. GCN的應用
    • 5. GCN的優點與局限性
      • 優點:
      • 局限性:
    • 6. GCN代碼示例
      • 代碼說明:
    • 7. 如何擴展GCN?
    • 8. 總結
  • 譜圖理論(Spectral Graph Theory)
    • 1. 背景:為什么需要譜圖理論?
    • 2. 什么是譜圖理論?
    • 3. 圖的拉普拉斯矩陣
      • 3.1 定義
      • 3.2 歸一化拉普拉斯矩陣
      • 3.3 拉普拉斯矩陣的性質
    • 4. 通過拉普拉斯矩陣定義卷積操作
      • 4.1 圖傅里葉變換
      • 4.2 譜卷積
      • 4.3 GCN的簡化
    • 5. 為什么用拉普拉斯矩陣定義卷積?
    • 6. 局限性與改進
  • 圖注意力網絡(GAT)
    • 1. 背景:為什么需要GAT?
    • 2. GAT的原理
    • 3. GAT的數學公式
      • 3.1 注意力系數
      • 3.2 歸一化注意力系數
      • 3.3 加權聚合
      • 3.4 多頭注意力
      • 3.5 完整傳播規則
    • 4. GAT的結構
    • 5. GAT與GCN的對比
    • 6. GAT的優勢
    • 7. GAT的局限性
    • 8. GAT的應用
    • 9. GAT的實現與代碼
      • 代碼說明:
    • 10. GAT的擴展與改進
    • 11. 總結

先驗知識:圖神經網絡 GNN

圖卷積神經網絡

1. 什么是圖卷積神經網絡(GCN)?

GCN是卷積神經網絡(CNN)的擴展,適用于非歐幾里得空間的數據(如圖結構數據)。傳統CNN處理規則的網格數據(如圖像或時間序列),而GCN處理節點和邊構成的圖結構數據。圖由節點(vertices)和邊(edges)組成,節點表示實體,邊表示實體間的關系。

GCN的核心思想是通過消息傳遞機制,利用圖的拓撲結構,將節點特征與其鄰居的特征聚合,從而學習節點的表示(embedding)。這些表示可用于節點分類、鏈接預測或圖分類等任務。

2. GCN的原理

GCN基于譜圖理論(Spectral Graph Theory)和消息傳遞框架。以下是其核心原理:

2.1 圖的表示

一個圖 G = ( V , E ) G = (V, E) G=(V,E) 由以下部分組成:

  • 節點集 V V V,節點數為 n n n
  • 邊集 E E E,表示節點之間的連接。
  • 鄰接矩陣 A A A,大小為 n × n n \times n n×n,其中 A i j = 1 A_{ij} = 1 Aij?=1 表示節點 i i i j j j 之間有邊,否則為 0 0 0
  • 節點特征矩陣 X X X,大小為 n × d n \times d n×d,其中每行是節點 i i i d d d 維特征向量。
  • 度矩陣 D D D,對角矩陣,其中 D i i = ∑ j A i j D_{ii} = \sum_j A_{ij} Dii?=j?Aij? 表示節點 i i i 的度。

2.2 譜圖卷積

GCN最初基于譜圖理論,通過圖的拉普拉斯矩陣定義卷積操作。圖拉普拉斯矩陣定義為:
L = D ? A L = D - A L=D?A
歸一化拉普拉斯矩陣為:
L n o r m = I ? D ? 1 / 2 A D ? 1 / 2 L_{norm} = I - D^{-1/2} A D^{-1/2} Lnorm?=I?D?1/2AD?1/2
其中 I I I 是單位矩陣, D ? 1 / 2 D^{-1/2} D?1/2 是度矩陣的對角元素的倒數平方根。

譜圖卷積通過拉普拉斯矩陣的特征分解,將卷積操作定義在圖的頻域上。然而,計算特征分解的復雜度較高( O ( n 3 ) O(n^3) O(n3)),因此實際中常用近似方法。

2.3 GCN的層級傳播規則

現代GCN(如Kipf & Welling, 2017)使用簡化的消息傳遞機制。一層GCN的傳播規則為:
H ( l + 1 ) = σ ( D ~ ? 1 / 2 A ~ D ~ ? 1 / 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma \left( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)} \right) H(l+1)=σ(D~?1/2A~D~?1/2H(l)W(l))
其中:

  • H ( l ) H^{(l)} H(l):第 l l l 層的節點特征矩陣, H ( 0 ) = X H^{(0)} = X H(0)=X
  • A ~ = A + I \tilde{A} = A + I A~=A+I:添加自環的鄰接矩陣(每個節點與自身連接)。
  • D ~ \tilde{D} D~ A ~ \tilde{A} A~ 對應的度矩陣, D ~ i i = ∑ j A ~ i j \tilde{D}_{ii} = \sum_j \tilde{A}_{ij} D~ii?=j?A~ij?
  • W ( l ) W^{(l)} W(l):第 l l l 層的可學習權重矩陣。
  • σ \sigma σ:激活函數(如ReLU)。
  • D ~ ? 1 / 2 A ~ D ~ ? 1 / 2 \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} D~?1/2A~D~?1/2:歸一化的鄰接矩陣,用于平衡不同度節點的影響。

直觀解釋:

  • 每層GCN聚合節點的鄰居特征(包括自身),通過 A ~ \tilde{A} A~ 實現。
  • 歸一化( D ~ ? 1 / 2 \tilde{D}^{-1/2} D~?1/2) 防止高階節點主導聚合。
  • 權重矩陣 W ( l ) W^{(l)} W(l) 進行特征變換,激活函數引入非線性。

2.4 消息傳遞框架

GCN可以看作消息傳遞神經網絡(Message Passing Neural Network, MPNN)的一種:

  1. 聚合:收集鄰居節點的特征(通過 A ~ \tilde{A} A~)。
  2. 更新:結合自身特征和聚合特征,更新節點表示(通過 W ( l ) W^{(l)} W(l) σ \sigma σ)。

3. GCN的結構

一個典型的GCN模型包含以下部分:

  1. 輸入層:接受圖的鄰接矩陣 A A A 和節點特征矩陣 X X X
  2. 多層GCN:堆疊若干GCN層,每層執行特征聚合和變換。
  3. 輸出層:
    • 對于節點分類,輸出每個節點的類別概率(通過Softmax)。
    • 對于圖分類,需要池化層(如全局平均池化)將節點特征匯總為圖特征。
  4. 損失函數:
    • 節點分類:交叉熵損失。
    • 圖分類:交叉熵或回歸損失,取決于任務。

4. GCN的應用

GCN在許多領域有廣泛應用:

  1. 社交網絡:
    • 節點分類:預測用戶興趣或社區歸屬。
    • 鏈接預測:推薦好友或合作關系。
  2. 推薦系統:
    • 使用用戶-物品交互圖,預測用戶偏好。
  3. 化學分子分析:
    • 圖表示分子結構,預測分子性質(如毒性或溶解度)。
  4. 知識圖譜:
    • 實體分類或關系預測。
  5. 生物信息學:
    • 分析蛋白質相互作用網絡。

5. GCN的優點與局限性

優點:

  • 適應圖結構:能有效處理非規則的圖數據。
  • 局部性:通過鄰居聚合,捕獲局部拓撲信息。
  • 可擴展性:可以堆疊多層,學習復雜模式。

局限性:

  1. 過平滑問題:
    • 堆疊過多GCN層會導致節點特征趨于相似,丟失區分度。
  2. 固定拓撲:
    • GCN依賴靜態圖結構,無法直接處理動態圖。
  3. 計算復雜度:
    • 對于大規模圖,矩陣運算(如 A ~ H \tilde{A} H A~H)可能耗時。
  4. 邊信息:
    • 基本GCN不考慮邊的權重或類型,后續變體(如GAT)改進此問題。

6. GCN代碼示例

以下是一個基于PyTorch Geometric的GCN實現,用于節點分類任務。我們使用Cora數據集(一個常用的學術引用網絡數據集),其中節點是論文,邊是引用關系,目標是預測論文的類別。

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv# 加載Cora數據集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]# 定義GCN模型
class GCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GCN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, x, edge_index):# 第一層GCNx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, p=0.5, training=self.training)# 第二層GCNx = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)# 設置設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)# 初始化模型
model = GCN(in_channels=dataset.num_features,hidden_channels=16,out_channels=dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)# 訓練模型
def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss.item()# 測試模型
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()return acc# 訓練循環
for epoch in range(200):loss = train()if epoch % 10 == 0:acc = test()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {acc:.4f}')# 最終測試
final_acc = test()
print(f'Final Test Accuracy: {final_acc:.4f}')

代碼說明:

  • 數據集:Cora數據集包含2708個節點(論文),每個節點有1433維特征(詞袋表示),7個類別,圖有10556條邊。
  • 模型:兩層GCN,第一層將特征從1433維降到16維,第二層輸出7維(對應類別)。
  • 訓練:使用Adam優化器,交叉熵損失,僅對訓練掩碼(train_mask)的節點計算損失。
  • 測試:在測試掩碼(test_mask)上計算分類準確率。
  • 依賴:需要安裝torchtorch_geometric
    pip install torch torch-geometric
    

7. 如何擴展GCN?

GCN是圖神經網絡(GNN)的基礎,許多改進模型基于GCN:

  1. 圖注意力網絡(GAT):引入注意力機制,動態分配鄰居權重。
  2. GraphSAGE:通過采樣鄰居,適應大規模圖。
  3. JK-Net:通過跳躍連接緩解過平滑問題。
  4. APPNP:結合個性化PageRank,增強傳播效果。

8. 總結

GCN是一種強大的圖神經網絡,通過消息傳遞機制聚合鄰居特征,學習圖中節點的表示。其基于譜圖理論,結構簡單,適合節點分類、鏈接預測等任務。盡管GCN在許多領域表現優異,但過平滑和計算復雜度問題需要通過變體或優化解決。

譜圖理論(Spectral Graph Theory)

1. 背景:為什么需要譜圖理論?

在傳統卷積神經網絡(CNN)中,卷積操作適用于規則的網格數據(如圖像的像素網格),通過滑動窗口提取局部特征。然而,圖結構數據(如社交網絡、分子結構)是非規則的,非歐幾里得空間的數據,節點之間的連接(邊)沒有固定模式。因此,直接應用傳統卷積不可行。

譜圖理論提供了一種數學框架,通過分析圖的拓撲結構(鄰接關系),將圖上的操作(如卷積)定義在頻域(類似于傅里葉變換)。GCN的早期工作(例如Bruna等人,2013)利用譜圖理論,將圖上的卷積定義為基于圖拉普拉斯矩陣的操作。

2. 什么是譜圖理論?

譜圖理論是圖論的一個分支,研究圖的性質通過其矩陣表示(如鄰接矩陣或拉普拉斯矩陣)的特征值(eigenvalues)和特征向量(eigenvectors)。這些特征值和特征向量描述了圖的拓撲結構,例如連通性、聚類特性等。

在GCN中,譜圖理論的核心思想是將圖上的信號(節點特征)投影到圖的頻域(由拉普拉斯矩陣的特征向量定義),進行類似傅里葉變換的操作,再轉換回空間域。這種方法允許我們在圖上定義卷積,類似于圖像上的卷積。

3. 圖的拉普拉斯矩陣

拉普拉斯矩陣(Laplacian Matrix)是圖的矩陣表示,用于捕捉圖的拓撲結構。以下是其定義和性質:

3.1 定義

對于一個無向圖 G = ( V , E ) G = (V, E) G=(V,E),有 n n n 個節點,拉普拉斯矩陣 L L L 定義為:
L = D ? A L = D - A L=D?A

  • A A A:鄰接矩陣,大小 n × n n \times n n×n,其中 A i j = 1 A_{ij} = 1 Aij?=1 如果節點 i i i j j j 之間有邊,否則為 0 0 0
  • D D D:度矩陣,對角矩陣,大小 n × n n \times n n×n,其中 D i i = ∑ j A i j D_{ii} = \sum_j A_{ij} Dii?=j?Aij? 表示節點 i i i 的度(連接的邊數),非對角元素為 0 0 0

例如,對于一個簡單圖:

在這里插入圖片描述

  • 鄰接矩陣 A A A:
    A = [ 0 1 1 1 1 0 0 0 1 0 0 1 1 0 1 0 ] A = \begin{bmatrix} 0 & 1 & 1 & 1 \\ 1 & 0 & 0 & 0 \\ 1 & 0 & 0 & 1 \\ 1 & 0 & 1 & 0 \end{bmatrix} A= ?0111?1000?1001?1010? ?
  • 度矩陣 D D D:
    D = [ 3 0 0 0 0 1 0 0 0 0 2 0 0 0 0 2 ] D = \begin{bmatrix} 3 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 2 & 0 \\ 0 & 0 & 0 & 2 \end{bmatrix} D= ?3000?0100?0020?0002? ?
  • 拉普拉斯矩陣 L L L:
    L = D ? A = [ 3 ? 1 ? 1 ? 1 ? 1 1 0 0 ? 1 0 2 ? 1 ? 1 0 ? 1 2 ] L = D - A = \begin{bmatrix} 3 & -1 & -1 & -1 \\ -1 & 1 & 0 & 0 \\ -1 & 0 & 2 & -1 \\ -1 & 0 & -1 & 2 \end{bmatrix} L=D?A= ?3?1?1?1??1100??102?1??10?12? ?

3.2 歸一化拉普拉斯矩陣

為了平衡不同度節點的影響,常用歸一化拉普拉斯矩陣:
L n o r m = I ? D ? 1 / 2 A D ? 1 / 2 L_{norm} = I - D^{-1/2} A D^{-1/2} Lnorm?=I?D?1/2AD?1/2

  • D ? 1 / 2 D^{-1/2} D?1/2:對角矩陣,其對角元素為 D i i ? 1 / 2 = 1 / 度 i D_{ii}^{-1/2} = 1 / \sqrt{\text{度}_i} Dii?1/2?=1/i? ?
  • I I I:單位矩陣。

歸一化拉普拉斯矩陣的特征值在 [ 0 , 2 ] [0, 2] [0,2] 范圍內,適合數值計算。

3.3 拉普拉斯矩陣的性質

  • 對稱性:對于無向圖, L L L 是對稱矩陣,因此有實特征值和正交特征向量。
  • 正半定性: L L L 的特征值非負,反映圖的連通性(例如,特征值 0 0 0 的重數等于連通分量的個數)。
  • 頻域解釋:拉普拉斯矩陣的特征向量形成圖的“頻域基”,類似傅里葉變換中的正弦和余弦函數。特征值表示“頻率”,低頻率對應平滑信號,高頻率對應快速變化的信號。

4. 通過拉普拉斯矩陣定義卷積操作

在譜圖理論中,圖上的卷積操作通過拉普拉斯矩陣的特征分解定義,類似于信號處理中的傅里葉變換。以下是具體步驟:

4.1 圖傅里葉變換

拉普拉斯矩陣 L L L 可以分解為:
L = U Λ U T L = U \Lambda U^T L=UΛUT

  • U U U: 特征向量矩陣,列是 L L L 的特征向量 u 1 , u 2 , … , u n u_1, u_2, \ldots, u_n u1?,u2?,,un?,表示圖的“頻域基”。
  • Λ \Lambda Λ: 對角矩陣,對角元素是特征值 λ 1 , λ 2 , … , λ n \lambda_1, \lambda_2, \ldots, \lambda_n λ1?,λ2?,,λn?,表示“頻率”。

對于節點特征向量 x ∈ R n x \in \mathbb{R}^n xRn(每個節點一個標量特征),其圖傅里葉變換定義為:
x ^ = U T x \hat{x} = U^T x x^=UTx

  • x ^ \hat{x} x^ 是頻域中的系數,表示 x x x 在特征向量基上的投影。
  • 逆變換為:
    x = U x ^ x = U \hat{x} x=Ux^

4.2 譜卷積

在頻域中,卷積等價于逐頻率相乘。假設有一個濾波器(卷積核) g g g,其頻域表示為 g ( Λ ) g(\Lambda) g(Λ)(對角矩陣,元素為 g ( λ i ) g(\lambda_i) g(λi?))。圖上的卷積定義為:
x ? g = U g ( Λ ) U T x x * g = U g(\Lambda) U^T x x?g=Ug(Λ)UTx

  • g ( Λ ) g(\Lambda) g(Λ): 濾波器在頻域的響應,控制如何放大或抑制不同頻率的信號。
  • U T x U^T x UTx: 將信號 x x x 轉換為頻域。
  • U g ( Λ ) U T x U g(\Lambda) U^T x Ug(Λ)UTx: 應用濾波器后轉換回空間域。

因此,圖卷積是通過拉普拉斯矩陣的特征分解,將節點特征在頻域中與濾波器結合,再轉換回節點特征。

4.3 GCN的簡化

早期譜GCN(如Bruna等人)直接使用上述卷積,但計算 U U U U T U^T UT 的復雜度為 O ( n 3 ) O(n^3) O(n3),對大圖不可行。Kipf & Welling (2017) 提出了簡化版本:

  • 使用多項式濾波器:假設 g ( Λ ) g(\Lambda) g(Λ) 是拉普拉斯矩陣特征值的多項式,例如 g ( Λ ) = ∑ k θ k Λ k g(\Lambda) = \sum_k \theta_k \Lambda^k g(Λ)=k?θk?Λk
  • 近似為低階多項式(如一階),避免特征分解。
  • 最終傳播規則為:
    H ( l + 1 ) = σ ( D ~ ? 1 / 2 A ~ D ~ ? 1 / 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma \left( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)} \right) H(l+1)=σ(D~?1/2A~D~?1/2H(l)W(l))
    其中 A ~ = A + I \tilde{A} = A + I A~=A+I D ~ \tilde{D} D~ A ~ \tilde{A} A~ 的度矩陣。這種形式避免了昂貴的矩陣分解,直接操作鄰接矩陣。

5. 為什么用拉普拉斯矩陣定義卷積?

拉普拉斯矩陣在圖卷積中有以下優勢:

  1. 捕捉拓撲結構:拉普拉斯矩陣編碼了圖的連接性(鄰接關系和節點度),適合定義局部聚合操作。
  2. 頻域解釋:通過特征分解,拉普拉斯矩陣提供了一種頻域視角,類似圖像上的傅里葉變換,便于定義卷積。
  3. 平滑性:拉普拉斯矩陣與圖的平滑性相關(例如, x T L x x^T L x xTLx 測量信號 x x x 在圖上的變化),卷積操作可以平滑節點特征,聚合鄰居信息。
  4. 數學優雅:譜圖理論提供了統一的框架,將圖上的操作與傳統信號處理連接起來。

6. 局限性與改進

基于譜圖理論的GCN有以下局限性:

  1. 計算復雜度:特征分解對大圖不可行( O ( n 3 ) O(n^3) O(n3))。
  2. 泛化性:譜方法依賴圖的固定拉普拉斯矩陣,難以直接應用于不同結構的圖。
  3. 局部性:譜卷積本質上是全局操作,可能忽略局部特征。

因此,后續工作(如GraphSAGE、GAT)轉向空間域方法,直接在圖的拓撲上定義卷積(如鄰居聚合),避免譜分解,提高效率和靈活性。

圖注意力網絡(GAT)

圖注意力網絡(Graph Attention Network, GAT)是一種圖神經網絡(Graph Neural Network, GNN)的變體,由Veli?kovi?等人于2017年提出(論文:《Graph Attention Networks》)。它通過引入注意力機制(Attention Mechanism)改進了傳統的圖卷積神經網絡(GCN),能夠動態地為不同鄰居節點分配權重,從而更好地捕捉圖結構中的異質性關系

1. 背景:為什么需要GAT?

圖神經網絡(如GCN)通過聚合節點鄰居的特征來學習節點表示,適用于處理圖結構數據(如社交網絡、分子結構)。然而,GCN存在以下局限性:

  1. 等權重聚合:GCN假設所有鄰居對節點的貢獻相同(通過歸一化的鄰接矩陣),無法區分鄰居的重要性。例如,在社交網絡中,某些好友的影響可能更大。
  2. 固定拓撲依賴:GCN的聚合權重完全由圖的拓撲結構(鄰接矩陣和節點度)決定,缺乏靈活性。
  3. 無法捕捉異質性:對于高度異質的圖(節點或邊的關系差異顯著),GCN的表現可能受限。

GAT通過引入注意力機制解決了這些問題,允許模型動態學習每個鄰居的貢獻權重,類似于自然語言處理中的Transformer模型。這種機制使GAT能夠聚焦于對任務更重要的鄰居,提高表示能力和泛化性。

2. GAT的原理

GAT的核心思想是通過注意力機制為每個節點的鄰居分配不同的權重,然后基于這些權重聚合鄰居特征。其操作可以概括為以下步驟:

  1. 計算注意力系數:為每條邊(或鄰居對)計算一個注意力分數,表示鄰居的重要性。
  2. 歸一化注意力系數:使用Softmax將注意力分數歸一化為權重。
  3. 加權聚合:根據歸一化的注意力權重,聚合鄰居的特征。
  4. 多頭注意力:可選地使用多組注意力機制(類似Transformer),增強模型表達能力。

GAT仍然基于消息傳遞框架(Message Passing Neural Network, MPNN),但其聚合方式比GCN更靈活。

3. GAT的數學公式

假設有一個圖 G = ( V , E ) G = (V, E) G=(V,E),包含 n n n 個節點,節點特征矩陣為 X ∈ R n × d X \in \mathbb{R}^{n \times d} XRn×d,其中每個節點有 d d d 維特征。以下是GAT一層的主要數學表達。

3.1 注意力系數

對于節點 i i i 和其鄰居 j ∈ N i j \in \mathcal{N}_i jNi?(包括節點 i i i 自身,若考慮自環),GAT首先將節點特征通過線性變換映射到新的特征空間:
h i = W x i , h j = W x j h_i = W x_i, \quad h_j = W x_j hi?=Wxi?,hj?=Wxj?

  • W ∈ R d ′ × d W \in \mathbb{R}^{d' \times d} WRd×d:可學習的權重矩陣,將特征從 d d d 維映射到 d ′ d' d 維。
  • x i , x j x_i, x_j xi?,xj?:節點 i i i j j j 的輸入特征。
  • h i , h j h_i, h_j hi?,hj?:映射后的特征。

然后,計算節點 i i i j j j 之間的注意力系數 e i j e_{ij} eij?
e i j = a ( W x i , W x j ) e_{ij} = a(W x_i, W x_j) eij?=a(Wxi?,Wxj?)

  • a ( ? , ? ) a(\cdot, \cdot) a(?,?):注意力函數,通常是一個前饋神經網絡。例如,Veli?kovi?等人使用單層感知器:
    e i j = LeakyReLU ( a T [ W x i ∥ W x j ] ) e_{ij} = \text{LeakyReLU} \left( a^T [W x_i \parallel W x_j] \right) eij?=LeakyReLU(aT[Wxi?Wxj?])
    • a ∈ R 2 d ′ a \in \mathbb{R}^{2d'} aR2d:可學習的注意力向量。
    • [ W x i ∥ W x j ] [W x_i \parallel W x_j] [Wxi?Wxj?]:將 W x i W x_i Wxi? W x j W x_j Wxj? 拼接,得到 2 d ′ 2d' 2d 維向量。
    • LeakyReLU:激活函數,增加非線性。

3.2 歸一化注意力系數

為了使注意力系數可比較(類似概率分布),對節點 i i i 的所有鄰居 j ∈ N i j \in \mathcal{N}_i jNi? 應用Softmax歸一化:
α i j = exp ? ( e i j ) ∑ k ∈ N i exp ? ( e i k ) \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}_i} \exp(e_{ik})} αij?=kNi??exp(eik?)exp(eij?)?

  • α i j \alpha_{ij} αij?:歸一化的注意力權重,表示鄰居 j j j 對節點 i i i 的相對重要性。

3.3 加權聚合

使用歸一化的注意力權重,聚合鄰居的特征,更新節點 i i i 的表示:
h i ′ = σ ( ∑ j ∈ N i α i j W x j ) h_i' = \sigma \left( \sum_{j \in \mathcal{N}_i} \alpha_{ij} W x_j \right) hi?=σ ?jNi??αij?Wxj? ?

  • h i ′ h_i' hi?:節點 i i i 的更新特征。
  • σ \sigma σ:激活函數(如ELU或ReLU)。
  • W x j W x_j Wxj?:鄰居 j j j 的變換特征。

3.4 多頭注意力

為了增強模型的表達能力和穩定性,GAT通常使用多頭注意力(Multi-Head Attention)。運行 K K K 個獨立的注意力機制,得到 K K K 組特征,然后拼接或平均:

  • 中間層:拼接多頭輸出:
    h i ′ = ∥ k = 1 K σ ( ∑ j ∈ N i α i j k W k x j ) h_i' = \parallel_{k=1}^K \sigma \left( \sum_{j \in \mathcal{N}_i} \alpha_{ij}^k W^k x_j \right) hi?=k=1K?σ ?jNi??αijk?Wkxj? ?
    • W k W^k Wk:第 k k k 頭的權重矩陣。
    • α i j k \alpha_{ij}^k αijk?:第 k k k 頭的注意力權重。
    • 輸出維度為 K × d ′ K \times d' K×d
  • 輸出層:平均多頭輸出:
    h i ′ = σ ( 1 K ∑ k = 1 K ∑ j ∈ N i α i j k W k x j ) h_i' = \sigma \left( \frac{1}{K} \sum_{k=1}^K \sum_{j \in \mathcal{N}_i} \alpha_{ij}^k W^k x_j \right) hi?=σ ?K1?k=1K?jNi??αijk?Wkxj? ?
    • 平均操作減少參數量,適合分類任務。

3.5 完整傳播規則

一層GAT的傳播規則可以總結為:
H ′ = σ ( ∑ j ∈ N i α i j W X ) H' = \sigma \left( \sum_{j \in \mathcal{N}_i} \alpha_{ij} W X \right) H=σ ?jNi??αij?WX ?
其中 H ′ H' H 是更新后的特征矩陣, X X X 是輸入特征矩陣, α i j \alpha_{ij} αij? 通過注意力機制計算。

4. GAT的結構

一個典型的GAT模型包含以下部分:

  1. 輸入層
    • 輸入:圖的鄰接矩陣(或邊索引列表)和節點特征矩陣 X X X
  2. 多層GAT
    • 堆疊若干GAT層,每層執行注意力機制和特征聚合。
    • 中間層通常使用多頭注意力(拼接),輸出層可能使用單頭或平均。
    • 每層后可添加激活函數(如ELU)和Dropout(防止過擬合)。
  3. 輸出層
    • 節點分類:通過Softmax輸出每個節點的類別概率。
    • 圖分類:通過池化(如全局平均池化)將節點特征匯總為圖特征。
  4. 損失函數
    • 節點分類:交叉熵損失。
    • 圖分類:交叉熵或回歸損失,取決于任務。

5. GAT與GCN的對比

特性GCNGAT
鄰居聚合方式等權重(基于歸一化鄰接矩陣)動態權重(通過注意力機制)
權重計算固定(由圖結構決定)可學習(注意力系數動態調整)
表達能力較弱(無法區分鄰居重要性)較強(捕捉異質性關系)
計算復雜度較低(矩陣乘法)較高(需計算注意力系數)
過平滑問題顯著(多層后特征趨同)較輕(注意力機制保留差異)

直觀解釋:

  • GCN像“平均池化”,對所有鄰居一視同仁。
  • GAT像“加權池化”,根據任務動態選擇重要鄰居,類似“聰明地聽意見”。

6. GAT的優勢

  1. 動態權重分配
    • 注意力機制允許模型根據任務自動學習鄰居的重要性。例如,在社交網絡中,某些好友的影響可能更大。
  2. 捕捉異質性
    • GAT能處理節點或邊關系差異顯著的圖,適合復雜網絡。
  3. 多頭注意力
    • 類似Transformer,增強模型表達能力,捕捉多種關系模式。
  4. 可解釋性
    • 注意力系數 α i j \alpha_{ij} αij? 可視化,揭示哪些鄰居對預測更重要。
  5. 緩解過平滑
    • 相比GCN,GAT通過選擇性聚合減少多層后特征趨同的問題。

7. GAT的局限性

  1. 計算復雜度
    • 計算注意力系數需要為每條邊執行操作,復雜度為 O ( ∣ E ∣ ? d ) O(|E| \cdot d) O(E?d),對稠密圖或大圖計算成本高。
    • 多頭注意力進一步增加計算量。
  2. 內存需求
    • 存儲注意力系數和多頭特征需要更多內存。
  3. 穩定性問題
    • 注意力機制可能導致訓練不穩定,尤其在深層網絡中,需小心調整超參數(如Dropout率、學習率)。
  4. 邊信息限制
    • 基本GAT不直接利用邊特征(如權重或類型),需擴展模型(如EGAT)。
  5. 過擬合風險
    • 在小圖或稀疏圖上,注意力機制可能過擬合,需正則化(如Dropout)。

8. GAT的應用

GAT在許多圖結構數據的任務中表現出色,包括:

  1. 社交網絡
    • 節點分類:預測用戶興趣、社區歸屬。
    • 鏈接預測:推薦好友或合作關系。
  2. 推薦系統
    • 使用用戶-物品交互圖,預測用戶偏好。
  3. 化學分子分析
    • 圖表示分子結構,預測分子性質(如毒性、溶解度)。
  4. 知識圖譜
    • 實體分類或關系預測。
  5. 生物信息學
    • 分析蛋白質相互作用網絡,預測蛋白質功能。
  6. 交通網絡
    • 預測交通流量或路徑優化。

9. GAT的實現與代碼

以下是一個基于PyTorch Geometric的GAT實現,用于節點分類任務。我們使用Cora數據集(一個常用的學術引用網絡數據集),其中節點是論文,邊是引用關系,目標是預測論文的類別。

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv# 加載Cora數據集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]# 定義GAT模型
class GAT(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, heads=8):super(GAT, self).__init__()# 第一層GAT,使用多頭注意力self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)# 第二層GAT,輸出類別self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)def forward(self, x, edge_index):# 第一層GATx = self.conv1(x, edge_index)x = F.elu(x)x = F.dropout(x, p=0.6, training=self.training)# 第二層GATx = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)# 設置設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)# 初始化模型
model = GAT(in_channels=dataset.num_features,hidden_channels=8,out_channels=dataset.num_classes,heads=8).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)# 訓練模型
def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss.item()# 測試模型
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()return acc# 訓練循環
for epoch in range(200):loss = train()if epoch % 10 == 0:acc = test()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {acc:.4f}')# 最終測試
final_acc = test()
print(f'Final Test Accuracy: {final_acc:.4f}')

代碼說明:

  • 數據集:Cora數據集,包含2708個節點(論文),每個節點有1433維特征(詞袋表示),7個類別,圖有10556條邊。
  • 模型:兩層GAT,第一層將特征從1433維降到8維(使用多頭注意力),第二層輸出7維(對應類別)。
  • 注意力機制:GAT使用注意力系數動態加權鄰居特征,增強模型對重要鄰居的關注。
  • 訓練:使用Adam優化器,交叉熵損失,僅對訓練掩碼(train_mask)的節點計算損失。
  • 測試:在測試掩碼(test_mask)上計算分類準確率。
  • 依賴:需要安裝torchtorch_geometric
    pip install torch torch-geometric
    

10. GAT的擴展與改進

GAT是GNN領域的重要進展,許多后續工作在其基礎上改進:

  1. GATv2(2021):改進注意力機制,解決原始GAT的表達能力瓶頸,增強性能。
  2. EGAT:引入邊特征,擴展到加權或有類型的圖。
  3. HGT(Heterogeneous Graph Transformer):結合GAT和Transformer,處理異構圖。
  4. 采樣優化:如GraphSAGE的采樣策略,結合GAT,適應大規模圖。
  5. 動態圖:擴展GAT到時序圖,處理動態拓撲。

11. 總結

圖注意力網絡(GAT)通過引入注意力機制,改進了GCN的局限性,能夠動態學習鄰居的重要性,增強對異質圖的建模能力。其核心是計算注意力系數、歸一化加權和多頭注意力,數學上基于消息傳遞框架。GAT在社交網絡、推薦系統、化學等領域的節點分類、鏈接預測等任務中表現優異,但計算復雜度和內存需求是其挑戰。相比GCN,GAT更靈活、可解釋,但在實際應用中需權衡性能和成本。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/905748.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/905748.shtml
英文地址,請注明出處:http://en.pswp.cn/news/905748.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

安裝win11硬盤分區MBR還是GPT_裝win11系統分區及安裝教程

最近有網友問我,裝win11系統分區有什么要求裝win11系統硬盤分區用mbr還是GPT?我們知道現在的引導模式有uefi和legacy兩種引導模式,如果采用的是uefi引導模式,分區類型對應的就是gpt分區(guid),如果引導模式采用的是legacy&#xf…

服務培訓QDA 的安裝調試方法,硬件模塊的講解和軟件控制臺使用及系統測試

#服務培訓##質譜儀##軟件控制##硬件模塊# 以下是關于Waters QDa單桿液質質譜儀的安裝調試、硬件模塊講解以及軟件控制臺使用培訓的相關內容: 安裝調試 場地準備:用戶需要提前準備好實驗室,確保實驗室環境符合儀器的要求,如溫度、…

在K8S集群中部署EFK日志收集

目錄 引言環境準備安裝自定義資源部署ElasticsearchMaster 節點與 Data 節點的區別生產優化建議安裝好以后測試ES是否正常部署Fluentd測試filebeat是否正常推送日志部署Kibana獲取賬號密碼,賬號是:elastic集群測試 引言 系統版本為 Centos7.9內核版本為…

polarctf-web-[rce1]

考點: (1)RCE(exec函數) (2)空格繞過 (3)執行函數(exec函數) (4)閉合(ping命令閉合) 題目來源:Polarctf-web-[rce1] 解題: 這段代碼實現了一個簡單的 Ping 測試工具,用戶可以通過表單提交一個 IP 地址,服務器會執…

【串流VR手勢】Pico 4 Ultra Enterprise 在 SteamVR 企業串流中無法識別手勢的問題排查與解決過程(Pico4UE串流手勢問題)

寫在前面的話 此前(用Pico 4U)接入了MRTK3,現項目落地需要部署,發現串流場景中,Pico4UE的企業串流無法正常識別手勢。(一體機方式部署使用無問題) 花了半小時解決,怕忘,…

ES(Elasticsearch)的應用與代碼示例

Elasticsearch應用與代碼示例技術文章大綱 一、引言 Elasticsearch在現代化應用中的核心作用典型應用場景分析(日志分析/全文檢索/數據聚合) 二、環境準備(前提條件) Elasticsearch 8.x集群部署要點IK中文分詞插件配置指南Ingest Attachment插件安裝…

臨床決策支持系統的提示工程優化路徑深度解析

引言 隨著人工智能技術在醫療領域的迅猛發展,臨床決策支持系統(CDSS)正經歷從傳統規則引擎向智能提示工程的范式轉變。在這一背景下,如何構建既符合循證醫學原則又能適應個體化醫療需求的CDSS成為醫學人工智能領域的核心挑戰。本報告深入剖析了臨床決策支持系統中提示工程的…

火山RTC 8 SDK集成進項目中

一、SDK 集成預備工作 1、SDK下載 https://www.volcengine.com/docs/6348/75707 2、解壓后 3、放在自己項目中的位置 1)、include 2)、lib 3)、dll 暫時,只需要VolcEngineRTC.dll RTCFFmpeg.dll openh264-4.dll, 放在intLive2…

OkHttp用法-Java調用http服務

特點:高性能,支持異步請求,連接池優化 官方文檔:提供快速入門指南和高級功能(如攔截器、連接池)的詳細說明,GitHub倉庫包含豐富示例。 社區資源:中文教程豐富,GitHub高…

python中常用的參數以及命名規范

以下是 Python 中常見的命名規范、參數用法及在大型項目中常用的操作模式,供記錄參考: 1. 命名規范(Naming Conventions) 前綴/形式含義示例_age單下劃線:弱“私有”標記(可訪問但不建議外部使用&#xff…

第五十七篇 Java接口設計之道:從咖啡機到智能家居的編程哲學

目錄 引言:生活中的接口無處不在一、咖啡機與基礎接口:理解抽象契約1.1 咖啡制作的標準接口 二、智能家居與策略模式:靈活切換實現2.1 溫度調節策略場景 三、物流系統與工廠模式:標準接口下的多樣實現3.1 快遞運輸接口設計 四、健…

第二十六天打卡

全局變量 global_var 全局變量是定義在函數、類或者代碼塊外部的變量,它在整個程序文件內都能被訪問。在代碼里, global_var 就是一個全局變量,下面是相關代碼片段: print("\n--- 變量作用域示例 ---") global_var …

聯合查詢

目錄 1、笛卡爾積 2、聯合查詢 2.1、內連接 2.2、外連接 1、笛卡爾積 笛卡爾積: 笛卡爾積是讓兩個表通過排列組合的方式,得到的一個更大的表。笛卡爾積的列數,是這兩個表的列數相加,笛卡爾積的行數,是這兩個表的行…

【HTML5學習筆記2】html標簽(下)

1表格標簽 1.1表格作用 顯示數據 1.2基本語法 <table><tr> 一行<td>單元格1</td></tr> </table> 1.3表頭單元格標簽 表頭單元格會加粗并且居中 <table><tr> 一行<th>單元格1</th></tr> </table&g…

window 顯示驅動開發-分頁視頻內存資源

與 Microsoft Windows 2000 顯示驅動程序模型不同&#xff0c;Windows Vista 顯示驅動程序模型允許創建比可用物理視頻內存總量更多的視頻內存資源&#xff0c;然后根據需要分頁進出視頻內存。 換句話說&#xff0c;并非所有視頻內存資源都同時位于視頻內存中。 GPU 的管道中可…

《C 語言指針高級指南:字符、數組、函數指針的進階攻略》

目錄 一. 字符指針變量 二. 數組指針變量 三. 二維數組傳參 3.1 二維數組的本質 3.2 訪問方式與地址計算 3.3 二維數組的傳參方式 3.4 深入解析 *(*(arri)j) 與 arr[i][j] 的等價性 四. 函數指針變量 4.1 函數指針變量的創建 4.2 函數指針變量的使用 4.3 兩段"…

Unity:場景管理系統 —— SceneManagement 模塊

目錄 &#x1f3ac; 什么是 Scene&#xff08;場景&#xff09;&#xff1f; Unity 項目中的 Scene 通常負責什么&#xff1f; &#x1f30d; 一個 Scene 包含哪些元素&#xff1f; Scene 的切換與管理 &#x1f4c1; 如何創建與管理 Scenes&#xff1f; 什么是Scene Man…

內容中臺重構企業知識管理路徑

智能元數據驅動知識治理 現代企業知識管理的核心挑戰在于海量非結構化數據的有效治理。通過智能元數據分類引擎&#xff0c;系統可自動識別文檔屬性并生成多維標簽體系&#xff0c;例如將技術手冊按產品版本、功能模塊、適用場景進行動態標注。這種動態元數據框架不僅支持跨部…

Vue3:腳手架

工程環境配置 1.安裝nodejs 這里我已經安裝過了&#xff0c;只需要打開鏈接Node.js — Run JavaScript Everywhere直接下載nodejs&#xff0c;安裝直接一直下一步下一步 安裝完成之后我們來使用電腦的命令行窗口檢查一下版本 查看npm源 這里npm源的地址是淘寶的源&#xff0…

悅數圖數據庫一體機發布,讓復雜關聯計算開箱即用

在金融風控、政務治理、能源監測等關鍵領域&#xff0c;復雜數據關聯分析已成為業務決策的核心需求。然而&#xff0c;信創場景的特殊性——全棧自主可控、海量實時計算、系統高可用性——對傳統技術架構提出了近乎苛刻的要求。悅數圖數據庫一體機應運而生&#xff0c;以軟硬協…