DenseDataLoader
是專門用于處理稠密圖數據的,而 DataLoader
通常用于處理稀疏圖數據。兩者的主要區別在于它們的輸入數據格式和處理方式。DenseDataLoader
適合處理固定大小的鄰接矩陣和節點特征矩陣的數據,而 DataLoader
更加靈活,可以處理稀疏表示的圖數據。
主要區別
-
DataLoader
:- 適合處理稀疏圖數據。
- 通常與
torch_geometric.data.Data
一起使用,其中邊索引是稀疏表示的。 - 更加靈活,適合處理各種不同形狀和大小的圖。
-
DenseDataLoader
:- 適合處理稠密圖數據。
- 通常與固定大小的鄰接矩陣和節點特征矩陣一起使用。
- 更高效地處理固定大小的圖數據。
使用示例
使用 DenseDataLoader
如果你有固定大小的鄰接矩陣和節點特征矩陣,可以直接使用 DenseDataLoader
加載數據:
1. 導入必要的庫
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader
2. 定義數據集類
class MyDenseDataset(torch.utils.data.Dataset):def __init__(self, num_samples, num_nodes, num_node_features):self.num_samples = num_samplesself.num_nodes = num_nodesself.num_node_features = num_node_featuresself.adj_matrix = self.create_adj_matrix(num_nodes)def create_adj_matrix(self, num_nodes):# 創建環形圖的鄰接矩陣adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)for i in range(num_nodes):adj_matrix[i, (i + 1) % num_nodes] = 1adj_matrix[(i + 1) % num_nodes, i] = 1return adj_matrixdef __len__(self):return self.num_samplesdef __getitem__(self, idx):# 創建隨機特征和標簽x = torch.randn((self.num_nodes, self.num_node_features))y = torch.randn((self.num_nodes, 1)) # 每個節點一個標簽return Data(x=x, adj=self.adj_matrix, y=y)
3. 創建數據集和封裝數據
# 參數設置
num_samples = 100 # 樣本數
num_nodes = 10 # 每個圖中的節點數
num_node_features = 8 # 每個節點的特征數# 創建數據集
dataset = MyDenseDataset(num_samples, num_nodes, num_node_features)
4. 使用 DenseDataLoader
# 使用 DenseDataLoader 加載數據
loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)# 從 DenseDataLoader 中獲取一個批次的數據并查看其形狀
for data in loader:print("Batch node features shape:", data.x.shape) # 期望輸出形狀為 (32, 10, 8)print("Batch adjacency matrix shape:", data.adj.shape) # 期望輸出形狀為 (32, 10, 10)print("Batch labels shape:", data.y.shape) # 期望輸出形狀為 (32, 10, 1)break # 僅查看第一個批次的形狀
解釋
-
導入庫:
- 導入
torch
、torch_geometric.data
中的Data
和torch_geometric.loader
中的DenseDataLoader
。
- 導入
-
定義
MyDenseDataset
類:__init__
方法初始化數據集參數,并創建鄰接矩陣。create_adj_matrix
方法創建環形圖的鄰接矩陣。__len__
方法返回數據集的樣本數量。__getitem__
方法生成每個樣本的隨機節點特征和標簽,并返回節點特征矩陣、鄰接矩陣和標簽。
-
創建數據集:
- 使用
MyDenseDataset
類創建一個包含 100 個樣本的數據集,每個樣本包含 10 個節點,每個節點有 8 個特征。
- 使用
-
使用
DenseDataLoader
:- 使用
DenseDataLoader
加載dataset
,設置批次大小為 32,并進行隨機打亂。 - 在獲取一個批次的數據時,檢查
x
、adj
和y
的形狀,以確保其符合期望的三維形狀。
- 使用
通過這個完整的示例代碼,你可以生成、封裝和加載稠密圖數據,并確保每個批次的數據形狀保持正確。這種方法適合處理節點數和邊數固定的圖數據,提高數據加載和處理的效率。
定義數據集類并使用 DenseDataLoader
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader # 更新導入路徑class MyDenseDataset(torch.utils.data.Dataset):def __init__(self, num_samples, num_nodes, num_node_features):self.num_samples = num_samplesself.num_nodes = num_nodesself.num_node_features = num_node_featuresself.adj_matrix = self.create_adj_matrix(num_nodes)def create_adj_matrix(self, num_nodes):# 創建環形圖的鄰接矩陣adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)for i in range(num_nodes):adj_matrix[i, (i + 1) % num_nodes] = 1adj_matrix[(i + 1) % num_nodes, i] = 1print(adj_matrix)return adj_matrixdef __len__(self):return self.num_samplesdef __getitem__(self, idx):# 創建隨機特征和標簽x = torch.randn((self.num_nodes, self.num_node_features))y = torch.randn((self.num_nodes, 1)) # 每個節點一個標簽return Data(x, self.adj_matrix, y=y)# 創建數據集
num_samples = 100 # 樣本數
num_nodes = 10 # 每個圖中的節點數
num_node_features = 8 # 每個節點的特征數
dataset = MyDenseDataset(num_samples, num_nodes, num_node_features)# 使用 DenseDataLoader 加載數據
loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)# 從 DenseDataLoader 中獲取一個批次的數據并查看其形狀
for data in loader:print("Batch node features shape:", data.x.shape) # 期望輸出形狀為 (32, 10, 8)# print("Batch adjacency matrix shape:", data.adj.shape) # 期望輸出形狀為 (32, 10, 10)print("Batch labels shape:", data.y.shape) # 期望輸出形狀為 (32, 10, 1)break # 僅查看第一個批次的形狀
使用 DataLoader
如果你使用的是 DataLoader
,則數據應當是 torch_geometric.data.Data
對象,并將數據封裝在列表中:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader # 更新導入路徑class MyDataset(torch.utils.data.Dataset):def __init__(self, num_samples, num_nodes, num_node_features):self.num_samples = num_samplesself.num_nodes = num_nodesself.num_node_features = num_node_featuresdef __len__(self):return self.num_samplesdef __getitem__(self, idx):x = torch.randn(self.num_nodes, self.num_node_features)edge_index = torch.tensor([[i, (i + 1) % self.num_nodes] for i in range(self.num_nodes)], dtype=torch.long).t().contiguous()y = torch.randn(self.num_nodes, 1)return Data(x=x, edge_index=edge_index, y=y)# 創建數據集
num_samples = 100 # 樣本數
num_nodes = 10 # 每個圖中的節點數
num_node_features = 8 # 每個節點的特征數
dataset = MyDataset(num_samples, num_nodes, num_node_features)# 使用 DataLoader 加載數據
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 迭代加載數據
for batch in loader:print("Batch node features shape:", batch.x.shape) # 期望輸出形狀為 (320, 8)print("Batch edge index shape:", batch.edge_index.shape)
總結
DenseDataLoader
:處理固定大小的鄰接矩陣和節點特征矩陣的數據,__getitem__
返回Data(x, adj, y)。DataLoader
:處理torch_geometric.data.Data
對象,__getitem__
返回一個Data
對象。
確保數據格式與使用的加載器相匹配,以避免屬性錯誤和其他兼容性問題。