Pytorch Geometric官方例程pytorch_geometric/examples/link_pred.py環境安裝教程及圖數據集制作

最近需要訓練圖卷積神經網絡(Graph Convolution Neural Network, GCNN),在配置GCNN環境上總結了一些經驗。

我覺得對于初學者而言,圖神經網絡的訓練會有2個難點:

①環境配置

②數據集制作

一、環境配置

我最初光想到要給GCNN配環境就覺得有些困難,感覺相比于目標檢測、分類識別這些任務用規則數據,圖神經網絡的模型、數據都是圖,所以內心覺得會比較難。

我之前更有一個誤區,就是覺得不規則結構的圖數據不能用CUDA進行并行加速。實際上,圖,在電腦里也是以張量這種規則結構數據存在的,完全能用CUDA進行加速計算,訓練GCN前配置CUDA完全OK。


以下是我配置的環境,可用CUDA成功運行link_pred.py

幾個關鍵包的版本:

torch? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 2.4.1
torch-geometric ? ? ? ? ? ? ? 2.3.1
torchaudio? ? ? ? ? ? ? ? ? ? ? ?2.4.1
torchvision? ? ? ? ? ? ? ? ? ? ? ?0.14.0
torchviz? ? ? ? ? ? ? ? ? ? ? ? ? ? 0.0.2

pandas? ? ? ? ? ? ? ? ? ? ? ? ? ? ?1.0.3

numpy? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 1.20.0

?CUDA: 11.8

注意要先安裝好CUDA,顯示了:

?

再安裝GPU版本的torch,不然python檢測安裝的是cpu版本的torch。這時,就得卸載重新安裝了

環境配置成功:

print(torch.__version__)
print(torch.cuda.is_available())

如果CUDA環境安裝失敗,會打印:

2.4.1+cpu
False

其實只安裝torch和CUDA還好,如果你的python中有numpy和pandas可能解決版本之間的沖突會耗費不少時間,我就是在numpy和pandas版本上試了很久,最終找到現在的版本是相互兼容的。

CUDA的版本切換可以參考我的另一篇博客:

CUDA版本切換

二、數據集制作

掌握圖數據集制作的關鍵在于掌握slices切片:

for ...data = Data(x=X, edge_index=Edge_index, edge_label_index=Edge_label_index,             edge_label=Edge_label)data_list.append(data)
data_, slices = self.collate(data_list)  # 將不同大小的圖數據對齊,填充
torch.save((data_, slices), self.processed_paths[0])

和CNN不同的是,GCN沒有樣本維度,需要把所有樣本拼成一張大圖喂給GCN進行訓練?

數據集生成代碼:

#作者:zhouzhichao
#創建時間:2025/5/30
#內容:生成200個樣本的PYG數據集import h5py
import hdf5storage
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import negative_samplingbase_dir = "D:\\無線通信網絡認知\\論文1\\experiment\\直推式拓撲推理實驗\\拓撲生成\\200樣本\\"N = 30
grapg_size = N
train_n = 31
M = 3000class graph_data(InMemoryDataset):def __init__(self, root, signals=None, tp_list = None, transform=None, pre_transform=None):# self.Signals = Signals# self.Tp_list = Tp_listself.signals = signalsself.tp_list = tp_listsuper().__init__(root, transform, pre_transform)# self.data, self.slices = torch.load(self.processed_paths[0])self.data = torch.load(self.processed_paths[0])# 返回process方法所需的保存文件名。你之后保存的數據集名字和列表里的一致@propertydef processed_file_names(self):return ['gcn_data.pt']# 生成數據集所用的方法def process(self):# data_list = []# for k in range(200):# signals = self.Signals[:, :, k]# tp_list = np.array(mat_file[self.Tp_list[0, k]])signals = self.signalstp_list =self.tp_list# tp = Tp[:,:,k]X = torch.tensor(signals, dtype=torch.float)# 所有的邊Edge_index = torch.tensor(tp_list, dtype=torch.long)# 所有的邊1標簽edge_label = np.ones((tp_list.shape[1]))# edge_label = np.zeros((tp_list.shape[1]))Edge_label = torch.tensor(edge_label, dtype=torch.float)neg_edge_index = negative_sampling(edge_index=Edge_index, num_nodes=grapg_size,num_neg_samples=Edge_index.shape[1], method='sparse')# 拼接正負樣本索引# c = 0# for i in range(31):#     for i in range(31):#         if torch.equal(Edge_index[:, i], neg_edge_index[:, i]):#             c = c + 1# print("c: ",c)Edge_label_index = Edge_indexperm = torch.randperm(Edge_index.size(1))Edge_index = Edge_index[:, perm]Edge_index = Edge_index[:, :train_n]Edge_label_index = torch.cat([Edge_label_index, neg_edge_index],dim=-1,)# 拼接正負樣本Edge_label = torch.cat([Edge_label,Edge_label.new_zeros(neg_edge_index.size(1))], dim=0)# Edge_label = torch.cat([#     Edge_label,#     Edge_label.new_ones(neg_edge_index.size(1))# ], dim=0)data = Data(x=X, edge_index=Edge_index, edge_label_index=Edge_label_index, edge_label=Edge_label)torch.save(data, self.processed_paths[0])# data_list.append(data)# data_, slices = self.collate(data_list)  # 將不同大小的圖數據對齊,填充# torch.save((data_, slices), self.processed_paths[0])for snr in [0,20,40]:print("snr: ", snr)mat_file = h5py.File(base_dir + str(N) + '_nodes_dataset_snr-' + str(snr) + '_M_' + str(M) + '.mat', 'r')# mat_file = hdf5storage.loadmat(base_dir + str(N) + '_nodes_dataset_snr-' + str(snr) + '_M_' + str(M) + '.mat', 'r')# 獲取數據集Signals = mat_file["Signals"][()]# signals = np.swapaxes(signals, 1, 0)Tp = mat_file["Tp"][()]Tp_list = mat_file["Tp_list"][()]# tp_list = tp_list - 1# 關閉文件# mat_file.close()# graph_data("gcn_data")# n = Signals.shape[2]n = 10for i in range(n):signals = Signals[:,:,i]tp_list = np.array(mat_file[Tp_list[0, i]])root = "gcn_data-"+str(i)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)graph_data(root, signals = signals, tp_list = tp_list)print("")print("...圖數據生成完成...")

