edge_index
是 PyTorch Geometric 中常用的表示圖邊的張量。它通常是一個形狀為 [2, num_edges]
的二維張量,其中 num_edges
表示圖中邊的數量。每一列表示一條邊,包含兩個節點的索引。
- 實際上這是COO存儲格式,官方文檔里也有寫,還有一種是鄰接矩陣的存儲格式,兩種方式是可以互相轉換的
https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html
edge_index[:, :10]
表示取出 edge_index
張量的前 10 列,即前 10 條邊的節點索引。
在 Python 中,使用切片語法 [:,] 是一種方便的方式來選擇多維數組或張量的特定部分(另外一部分Python語法知識)
程序輸出結果
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
First 10 edges (edge_index[:, :10]):
tensor([[ 0, 0, 0, 1, 1, 1, 2, 2, 2, 2],[ 633, 1862, 2582, 2, 652, 654, 1, 332, 1454, 1666]])
edge_index 0: tensor([ 0, 633])
edge_index 1: tensor([ 0, 1862])
edge_index 2: tensor([ 0, 2582])
edge_index 3: tensor([1, 2])
edge_index 4: tensor([ 1, 652])
edge_index 5: tensor([ 1, 654])
edge_index 6: tensor([2, 1])
edge_index 7: tensor([ 2, 332])
edge_index 8: tensor([ 2, 1454])
edge_index 9: tensor([ 2, 1666])
示例代碼
假設你已經使用 PyTorch Geometric 加載了 Cora 數據集,并且 edge_index
已經被定義,以下代碼展示如何查看前 10 條邊的信息:
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures# 加載 Cora 數據集
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())# 獲取數據集中的第一個圖
data = dataset[0]# 打印數據集的基本信息
print(data)# 獲取邊索引
edge_index = data.edge_index# 打印前 10 條邊的節點索引
print("First 10 edges (edge_index[:, :10]):")
print(edge_index[:, :10])for i in range(10):print(f"edge_index {i}: {edge_index[:,i]}")
示例輸出
假設 edge_index
的前 10 列如下所示:
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],[ 1, 2, 0, 4, 5, 3, 7, 6, 5, 8]])
這表示:
- 第一條邊是從節點 0 到節點 1。
- 第二條邊是從節點 1 到節點 2。
- 第三條邊是從節點 2 到節點 0。
- 第四條邊是從節點 3 到節點 4。
- 第五條邊是從節點 4 到節點 5。
- 第六條邊是從節點 5 到節點 3。
- 第七條邊是從節點 6 到節點 7。
- 第八條邊是從節點 7 到節點 6。
- 第九條邊是從節點 8 到節點 5。
- 第十條邊是從節點 9 到節點 8。
解釋
edge_index
的形狀為[2, num_edges]
,其中num_edges
表示邊的數量。edge_index[:, :10]
表示取出前 10 條邊的節點索引。- 輸出的張量第一行表示每條邊的起始節點,第二行表示每條邊的結束節點。
通過這種方式,你可以方便地查看和理解數據集中邊的表示方式。