「日拱一碼」027 深度學習庫——PyTorch Geometric(PyG)

目錄

數據處理與轉換

數據表示

數據加載

數據轉換

特征歸一化

添加自環

隨機擾動

組合轉換

圖神經網絡層

圖卷積層(GCNConv)

圖注意力層(GATConv)

池化

全局池化(Global Pooling)

全局平均池化

全局最大池化

全局求和池化

基于注意力的池化(Attention-based Pooling)

基于圖的池化(Graph-based Pooling)

層次化池化(Hierarchical Pooling)

采樣

子圖采樣(Subgraph Sampling)

鄰域采樣(Neighbor Sampling)

模型訓練與評估

訓練過程

測試過程

異構圖處理

異構圖定義

異構圖卷積

圖生成模型

Deep Graph Infomax (DGI)

Graph Autoencoder (GAE)

Variational Graph Autoencoder (VGAE)


PyTorch Geometric(PyG)是PyTorch的一個擴展庫,專注于圖神經網絡(GNN)的實現。它提供了豐富的圖數據處理工具、圖神經網絡層和模型。以下是對PyG庫中常用方法的介紹

數據處理與轉換

數據表示

PyG使用 torch_geometric.data.Data 類來表示圖數據,包含節點特征 x 、邊索引 edge_index 、邊特征 edge_attr 等

## 數據處理與轉換
# 1. 數據表示
import torch
from torch_geometric.data import Data# 創建一個簡單的圖
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float)  # 節點特征
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)  # 邊索引
edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float)  # 邊特征data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
print(data)  # Data(x=[3, 2], edge_index=[2, 4], edge_attr=[4])

數據加載

PyG提供了 torch_geometric.data.DataLoader 類,用于批量加載圖數據

# 2. 數據加載
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader# 加載Cora數據集
dataset = Planetoid(root='./data', name='Cora')
loader = DataLoader(dataset, batch_size=32, shuffle=True)print(f"節點數: {data.num_nodes}")  # 3
print(f"邊數: {data.num_edges}")  # 4
print(f"特征維度: {data.num_node_features}")  # 2
print(f"類別數: {dataset.num_classes}")  # 7for batch in loader:print(batch)# DataBatch(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708],#           batch=[2708], ptr=[2])

數據轉換

  • 特征歸一化

NormalizeFeatures ?是一個常用的轉換方法,用于將節點特征歸一化到單位范數(如 0, 1 或 -1, 1)

# 3. 數據轉換
# 3.1 特征歸一化
from torch_geometric.transforms import NormalizeFeaturesdataset = Planetoid(root='./data', name='Cora', transform=NormalizeFeatures())# 查看歸一化后的特征
data = dataset[0]
print(data.x)
# tensor([[0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         ...,
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.]])
  • 添加自環

AddSelfLoops ?是一個轉換方法,用于為圖中的每個節點添加自環(即每個節點連接到自己)

# 3.2 添加自環
from torch_geometric.transforms import AddSelfLoopsdataset = Planetoid(root='./data', name='Cora', transform=AddSelfLoops())# 查看添加自環后的邊索引
data = dataset[0]
print(data.edge_index)
# tensor([[   0,    0,    0,  ..., 2705, 2706, 2707],
#         [ 633, 1862, 2582,  ..., 2705, 2706, 2707]])
  • 隨機擾動

RandomNodeSplit ?是一個轉換方法,用于隨機劃分訓練集、驗證集和測試集

# 3.3 隨機擾動
from torch_geometric.transforms import RandomNodeSplitdataset = Planetoid(root='./data', name='Cora', transform=RandomNodeSplit(num_splits=10))# 查看劃分后的掩碼
data = dataset[0]
print(data.train_mask)
# tensor([[False,  True,  True,  ..., False, False,  True],
#         [False, False,  True,  ...,  True, False, False],
#         [False,  True, False,  ..., False, False, False],
#         ...,
#         [ True,  True,  True,  ..., False, False, False],
#         [ True,  True,  True,  ..., False, False,  True],
#         [ True,  True,  True,  ...,  True, False,  True]])
print(data.val_mask)
# tensor([[False, False, False,  ..., False,  True, False],
#         [False, False, False,  ..., False, False, False],
#         [False, False, False,  ..., False, False,  True],
#         ...,
#         [False, False, False,  ...,  True,  True, False],
#         [False, False, False,  ..., False,  True, False],
#         [False, False, False,  ..., False, False, False]])
print(data.test_mask)
# tensor([[ True, False, False,  ...,  True, False, False],
#         [ True,  True, False,  ..., False,  True,  True],
#         [ True, False,  True,  ...,  True,  True, False],
#         ...,
#         [False, False, False,  ..., False, False,  True],
#         [False, False, False,  ...,  True, False, False],
#         [False, False, False,  ..., False,  True, False]])
  • 組合轉換