訓練代碼:

#作者:zhouzhichao
#創建時間:25年5月29日
#內容:統計圖中有關系節點和無關系節點的GCN特征歐式距離import sys
import torch
import random
import numpy as np
import pandas as pd
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score
sys.path.append('D:\無線通信網絡認知\論文1\experiment\直推式拓撲推理實驗\GCN推理')
from gcn_dataset import graph_data
print(torch.__version__)
print(torch.cuda.is_available())mode = "gcn"class Net(torch.nn.Module):def __init__(self):super().__init__()self.conv1 = GCNConv(Input_L, 1000)self.conv2 = GCNConv(1000, 20)def encode(self, x, edge_index):x1 = self.conv1(x, edge_index)x1_1 = x1.relu()x2 = self.conv2(x1_1, edge_index)x2_2 = x2.relu()return x2_2def decode(self, z, edge_label_index):# 節點和邊都是矩陣,不同的計算方法致使:節點->節點,節點->邊# nodes_relation = (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)# distances  = torch.norm(z[edge_label_index[0]] - z[edge_label_index[1]], dim=-1)distance_squared = torch.sum((z[edge_label_index[0]] - z[edge_label_index[1]]) ** 2, dim=-1)# print("distance_squared: ",distance_squared)return distance_squareddef decode_all(self, z):prob_adj = z @ z.t()  # 得到所有邊概率矩陣return (prob_adj > 0).nonzero(as_tuple=False).t()  # 返回概率大于0的邊,以edge_index的形式@torch.no_grad()def test(self,input_data):model.eval()z = model.encode(input_data.x, input_data.edge_index)out = model.decode(z, input_data.edge_label_index).view(-1)out = 1 - outN = 30
train_n = 31
M = 3000
# snr = -20
# for train_n in range(1,51):
# for M in range(3000, 499, -100):
for snr in [0,20,40]:print("snr: ", snr)for I in range(10):root = "gcn_data-"+str(I)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)gcn_data = graph_data(root)Input_L = gcn_data.x.shape[1]model = Net()# model = Net().to(device)optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()z = model.encode(gcn_data.x, gcn_data.edge_index)# out = model.decode(z, train_data.edge_label_index).view(-1).sigmoid()out = model.decode(z, gcn_data.edge_label_index).view(-1)out = 1 - outloss = criterion(out, gcn_data.edge_label)loss.backward()optimizer.step()return lossmin_loss = 99999count = 0#早停for epoch in range(10000):loss = train()if loss<min_loss:min_loss = losscount = 0count = count + 1if count>100:breakprint("epoch:  ",epoch,"   loss: ",round(loss.item(),2), "   min_loss: ",round(min_loss.item(),2))z = model.encode(gcn_data.x, gcn_data.edge_index)out = model.decode(z, gcn_data.edge_label_index).view(-1)list_0 = []list_1 = []for i in range(len(gcn_data.edge_label)):true_label = gcn_data.edge_label[i].item()euclidean_distance_value = out[i].item()if true_label==1:list_1.append(euclidean_distance_value)if true_label==0:list_0.append(euclidean_distance_value)minlength = min(len(list_1), len(list_0))list_1 = random.sample(list_1, minlength)list_0 = random.sample(list_0, minlength)value = list_1 + list_0large_class = list(np.full(len(value), snr))small_class = list(np.full(len(list_1), 1)) + list(np.full(len(list_0), 0))data = {'large_class': large_class,'small_class': small_class,'value': value}# 創建一個 DataFramedf = pd.DataFrame(data)## # 保存到 Excel 文件file_path = 'D:\無線通信網絡認知\論文1\大修意見\圖聚類、閾值相似性圖實驗補充\\' + mode + '_similarity_' + str(snr) + 'db_'+str(I)+'.xlsx'df.to_excel(file_path, index=False)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/908052.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/908052.shtml
英文地址,請注明出處:http://en.pswp.cn/news/908052.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

