GCN從理論到實踐——基于PyTorch的圖卷積網絡層實現

Hi,大家好,我是半畝花海。圖卷積網絡(Graph Convolutional Network, GCN)是一種處理圖結構數據的深度學習模型。它通過聚合鄰居節點的信息來更新每個節點的特征表示,廣泛應用于社交網絡分析、推薦系統和生物信息學等領域。本實驗通過實現一個簡單的 GCN 層,展示了其核心思想,并通過具體代碼示例說明了 GCN 層的工作原理。

目錄

一、圖卷積網絡的含義

二、實驗展示——基于PyTorch的圖卷積網絡(GCN)層實現

(一)實驗目標

(二)實驗方法

(三)實驗結果分析

(四)思考與總結

三、完整代碼

四、參考文章


一、圖卷積網絡的含義

說起圖卷積神經網絡(Graph Convolutional networks, GCN),可以先探討一下卷積神經網絡(CNN),CNN 中的卷積本質上就是利用共享參數的過濾器,通過計算中心像素點以及相鄰像素點的加權和來實現空間特征的提取。而 GCN 也是如此,類似于圖像中的卷積處理,它依賴于節點間的消息傳遞方法,這意味著節點與其鄰居點交換信息,并相互發送消息。

在看具體的數學表達式之前,我們可以試著直觀地理解 GCN 是如何工作的,可分為以下兩大步驟:

  • 第一步:每個節點創建一個特征向量,表示它要發送給所有鄰居的消息。
  • 第二步:消息被發送到相鄰節點,這樣每個節點均會從其相鄰節點接收一條消息。

下面的圖可視化了以上兩大步驟:

那么隨后該如何組合節點、接收消息呢?

由于節點間消息的數量不同,需要一個適用于任意數量的操作,通常的方法是求和或取平均值。令?H^{(l)}?表示節點?以前的特征表示,H^{(l+1)}?為整合消息后的特征表示,GCN 層定義如下:

H^{(l+1)}=\sigma\left(\hat{D}^{-1 / 2} \hat{A} \hat{D}^{-1 / 2} H^{(l)} W^{(l)}\right)

W^{(l)}?是將輸入特征轉換為消息的權重參數。在鄰接矩陣 A 的基礎上,加上單位矩陣,以便每個節點也向自身發送消息,即:\hat{A}=A+I。最后,為了取平均值的運算,需要用到矩陣 \hat{D},這是一個對角矩陣,D_{i}?表示節點 i 的鄰居數。\sigma 表示一個任意的激活函數,當然,不一定是 Sigmoid,事實上,在 GNN 中通常使用基于 ReLU 的激活函數

二、實驗展示——基于PyTorch的圖卷積網絡(GCN)層實現

(一)實驗目標

  • 理解 GCN 層的基本原理
  • 實現一個簡單的 GCN 層,并通過手動設置權重矩陣驗證其計算過程
  • 分析輸入節點特征與鄰接矩陣如何影響輸出特征

(二)實驗方法

在 PyTorch 中實現 GCN 層時,我們可以靈活地利用張量進行運算,不必定義矩陣?\hat{D},只需將求和的消息除以之后的鄰居數即可。此外,線性層便是以上的權重矩陣,同時可添加偏置(bias)。基于 PyTorch,定義GCN層的具體步驟如下所示。

1.?導入必要的庫

import torch
import torch.nn as nn
  • torch:PyTorch 深度學習框架的核心庫,用于張量操作和自動求導。
  • torch.nn:提供了構建神經網絡所需的模塊和函數。

2.?定義圖卷積層(GCNLayer)

class GCNLayer(nn.Module):def __init__(self, c_in, c_out):"""Inputs::param c_in: 輸入特征維度:param c_out: 輸出特征維度"""super().__init__()self.projection = nn.Linear(c_in, c_out)  # 線性層
  • GCNLayer 繼承自 nn.Module,是 PyTorch 中所有神經網絡模塊的基類。
  • c_inc_out 分別表示輸入特征和輸出特征的維度。
  • self.projection 是 PyTorch 中的線性變換層,將輸入特征從 c_in 維映射到 c_out 維。其公式為:

