PYG - Cora數據集加載 (自動加載+手動實現)

本文從Cora的例子來展示PYG如何加載圖數據集。
Cora 是一個小型的有標注的圖數據集,包含以下內容:

  • data.x:2708 個節點(即 2708 篇論文),每個節點有 1433 個特征,形狀為 (2708, 1433)。
  • data.edge_index:5429 條邊(即 5429 個引用關系),形狀為 (2, 5429)。
  • data.y:節點標簽,共 7 類,形狀為 (2708,)。(共有 7 個類別,表示論文的研究領域)
  • data.train_mask:訓練集掩碼,布爾向量,表示哪些節點用于訓練。
  • data.val_mask:驗證集掩碼,布爾向量,表示哪些節點用于驗證。
  • data.test_mask:測試集掩碼,布爾向量,表示哪些節點用于測試。

數據主要描述了論文之間的引用關系以及每篇論文的主題。可用于進行訓練節點分類問題(即判斷每篇論文屬于哪個類別)

1.自動加載

1.1 數據加載操作詳解

PYG庫提供了自動加載數據集的方法:

from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='data/Planetoid', name='Cora')
dataset[0]
print(len(dataset))  # 輸出: 1
print(data)

1
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

對于 Planetoid 類來說:

  • 它是一個專門為 Planetoid 系列數據集(Cora、CiteSeer、PubMed) 設計的類。
  • 這些數據集的主要特點是:它們實際上是單圖數據集,即整個數據集中只包含一個圖。

dataset 是一個包含 單個 Data 對象(圖) 的數據集對象。


由于 Planetoid 類的數據集中只有一個圖,因此:

  • dataset[0] 返回了這個唯一的圖,類型是 Data 對象,表示整個 Cora 數據集的圖。
  • Dataset 是一個可索引的對象,dataset[0] 的作用就是提取第一(也是唯一)個圖。

  • dataset = Planetoid(root='data/Planetoid', name='Cora') 加載了 Cora 數據集,它是一個 單圖數據集,包含一張圖的節點特征、邊索引、節點標簽和數據集劃分信息。
  • dataset[0] 提取了該圖的數據,返回了一個 Data 對象,表示整個圖。
  • dataset 本身是一個數據集管理器,幫助加載和存儲數據,同時提供一些元信息和操作方法。

1. 2 數據加載的過程

  1. 下載數據:

    • 如果指定路徑 'data/Planetoid' 下沒有數據集文件,Planetoid 類會從 指定的遠程服務器(由 PyG 維護)下載 Cora 數據集文件,并存儲在 'data/Planetoid/Cora' 文件夾下。
    • 數據集下載地址為:
      • Cora 數據集原始文件
  2. 解壓文件:

    • 下載的數據集是 .zip.tar 格式,會被自動解壓為一系列文件,主要包括:
      • ind.cora.x:訓練節點的特征矩陣;
      • ind.cora.tx:測試節點的特征矩陣;
      • ind.cora.allx:包含訓練節點和一些驗證節點的特征矩陣;
      • ind.cora.y:訓練節點的標簽;
      • ind.cora.ty:測試節點的標簽;
      • ind.cora.ally:訓練和驗證節點的標簽;
      • ind.cora.graph:節點的鄰接表(圖結構信息);
      • ind.cora.test.index:測試節點的索引。
        如圖所示:
        請添加圖片描述
  3. 解析數據:

    • PyG 將原始文件的內容解析為圖數據格式(Data 對象),將以下內容整合起來:
      • 節點特征矩陣 x
      • 圖的邊信息 edge_index
      • 節點標簽 y
      • 訓練、驗證和測試集的掩碼(train_maskval_masktest_mask)。
  4. 數據存儲:

    • 如果數據加載成功,解析后的數據將被緩存到指定路徑(data/Planetoid/Cora)中,后續運行時會直接加載解析后的緩存文件,而不會重復下載和解析。

2. 數據集原始文件的形式

原始文件(以 ind.cora.* 為前綴)是以下幾種內容的存儲形式:

文件名內容描述
ind.cora.x稀疏矩陣,訓練集中節點的特征矩陣,大小為 (num_train_nodes, num_features)
ind.cora.tx稀疏矩陣,測試集中節點的特征矩陣,大小為 (num_test_nodes, num_features)
ind.cora.allx稀疏矩陣,包含訓練集和部分驗證集中節點的特征矩陣,大小為 (num_allx_nodes, num_features)
ind.cora.y訓練集的標簽,大小為 (num_train_nodes, num_classes) 的獨熱編碼矩陣。
ind.cora.ty測試集的標簽,大小為 (num_test_nodes, num_classes) 的獨熱編碼矩陣。
ind.cora.ally訓練和驗證集的標簽,大小為 (num_allx_nodes, num_classes) 的獨熱編碼矩陣。
ind.cora.graph字典格式,存儲圖的鄰接表,鍵為節點 ID,值為該節點的鄰居節點列表。
ind.cora.test.index列表形式,包含測試節點的索引。

3. 加載后的數據形式

加載后,數據以 torch_geometric.data.Data 對象的形式存儲,主要包含以下內容:

屬性描述形狀
data.x節點的特征矩陣,每一行表示一個節點的特征向量。(num_nodes, num_features)
data.edge_index圖的邊信息,存儲為 COO 格式的索引矩陣(兩個一維數組,分別表示邊的起始節點和結束節點)。(2, num_edges)
data.y節點的標簽,每個節點對應一個整數,表示其所屬類別的索引值。(num_nodes,)
data.train_mask訓練節點的布爾掩碼,值為 True 的位置表示該節點屬于訓練集。(num_nodes,)
data.val_mask驗證節點的布爾掩碼,值為 True 的位置表示該節點屬于驗證集。(num_nodes,)
data.test_mask測試節點的布爾掩碼,值為 True 的位置表示該節點屬于測試集。(num_nodes,)

4. 加載后的具體內容

Cora 數據集為例,加載后的數據具有以下具體特性:

  • 節點數num_nodes = 2708(共 2708 篇論文)。
  • 特征數num_features = 1433(每篇論文的特征是一個 1433 維向量,表示詞袋模型中的單詞出現情況)。
  • 邊數num_edges = 10556(論文之間的引用關系,構成無向圖)。
  • 類別數num_classes = 7(每篇論文屬于 7 個主題之一)。
  • 掩碼分布
    • 訓練集:140 個節點;
    • 驗證集:500 個節點;
    • 測試集:1000 個節點。

手動讀取數據集

下面手動實現的 CoraData 類代碼,經過修改后與 PyTorch Geometric (PyG) 的 Planetoid 類功能一致,可以直接生成標準的 Data 對象,用于圖神經網絡訓練。


完整代碼:CoraData