2025年微信小程序開發:AR/VR與電商的最新案例

引言 微信小程序自2017年推出以來&#xff0c;已成為中國移動互聯網生態的核心組成部分。根據最新數據&#xff0c;截至2025年&#xff0c;微信小程序的日活躍用戶超過4.5億&#xff0c;總數超過430萬&#xff0c;覆蓋電商、社交、線下服務等多個領域&#xff08;WeChat Mini …

互聯網向左,區塊鏈向右

2008年&#xff0c;中本聰首次提出了比特幣的設想&#xff0c;這打開了去中心化的大門。 比特幣白皮書清晰的描述了去中心化支付的解決方案&#xff0c;并分別從以下幾個方面闡述了他的理念&#xff1a; 一、由轉賬雙方點對點的通訊&#xff0c;而不通過中心化的第三方&#xf…

PV操作的C++代碼示例講解

文章目錄 一、PV操作基本概念&#xff08;一&#xff09;信號量&#xff08;二&#xff09;P操作&#xff08;三&#xff09;V操作 二、PV操作的意義三、C中實現PV操作的方法&#xff08;一&#xff09;使用信號量實現PV操作代碼解釋&#xff1a; &#xff08;二&#xff09;使…

《對象創建的秘密:Java 內存布局、逃逸分析與 TLAB 優化詳解》

大家好呀&#xff01;今天我們來聊聊Java世界里那些"看不見摸不著"但又超級重要的東西——對象在內存里是怎么"住"的&#xff0c;以及JVM這個"超級管家"是怎么幫我們優化管理的。放心&#xff0c;我會用最接地氣的方式講解&#xff0c;保證連小學…

簡單實現Ajax基礎應用

Ajax不是一種技術&#xff0c;而是一個編程概念。HTML 和 CSS 可以組合使用來標記和設置信息樣式。JavaScript 可以修改網頁以動態顯示&#xff0c;并允許用戶與新信息進行交互。內置的 XMLHttpRequest 對象用于在網頁上執行 Ajax&#xff0c;允許網站將內容加載到屏幕上而無需…

詳解開漏輸出和推挽輸出

開漏輸出和推挽輸出 以上是 GPIO 配置為輸出時的內部示意圖&#xff0c;我們要關注的其實就是這兩個 MOS 管的開關狀態&#xff0c;可以組合出四種狀態&#xff1a; 兩個 MOS 管都關閉時&#xff0c;輸出處于一個浮空狀態&#xff0c;此時他對其他點的電阻是無窮大的&#xff…

Matlab實現LSTM-SVM回歸預測,作者:機器學習之心

Matlab實現LSTM-SVM回歸預測&#xff0c;作者&#xff1a;機器學習之心 目錄 Matlab實現LSTM-SVM回歸預測&#xff0c;作者&#xff1a;機器學習之心效果一覽基本介紹程序設計參考資料 效果一覽 基本介紹 代碼主要功能 該代碼實現了一個LSTM-SVM回歸預測模型&#xff0c;核心流…

Leetcode - 周賽 452

目錄 一&#xff0c;3566. 等積子集的劃分方案二&#xff0c;3567. 子矩陣的最小絕對差三&#xff0c;3568. 清理教室的最少移動四&#xff0c;3569. 分割數組后不同質數的最大數目 一&#xff0c;3566. 等積子集的劃分方案 題目列表 本題有兩種做法&#xff0c;dfs 選或不選…

【FAQ】HarmonyOS SDK 閉源開放能力 —Account Kit(5)

1.問題描述&#xff1a; 集成華為一鍵登錄的LoginWithHuaweiIDButton&#xff0c; 但是Button默認名字叫 “華為賬號一鍵登錄”&#xff0c;太長無法顯示&#xff0c;能否簡寫成“一鍵登錄”與其他端一致&#xff1f; 解決方案&#xff1a; 問題分兩個場景&#xff1a; 一、…