output=input \cdot weight^{T}+bias

3.?前向傳播

def forward(self, node_feats, adj_matrix):"""輸入::param node_feats: 節點特征表示,大小為 [batch_size, num_nodes, c_in]:param adj_matrix: 鄰接矩陣,大小為 [batch_size, num_nodes, num_nodes]:return: 更新后的節點特征"""num_neighbors = adj_matrix.sum(dim=-1, keepdims=True)  # 各節點的鄰居數node_feats = self.projection(node_feats)  # 將特征轉化為消息# 各鄰居節點消息求和并求平均node_feats = torch.bmm(adj_matrix, node_feats)node_feats = node_feats / num_neighborsreturn node_feats
  • 輸入參數:
    • node_feats:表示每個節點的特征,形狀為 [batch_size, num_nodes, c_in]
    • adj_matrix:圖的鄰接矩陣,形狀為 [batch_size, num_nodes, num_nodes]
  • 步驟解析:
    • 計算鄰居數量:num_neighbors = adj_matrix.sum(dim=-1, keepdims=True) 計算每個節點的鄰居數量(包括自身)。
    • 線性變換:node_feats = self.projection(node_feats) 對節點特征進行線性變換
    • 鄰居信息聚合:torch.bmm(adj_matrix, node_feats) 使用批量矩陣乘法(Batch Matrix Multiplication)將鄰居節點的消息加權求和
    • 歸一化:node_feats = node_feats / num_neighbors 將聚合結果按鄰居數量歸一化,得到每個節點的更新特征。

4.?實驗數據準備

node_feats = torch.arange(8, dtype=torch.float32).view(1, 4, 2)
adj_matrix = torch.Tensor([[[1, 1, 0, 0],[1, 1, 1, 1],[0, 1, 1, 1],[0, 1, 1, 1]]])
print("節點特征:\n", node_feats)
print("添加自連接的鄰接矩陣:\n", adj_matrix)

(1)節點特征

  • node_feats 是一個形狀為 [1, 4, 2] 的張量,表示一個批次中 4 個節點的特征,每個節點有 2 維特征。
節點特征:tensor([[[0., 1.],[2., 3.],[4., 5.],[6., 7.]]])

(2)鄰接矩陣

  • adj_matrix 是一個形狀為 [1, 4, 4] 的張量,表示圖的鄰接矩陣。
添加自連接的鄰接矩陣:tensor([[[1., 1., 0., 0.],[1., 1., 1., 1.],[0., 1., 1., 1.],[0., 1., 1., 1.]]])
  • 鄰接矩陣中的元素為 1 表示兩個節點之間存在連接,0 表示無連接。

5.?初始化GCN層并設置權重