可以將多個轉換方法組合在一起,形成一個復合轉換

# 3.4 組合轉換
from torch_geometric.transforms import Compose, NormalizeFeatures, AddSelfLoops# 定義一個復合轉換
transform = Compose([NormalizeFeatures(), AddSelfLoops()])# 創建一個數據集,并應用復合轉換
dataset = Planetoid(root='./data', name='Cora', transform=transform)# 查看轉換后的數據
data = dataset[0]
print(data.x)
# tensor([[0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         ...,
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.]])
print(data.edge_index)
# tensor([[   0,    0,    0,  ..., 2705, 2706, 2707],
#         [ 633, 1862, 2582,  ..., 2705, 2706, 2707]])

圖神經網絡層

圖卷積層(GCNConv)

GCNConv是圖卷積網絡(GCN)的基本層

## 圖神經網絡層
# 1. 圖卷積層 GCNConv
import torch
from torch_geometric.nn import GCNConvclass GCN(torch.nn.Module):def __init__(self, in_channels, out_channels):super(GCN, self).__init__()self.conv1 = GCNConv(in_channels, 16)self.conv2 = GCNConv(16, out_channels)def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = torch.relu(x)x = self.conv2(x, edge_index)return xmodel = GCN(in_channels=dataset.num_features, out_channels=dataset.num_classes)
print(model)
# GCN(
#   (conv1): GCNConv(1433, 16)
#   (conv2): GCNConv(16, 7)
# )

圖注意力層(GATConv)

GATConv是圖注意力網絡(GAT)的基本層

# 2. 圖注意力層 GATConv
from torch_geometric.nn import GATConvclass GAT(torch.nn.Module):def __init__(self, in_channels, out_channels):super(GAT, self).__init__()self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6)self.conv2 = GATConv(8 * 8, out_channels, heads=1, concat=True, dropout=0.6)def forward(self, x, edge_index):x = torch.dropout(x, p=0.6, training=self.training)x = self.conv1(x, edge_index)x = torch.relu(x)x = torch.dropout(x, p=0.6, training=self.training)x = self.conv2(x, edge_index)return xmodel = GAT(in_channels=dataset.num_features, out_channels=dataset.num_classes)
print(model)
# GAT(
#   (conv1): GATConv(1433, 8, heads=8)
#   (conv2): GATConv(64, 7, heads=1)
# )

池化

全局池化(Global Pooling)

全局池化將整個圖的所有節點聚合為一個全局表示

  • 全局平均池化