import os
import os.path as osp
import pickle
import numpy as np
import torch
from torch_geometric.data import Data
import scipy.sparse as sp
import urllib.requestclass CoraData(object):download_url = "https://github.com/kimiyoung/planetoid/raw/master/data"filenames = ["ind.cora.{}".format(name) for name in['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']]def __init__(self, data_root="cora", rebuild=False):"""Cora 數據加載器,包括下載、處理和緩存功能。處理后的數據可以通過屬性 .data 獲取,返回 PyG 標準的 Data 對象。Args:data_root: str, 數據存儲的根目錄rebuild: bool, 是否強制重新構建數據"""self.data_root = data_rootsave_file = osp.join(self.data_root, "processed_cora.pkl")if osp.exists(save_file) and not rebuild:print("Using Cached file: {}".format(save_file))self._data = pickle.load(open(save_file, "rb"))else:self.maybe_download()self._data = self.process_data()with open(save_file, "wb") as f:pickle.dump(self.data, f)print("Cached file: {}".format(save_file))@propertydef data(self):"""返回 PyG 標準的 Data 對象"""return self._datadef maybe_download(self):save_path = osp.join(self.data_root, "raw")for name in self.filenames:if not osp.exists(osp.join(save_path, name)):self.download_data("{}/{}".format(self.download_url, name), save_path)def process_data(self):"""處理數據并生成 PyG 標準的 Data 對象,包括以下屬性:- x: 節點特征,(2708, 1433)- y: 節點標簽,共 7 類,(2708,)- edge_index: 圖邊索引,(2, num_edges)- train_mask: 訓練集掩碼,(2708,)- val_mask: 驗證集掩碼,(2708,)- test_mask: 測試集掩碼,(2708,)"""print("Processing data ...")# 讀取原始數據x, tx, allx, y, ty, ally, graph, test_index = [self.read_data(osp.join(self.data_root, "raw", name)) for name in self.filenames]train_index = np.arange(y.shape[0])  # 訓練集索引 [0, 1, ..., 139]val_index = np.arange(y.shape[0], y.shape[0] + 500)  # 驗證集索引 [140, ..., 639]sorted_test_index = sorted(test_index)  # 排序后的測試集索引# 特征和標簽拼接x = np.concatenate((allx, tx), axis=0)  # (2708, 1433)y = np.concatenate((ally, ty), axis=0).argmax(axis=1)  # (2708,)# 重新排序測試集數據x[test_index] = x[sorted_test_index]y[test_index] = y[sorted_test_index]# 創建訓練、驗證、測試掩碼num_nodes = x.shape[0]train_mask = np.zeros(num_nodes, dtype=np.bool_)val_mask = np.zeros(num_nodes, dtype=np.bool_)test_mask = np.zeros(num_nodes, dtype=np.bool_)train_mask[train_index] = Trueval_mask[val_index] = Truetest_mask[test_index] = True# 構造 edge_indexedge_index = self.build_edge_index(graph)# 轉換為 PyTorch 格式x = torch.tensor(x, dtype=torch.float32)y = torch.tensor(y, dtype=torch.long)edge_index = torch.tensor(edge_index, dtype=torch.long)train_mask = torch.tensor(train_mask, dtype=torch.bool)val_mask = torch.tensor(val_mask, dtype=torch.bool)test_mask = torch.tensor(test_mask, dtype=torch.bool)# 打印基本信息print("Node feature shape: ", x.shape)print("Node label shape: ", y.shape)print("Edge index shape: ", edge_index.shape)print("Number of training nodes: ", train_mask.sum().item())print("Number of validation nodes: ", val_mask.sum().item())print("Number of test nodes: ", test_mask.sum().item())# 返回 PyG 的 Data 對象return Data(x=x, y=y, edge_index=edge_index,train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)@staticmethoddef build_edge_index(graph):"""根據鄰接表生成 edge_index 格式 (2, num_edges)。"""edge_index = []for src, dst in graph.items():edge_index.extend([[src, v] for v in dst])  # 正向邊edge_index.extend([[v, src] for v in dst])  # 反向邊edge_index = np.array(edge_index).T  # 轉置為 (2, num_edges)return edge_index@staticmethoddef read_data(path):"""讀取數據文件,根據文件名選擇加載方式。"""name = osp.basename(path)if name == "ind.cora.test.index":out = np.genfromtxt(path, dtype="int64")return outelse:out = pickle.load(open(path, "rb"), encoding="latin1")out = out.toarray() if hasattr(out, "toarray") else outreturn out@staticmethoddef download_data(url, save_path):"""從指定 URL 下載數據,并保存到本地路徑。"""if not os.path.exists(save_path):os.makedirs(save_path)data = urllib.request.urlopen(url)filename = os.path.split(url)[-1]with open(os.path.join(save_path, filename), 'wb') as f:f.write(data.read())return True

代碼解析

  1. 下載和緩存功能

    • 如果處理后的數據已緩存 (processed_cora.pkl),直接加載緩存。
    • 如果未緩存,則從 GitHub 下載原始數據,處理后存儲為緩存文件。
  2. 數據處理:process_data

    • 加載原始數據,并將訓練、驗證、測試節點特征拼接成完整矩陣。
    • 生成 PyG 格式的 edge_index(用于圖神經網絡的鄰接表表示)。
    • 生成訓練、驗證和測試集掩碼。
  3. 鄰接表轉換為邊索引

    • build_edge_index 將鄰接表 (graph) 轉換為 edge_index 格式。
    • edge_index 是一個形狀為 (2, num_edges) 的數組,列表示一條邊的起點和終點。
  4. 返回 PyG 數據對象

    • 數據對象包括 xyedge_indextrain_maskval_masktest_mask

運行代碼測試

要測試 CoraData 類,可以直接運行以下代碼:

cora_data = CoraData(data_root="cora", rebuild=True)
data = cora_data.data  # 獲取 PyG 的 Data 對象
print(data)

輸出示例:

Processing data ...
Node feature shape:  torch.Size([2708, 1433])
Node label shape:  torch.Size([2708])
Edge index shape:  torch.Size([2, 10556])
Number of training nodes:  140
Number of validation nodes:  500
Number of test nodes:  1000
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

