PyTorch處理數據--Dataset和DataLoader

? ? ? ?在 PyTorch 中,Dataset?和?DataLoader?是處理數據的核心工具。它們的作用是將數據高效地加載到模型中,支持批量處理、多線程加速和數據增強等功能。

一、Dataset:數據集的抽象?

Dataset?是一個抽象類,用于表示數據集的接口。你需要繼承?torch.utils.data.Dataset?并實現以下兩個方法:

  • __len__(): 返回數據集的總樣本數。
  • __getitem__(idx): 根據索引?idx?返回一個樣本(數據和標簽)。
?示例:自定義 Dataset
import torch
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data, labels, transform=None):self.data = dataself.labels = labelsself.transform = transform  # 數據預處理/增強函數def __len__(self):return len(self.data)def __getitem__(self, idx):sample = {"data": self.data[idx], "label": self.labels[idx]}if self.transform:sample = self.transform(sample)return sample
?使用場景?
  • 加載圖像、文本、表格數據等。
  • 支持數據預處理(如歸一化、裁剪)和數據增強(如隨機翻轉)。

二、?DataLoader:高效加載數據?

DataLoader?負責將?Dataset?包裝成一個可迭代對象,支持批量加載、多線程加速和數據打亂。

?基本用法
from torch.utils.data import DataLoader# 假設 dataset 是你的 CustomDataset 實例
data_loader = DataLoader(dataset,batch_size=32,       # 批量大小shuffle=True,        # 是否打亂數據(訓練時建議開啟)num_workers=4,       # 多線程加載數據的進程數drop_last=False      # 是否丟棄最后不足一個 batch 的數據
)

??遍歷 DataLoader

for batch in data_loader:data = batch["data"]    # 形狀:[batch_size, ...]labels = batch["label"] # 形狀:[batch_size]# 將數據送入模型訓練...

、pytorch內置數據集

PyTorch 提供了一系列內置數據集,這些數據集可以直接用于訓練模型。這些數據集涵蓋了多種領域,如圖像、文本、音頻等。以下是一些常用的PyTorch內置數據集:

圖像數據集
  1. MNIST: 手寫數字數據集,包含0到9的手寫數字圖片。

    from torchvision import datasets
    mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  2. CIFAR10/CIFAR100: 包含彩色圖片的數據集,CIFAR10有60000張32x32的彩色圖片,分為10個類別;CIFAR100類似但有100個類別。

    cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
  3. ImageNet: 包含超過1400萬張圖片的非常龐大的數據集,常用于圖像識別和分類任務。

    import torchvision.datasets as datasets
    imagenet_train = datasets.ImageNet(root='./data', split='train', download=True)
  4. STL10: 一個用于計算機視覺研究的小型圖像數據集,包含96x96的彩色圖片。

    stl10_train = datasets.STL10(root='./data', split='train', download=True)
  5. SVHN: 包含數字圖片的數據集,與MNIST類似但包含更多實際場景的圖片。

    svhn_train = datasets.SVHN(root='./data', split='train', download=True, transform=transform)
文本數據集

? ? 1.Text8: 一個用于自然語言處理的小型文本數據集。

from torchtext.datasets import Text8
text8_train = Text8(split=('train',))

? ? 2.?AG_NEWS: 包含新聞文章的文本數據集,分為4個類別。

from torchtext.datasets import AG_NEWS
ag_news_train = AG_NEWS(split=('train',))

音頻數據集??

? 1. Speech Commands: 一個用于語音識別的數據集,包含約65,000個單詞發音的音頻文件。?

from torchaudio.datasets import SPEECHCOMMANDS
speech_commands = SPEECHCOMMANDS(root="./data", download=True)

?使用方法
要使用這些數據集,首先需要導入torchvision(對于圖像數據集)、torchtext(對于文本數據集)或torchaudio(對于音頻數據集),然后使用其提供的類來加載數據。通常還包括一些數據預處理步驟,例如轉換(transforms)。

import torchvision.transforms as transforms
from torchvision import datasetstransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

四、完整代碼示例

步驟 1:創建數據集
import numpy as np
from torch.utils.data import Dataset, DataLoader# 生成示例數據(假設是 10 個樣本,每個樣本是長度為 5 的向量)
data = np.random.randn(10, 5)
labels = np.random.randint(0, 2, size=(10,))  # 二分類標簽class MyDataset(Dataset):def __init__(self, data, labels):self.data = torch.tensor(data, dtype=torch.float32)self.labels = torch.tensor(labels, dtype=torch.long)def __len__(self):return len(self.data)def __getitem__(self, idx):return {"data": self.data[idx],"label": self.labels[idx]}dataset = MyDataset(data, labels)
?步驟 2:創建 DataLoader
data_loader = DataLoader(dataset,batch_size=2,shuffle=True,num_workers=2
)