Asp.Net Core SignalR的分布式部署

文章目錄 前言一、核心二、解決方案架構三、實現方案1.使用 Azure SignalR Service2.Redis Backplane(Redis 背板方案&#xff09;3.負載均衡配置粘性會話要求無粘性會話方案&#xff08;僅WebSockets&#xff09;完整部署示例&#xff08;Redis Docker&#xff09;性能優化技…

L2-054 三點共線 - java

L2-054 三點共線 語言時間限制內存限制代碼長度限制棧限制Java (javac)2600 ms512 MB16KB8192 KBPython (python3)2000 ms256 MB16KB8192 KB其他編譯器2000 ms64 MB16KB8192 KB 題目描述&#xff1a; 給定平面上 n n n 個點的坐標 ( x _ i , y _ i ) ( i 1 , ? , n ) (x\_i…

【 java 基礎知識 第一篇 】

目錄 1.概念 1.1.java的特定有哪些&#xff1f; 1.2.java有哪些優勢哪些劣勢&#xff1f; 1.3.java為什么可以跨平臺&#xff1f; 1.4JVM,JDK,JRE它們有什么區別&#xff1f; 1.5.編譯型語言與解釋型語言的區別&#xff1f; 2.數據類型 2.1.long與int類型可以互轉嗎&…

高效背誦英語四級范文

以下是結合認知科學和實戰驗證的 ??高效背誦英語作文五步法??&#xff0c;助你在30分鐘內牢固記憶一篇作文&#xff0c;特別適配考前沖刺場景&#xff1a; &#x1f4dd; ??一、解構作文&#xff08;5分鐘&#xff09;?? ??拆解邏輯框架?? 用熒光筆標出&#xff…

RHEL7安裝教程

RHEL7安裝教程 下載RHEL7鏡像 通過網盤分享的文件&#xff1a;RHEL 7.zip 鏈接: https://pan.baidu.com/s/1ExLhdJigj-tcrHJxIca5XA?pwdjrrj 提取碼: jrrj --來自百度網盤超級會員v6的分享安裝 1.打開VMware&#xff0c;新建虛擬機&#xff0c;選擇自定義然后下一步 2.點擊…

結構型設計模式之Decorator(裝飾器)

結構型設計模式之Decorator&#xff08;裝飾器&#xff09; 前言&#xff1a; 本案例通過李四舉例&#xff0c;不改變源代碼的情況下 對“才藝”進行增強。 摘要&#xff1a; 摘要&#xff1a; 裝飾器模式是一種結構型設計模式&#xff0c;允許動態地為對象添加功能而不改變其…

Kotlin委托機制使用方式和原理

目錄 類委托屬性委托簡單的實現屬性委托Kotlin標準庫中提供的幾個委托延遲屬性LazyLazy委托參數可觀察屬性Observable委托vetoable委托屬性儲存在Map中 實踐方式雙擊back退出Fragment/Activity傳參ViewBinding和委托 類委托 類委托有點類似于Java中的代理模式 interface Base…

SpringBoot接入Kimi實踐記錄輕松上手

kimi簡單使用 什么是Kimi API 官網&#xff1a;https://platform.moonshot.cn/ Kimi API 并不是一個我所熟知的廣泛通用的術語。我的推測是&#xff0c;你可能想問的是關于 API 的一些基礎知識。API&#xff08;Application Programming Interface&#xff0c;應用程序編程接…

書籍在其他數都出現k次的數組中找到只出現一次的數(7)0603

題目 給定一個整型數組arr和一個大于1的整數k。已知arr中只有1個數出現了1次&#xff0c;其他的數都出現了k次&#xff0c;請返回只出現了1次的數。 解答&#xff1a; 對此題進行思路轉換&#xff0c;可以將此題&#xff0c;轉換成k進制數。 k進制的兩個數c和d&#xff0c;…

React 項目初始化與搭建指南

React 項目初始化有多種方式&#xff0c;可以選擇已有的腳手架工具快速創建項目&#xff0c;也可以自定義項目結構并使用構建工具實現項目的構建打包流程。 1. 腳手架方案 1.1. Vite 通過 Vite 創建 React 項目非常簡單&#xff0c;只需一行命令即可完成。Vite 的工程初始化…

大模型模型推理的成本過高,如何進行量化或蒸餾優化

在人工智能的浪潮中,大模型已經成為推動技術革新的核心引擎。從自然語言處理到圖像生成,再到復雜的多模態任務,像GPT、BERT、T5這樣的龐大模型展現出了驚人的能力。它們在翻譯、對話系統、內容生成等領域大放異彩,甚至在醫療、金融等行業中也開始扮演重要角色。可以說,這些…