該類的功能與 PyTorch Geometric 的 Planetoid 類一致,支持加載 Cora 數據集,并生成標準的 PyG Data 對象,適用于圖神經網絡模型訓練。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/63754.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/63754.shtml
英文地址,請注明出處:http://en.pswp.cn/web/63754.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

《 火星人 》

題目描述 人類終于登上了火星的土地并且見到了神秘的火星人。人類和火星人都無法理解對方的語言,但是我們的科學家發明了一種用數字交流的方法。這種交流方法是這樣的,首先,火星人把一個非常大的數字告訴人類科學家,科學家破解這…

機器學習基礎算法 (二)-邏輯回歸

python 環境的配置參考 從零開始:Python 環境搭建與工具配置 邏輯回歸是一種用于解決二分類問題的機器學習算法,它可以預測輸入數據屬于某個類別的概率。本文將詳細介紹邏輯回歸的原理、Python 實現、模型評估和調優,并結合垃圾郵件分類案例進…

BiTCN-BiGRU基于雙向時間卷積網絡結合雙向門控循環單元的數據多特征分類預測(多輸入單輸出)

Matlab實現BiTCN-BiGRU基于雙向時間卷積網絡結合雙向門控循環單元的數據多特征分類預測(多輸入單輸出) 目錄 Matlab實現BiTCN-BiGRU基于雙向時間卷積網絡結合雙向門控循環單元的數據多特征分類預測(多輸入單輸出)分類效果基本描述…

云備份項目--工具類編寫

4. 文件工具類的設計 4.1 整體的類 該類實現對文件進行操作 FileUtil.hpp如下 /* 該類實現對文件進行操作 */ #pragma once #include <iostream> #include <string> #include <fstream> #include <vector> #include <sys/types.h> #include …

51c大模型~合集94

我自己的原文哦~ https://blog.51cto.com/whaosoft/12897659 #D(R,O) Grasp 重塑跨智能體靈巧手抓取&#xff0c;NUS邵林團隊提出全新交互式表征&#xff0c;斬獲CoRL Workshop最佳機器人論文獎 本文的作者均來自新加坡國立大學 LinS Lab。本文的共同第一作者為上海交通大…

【大學英語】英語范文十八篇,書信,議論文,材料分析

關注作者了解更多 我的其他CSDN專欄 過程控制系統 工程測試技術 虛擬儀器技術 可編程控制器 工業現場總線 數字圖像處理 智能控制 傳感器技術 嵌入式系統 復變函數與積分變換 單片機原理 線性代數 大學物理 熱工與工程流體力學 數字信號處理 光電融合集成電路…

一起學Git【第一節:Git的安裝】

Git是什么&#xff1f; Git是什么&#xff1f;相信大家點擊進來已經有了初步的認識&#xff0c;這里就簡單的進行介紹。 Git是一個開源的分布式版本控制系統&#xff0c;由Linus Torvalds創建&#xff0c;用于有效、高速地處理從小到大的項目版本管理。Git是目前世界上最流行…

消息隊列 Kafka 架構組件及其特性

Kafka 人們通常有時會將 Kafka 中的 Topic 比作隊列&#xff1b; 在 Kafka 中&#xff0c;數據是以主題&#xff08;Topic&#xff09;的形式組織的&#xff0c;每個 Topic 可以被分為多個分區&#xff08;Partition&#xff09;。每個 Partition 是一個有序的、不可變的消息…

《Mycat核心技術》第06章:Mycat問題處理總結

作者&#xff1a;冰河 星球&#xff1a;http://m6z.cn/6aeFbs 博客&#xff1a;https://binghe.gitcode.host 文章匯總&#xff1a;https://binghe.gitcode.host/md/all/all.html 星球項目地址&#xff1a;https://binghe.gitcode.host/md/zsxq/introduce.html 沉淀&#xff0c…

【day11】面向對象編程進階(繼承)

概述 本文深入探討面向對象編程的核心概念&#xff0c;包括繼承、方法重寫、this和super關鍵字的使用&#xff0c;以及抽象類和方法的定義與實現。通過本文的學習&#xff0c;你將能夠&#xff1a; 理解繼承的優勢。掌握繼承的使用方法。了解繼承后成員變量和成員方法的訪問特…