??步驟 3:使用 DataLoader 訓練模型

model = ...  # 你的模型
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()for epoch in range(10):for batch in data_loader:x = batch["data"]y = batch["label"]# 前向傳播outputs = model(x)loss = loss_fn(outputs, y)# 反向傳播optimizer.zero_grad()loss.backward()optimizer.step()

五、常見問題解決

?(1)數據格式不匹配?
  • ?問題?:DataLoader?返回的數據形狀與模型輸入不匹配。
  • ?解決?:檢查?Dataset?的?__getitem__?返回的數據類型和形狀,確保與模型輸入一致。
?(2)多線程加載卡頓?
  • ?問題?:設置?num_workers>0?時程序卡死或報錯。
  • ?解決?:在 Windows 系統中,多線程可能需要將代碼放在?if __name__ == "__main__":?塊中運行。
?(3)數據增強?
  • 使用?torchvision.transforms?中的工具(如?RandomCropRandomHorizontalFlip)對圖像數據進行增強:
    from torchvision import transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5]),
    ])
    
?(4)內存不足?
  • 對于大型數據集,使用?torch.utils.data.DataLoader?的?persistent_workers=True(PyTorch 1.7+)或優化數據加載邏輯。

六、高級功能

  • 分布式訓練?:使用?torch.utils.data.distributed.DistributedSampler?配合多 GPU。
  • ?預加載數據?:使用?torch.utils.data.TensorDataset?直接加載 Tensor 數據。
  • ?自定義采樣器?:通過?sampler?參數控制數據采樣順序(如平衡類別采樣)。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/73713.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/73713.shtml
英文地址,請注明出處:http://en.pswp.cn/web/73713.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Android 藍牙/Wi-Fi通信協議之:經典藍牙(BT 2.1/3.0+)介紹

在 Android 開發中,經典藍牙(BT 2.1/3.0)支持多種協議,其中 RFCOMM/SPP(串口通信)、A2DP(音頻流傳輸)和 HFP(免提通話)是最常用的。以下是它們在 Android 中的…

R002-云計算

1 概念 英文名:Cloud Computing 核心:云計算的核心概念就是以互聯網為中心,在網站上提供快速且安全的云計算服務與數據存儲,讓每一個使用互聯網的人都可以使用網絡上的龐大計算資源與數據中心 2.分類 基礎設施即服務(IaaS)它向…

降維(DimensionalityReduction)基礎知識2

文章目錄 五、基于局部結構保持的降維1、Laplacian Eigenmaps(拉普拉斯特征映射)(1)鄰接矩陣(2)圖論基礎(3)Laplace算子1、散度(Divergence)2、拉普拉斯算子3…

物聯網中的物模型是什么意思,在嵌入式軟件開發中如何體現?

1. 物模型的概念 物模型(Thing Model)是物聯網中對物理設備或虛擬設備的抽象描述,定義了設備的屬性、事件和服務。它是設備與云平臺或其他設備之間交互的基礎,用于統一描述設備的能力和行為。 1.1 物模型的組成 屬性&#xff0…

【藍橋杯】單片機設計與開發,PWM

一、PWM概述 用來輸出特定的模擬電壓。 二、PWM的輸出 三、例程一:單片機P34引腳輸出1kHZ的頻率 void Timer0Init(void);unsigned char PWMtt 0;void main(void) {P20XA0;P00X00;P20X80;P00XFF;Timer0Init();EA1;ET01;ET11;while(1);}void Timer0Init(void) //1…

C#中,什么是委托,什么是事件及它們之間的關系

