【PYG】dataloader和densedataloader

DenseDataLoader 是專門用于處理稠密圖數據的,而 DataLoader 通常用于處理稀疏圖數據。兩者的主要區別在于它們的輸入數據格式和處理方式。DenseDataLoader 適合處理固定大小的鄰接矩陣和節點特征矩陣的數據,而 DataLoader 更加靈活,可以處理稀疏表示的圖數據。

主要區別

  • DataLoader:

    • 適合處理稀疏圖數據。
    • 通常與 torch_geometric.data.Data 一起使用,其中邊索引是稀疏表示的。
    • 更加靈活,適合處理各種不同形狀和大小的圖。
  • DenseDataLoader:

    • 適合處理稠密圖數據。
    • 通常與固定大小的鄰接矩陣和節點特征矩陣一起使用。
    • 更高效地處理固定大小的圖數據。

使用示例

使用 DenseDataLoader

如果你有固定大小的鄰接矩陣和節點特征矩陣,可以直接使用 DenseDataLoader 加載數據:

1. 導入必要的庫
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader
2. 定義數據集類
class MyDenseDataset(torch.utils.data.Dataset):def __init__(self, num_samples, num_nodes, num_node_features):self.num_samples = num_samplesself.num_nodes = num_nodesself.num_node_features = num_node_featuresself.adj_matrix = self.create_adj_matrix(num_nodes)def create_adj_matrix(self, num_nodes):# 創建環形圖的鄰接矩陣adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)for i in range(num_nodes):adj_matrix[i, (i + 1) % num_nodes] = 1adj_matrix[(i + 1) % num_nodes, i] = 1return adj_matrixdef __len__(self):return self.num_samplesdef __getitem__(self, idx):# 創建隨機特征和標簽x = torch.randn((self.num_nodes, self.num_node_features))y = torch.randn((self.num_nodes, 1))  # 每個節點一個標簽return Data(x=x, adj=self.adj_matrix, y=y)
3. 創建數據集和封裝數據
# 參數設置
num_samples = 100  # 樣本數
num_nodes = 10  # 每個圖中的節點數
num_node_features = 8  # 每個節點的特征數# 創建數據集
dataset = MyDenseDataset(num_samples, num_nodes, num_node_features)
4. 使用 DenseDataLoader
# 使用 DenseDataLoader 加載數據
loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)# 從 DenseDataLoader 中獲取一個批次的數據并查看其形狀
for data in loader:print("Batch node features shape:", data.x.shape)  # 期望輸出形狀為 (32, 10, 8)print("Batch adjacency matrix shape:", data.adj.shape)  # 期望輸出形狀為 (32, 10, 10)print("Batch labels shape:", data.y.shape)  # 期望輸出形狀為 (32, 10, 1)break  # 僅查看第一個批次的形狀

解釋

  1. 導入庫

    • 導入 torchtorch_geometric.data 中的 Datatorch_geometric.loader 中的 DenseDataLoader
  2. 定義 MyDenseDataset

    • __init__ 方法初始化數據集參數,并創建鄰接矩陣。
    • create_adj_matrix 方法創建環形圖的鄰接矩陣。
    • __len__ 方法返回數據集的樣本數量。
    • __getitem__ 方法生成每個樣本的隨機節點特征和標簽,并返回節點特征矩陣、鄰接矩陣和標簽。
  3. 創建數據集

    • 使用 MyDenseDataset 類創建一個包含 100 個樣本的數據集,每個樣本包含 10 個節點,每個節點有 8 個特征。
  4. 使用 DenseDataLoader

    • 使用 DenseDataLoader 加載 dataset,設置批次大小為 32,并進行隨機打亂。
    • 在獲取一個批次的數據時,檢查 xadjy 的形狀,以確保其符合期望的三維形狀。

通過這個完整的示例代碼,你可以生成、封裝和加載稠密圖數據,并確保每個批次的數據形狀保持正確。這種方法適合處理節點數和邊數固定的圖數據,提高數據加載和處理的效率。

定義數據集類并使用 DenseDataLoader

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader  # 更新導入路徑class MyDenseDataset(torch.utils.data.Dataset):def __init__(self, num_samples, num_nodes, num_node_features):self.num_samples = num_samplesself.num_nodes = num_nodesself.num_node_features = num_node_featuresself.adj_matrix = self.create_adj_matrix(num_nodes)def create_adj_matrix(self, num_nodes):# 創建環形圖的鄰接矩陣adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)for i in range(num_nodes):adj_matrix[i, (i + 1) % num_nodes] = 1adj_matrix[(i + 1) % num_nodes, i] = 1print(adj_matrix)return adj_matrixdef __len__(self):return self.num_samplesdef __getitem__(self, idx):# 創建隨機特征和標簽x = torch.randn((self.num_nodes, self.num_node_features))y = torch.randn((self.num_nodes, 1))  # 每個節點一個標簽return Data(x, self.adj_matrix, y=y)# 創建數據集
num_samples = 100  # 樣本數
num_nodes = 10  # 每個圖中的節點數
num_node_features = 8  # 每個節點的特征數
dataset = MyDenseDataset(num_samples, num_nodes, num_node_features)# 使用 DenseDataLoader 加載數據
loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)# 從 DenseDataLoader 中獲取一個批次的數據并查看其形狀
for data in loader:print("Batch node features shape:", data.x.shape)  # 期望輸出形狀為 (32, 10, 8)# print("Batch adjacency matrix shape:", data.adj.shape)  # 期望輸出形狀為 (32, 10, 10)print("Batch labels shape:", data.y.shape)  # 期望輸出形狀為 (32, 10, 1)break  # 僅查看第一個批次的形狀