## 池化
# 1. 全局池化
# 1.1 全局平均池化
from torch_geometric.nn import global_mean_pool
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader# 加載數據集
dataset = TUDataset(root='./data', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 獲取一個批次的數據
for batch in loader:x = batch.xbatch_index = batch.batchglobal_mean = global_mean_pool(x, batch_index)print("Global Mean Pooling Result:", global_mean)break
# tensor([[0.7647, 0.0588, 0.1176, 0.0000, 0.0588, 0.0000, 0.0000],
#         [0.7500, 0.1250, 0.1250, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6250, 0.1250, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.5217, 0.1739, 0.3043, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.5455, 0.2727, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6400, 0.1200, 0.2400, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6364, 0.0909, 0.2727, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7857, 0.0714, 0.1429, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8000, 0.0667, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7647, 0.0588, 0.1176, 0.0000, 0.0000, 0.0000, 0.0588],
#         [0.5000, 0.1667, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7692, 0.0769, 0.1538, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7826, 0.0435, 0.1739, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8000, 0.0500, 0.1500, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8696, 0.0435, 0.0870, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7273, 0.0909, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6667, 0.0833, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7273, 0.0909, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8125, 0.0625, 0.1250, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8235, 0.0588, 0.1176, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6000, 0.1000, 0.2000, 0.0000, 0.0000, 0.1000, 0.0000],
#         [0.4615, 0.1538, 0.3077, 0.0000, 0.0000, 0.0769, 0.0000],
#         [0.7647, 0.0588, 0.1765, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6000, 0.0500, 0.2000, 0.0000, 0.0000, 0.1500, 0.0000],
#         [0.6000, 0.1000, 0.2000, 0.1000, 0.0000, 0.0000, 0.0000],
#         [0.7647, 0.0588, 0.1176, 0.0000, 0.0000, 0.0588, 0.0000],
#         [0.8000, 0.0667, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.4615, 0.1538, 0.3077, 0.0769, 0.0000, 0.0000, 0.0000],
#         [0.8696, 0.0435, 0.0870, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8696, 0.0435, 0.0870, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7273, 0.0909, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8421, 0.0526, 0.1053, 0.0000, 0.0000, 0.0000, 0.0000]])
  • 全局最大池化
# 1.2 全局最大池化
from torch_geometric.nn import global_max_pool# 獲取一個批次的數據
for batch in loader:x = batch.xbatch_index = batch.batchglobal_max = global_max_pool(x, batch_index)print("Global Max Pooling Result:", global_max)break
  • 全局求和池化
# 3. 全局求和池化
from torch_geometric.nn import global_add_pool# 獲取一個批次的數據
for batch in loader:x = batch.xbatch_index = batch.batchglobal_sum = global_add_pool(x, batch_index)print("Global Sum Pooling Result:", global_sum)break

基于注意力的池化(Attention-based Pooling)

基于注意力的池化方法通過學習節點的重要性權重來進行池化。一個常見的例子是 Set2Set 池化

# 2. 基于注意力的池化——Set2Set
from torch_geometric.nn import Set2Set
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader# 加載數據集
dataset = TUDataset(root='./data', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 定義 Set2Set 池化
set2set = Set2Set(in_channels=dataset.num_node_features, processing_steps=3)# 獲取一個批次的數據
for batch in loader:x = batch.xbatch_index = batch.batchglobal_set2set = set2set(x, batch_index)print("Set2Set Pooling Result:", global_set2set)break
# Set2Set Pooling Result: tensor([[ 0.1719,  0.0986,  0.1594, -0.0438,  0.1743,  0.1663, -0.0578,  0.8464,
#           0.0492,  0.1045,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1730,  0.0987,  0.1601, -0.0420,  0.1730,  0.1658, -0.0549,  0.8733,
#           0.0405,  0.0862,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1686,  0.0919,  0.1603, -0.0525,  0.1807,  0.1707, -0.0683,  0.7540,
#           0.0466,  0.1994,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1601,  0.1165,  0.1425, -0.0525,  0.1782,  0.1602, -0.0836,  0.6232,
#           0.2237,  0.1531,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1725,  0.0987,  0.1598, -0.0428,  0.1736,  0.1660, -0.0562,  0.8611,
#           0.0444,  0.0945,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1673,  0.1060,  0.1527, -0.0473,  0.1761,  0.1642, -0.0679,  0.7570,
#           0.1187,  0.1243,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1579,  0.0996,  0.1486, -0.0658,  0.1874,  0.1695, -0.0954,  0.5284,
#           0.1662,  0.3054,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1584,  0.0969,  0.1503, -0.0665,  0.1881,  0.1709, -0.0949,  0.5327,
#           0.1503,  0.3170,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1634,  0.0976,  0.1537, -0.0581,  0.1835,  0.1695, -0.0809,  0.6464,
#           0.1135,  0.2401,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1704,  0.0952,  0.1599, -0.0479,  0.1774,  0.1684, -0.0626,  0.8042,
#           0.0466,  0.1492,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1634,  0.1081,  0.1488, -0.0522,  0.1789,  0.1640, -0.0776,  0.6743,
#           0.1595,  0.1661,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1671,  0.0980,  0.1563, -0.0518,  0.1797,  0.1682, -0.0707,  0.7332,
#           0.0855,  0.1813,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1564,  0.1193,  0.1384, -0.0562,  0.1800,  0.1590, -0.0922,  0.5527,
#           0.2663,  0.1810,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1704,  0.0952,  0.1599, -0.0479,  0.1774,  0.1684, -0.0626,  0.8042,
#           0.0466,  0.1492,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1671,  0.0980,  0.1563, -0.0518,  0.1797,  0.1682, -0.0707,  0.7332,
#           0.0855,  0.1813,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1730,  0.0987,  0.1601, -0.0420,  0.1730,  0.1658, -0.0549,  0.8733,
#           0.0405,  0.0862,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1604,  0.0972,  0.1517, -0.0631,  0.1863,  0.1704, -0.0893,  0.5779,
#           0.1356,  0.2864,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1711,  0.0985,  0.1589, -0.0451,  0.1752,  0.1666, -0.0599,  0.8281,
#           0.0550,  0.1169,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1584,  0.1178,  0.1406, -0.0542,  0.1790,  0.1597, -0.0875,  0.5910,
#           0.2432,  0.1659,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1673,  0.1060,  0.1527, -0.0473,  0.1761,  0.1642, -0.0679,  0.7570,
#           0.1187,  0.1243,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1602,  0.1097,  0.1457, -0.0562,  0.1811,  0.1638, -0.0856,  0.6077,
#           0.1926,  0.1997,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1696,  0.1047,  0.1549, -0.0444,  0.1743,  0.1641, -0.0623,  0.8062,
#           0.0945,  0.0993,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1604,  0.0972,  0.1517, -0.0631,  0.1863,  0.1704, -0.0893,  0.5779,
#           0.1356,  0.2864,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1671,  0.0980,  0.1563, -0.0518,  0.1797,  0.1682, -0.0707,  0.7332,
#           0.0855,  0.1813,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1638,  0.0919,  0.1567, -0.0605,  0.1855,  0.1723, -0.0815,  0.6416,
#           0.0853,  0.2731,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1732,  0.1049,  0.1525, -0.0508,  0.1755,  0.1624, -0.0665,  0.7700,
#           0.0553,  0.1160,  0.0000,  0.0000,  0.0586,  0.0000],
#         [ 0.1711,  0.0985,  0.1589, -0.0451,  0.1752,  0.1666, -0.0599,  0.8281,
#           0.0550,  0.1169,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1573,  0.0968,  0.1494, -0.0685,  0.1891,  0.1712, -0.0982,  0.5063,
#           0.1589,  0.3349,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1729,  0.1053,  0.1365, -0.0582,  0.1904,  0.1594, -0.0881,  0.5637,
#           0.0878,  0.1812,  0.0746,  0.0000,  0.0927,  0.0000],
#         [ 0.1586,  0.1026,  0.1477, -0.0628,  0.1855,  0.1678, -0.0924,  0.5526,
#           0.1742,  0.2733,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1646,  0.1075,  0.1500, -0.0506,  0.1781,  0.1641, -0.0746,  0.6999,
#           0.1469,  0.1533,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1695,  0.0983,  0.1579, -0.0477,  0.1770,  0.1672, -0.0641,  0.7909,
#           0.0670,  0.1421,  0.0000,  0.0000,  0.0000,  0.0000]],
#        grad_fn=<CatBackward0>)

基于圖的池化(Graph-based Pooling)

基于圖的池化方法通過圖的結構信息來進行池化。常見的方法包括 TopKPooling,通過選擇重要性最高的節點來進行池化

# 3. 基于圖的池化——TopKPooling
from torch_geometric.nn import TopKPooling
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader# 加載數據集
dataset = TUDataset(root='./data', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 定義 TopKPooling
pool = TopKPooling(in_channels=dataset.num_node_features, ratio=0.5) # 獲取一個批次的數據
for batch in loader:x = batch.xedge_index = batch.edge_indexbatch_index = batch.batchx, edge_index, _, batch_index, _, _ = pool(x, edge_index, batch=batch_index)print("TopKPooling Result:", x)break
# tensor([[-0.0000, -0.0392, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.0000, -0.0392, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.0000, -0.0392, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         ...,
#         [-0.4577, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.4577, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.4577, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000]],
#        grad_fn=<MulBackward0>)

層次化池化(Hierarchical Pooling)

層次化池化通過多層池化操作生成圖的層次化表示。一個常見的例子是 EdgePooling,通過邊的合并操作來進行池化

# 4. 層次化池化——EdgePooling
from torch_geometric.nn import EdgePooling
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader# 加載數據集
dataset = TUDataset(root='./data', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 定義 EdgePooling
pool = EdgePooling(in_channels=dataset.num_node_features)  # 獲取一個批次的數據
for batch in loader:x = batch.xedge_index = batch.edge_indexbatch_index = batch.batchx, edge_index, batch_index, _ = pool(x, edge_index, batch=batch_index)print("EdgePooling Result:", x)break# tensor([[0.0000, 1.5000, 1.5000,  ..., 0.0000, 0.0000, 0.0000],
#         [0.0000, 1.5000, 1.5000,  ..., 0.0000, 0.0000, 0.0000],
#         [0.0000, 1.5000, 1.5000,  ..., 0.0000, 0.0000, 0.0000],
#         ...,
#         [0.0000, 0.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
#         [1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
#         [0.0000, 0.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000]],
#        grad_fn=<MulBackward0>)

采樣

子圖采樣(Subgraph Sampling)

子圖采樣是從原始圖中提取一個子圖,通常用于減少計算復雜度和增強模型的泛化能力

## 采樣
# 1. 子圖采樣
import torch
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import k_hop_subgraph# 加載數據集
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]# 選擇一個起始節點
start_node = 0
num_hops = 2  # 采樣半徑# 提取子圖
sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(start_node, num_hops, data.edge_index)# 創建子圖
sub_data = Data(x=data.x[sub_nodes], edge_index=sub_edge_index, y=data.y[sub_nodes])print("Original Graph Nodes:", data.num_nodes)  # 2708
print("Subgraph Nodes:", sub_data.num_nodes)  # 8
print("Subgraph Edges:", sub_data.edge_index.shape[1])  # 20

鄰域采樣(Neighbor Sampling)

鄰域采樣通過選擇節點的鄰居來生成子圖,適用于大規模圖數據

# 2. 鄰域采樣
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader# 加載數據集
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]# 定義 NeighborSampler
loader = NeighborLoader(data,num_neighbors=[10, 10],  # 每層采樣的鄰居數量batch_size=1024,shuffle=True,
)# 遍歷數據加載器
for batch in loader:print(batch)break

模型訓練與評估

訓練過程

## 模型訓練與評估
# 1. 訓練過程
import torch.nn.functional as Foptimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss

測試過程

# 2. 測試過程
@torch.no_grad()
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())acc = correct / int(data.test_mask.sum())return accfor epoch in range(200):loss = train()acc = test()print(f'Epoch: {epoch + 1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

異構圖處理

異構圖定義

## 異構圖處理
# 1. 異構圖定義
from torch_geometric.data import HeteroData
import torchdata = HeteroData()
# 添加兩種類型節點
data['user'].x = torch.randn(4, 16)  # 4個用戶
data['movie'].x = torch.randn(5, 32)  # 5部電影
# 添加邊
data['user', 'rates', 'movie'].edge_index = torch.tensor([[0, 0, 1, 2, 3], [0, 2, 3, 1, 4]]  # user->movie評分關系
)

異構圖卷積

# 2. 異構圖卷積
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv
from torch_geometric.transforms import NormalizeFeaturesclass HeteroGNN(torch.nn.Module):def __init__(self, in_channels, out_channels, hidden_channels):super().__init__()self.conv1 = HeteroConv({('user', 'rates', 'movie'): SAGEConv((in_channels['user'], in_channels['movie']), hidden_channels),('movie', 'rev_rates', 'user'): GCNConv(in_channels['movie'], hidden_channels, add_self_loops=False)  # 禁用自環}, aggr='sum')self.conv2 = HeteroConv({('user', 'rates', 'movie'): SAGEConv((hidden_channels, hidden_channels), out_channels),('movie', 'rev_rates', 'user'): GCNConv(hidden_channels, out_channels, add_self_loops=False)  # 禁用自環}, aggr='sum')def forward(self, x_dict, edge_index_dict):x_dict = self.conv1(x_dict, edge_index_dict)x_dict = {key: torch.relu(x) for key, x in x_dict.items()}x_dict = self.conv2(x_dict, edge_index_dict)return x_dict# 定義輸入和輸出通道數
in_channels = {'user': 16, 'movie': 32}
out_channels = 7  # 假設輸出通道數為7
hidden_channels = 64  # 假設隱藏層通道數為64# 實例化模型
model = HeteroGNN(in_channels, out_channels, hidden_channels)
print(model)
# HeteroGNN(
#   (conv1): HeteroConv(num_relations=2)
#   (conv2): HeteroConv(num_relations=2)
# )

圖生成模型

Deep Graph Infomax (DGI)

DGI 是一種無監督圖表示學習方法,通過最大化局部和全局圖表示之間的一致性來學習節點嵌入

## 圖生成模型
# 1. Deep Graph Infomax (DGI)
from torch_geometric.nn import DeepGraphInfomax
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import torch.nn as nn
import torch.nn.functional as Fclass Encoder(nn.Module):def __init__(self, in_channels, hidden_channels):super(Encoder, self).__init__()self.conv = GCNConv(in_channels, hidden_channels)self.prelu = nn.PReLU(hidden_channels)def forward(self, x, edge_index):x = self.conv(x, edge_index)x = self.prelu(x)return xdef corruption(x, edge_index):return x[torch.randperm(x.size(0))], edge_indexdataset = Planetoid(root='./data', name='Cora')
data = dataset[0]encoder = Encoder(dataset.num_features, hidden_channels=512)
model = DeepGraphInfomax(hidden_channels=512, encoder=encoder,summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),corruption=corruption
)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)def train():model.train()optimizer.zero_grad()pos_z, neg_z, summary = model(data.x, data.edge_index)loss = model.loss(pos_z, neg_z, summary)loss.backward()optimizer.step()return lossfor epoch in range(100):loss = train()print(f'Epoch: {epoch + 1}, Loss: {loss:.4f}')

Graph Autoencoder (GAE)

GAE 是一種基于圖神經網絡的自編碼器,用于圖生成任務。它通過學習節點嵌入來重建圖的鄰接矩陣

# 2. Graph Autoencoder(GAE)
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GAE
import torch.nn.functional as Fclass Encoder(nn.Module):def __init__(self, in_channels, out_channels):super(Encoder, self).__init__()self.conv1 = GCNConv(in_channels, 2 * out_channels)self.conv2 = GCNConv(2 * out_channels, out_channels)def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = F.relu(x)return self.conv2(x, edge_index)dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]encoder = Encoder(dataset.num_features, out_channels=16)
model = GAE(encoder)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)def train():model.train()optimizer.zero_grad()z = model.encode(data.x, data.edge_index)loss = model.recon_loss(z, data.edge_index)loss.backward()optimizer.step()return lossfor epoch in range(100):loss = train()print(f'Epoch: {epoch + 1}, Loss: {loss:.4f}')