1. 委托(Delegate) 定義與作用 ?委托?是類型安全的函數指針,用于封裝方法,支持多播(鏈式調用)。?核心能力?:將方法作為參數傳遞或異步回調。 使用場景 回調機制(如異步操作完…

從替代到超越,禪道國產化替代解決方案2.0發布!

3月22日,由禪道攜手上海惠艾信息科技、麥哲思科技共同舉辦的禪道?中國行北京站活動圓滿落下帷幕。 除深入探究AI賦能研發項目管理外,禪道在活動現場正式發布了《禪道國產化替代解決方案2.0》,助力企業全方位構建自主可控的研發項目管理新體…

【VirtualBox 安裝 Ubuntu 22.04】

網上教程良莠不齊,有一個CSDN的教程雖然很全面,但是截圖冗余,看蒙了給我,這里記錄一個整潔的教程鏈接。以備后患。 下載安裝全流程 UP還在記錄生活,看的我好羨慕,嗚嗚。 [VirtualBox網絡配置超全詳解]&am…

2025美國網絡專線國內服務商推薦

在海外業務競爭加劇的背景下,穩定高效的美國網絡專線已成為外貿企業、跨國電商及跨國企業的剛需。面對復雜的國際網絡環境和嚴苛的業務要求,國內服務商Ogcloud憑借其創新的SD-WAN技術架構與全球化網絡布局,正成為企業拓展北美市場的優選合作伙…

2.2.2 引入配置文件和定義配置類

本實戰通過三種方式實現Spring Boot中的配置加載與管理。首先,通過PropertySource加載自定義配置文件,結合ConfigurationProperties注解將配置文件中的屬性綁定到Java類中,實現配置的靈活管理。其次,利用ImportResource加載XML配置…

Django:構建高性能Web應用

引言:為何選擇Django? 在當今快速發展的互聯網時代,Web應用的開發效率與可維護性成為開發者關注的核心。Django作為一款基于Python的高級Web框架,以其"開箱即用"的特性、強大的ORM系統、優雅的URL路由設計,…

【銀河麒麟高級服務器操作系統 】虛擬機運行數據庫存儲異常現象分析及處理全流程

更多銀河麒麟操作系統產品及技術討論,歡迎加入銀河麒麟操作系統官方論壇 https://forum.kylinos.cn 了解更多銀河麒麟操作系統全新產品,請點擊訪問 麒麟軟件產品專區:https://product.kylinos.cn 開發者專區:https://developer…

《2核2G阿里云神操作!Ubuntu+Ollama低成本部署Deepseek模型實戰》

簡介: “本文為AI開發者揭秘如何在阿里云2核2G輕量級ECS服務器上,通過Ubuntu系統與Ollama框架實現Deepseek模型的高效部署。無需昂貴硬件,手把手教程涵蓋環境配置、資源優化及避坑指南,助力初學者用極低成本在云端跑通行業領先的大…

【bug解決】NameError: name ‘fused_act_ext‘ is not defined

問題 使用basicsr庫做超分的時候發現NameError: name fused_act_ext is not defined這個問題,一直不斷重復的使用pip uninstall basicsr 和 BASICSR_EXTTrue pip install basicsr 發現一直沒有執行編譯過程,導致一直推理失敗 原因 之前已經安裝過basi…

Anaconda開始菜單里添加JupyterLab快捷方式

Anaconda開始菜單里添加JupyterLab快捷方式 在 Windows 系統安裝 Anaconda 后,發現開始菜單只有 Jupyter Notebook,卻找不到Jupyter Lab入口。其實這是因為最新版 Anaconda 默認未預裝 Lab 組件,本篇介紹一種添加 Jupyter Lab入口到開始菜單…

【Qt】modbus客戶端筆記

Qt 中基于 Modbus 協議的通用客戶端學習筆記 一、概述 本客戶端利用 Qt 的 QModbusTcpClient 實現與 Modbus 服務器的通信,具備連接、讀寫寄存器、心跳檢測、自動重連等功能,旨在提供一個可靠且易用的 Modbus 客戶端框架,方便在不同項目中集…

解決Vmware 運行虛擬機Ubuntu22.04卡頓、終端打字延遲問題

親測可用 打開虛擬機設置,關閉加速3D圖形 (應該是顯卡驅動的問題,不知道那個版本的驅動不會出現這個問題,所以干脆把加速關了)

【網絡】Socket套接字

目錄 一、端口號 二、初識TCP/UDP協議 三、網絡字節序 3.1 概念 3.2 常用API 四、Socket套接字 4.1 概念 4.2 常用API (1)socket (2)bind sockaddr結構 (3)listen (4)a…

內聯函數/函數重載/函數參數缺省

一、內聯函數 為了減少函數調用的開銷 在函數定義前加“inline”關鍵字,即可定義內聯函數 二、函數重載 1.名字相同 2.參數個數或者參數類型不同 編譯器根據調用語句實參的個數和類型判斷應該調用哪個函數 三、函數的缺省參數 定義函數的時候可以讓最右邊的連…

基于神經網絡的文本分類的設計與實現

標題:基于神經網絡的文本分類的設計與實現 內容:1.摘要 在信息爆炸的時代,大量文本數據的分類處理變得至關重要。本文旨在設計并實現一種基于神經網絡的文本分類系統。通過構建合適的神經網絡模型,采用公開的文本數據集進行訓練和測試。在實驗中&#x…