隨手記:小程序兼容后臺的wangEditor富文本配置鏈接

場景&#xff1a; 在后臺配置wangEditor富文本&#xff0c;可以文字配置鏈接&#xff0c;圖片配置鏈接&#xff0c;產生的json格式為&#xff1a; 例子&#xff1a; <h1><a href"https://uniapp.dcloud.net.cn/" target"_blank"><span sty…

6.8 Newman自動化運行Postman測試集

歡迎大家訂閱【軟件測試】 專欄&#xff0c;開啟你的軟件測試學習之旅&#xff01; 文章目錄 1 安裝Node.js2 安裝Newman3 使用Newman運行Postman測試集3.1 導出Postman集合3.2 使用Newman運行集合3.3 Newman常用參數3.4 Newman報告格式 4 使用定時任務自動化執行腳本4.1 編寫B…

工具環境 | 工具準備

搭建一套驗證環境需要的工具如下&#xff1a; 虛擬機&#xff1a;推薦virtualbox ubuntu VM好用&#xff0c;但是免費的好像木有了&#xff0c;ubuntu界面版更加容易上手。 網上找安裝教程即可&#xff0c;注意實現文件共享、復制粘貼等功能。 EDA&#xff1a;VCS Veridi 工…

計算機網絡之王道考研讀書筆記-2

第 2 章 物理層 2.1 通信基礎 2.1.1 基本概念 1.數據、信號與碼元 通信的目的是傳輸信息。數據是指傳送信息的實體。信號則是數據的電氣或電磁表現&#xff0c;是數據在傳輸過程中的存在形式。碼元是數字通信中數字信號的計量單位&#xff0c;這個時長內的信號稱為 k 進制碼…

ROS2學習配套C++知識

第3章 訂閱和發布——話題通信探索 3.3.1 發布速度控制海龜畫圓 std::bind cstd::bind綁定成員函數時&#xff0c;需要加上作用域以及取址符號 因為不會將成員函數隱式的轉換成指針&#xff0c;因此需要加&符號&#xff1b; 后面的第一個參數必須跟具體對象&#xff0c;c…

法規標準-C-NCAP評測標準解析(2024版)

文章目錄 什么是C-NCAP&#xff1f;C-NCAP 評測標準C-NCAP評測維度三大維度的評測場景及對應分數評星標準 自動駕駛相關評測場景評測方法及評測標準AEB VRU——評測內容(測什么&#xff1f;)AEB VRU——評測方法(怎么測&#xff1f;)車輛直行與前方縱向行走的行人測試場景&…

第十七屆山東省職業院校技能大賽 中職組“網絡安全”賽項任務書正式賽題

第十七屆山東省職業院校技能大賽 中職組“網絡安全”賽項任務書-A 目錄 一、競賽階段 二、競賽任務書內容 &#xff08;一&#xff09;拓撲圖 &#xff08;二&#xff09;模塊A 基礎設施設置與安全加固(200分) &#xff08;三&#xff09;B模塊安全事件響應/網絡安全數據取證/…

mlr3機器學習AUC的置信區間提取

如果你在mlr3拿到機器學習的預測數據 ROC 過程原理探索 假設數據 df <- data.frame(Airis$Sepal.Length, groupsample(x c(0,1),size 150,replace T)) 分組為 0,1 # 變量A為連續性變量 library(pROC) roc_obj <- roc(df g r o u p , d f group, df group,dfA, le…

Halcon例程代碼解讀:安全環檢測(附源碼|圖像下載鏈接)

安全環檢測核心思路與代碼詳解 項目目標 本項目的目標是檢測圖像中的安全環位置和方向。通過形狀匹配技術&#xff0c;從一張模型圖像中提取安全環的特征&#xff0c;并在后續圖像中識別多個實例&#xff0c;完成檢測和方向標定。 實現思路 安全環檢測分為以下核心步驟&…

Java——多線程進階知識

目錄 一、常見的鎖策略 樂觀鎖VS悲觀鎖 讀寫鎖 重量級鎖VS輕量級鎖 總結&#xff1a; 自旋鎖&#xff08;Spin Lock&#xff09; 公平鎖VS非公平鎖 可重入鎖VS不可重入鎖 二、CAS 何為CAS CAS有哪些應用 1&#xff09;實現原子類 2&#xff09;實現自旋鎖 CAS的ABA…