使用 DataLoader

如果你使用的是 DataLoader,則數據應當是 torch_geometric.data.Data 對象,并將數據封裝在列表中:

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader  # 更新導入路徑class MyDataset(torch.utils.data.Dataset):def __init__(self, num_samples, num_nodes, num_node_features):self.num_samples = num_samplesself.num_nodes = num_nodesself.num_node_features = num_node_featuresdef __len__(self):return self.num_samplesdef __getitem__(self, idx):x = torch.randn(self.num_nodes, self.num_node_features)edge_index = torch.tensor([[i, (i + 1) % self.num_nodes] for i in range(self.num_nodes)], dtype=torch.long).t().contiguous()y = torch.randn(self.num_nodes, 1)return Data(x=x, edge_index=edge_index, y=y)# 創建數據集
num_samples = 100  # 樣本數
num_nodes = 10  # 每個圖中的節點數
num_node_features = 8  # 每個節點的特征數
dataset = MyDataset(num_samples, num_nodes, num_node_features)# 使用 DataLoader 加載數據
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 迭代加載數據
for batch in loader:print("Batch node features shape:", batch.x.shape)  # 期望輸出形狀為 (320, 8)print("Batch edge index shape:", batch.edge_index.shape)

總結

  • DenseDataLoader:處理固定大小的鄰接矩陣和節點特征矩陣的數據,__getitem__ 返回Data(x, adj, y)。
  • DataLoader:處理 torch_geometric.data.Data 對象,__getitem__ 返回一個 Data 對象。

確保數據格式與使用的加載器相匹配,以避免屬性錯誤和其他兼容性問題。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/39930.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/39930.shtml
英文地址,請注明出處:http://en.pswp.cn/web/39930.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

flask中解決圖片不顯示的問題(很細微的點)

我在編寫flask項目的時候,在編寫html的時候,發現不管我的圖片路徑如何變化,其就是顯示不出來。如下圖我框中的地方。 我嘗試過使用瀏覽器打開,是可以的。 一旦運行這個flask項目,就無法顯示了。 我查閱資料后。發現…

簡易版async/await

參考:https://juejin.cn/post/7007031572238958629?searchId20240704101813568E9B5B1013C881A239#heading-15 總結一下async/await的知識點 1、 await只能在async函數中使用,不然會報錯 2、 async函數返回的是一個Promise對象,有無值看有…

泛微開發修煉之旅--29用計劃任務定時發送郵件提醒

文章鏈接:29用計劃任務定時發送郵件提醒

[單master節點k8s部署]17.監控系統構建(二)Prometheus安裝

prometheus server安裝 創建sa賬號,對prometheus server進行授權。因為Prometheus是安裝在pod里面,以pod的形式去運行的,因此需要創建sa,并對他做rbac授權。 apiVersion: v1 kind: ServiceAccount metadata:name: monitornamesp…

k8s-第九節-命名空間

命名空間 如果一個集群中部署了多個應用,所有應用都在一起,就不太好管理,也可以導致名字沖突等。 我們可以使用 namespace 把應用劃分到不同的命名空間,跟代碼里的 namespace 是一個概念,只是為了劃分空間。 # 創建命…

LeetCode熱題100刷題4:76. 最小覆蓋子串、239. 滑動窗口最大值、53. 最大子數組和、56. 合并區間

76. 最小覆蓋子串 滑動窗口解決字串問題。 labuladong的算法小抄中關于滑動窗口的算法總結&#xff1a; class Solution { public:string minWindow(string s, string t) {unordered_map<char,int> need,window;for(char c : t) {need[c];}int left 0, right 0;int …

2.8億東亞五國建筑數據分享

數據是GIS的血液&#xff01; 我們現在為你分享東亞5國的2.8億條建筑輪廓數據&#xff0c;該數據包括中國、日本、朝鮮、韓國和蒙古5個東亞國家完整、高質量的建筑物輪廓數據&#xff0c;你可以在文末查看領取方法。 數據介紹 雖然開源的全球的建筑數據已經有微軟的建筑數據…

elementUI中table組件固定列時會渲染兩次模板內容問題

今天在使用elementUI的table組件時&#xff0c;由于業務需要固定表格的前幾項列&#xff0c;然后獲取表格對象時發現竟然有兩個對象。 查閱資料發現&#xff0c;elementUI的固定列的實現原理是將兩個表格拼裝而成&#xff0c;因此獲取的對象也是兩個。對于需要使用對象的方法的…