layer = GCNLayer(c_in=2, c_out=2)
# 初始化權重矩陣
layer.projection.weight.data = torch.Tensor([[1., 0.], [0., 1.]])
layer.projection.bias.data = torch.Tensor([0., 0.])
  • 創建一個 GCNLayer 實例,輸入特征維度為 2,輸出特征維度也為 2。
  • 手動初始化權重矩陣和偏置(bia):
    • 權重矩陣為單位矩陣,表示不改變輸入特征:(該單位矩陣的值?I = 1
    • 偏置為零向量

由于權重矩陣是單位矩陣,偏置為零,線性變換的公式簡化為:

output=input \cdot weight^{T}+bias=input \cdot I+0=input

因此,線性變換后的節點特征與輸入特征相同

6.?前向傳播并計算輸出特征

# 將節點特征和添加自連接的鄰接矩陣輸入 GCN 層
with torch.no_grad():out_feats = layer(node_feats, adj_matrix)print("節點輸出特征:\n", out_feats)
  • 使用 torch.no_grad() 關閉梯度計算,避免不必要的內存開銷。
  • 調用 layer(node_feats, adj_matrix) 進行前向傳播,得到更新后的節點特征。
  • 輸出結果:
節點輸出特征:tensor([[[1., 2.],[3., 4.],[4., 5.],[4., 5.]]])

(三)實驗結果分析

1.?輸入數據

(1)節點特征

節點特征是一個大小為?[1, 4, 2]?的張量,表示一個批次中有 4 個節點,每個節點有 2 維特征。具體值如下:

tensor([[[0., 1.],[2., 3.],[4., 5.],[6., 7.]]])
  • 節點 0 的特征為:[0., 1.]
  • 節點 1 的特征為:[2., 3.]
  • 節點 2 的特征為:[4., 5.]
  • 節點 3 的特征為:[6., 7.]

(2)鄰接矩陣

鄰接矩陣是一個大小為?[1, 4, 4]?的張量,表示?4 個節點之間的連接關系。具體值如下:

tensor([[[1., 1., 0., 0.],[1., 1., 1., 1.],[0., 1., 1., 1.],[0., 1., 1., 1.]]])
  • 節點 0 的鄰居為:節點 0 和節點 1。
  • 節點 1 的鄰居為:節點 0、節點 1、節點 2 和節點 3。
  • 節點 2 的鄰居為:節點 1、節點 2 和節點 3。
  • 節點 3 的鄰居為:節點 1、節點 2 和節點 3。

如何通過鄰接矩陣來判斷每個節點的鄰居是什么?——看值為1的索引是多少,那么鄰居便是多少。

[[1., 1., 0., 0.],  # 節點0的鄰居:值為1的列索引為[0, 1],即節點0和節點1。[1., 1., 1., 1.],  # 節點1的鄰居:值為1的列索引為[0, 1, 2, 3],即節點0、節點1、節點2和節點3。[0., 1., 1., 1.],  # 節點2的鄰居:值為1的列索引為[1, 2, 3],即節點1、節點2和節點3。[0., 1., 1., 1.]]  # 節點3的鄰居:值為1的列索引為[1, 2, 3],即節點1、節點2和節點3。

本實驗中的圖 G?的圖示如下:

2. 輸出特征分析

經GCN層的前向傳播后,得到輸出特征,其形狀為 [1, 4, 2] 的張量,表示更新后的節點特征。

tensor([[[1., 2.],[3., 4.],[4., 5.],[4., 5.]]])

GCN 層通過鄰接矩陣聚合鄰居節點的消息。具體計算如下:對于每個節點,將其鄰居節點的特征相加。再將聚合后的特征除以鄰居數量,得到平均特征,即最終的輸出特征。下面逐節點分析輸出特征的計算過程:

(1)節點0的計算
  • 鄰居節點:節點0和節點1。
  • 聚合特征:[0., 1.] + [2., 3.] = [2., 4.]
  • 鄰居數量:2
  • 平均特征:[2., 4.] / 2 = [1., 2.]
(2)節點1的計算
  • 鄰居節點:節點0、節點1、節點2和節點3。
  • 聚合特征:[0., 1.] + [2., 3.] + [4., 5.] + [6., 7.] = [12., 16.]
  • 鄰居數量:4
  • 平均特征:[12., 16.] / 4 = [3., 4.]
(3)節點2的計算
  • 鄰居節點:節點1、節點2和節點3。
  • 聚合特征:[2., 3.] + [4., 5.] + [6., 7.] = [12., 15.]
  • 鄰居數量:3
  • 平均特征:[12., 15.] / 3 = [4., 5.]
(4)節點3的計算
  • 鄰居節點:節點1、節點2和節點3。
  • 聚合特征:[2., 3.] + [4., 5.] + [6., 7.] = [12., 15.]
  • 鄰居數量:3
  • 平均特征:[12., 15.] / 3 = [4., 5.]

通過上述分析可以看出,GCN 層的核心思想是通過聚合鄰居節點的信息來更新每個節點的特征表示。具體來說:

  • 線性變換 :首先對輸入特征進行線性變換(本實驗中權重矩陣為單位矩陣,因此特征未發生變化)。
  • 鄰居信息聚合 :通過鄰接矩陣將鄰居節點的特征加權求和。
  • 歸一化 :將聚合結果按鄰居數量歸一化,得到最終的節點特征。

(四)思考與總結

1. 思考

如上所見,第一個節點的輸出值是其自身和第二個節點的平均值,其他節點同理。當然,在具體實踐中,我們還希望允許節點之間的消息傳遞不僅僅局限于鄰居節點,還可以通過應用多個 GCN 層來實現,而很多的 GNN 即是由多個 GCN 和非線性(如 ReLU)的組合構建而成,如下圖所示:

通過以上 GCN 層的運算示例,發現一個問題,即節點 3 和 4 的輸出相同,這是因為它們具有相同的相鄰節點(包括自身)輸入,再取均值,所得到的值便一樣了。這在大部分情況下并不合理。

2. 總結

本實驗通過實現一個簡單的 GCN 層,展示了圖卷積網絡的核心思想——通過聚合鄰居節點的信息來更新節點特征。通過手動設置權重矩陣和偏置,我們驗證了 GCN 層的計算過程,并分析了輸入特征與鄰接矩陣對輸出特征的影響。實驗結果表明,GCN 層能夠有效地捕捉圖結構中的局部信息。

未來可以進一步擴展該實驗:

  • 引入非線性激活函數 :在 GCN 層中加入 ReLU 等非線性激活函數,增強模型的表達能力。
  • 多層 GCN :堆疊多個 GCN 層,以捕獲更高階的鄰居信息。
  • 真實數據集實驗 :在實際圖數據集(如 Cora 或 Citeseer)上測試 GCN 模型的性能。
  • 優化算法 :結合梯度下降等優化算法,訓練 GCN 模型以完成特定任務(如節點分類或鏈接預測)。

通過這些擴展,可以更全面地理解圖卷積網絡的工作原理及其在實際問題中的應用價值。


三、完整代碼

#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Project : GNN/GCN
@File    : gcn1.py
@IDE     : PyCharm
@Author  : 半畝花海
@Date    : 2025/02/28 21:33
"""
import torch
import torch.nn as nnclass GCNLayer(nn.Module):def __init__(self, c_in, c_out):"""Inputs::param c_in: 輸入特征:param c_out: 輸出特征"""super().__init__()self.projection = nn.Linear(c_in, c_out);  # 線性層def forward(self, node_feats, adj_matrix):"""輸入:param node_feats: 節點特征表示,大小為[batch_size,num_nodes,c_in]:param adj_matrix: 鄰接矩陣:[batch_size,num_nodes,num_nodes]:return:"""num_neighbors = adj_matrix.sum(dim=-1, keepdims=True)  # 各節點的鄰居數node_feats = self.projection(node_feats)  # 將特征轉化為消息# 各鄰居節點消息求和并求平均node_feats = torch.bmm(adj_matrix, node_feats)node_feats = node_feats / num_neighborsreturn node_featsnode_feats = torch.arange(8, dtype=torch.float32).view(1, 4, 2)
adj_matrix = torch.Tensor([[[1, 1, 0, 0],[1, 1, 1, 1],[0, 1, 1, 1],[0, 1, 1, 1]]])
print("節點特征:\n", node_feats)
print("添加自連接的鄰接矩陣:\n", adj_matrix)layer = GCNLayer(c_in=2, c_out=2)
# 初始化權重矩陣
layer.projection.weight.data = torch.Tensor([[1., 0.], [0., 1.]])
layer.projection.bias.data = torch.Tensor([0., 0.])# 將節點特征和添加自連接的鄰接矩陣輸入 GCN 層
with torch.no_grad():out_feats = layer(node_feats, adj_matrix)print("節點輸出特征:\n", out_feats)

四、參考文章

[1]?實戰-----基于 PyTorch 的 GNN 搭建_pytorch gnn-CSDN博客

[2]?圖神經網絡簡單理解 — — 附帶案例_圖神經網絡實例-CSDN博客

[3]?一文快速預覽經典深度學習模型(二)——遷移學習、半監督學習、圖神經網絡(GNN)、聯邦學習_遷移學習 圖神經網絡-CSDN博客

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/72422.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/72422.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/72422.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

MyBatis-Plus 邏輯刪除實現

在很多企業級應用中,數據刪除操作通常采用 邏輯刪除 的方式,而不是物理刪除。邏輯刪除指的是通過更新字段(例如 is_deleted 或 status)來標記數據為刪除狀態,而不是真的從數據庫中刪除記錄。這樣做的好處是保留數據的歷…

STM32_IIC外設工作流程

STM32 IC 外設工作流程(基于寄存器) 在 STM32 中,IC 通信主要通過一系列寄存器控制。理解這些寄存器的作用,能夠幫助我們掌握 IC 硬件的運行機制,實現高效的數據傳輸。本文以 STM32F1(如 STM32F103&#x…

集合遍歷的多種方式

目錄 1.增強for 2.迭代器(在遍歷的過程中需要刪除元素,請使用迭代器) 3.雙列集合 4.Lambda表達式(forEach方法) 1.單列集合: 2.雙列集合: 4.Stream 流 5.普通for循環 6.列表迭代器 7.總結 1.增強for 注&…

DeepSeek在MATLAB上的部署與應用

在科技飛速發展的當下,人工智能與編程語言的融合不斷拓展著創新邊界。DeepSeek作為一款備受矚目的大語言模型,其在自然語言處理領域展現出強大的能力。而MATLAB,作為科學計算和工程領域廣泛應用的專業軟件,擁有豐富的工具包和高效…

value_counts()和unique()

我今天發現一個很有意思的問題哈 import scanpy as sc import numpy as npX np.random.randn(10,3) adata1 sc.AnnData(X) adata1.obs["sample"] "H1" print(adata1)X np.random.randn(20,3) adata2 sc.AnnData(X) adata2.obs["sample"] &…

每日OJ_牛客_游游的字母串_枚舉_C++_Java

目錄 牛客_游游的字母串_枚舉 題目解析 C代碼 Java代碼 牛客_游游的字母串_枚舉 游游的字母串 描述: 對于一個小寫字母而言,游游可以通過一次操作把這個字母變成相鄰的字母。a和b相鄰,b和c相鄰,以此類推。特殊的&#xff0…

【AI深度學習基礎】Pandas完全指南入門篇:數據處理的瑞士軍刀 (含完整代碼)

📚 Pandas 系列文章導航 入門篇 🌱進階篇 🚀終極篇 🌌 📌 一、引言 在大數據與 AI 驅動的時代,數據預處理和分析是深度學習與機器學習的基石。Pandas 作為 Python 生態中最強大的數據處理庫,以…

數字萬用表的使用教程

福祿克經濟型數字萬用表前面板按鍵功能介紹示意圖 1. 萬用表簡單介紹 萬用表是一種帶有整流器的、可以測量交、直流電流、電壓及電阻等多種電學參量的磁電式儀表。分為數字萬用表,鉗形萬用表, (1)表筆分為紅、黑二只。使用時黑色…

C# IComparable<T> 使用詳解

總目錄 前言 在C#編程中&#xff0c;IComparable<T> 是一個非常重要的接口&#xff0c;它允許我們為自定義類型提供默認的比較邏輯。這對于實現排序、搜索和其他需要基于特定規則進行比較的操作特別有用。本文將詳細介紹 IComparable<T> 的使用方法、應用場景及其…

DeepSeek使用手冊分享-附PDF下載連接

本次主要分享DeepSeek從技術原理到使用技巧內容&#xff0c;這里展示一些基本內容&#xff0c;后面附上詳細PDF下載鏈接。 DeepSeek基本介紹 DeepSeek公司和模型的基本簡介&#xff0c;以及DeepSeek高性能低成本獲得業界的高度認可的原因。 DeepSeek技術路線解析 DeepSeek V3…

Hugging Face 推出 FastRTC:實時語音視頻應用開發變得得心應手

估值超過 40 億美元的 AI 初創公司 Hugging Face 推出了 FastRTC&#xff0c;這是一個開源 Python 庫&#xff0c;旨在消除開發者在構建實時音頻和視頻 AI 應用時的主要障礙。 "在 Python 中正確構建實時 WebRTC 和 Websocket 應用一直都很困難&#xff0c;"FastRTC…

for循環相關(循環的過程中對數據進行刪除會踩坑)

# 錯誤方式&#xff0c; 有坑&#xff0c;結果不是你想要的。 user_list ["劉的話", "范德彪", "劉華強", 劉尼古拉斯趙四, "宋小寶", "劉能"] for item in user_list: if item.startswith("劉"): …

Qt顯示一個hello world

一、顯示思路 思路一&#xff1a;通過圖形化方式&#xff0c;界面上創建出一個控件顯示。 思路二&#xff1a;通過編寫C代碼在界面上創建控件顯示。 二、思路一實現 點開 Froms 的 widget.ui&#xff0c;拖拽 label 控件&#xff0c;顯示 hello world 即可。 qmake 基于 .…

復合機器人為 CNC 毛坯件上下料注入 “智能強心針”

在競爭日益激烈的 CNC 加工行業&#xff0c;如何提升生產效率、保證產品質量、實現智能化生產成為眾多企業亟待解決的問題。富唯智能憑借其先進的復合機器人技術&#xff0c;成功為多家 CNC 加工企業提供了毛坯件上下料的優質解決方案&#xff0c;有效提升了生產效能&#xff0…

電商業務數據測試用例參考

1. 數據采集層測試 用例編號測試目標測試場景預期結果TC-001驗證用戶行為日志采集完整性模擬用戶瀏覽、點擊、加購行為Kafka Topic中日志記錄數與模擬量一致TC-002驗證無效數據過濾規則發送爬蟲請求&#xff08;高頻IP&#xff09;清洗后數據中無該IP的日志記錄 2. 數據處理層…

Spring Cloud Gateway 網關的使用

在之前的學習中&#xff0c;所有的微服務接口都是對外開放的&#xff0c;這就意味著用戶可以直接訪問&#xff0c;為了保證對外服務的安全性&#xff0c;服務端實現的微服務接口都帶有一定的權限校驗機制&#xff0c;但是由于使用了微服務&#xff0c;就需要每一個服務都進行一…

webstorm的Live Edit插件配合chrome擴展程序JetBrains IDE Support實現實時預覽html效果

前言 我們平時在前端網頁修改好代碼要點擊刷新再去看修改的效果&#xff0c;這樣比較麻煩&#xff0c;那么很多軟件都提供了實時預覽的功能&#xff0c;我們一邊編輯代碼一邊可以看到效果。下面說的是webstorm。 1 Live Edit 首先我們需要在webstorm的settings里安裝插件Live …

map的operator[]的實現

map的operator[]的實現 operator[]里包含插入操作&#xff0c;所以我們先看一下首先看一下map的insert函數 返回值是一個pair類型。正常的常見的insert&#xff0c;插入成功返回true&#xff0c;失敗返回false 這里設計的insert不單單返回布爾值&#xff0c;而是返回一個pair…

定時器的編碼器接口模式

選擇編碼器接口模式的方法是&#xff1a;如果計數器只在TI2的邊沿計數&#xff0c;則置TIMx_SMCR寄存器中的SMS001&#xff0c;如果只在TI1邊沿計數&#xff0c;則置SMS010&#xff0c;如果計數器同時在TI1和TI2邊沿計數&#xff0c;則置SMS 011 明確一點&#xff0c;計數器…

Openshift配置默認調度

配置默認調度選擇角色為worker的機器運行pod。 編輯scheduler oc edit schedulers.config.openshift.iospec:defaultNodeSelector: node-role.kubernetes.io/worker ## 添加這一段如果pod需要運行在非worker主機&#xff0c;需要配置pod所在的項目添加注解 openshift.io/node…