本文從Cora的例子來展示PYG如何加載圖數據集。
Cora 是一個小型的有標注的圖數據集,包含以下內容:
- data.x:2708 個節點(即 2708 篇論文),每個節點有 1433 個特征,形狀為 (2708, 1433)。
- data.edge_index:5429 條邊(即 5429 個引用關系),形狀為 (2, 5429)。
- data.y:節點標簽,共 7 類,形狀為 (2708,)。(共有 7 個類別,表示論文的研究領域)
- data.train_mask:訓練集掩碼,布爾向量,表示哪些節點用于訓練。
- data.val_mask:驗證集掩碼,布爾向量,表示哪些節點用于驗證。
- data.test_mask:測試集掩碼,布爾向量,表示哪些節點用于測試。
數據主要描述了論文之間的引用關系以及每篇論文的主題。可用于進行訓練節點分類問題(即判斷每篇論文屬于哪個類別)
1.自動加載
1.1 數據加載操作詳解
PYG庫提供了自動加載數據集的方法:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='data/Planetoid', name='Cora')
dataset[0]
print(len(dataset)) # 輸出: 1
print(data)
1
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
對于 Planetoid
類來說:
- 它是一個專門為 Planetoid 系列數據集(Cora、CiteSeer、PubMed) 設計的類。
- 這些數據集的主要特點是:它們實際上是單圖數據集,即整個數據集中只包含一個圖。
dataset
是一個包含 單個 Data
對象(圖) 的數據集對象。
由于 Planetoid
類的數據集中只有一個圖,因此:
dataset[0]
返回了這個唯一的圖,類型是Data
對象,表示整個 Cora 數據集的圖。Dataset
是一個可索引的對象,dataset[0]
的作用就是提取第一(也是唯一)個圖。
dataset = Planetoid(root='data/Planetoid', name='Cora')
加載了 Cora 數據集,它是一個 單圖數據集,包含一張圖的節點特征、邊索引、節點標簽和數據集劃分信息。dataset[0]
提取了該圖的數據,返回了一個Data
對象,表示整個圖。dataset
本身是一個數據集管理器,幫助加載和存儲數據,同時提供一些元信息和操作方法。
1. 2 數據加載的過程
-
下載數據:
- 如果指定路徑
'data/Planetoid'
下沒有數據集文件,Planetoid
類會從 指定的遠程服務器(由 PyG 維護)下載 Cora 數據集文件,并存儲在'data/Planetoid/Cora'
文件夾下。 - 數據集下載地址為:
- Cora 數據集原始文件
- 如果指定路徑
-
解壓文件:
- 下載的數據集是
.zip
或.tar
格式,會被自動解壓為一系列文件,主要包括:ind.cora.x
:訓練節點的特征矩陣;ind.cora.tx
:測試節點的特征矩陣;ind.cora.allx
:包含訓練節點和一些驗證節點的特征矩陣;ind.cora.y
:訓練節點的標簽;ind.cora.ty
:測試節點的標簽;ind.cora.ally
:訓練和驗證節點的標簽;ind.cora.graph
:節點的鄰接表(圖結構信息);ind.cora.test.index
:測試節點的索引。
如圖所示:
- 下載的數據集是
-
解析數據:
- PyG 將原始文件的內容解析為圖數據格式(
Data
對象),將以下內容整合起來:- 節點特征矩陣
x
; - 圖的邊信息
edge_index
; - 節點標簽
y
; - 訓練、驗證和測試集的掩碼(
train_mask
、val_mask
、test_mask
)。
- 節點特征矩陣
- PyG 將原始文件的內容解析為圖數據格式(
-
數據存儲:
- 如果數據加載成功,解析后的數據將被緩存到指定路徑(
data/Planetoid/Cora
)中,后續運行時會直接加載解析后的緩存文件,而不會重復下載和解析。
- 如果數據加載成功,解析后的數據將被緩存到指定路徑(
2. 數據集原始文件的形式
原始文件(以 ind.cora.*
為前綴)是以下幾種內容的存儲形式:
文件名 | 內容描述 |
---|---|
ind.cora.x | 稀疏矩陣,訓練集中節點的特征矩陣,大小為 (num_train_nodes, num_features) 。 |
ind.cora.tx | 稀疏矩陣,測試集中節點的特征矩陣,大小為 (num_test_nodes, num_features) 。 |
ind.cora.allx | 稀疏矩陣,包含訓練集和部分驗證集中節點的特征矩陣,大小為 (num_allx_nodes, num_features) 。 |
ind.cora.y | 訓練集的標簽,大小為 (num_train_nodes, num_classes) 的獨熱編碼矩陣。 |
ind.cora.ty | 測試集的標簽,大小為 (num_test_nodes, num_classes) 的獨熱編碼矩陣。 |
ind.cora.ally | 訓練和驗證集的標簽,大小為 (num_allx_nodes, num_classes) 的獨熱編碼矩陣。 |
ind.cora.graph | 字典格式,存儲圖的鄰接表,鍵為節點 ID,值為該節點的鄰居節點列表。 |
ind.cora.test.index | 列表形式,包含測試節點的索引。 |
3. 加載后的數據形式
加載后,數據以 torch_geometric.data.Data
對象的形式存儲,主要包含以下內容:
屬性 | 描述 | 形狀 |
---|---|---|
data.x | 節點的特征矩陣,每一行表示一個節點的特征向量。 | (num_nodes, num_features) |
data.edge_index | 圖的邊信息,存儲為 COO 格式的索引矩陣(兩個一維數組,分別表示邊的起始節點和結束節點)。 | (2, num_edges) |
data.y | 節點的標簽,每個節點對應一個整數,表示其所屬類別的索引值。 | (num_nodes,) |
data.train_mask | 訓練節點的布爾掩碼,值為 True 的位置表示該節點屬于訓練集。 | (num_nodes,) |
data.val_mask | 驗證節點的布爾掩碼,值為 True 的位置表示該節點屬于驗證集。 | (num_nodes,) |
data.test_mask | 測試節點的布爾掩碼,值為 True 的位置表示該節點屬于測試集。 | (num_nodes,) |
4. 加載后的具體內容
以 Cora 數據集為例,加載后的數據具有以下具體特性:
- 節點數:
num_nodes = 2708
(共 2708 篇論文)。 - 特征數:
num_features = 1433
(每篇論文的特征是一個 1433 維向量,表示詞袋模型中的單詞出現情況)。 - 邊數:
num_edges = 10556
(論文之間的引用關系,構成無向圖)。 - 類別數:
num_classes = 7
(每篇論文屬于 7 個主題之一)。 - 掩碼分布:
- 訓練集:140 個節點;
- 驗證集:500 個節點;
- 測試集:1000 個節點。
手動讀取數據集
下面手動實現的 CoraData
類代碼,經過修改后與 PyTorch Geometric (PyG
) 的 Planetoid
類功能一致,可以直接生成標準的 Data
對象,用于圖神經網絡訓練。
完整代碼:CoraData
import os
import os.path as osp
import pickle
import numpy as np
import torch
from torch_geometric.data import Data
import scipy.sparse as sp
import urllib.requestclass CoraData(object):download_url = "https://github.com/kimiyoung/planetoid/raw/master/data"filenames = ["ind.cora.{}".format(name) for name in['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']]def __init__(self, data_root="cora", rebuild=False):"""Cora 數據加載器,包括下載、處理和緩存功能。處理后的數據可以通過屬性 .data 獲取,返回 PyG 標準的 Data 對象。Args:data_root: str, 數據存儲的根目錄rebuild: bool, 是否強制重新構建數據"""self.data_root = data_rootsave_file = osp.join(self.data_root, "processed_cora.pkl")if osp.exists(save_file) and not rebuild:print("Using Cached file: {}".format(save_file))self._data = pickle.load(open(save_file, "rb"))else:self.maybe_download()self._data = self.process_data()with open(save_file, "wb") as f:pickle.dump(self.data, f)print("Cached file: {}".format(save_file))@propertydef data(self):"""返回 PyG 標準的 Data 對象"""return self._datadef maybe_download(self):save_path = osp.join(self.data_root, "raw")for name in self.filenames:if not osp.exists(osp.join(save_path, name)):self.download_data("{}/{}".format(self.download_url, name), save_path)def process_data(self):"""處理數據并生成 PyG 標準的 Data 對象,包括以下屬性:- x: 節點特征,(2708, 1433)- y: 節點標簽,共 7 類,(2708,)- edge_index: 圖邊索引,(2, num_edges)- train_mask: 訓練集掩碼,(2708,)- val_mask: 驗證集掩碼,(2708,)- test_mask: 測試集掩碼,(2708,)"""print("Processing data ...")# 讀取原始數據x, tx, allx, y, ty, ally, graph, test_index = [self.read_data(osp.join(self.data_root, "raw", name)) for name in self.filenames]train_index = np.arange(y.shape[0]) # 訓練集索引 [0, 1, ..., 139]val_index = np.arange(y.shape[0], y.shape[0] + 500) # 驗證集索引 [140, ..., 639]sorted_test_index = sorted(test_index) # 排序后的測試集索引# 特征和標簽拼接x = np.concatenate((allx, tx), axis=0) # (2708, 1433)y = np.concatenate((ally, ty), axis=0).argmax(axis=1) # (2708,)# 重新排序測試集數據x[test_index] = x[sorted_test_index]y[test_index] = y[sorted_test_index]# 創建訓練、驗證、測試掩碼num_nodes = x.shape[0]train_mask = np.zeros(num_nodes, dtype=np.bool_)val_mask = np.zeros(num_nodes, dtype=np.bool_)test_mask = np.zeros(num_nodes, dtype=np.bool_)train_mask[train_index] = Trueval_mask[val_index] = Truetest_mask[test_index] = True# 構造 edge_indexedge_index = self.build_edge_index(graph)# 轉換為 PyTorch 格式x = torch.tensor(x, dtype=torch.float32)y = torch.tensor(y, dtype=torch.long)edge_index = torch.tensor(edge_index, dtype=torch.long)train_mask = torch.tensor(train_mask, dtype=torch.bool)val_mask = torch.tensor(val_mask, dtype=torch.bool)test_mask = torch.tensor(test_mask, dtype=torch.bool)# 打印基本信息print("Node feature shape: ", x.shape)print("Node label shape: ", y.shape)print("Edge index shape: ", edge_index.shape)print("Number of training nodes: ", train_mask.sum().item())print("Number of validation nodes: ", val_mask.sum().item())print("Number of test nodes: ", test_mask.sum().item())# 返回 PyG 的 Data 對象return Data(x=x, y=y, edge_index=edge_index,train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)@staticmethoddef build_edge_index(graph):"""根據鄰接表生成 edge_index 格式 (2, num_edges)。"""edge_index = []for src, dst in graph.items():edge_index.extend([[src, v] for v in dst]) # 正向邊edge_index.extend([[v, src] for v in dst]) # 反向邊edge_index = np.array(edge_index).T # 轉置為 (2, num_edges)return edge_index@staticmethoddef read_data(path):"""讀取數據文件,根據文件名選擇加載方式。"""name = osp.basename(path)if name == "ind.cora.test.index":out = np.genfromtxt(path, dtype="int64")return outelse:out = pickle.load(open(path, "rb"), encoding="latin1")out = out.toarray() if hasattr(out, "toarray") else outreturn out@staticmethoddef download_data(url, save_path):"""從指定 URL 下載數據,并保存到本地路徑。"""if not os.path.exists(save_path):os.makedirs(save_path)data = urllib.request.urlopen(url)filename = os.path.split(url)[-1]with open(os.path.join(save_path, filename), 'wb') as f:f.write(data.read())return True
代碼解析
-
下載和緩存功能:
- 如果處理后的數據已緩存 (
processed_cora.pkl
),直接加載緩存。 - 如果未緩存,則從 GitHub 下載原始數據,處理后存儲為緩存文件。
- 如果處理后的數據已緩存 (
-
數據處理:
process_data
:- 加載原始數據,并將訓練、驗證、測試節點特征拼接成完整矩陣。
- 生成 PyG 格式的
edge_index
(用于圖神經網絡的鄰接表表示)。 - 生成訓練、驗證和測試集掩碼。
-
鄰接表轉換為邊索引:
build_edge_index
將鄰接表 (graph
) 轉換為edge_index
格式。edge_index
是一個形狀為(2, num_edges)
的數組,列表示一條邊的起點和終點。
-
返回 PyG 數據對象:
- 數據對象包括
x
、y
、edge_index
、train_mask
、val_mask
和test_mask
。
- 數據對象包括
運行代碼測試
要測試 CoraData
類,可以直接運行以下代碼:
cora_data = CoraData(data_root="cora", rebuild=True)
data = cora_data.data # 獲取 PyG 的 Data 對象
print(data)
輸出示例:
Processing data ...
Node feature shape: torch.Size([2708, 1433])
Node label shape: torch.Size([2708])
Edge index shape: torch.Size([2, 10556])
Number of training nodes: 140
Number of validation nodes: 500
Number of test nodes: 1000
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
該類的功能與 PyTorch Geometric 的 Planetoid
類一致,支持加載 Cora
數據集,并生成標準的 PyG Data
對象,適用于圖神經網絡模型訓練。