vxe-table的序號一樣

使用vxe-table的時候&#xff0c;有的時候會出現序號相同的現象&#xff0c;這種現象一般出現在我們后面自己添加的行中&#xff0c;就像這種 此時的這三個序號是相同的&#xff0c;我來說一下原因&#xff0c;這是在添加新的一行的時候&#xff0c;有的時候數據很多&#xff0…

Mac 運行 Windows 軟件,Parallels Desktop 19和 CrossOver 24全面對比

Parallels Desktop 和 CrossOver 都是能滿足你「在 Mac 上運行 Windows 軟件」需求的工具。可能很多人都已經知道 Parallels Desktop 是「虛擬機」&#xff0c;但 CrossOver 其實并不是「虛擬機」。這兩款軟件有相同的作用&#xff0c;但由于實現原理的不同&#xff0c;兩者也有…

系統提示我未定義與 ‘double‘ 類型的輸入參數相對應的函數 ‘finverse‘,如何解決?

&#x1f3c6;本文收錄于「Bug調優」專欄&#xff0c;主要記錄項目實戰過程中的Bug之前因后果及提供真實有效的解決方案&#xff0c;希望能夠助你一臂之力&#xff0c;幫你早日登頂實現財富自由&#x1f680;&#xff1b;同時&#xff0c;歡迎大家關注&&收藏&&…

Kubernetes 部署簡單的應用

Kubernetes 部署簡單的應用 Kubernetes 是一個強大的容器編排平臺&#xff0c;它可以幫助我們自動化應用程序的部署、擴展和管理。在本期文章中&#xff0c;我們將學習如何使用 Kubernetes 部署一個簡單的應用程序。 1. 環境準備 確保你已經安裝了 Kubernetes 集群&#xff…

【python模塊】argparse

文章目錄 argparse模塊介紹基本用法add_argument() argparse模塊介紹 argparse 模塊是 Python 標準庫中的一個用于編寫用戶友好的命令行接口&#xff08;CLI&#xff09;的模塊。它允許程序定義它所需要的命令行參數&#xff0c;然后 argparse 會自動從 sys.argv 解析出那些參…

TCP粘包解決方法

一. 產生原因及解決方法 產生原因&#xff1a;TCP是面向連接、基于字節流的協議&#xff0c;其無邊界標記。當服務端處理速度比不其接收速度時&#xff0c;就很容易產生粘包現象。 解決方法&#xff1a;目前主要有兩種解決方法&#xff0c;一個是在內容中添加分割標識&#xf…

人臉識別考勤系統

人臉識別考勤系統是一種利用生物識別技術進行自動身份驗證的現代解決方案&#xff0c;它通過分析和比對人臉特征來進行員工的出勤記錄。這種系統不僅提升了工作效率&#xff0c;還大大減少了人為錯誤和欺詐行為的可能性。 一、工作原理 人臉識別考勤系統的核心在于其生物識別…

深入剖析Python中的Pandas庫:通過實戰案例全方位解讀數據清洗與預處理藝術

引言 隨著大數據時代的到來&#xff0c;數據的質量直接影響到最終分析結果的可靠性和有效性。在這個背景下&#xff0c;Python憑借其靈活強大且易于上手的特點&#xff0c;在全球范圍內被廣泛應用于數據科學領域。而在Python的數據處理生態中&#xff0c;Pandas庫無疑是最耀眼…

高級策略:解讀 SQL 中的復雜連接

了解基本連接 在深入研究復雜連接之前&#xff0c;讓我們先回顧一下基本連接的基礎知識。 INNER JOIN&#xff1a;根據指定的連接條件檢索兩個表中具有匹配值的記錄。LEFT JOIN&#xff1a;從左表檢索所有記錄&#xff0c;并從右表中檢索匹配的記錄&#xff08;如果有&#x…

管道支架安裝

工程結構施工完畢后&#xff0c;系統管道安裝完畢后的第一步任務就是管道支架的制作安裝&#xff0c;作為對管道固定和承重作用至關重要的支、托、吊架&#xff0c;有些項目部在施工中卻往往因為對它們的重要性認識不足&#xff0c;因存在僥幸心里或經驗主義&#xff0c;導致支…

NIO為什么會導致CPU100%?

1. Java IO 類型概覽 BIO&#xff1a;阻塞I/O&#xff0c;每個連接一個線程&#xff0c;簡單但遇到高并發時性能瓶頸明顯。NIO&#xff1a;非阻塞I/O&#xff0c;JDK 1.4引入&#xff0c;一個線程處理多個IO操作&#xff0c;提高資源利用率和系統吞吐量。AIO&#xff1a;異步I…

技術探索:利用Python庫wxauto實現Windows微信客戶端的全面自動化管理

項目地址&#xff1a;github-wxauto 點擊即可訪問 項目官網&#xff1a;wxauto 點擊即可訪問 &#x1f602;什么是wxauto? wxauto 是作者在2020年開發的一個基于 UIAutomation 的開源 Python 微信自動化庫&#xff0c;最初只是一個簡單的腳本&#xff0c;只能獲取消息和發送…