Variational Graph Autoencoder (VGAE)

VGAE 是 GAE 的變體,通過引入變分推斷來學習節點嵌入的分布

# 3. Variational Graph Autoencoder(VGAE)
from torch_geometric.nn import VGAE
from torch_geometric.datasets import Planetoid# 定義數據集
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]class Encoder(nn.Module):def __init__(self, in_channels, out_channels):super(Encoder, self).__init__()self.conv1 = GCNConv(in_channels, 2 * out_channels)self.conv2 = GCNConv(2 * out_channels, 2 * out_channels)def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)mu = x[:, :x.size(1) // 2]logstd = x[:, x.size(1) // 2:]return mu, logstd# 定義 Encoder
encoder = Encoder(dataset.num_features, out_channels=16)# 定義 VGAE 模型
model = VGAE(encoder)# 定義優化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 訓練函數
def train():model.train()optimizer.zero_grad()z = model.encode(data.x, data.edge_index)loss = model.recon_loss(z, data.edge_index)kl_loss = model.kl_loss()loss += kl_lossloss.backward()optimizer.step()return loss# 訓練模型
for epoch in range(100):loss = train()print(f'Epoch: {epoch + 1}, Loss: {loss:.4f}')

    本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
    如若轉載,請注明出處:http://www.pswp.cn/web/89092.shtml
    繁體地址,請注明出處:http://hk.pswp.cn/web/89092.shtml
    英文地址,請注明出處:http://en.pswp.cn/web/89092.shtml

    如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

    相關文章

    IoC容器深度解析:架構、原理與實現

    &#x1f31f; IoC容器深度解析&#xff1a;架構、原理與實現 引用&#xff1a; .NET IoC容器原理與實現等巫山的雲彩都消散撒下的碧色如何看淡 &#x1f50d; 一、引言&#xff1a;從服務定位器到IoC的演進 #mermaid-svg-BmRIuI4iMgiUqFVN {font-family:"trebuchet ms&…

    從零開始學前端html篇3

    表單基本結構表單是 HTML 中用于創建用戶輸入區域的標簽。它允許用戶輸入數據&#xff08;例如文本、選擇選項、文件等&#xff09;&#xff0c;并將這些數據提交到服務器進行處理。<form>&#xff0c;表單標簽&#xff0c;用于創建表單常用屬性&#xff1a;action&#…

    Linux系統調優和工具

    Linux系統調優和問題定位需要掌握一系列強大的工具&#xff0c;涵蓋系統監控、性能分析、故障排查等多個方面。以下是一些核心工具和它們的典型應用場景&#xff0c;分類整理如下&#xff1a; 一、系統資源監控&#xff08;實時概覽&#xff09;top / htop 功能&#xff1a; 實…

    如何快速有效地在WordPress中添加Instagram動態

    在當今社交媒體的時代&#xff0c;通過展示Instagram的最新動態&#xff0c;可以有效吸引讀者的目光&#xff0c;同時豐富網站內容。很多人想知道&#xff0c;如何把自己精心運營的Instagram內容無縫嵌入WordPress網站呢&#xff1f;別擔心&#xff0c;操作并不復雜&#xff0c…

    spring容器加載工具類

    在Spring框架中&#xff0c;工具類通常不需要被Spring容器管理&#xff0c;但如果確實需要獲取Spring容器中的Bean實例&#xff0c;可以通過靜態方法設置和獲取ApplicationContext。下面是一個典型的Spring容器加載工具類的實現&#xff1a;這個工具類通過實現ApplicationConte…

    定時器更新中斷與串口中斷

    問題&#xff1a;我想把打印姿態傳感器的角度&#xff0c;但是重定向的打印函數突然打印不出來。嘗試&#xff1a;我懷疑是優先級的問題&#xff0c;故調整了串口&#xff0c;定時器&#xff0c;dma的優先級可是發現調了還是沒有用&#xff0c;最終發現&#xff0c;我把定時器中…

    用Python向PDF添加文本:精確插入文本到PDF文檔

    PDF 文檔的版式特性使其適用于輸出不可變格式的報告與合同。但若要在此類文檔中插入或修改文本&#xff0c;常規方式難以實現。借助Python&#xff0c;我們可以高效地向 PDF 添加文本&#xff0c;實現從文檔生成到內容管理的自動化流程。 本文將從以下方面介紹Python實現PDF中…

    Quick API:賦能能源行業,化解數據痛點

    隨著全球能源結構的轉型和數字化的深入推進&#xff0c;能源行業正面臨前所未有的機遇與挑戰。海量的實時數據、復雜的業務系統、以及對數據安全和高效利用的迫切需求&#xff0c;都成為了能源企業在數字化轉型道路上的核心痛點。本文將深入探討麥聰Quick API如何憑借其獨特優勢…

    Google Chrome V8< 13.6.86 類型混淆漏洞

    【高危】Google Chrome V8< 13.6.86 類型混淆漏洞 漏洞描述 Google Chrome 是美國谷歌&#xff08;Google&#xff09;公司的一款Web瀏覽器&#xff0c;V8 是 Google 開發的高性能開源 JavaScript 和 WebAssembly 引擎&#xff0c;廣泛應用于 Chrome 瀏覽器和 Node.js 等環…

    力扣經典算法篇-23-環形鏈表(哈希映射法,快慢指針法)

    1、題干 給你一個鏈表的頭節點 head &#xff0c;判斷鏈表中是否有環。 如果鏈表中有某個節點&#xff0c;可以通過連續跟蹤 next 指針再次到達&#xff0c;則鏈表中存在環。 為了表示給定鏈表中的環&#xff0c;評測系統內部使用整數 pos 來表示鏈表尾連接到鏈表中的位置&…

    HarmonyOS DevEco Studio 小技巧 42 - 鴻蒙單向數據流

    在鴻蒙應用開發中&#xff0c;狀態管理是構建響應式界面的核心支柱&#xff0c;而 單向數據流&#xff08;Unidirectional Data Flow, UDF&#xff09;作為鴻蒙架構的重要設計原則&#xff0c;貫穿于組件通信、狀態更新和界面渲染的全流程。本文將結合鴻蒙 ArkUI 框架特性&…

    【LeetCode 3136. 有效單詞】解析

    目錄LeetCode中國站原文原始題目題目描述示例 1&#xff1a;示例 2&#xff1a;示例 3&#xff1a;提示&#xff1a;講解化繁為簡&#xff1a;如何優雅地“盤”邏輯判斷題第一部分&#xff1a;算法思想 —— “清單核對”與“一票否決”第二部分&#xff1a;代碼實現 —— 清晰…

    前端面試專欄-算法篇:24. 算法時間與空間復雜度分析

    &#x1f525; 歡迎來到前端面試通關指南專欄&#xff01;從js精講到框架到實戰&#xff0c;漸進系統化學習&#xff0c;堅持解鎖新技能&#xff0c;祝你輕松拿下心儀offer。 前端面試通關指南專欄主頁 前端面試專欄規劃詳情 算法時間與空間復雜度分析&#xff1a;從理論到實踐…

    bash中||與的區別

    在 Bash 中&#xff0c;|| 和 && 是兩種常用的邏輯操作符&#xff0c;用于控制命令的執行流程。它們的核心區別如下&#xff1a;1. ||&#xff08;邏輯 OR&#xff09; 作用&#xff1a;如果前一個命令失敗&#xff08;返回非零退出碼&#xff09;&#xff0c;則執行后…

    OpenCV實現感知哈希(Perceptual Hash)算法的類cv::img_hash::PHash

    操作系統&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 編程語言&#xff1a;C11 算法描述 PHash是OpenCV中實現感知哈希&#xff08;Perceptual Hash&#xff09;算法的類。該算法用于快速比較圖像的視覺相似性。它將圖像壓縮為一個簡短的…

    數據庫遷移人大金倉數據庫

    遷移前的準備工作 安裝官方的kdts和KStudio工具 方案說明 一、數據庫遷移&#xff1a;可以使用kdts進行數據庫的按照先遷移表結構、后數據的順序遷移&#xff08;kdts的使用可以參考官方文檔&#xff09; 其他參考文檔 人大金倉官網&#xff1a;https://download.kingbase…

    uniapp 微信小程序Vue3項目使用內置組件movable-area封裝懸浮可拖拽按鈕(拖拽結束時自動吸附到最近的屏幕邊緣)

    一、最終效果 二、具體詳情請看movable-area與movable-view官方文檔說明 三、參數配置 1、代碼示例 <TFab title"新建訂單" click"addOrder" /> // title:表按鈕文案 // addOrder:點擊按鈕事件四、組件源碼 <template><movable-area cl…

    linux kernel為什么要用IS_ERR()宏來判斷指針合法性?

    在 Linux 內核中&#xff0c;IS_ERR() 宏的設計與內核的錯誤處理機制和指針編碼規范密切相關&#xff0c;主要用于判斷一個“可能攜帶錯誤碼的指針”是否代表異常狀態。其核心目的是解決內核中指針返回值與錯誤碼的統一表示問題。以下從技術背景、設計邏輯和實際場景三個維度詳…

    Cookie與Session:Web開發核心差異詳解

    理解 Cookie 和 Session 的區別對于 Web 開發至關重要,它們雖然經常一起使用,但扮演著不同的角色。核心區別在于: Cookie:存儲在客戶端(用戶的瀏覽器)的數據片段。 Session:存儲在服務器端的數據結構,用于跟蹤特定用戶的狀態。 下面是詳細的對比: 特性CookieSession…

    【相干、相參】 雷電名詞溯源

    〇、廢話因緣 最近某些國產的微波制造公司總是提到一個概念【相干】【相參】【嚴格相參】等等概念層出不窮&#xff0c;讓人苦惱。 一、這玩意還是英文溯源吧 這幾個概念都聚焦在一個單詞【Coherence】&#xff1b;所以就是說兩個波形之間有某種聯系&#xff0c;不一定就是完全…