Pytorch-02數據集和數據加載器的基本原理和基本操作

1. 為什么要有數據集類和數據加載器類?

一萬個人會有一萬種獲取并處理原始數據樣本的代碼,這會導致對數據的操作代碼標準不一,并且很難復用。
在這里插入圖片描述

為了解決這個問題,Pytorch提供了兩種最基本的數據相關類:

  • torch.utils.data.Dataset: 一個數據集對象,包含每個數據樣本路徑以及對應標簽
  • torch.utils.data.DataLoader:持有一個對Dataloader的迭代器,通過調用Dataset__getitem__函數方便地獲取實際的樣本-標簽對

PyTorch 為不同的任務類型提供了方便的預加載數據集,例如 torchvision.datasets、torchaudio.datasets 等。這些數據集都是 torch.utils.data.Dataset 的子類,可以直接通過dataset.數據集名稱的方式來方便的下載經典的數據集,在下面你會看到它的使用例。

2. Dataset類的使用方法

2.1 加載一個Fashion-MNIST數據集

Fashion-MNIST 是一個來自 Zalando 的文章圖像數據集,包含 60,000 個訓練樣本和 10,000 個測試樣本。每個樣本由一張 28×28 的灰度圖像和其對應的 10 個類別中的一個標簽組成。

這是一個使用TorchVision預加載數據集類加載Fashion-MNIST 數據集的例子,如下是每個參數代表的意思:

  • root:是存儲訓練/測試數據的路徑。
  • train:指定是訓練數據集還是測試數據集。
  • download=True:如果數據在 root 路徑下不可用,則從互聯網下載。
  • transform 和 target_transform:分別指定特征和標簽的轉換。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data", # 指定數據集實際存放的路徑(相對于本代碼文件)train=True, # 指定這是訓練集還是測試集download=True, # 如果在root下沒有數據,從網絡上自動下載transform=ToTensor() # 給每一張圖片轉換為Tensor的數據類型
)test_data = datasets.FashionMNIST(root="data", # 指定數據集實際存放的路徑(相對于本代碼文件)train=False, # 指定這是訓練集還是測試集download=True, # 如果在root下沒有數據,從網絡上自動下載transform=ToTensor() # 給每一張圖片轉換為Tensor的數據類型
)

在這里插入圖片描述

2.2 遍歷并可視化數據集

我們可以簡單的使用training_data[index]來獲取Datasets類中對應index的樣本。通常可以用matplotlib來可視化我們的一些訓練數據集:

labels_map = { # 定義一個標簽映射字典0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}figure = plt.figure(figsize=(8, 8)) # 創建一個新的畫布,大小為8x8英寸
cols, rows = 3, 3 # 定義展示網格尺寸 3x3的展示網格,每個網格展示i一個圖片for i in range(1, cols * rows + 1): # plt的索引從1開始,配合一下sample_idx = torch.randint(len(training_data), size=(1,)).item() # 生成一個包含1個元素的張量,item()回python數據類型之后為0到數據集大小-1的隨機整數img, label = training_data[sample_idx] # 本質上是在調用__getitem__函數figure.add_subplot(rows, cols, i) # 在之前創建的圖形窗口中,添加一個子圖(subplot),并將當前的畫筆操作對象設置為當前子圖plt.title(labels_map[label]) # 子圖的標題設置為對應的標簽字符串plt.axis("off") # 不顯示坐標軸plt.imshow(img.squeeze(), cmap="gray") # 把當前網格畫好
plt.show() # 展示畫布

這里我并不知道為啥要使用img.squeeze()這個方法, 直到我把img的shape的打印出來:
在這里插入圖片描述
現在img是一個3維的tensor,但是plt.imshow需要輸入二維的tensor,所以使用squeeze的目的是把所有的尺寸為1的維度給擠壓掉,將img維度降維到2維,然后就可以用plt可視化了。

在這里插入圖片描述

2.3 進階:如何制作一個自己的數據集類

自定義的 Dataset 類必須實現三個函數:__init____len____getitem__。請看下面的實現示例:FashionMNIST 圖像存儲在 img_dir 目錄中,而它們的標簽則單獨保存在 annotations_file 的 CSV 文件里。

import os
import pandas as pd
from torchvision.io import decode_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitemm__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) # iloc全寫為“integer location”, 表明你要通過數據的行和列的整數索引來選擇數據image = decode_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

在接下來的部分將詳細解釋每個方法的作用。

__init__

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform

這個方法會在初始化數據集的時候調用。其主要完成如下工作:

  1. 讀取標簽文件
  2. 指定圖片文件夾路徑
  3. 指定樣本和標簽的transform(這個下面細講)

一個Fashion-MNIST是一個分類任務,其標簽文件annotations大概長這樣:

tshirt1.jpg, 0 # 樣本-標簽對
tshirt2.jpg, 0
......
ankleboot999.jpg, 9

__len__

這個方法是簡單返回數據集的樣本數量:

def __len__(self):return len(self.img_labels)

__getitem__

這個方法是Dataset類的核心,當此方法被Dataloader調用,請求特定idx的數據時,Dataset會根據idx,讀取對應的圖片和標簽,并對它們做出各自的transform之后,返回給Dataloader,讓它把圖片和標簽搬運到內存.

def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

3. Dataloader類的使用方法

3.1 對數據集對象配置Dataloader

Dataset類的__getitem__方法被調用的時候,他會返回一個樣本-標簽對。

但是在實際的模型訓練中,我們還有一些別的要求,例如:

  1. 以“小批量(minibatches)”的方式傳遞樣本。(減少單樣本噪聲帶來的震蕩,讓梯度更新的方向更加穩定)
  2. 在每個周期(epoch)對數據進行重新洗牌(reshuffle),以減少模型過擬合。
  3. 使用 Python 的多進程(multiprocessing)來加快數據檢索速度。

以上的要求可以通過如下的參數設定來滿足:

from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True, num_workers=5)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=5)
  • batch_size=64 設定批量大小為64
  • shuffle=True 指定一個epoch之后dataloader持有的索引要重新洗牌
  • num_workers=5 指定dataloader會同時開啟5個進程去調用dataset的__getitem__方法

以上是Dataloader最基本的用法,不過,當你有GPU的時候,我推薦你也把下面兩個參數打開:
pin_memory=True 開啟鎖頁內存,減少CPU到GPU的數據傳遞延遲
persistent_workers=True 每個epoch結束后不銷毀dataloader所開啟的worker進程,而是接著用,這樣剩下了worker的初始化時間

3.2 使用Dataloader遍歷數據集

給Dataset配置好對應的Dataloader后,就可以開始用dataloader遍歷它了。每次遍歷都會返回一個batch_size的訓練圖片和訓練標簽對(這里就是64個)。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader)) # 先從train_dataloader中獲得一個迭代器,然后調用next獲取其下一個元素
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

在這里插入圖片描述

由于開啟了shuffle=True,所以每次遍歷完整個數據集后train_dataloader持有的索引會被打亂。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/91740.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/91740.shtml
英文地址,請注明出處:http://en.pswp.cn/web/91740.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

無圖形界面的CentOS 7網絡如何配置

進入虛擬機輸入ip addr命令:從 ip addr命令的輸出可以明確看出 ??lo和 ens33是兩個不同的網絡接口(網卡)lo(回環接口)????作用??:虛擬的本地回環網卡,用于本機內部通信(如 1…

機器學習之線性回歸的入門學習

線性回歸是一種監督學習算法,用于解決回歸問題。它的目標是找到一個線性關系(一條直線或一個超平面),能夠最好地描述一個或多個自變量(特征)與一個因變量(目標)之間的關系。利用回歸…

2-5 Dify案例實踐—利用RAG技術構建企業私有知識庫

目錄 一、RAG技術的定義與作用 二、RAG技術的關鍵組件 三、RAG技術解決的問題 四、RAG技術的核心價值與應用場景 五、如何實現利用RAG技術構建企業私有知識庫 六、Dify知識庫實現詳解 七、創建知識庫 1、創建知識庫 2、上傳文檔 3、文本分段與清洗 4、索引方式 5、…

斷路器瞬時跳閘曲線數據獲取方式

斷路器瞬時短路電流時,時間是在60ms內的,仿真器去直接捕獲電流有效值很難。按照電流互感器的電流曲線特性,電流越大,由于互感器飽和,到達一定電流值的時候,電流會趨于平穩不再上升,ADC-I曲線由線…

技巧|SwanLab記錄混淆矩陣攻略

繪制混淆矩陣(Confusion Matrix),用于評估分類模型的性能。混淆矩陣展示了模型預測結果與真實標簽之間的對應關系,能夠直觀地顯示各類別的預測準確性和錯誤類型。 混淆矩陣是評估分類模型性能的基礎工具,特別適用于多…

HTTPS的工作原理

文章目錄HTTP有什么問題?1. 明文傳輸,容易被竊聽2. 無法驗證通信方身份3. 數據完整性無法保證HTTPS是如何解決這些問題的?HTTPS的工作原理1. SSL/TLS握手2. 數據加密傳輸3. 完整性保護4. 連接關閉總結HTTP有什么問題? 1. 明文傳輸…

ECMAScript2020(ES11)新特性

概述 ECMAScript2020于2020年6月正式發布, 本文會介紹ECMAScript2020(ES11),即ECMAScript的第11個版本的新特性。 以下摘自官網:ecma-262 ECMAScript 2020, the 11th edition, introduced the matchAll method for Strings, to produce an …

機器視覺引導機器人修磨加工系統助力芯片封裝

芯片制造中,劈刀同軸度精度對封裝質量至關重要。傳統加工在精度、效率、穩定性、良率及操作便捷性上存在不足:精度不足:劈刀同軸度需控在 0.003mm 內,傳統手段難達標,致芯片封裝良率低;效率良率低 &#xf…

Python編程基礎與實踐:Python模塊與包入門實踐

Python模塊與包的深度探索 學習目標 通過本課程的學習,學員將掌握Python中模塊和包的基本概念,了解如何導入和使用標準庫中的模塊,以及如何創建和組織自己的模塊和包。本課程將通過實際操作,幫助學員加深對Python模塊化編程的理解…

【Django】-4- 數據庫存儲和管理

一、關于ORM ORM 是啥呀ORM 就是用 面向對象 的方式,把數據庫里的數據還有它們之間的關系映射起來~就好像給數據庫和面向對象之間搭了一座小橋梁🎀對應關系大揭秘面向對象和數據庫里的東西,有超有趣的對應呢👇類 → 數…

深入 Go 底層原理(四):GMP 模型深度解析

1. 引言在上一篇文章中,我們宏觀地了解了 Go 的調度策略。現在,我們將深入到構成這個調度系統的三大核心組件:G、M、P。理解 GMP 模型是徹底搞懂 Go 并發調度原理的關鍵。本文將詳細解析 G、M、P 各自的職責以及它們之間是如何協同工作的。2.…

AI賦能測試:技術變革與應用展望

AI 在測試中的應用:技術賦能與未來展望 目錄 AI 在測試中的應用:技術賦能與未來展望 1. 引言 1.1 測試在軟件開發中的重要性 1.2 AI 技術如何改變傳統測試模式 1.3 文章結構概述 2. AI 在測試中的核心應用場景 2.1 自動化測試優化 2.1.1 智能測…

Mujoco(MuJoCo,全稱Multi - Joint dynamics with Contact)一種高性能的物理引擎

Mujoco(MuJoCo,全稱Multi - Joint dynamics with Contact)是一種高性能的物理引擎,主要用于模擬多體動力學系統,廣泛應用于機器人仿真、運動學研究、人工智能等領域。以下是關于Mujoco仿真的一些詳細介紹: …

winform-窗體應用的功能介紹(部分)

1--Point實現在窗口(Form)中一個按鈕(控件)的固定位置(所在位置)一個按鈕(控件)的位置一般是固定的,另一個按鈕在窗口中位置是隨機產生的Location屬性:Location new Point(X,Y);在C#的Winform應用程序里,Button控件的鼠標懸標懸浮事件是不存在內置延遲時間的。當鼠標指針進入按…

最新Windows11系統鏡像,23H2 64位ISO鏡像

Windows 11 主要分為 Consumer Editions(消費者版)和 Business Editions(商業版)兩大類別 。消費者版主要面向家庭和個人用戶,商業版則側重于企業和商業用戶。這兩大類別中存在部分重疊的版本,比如專業版和…

linux基本系統服務——DNS服務

一、DNS域名解析原理DNS&#xff0c;Domain Name System&#xff0c;域名系統&#xff1a;在互聯網中由大量域名解析服務器共同提供的一整套關于“域名 <--> IP地址”信息查詢的數據系統!!!! C/S架構&#xff1a;DNS服務端監聽UDP 53端口&#xff08;處理客戶端查詢&…

數據處理和統計分析——08 apply自定義函數

1 apply()函數 1.1 apply()函數簡介 Pandas提供了很多數據處理的API&#xff0c;但當提供的API不能滿足需求的時候&#xff0c;需要自己編寫數據處理函數, 這個時候可以使用apply()函數&#xff1b;apply()函數可以接收一個自定義函數&#xff0c;可以將DataFrame的行或列數據傳…

C++冰箱管理實戰代碼

基于C++的冰箱管理實例 以下是一些基于C++的冰箱管理實例示例,涵蓋不同功能場景,每個示例聚焦特定實現點,代碼可直接擴展或整合到項目中。 示例1:基礎冰箱類定義 class Refrigerator { private:int capacity;std::vector<std::string> items; public:Refrigerator(…

【Python】【數據分析】Python 數據分析與可視化:全面指南

目錄1. 環境準備2. 數據處理與清洗2.1 導入數據2.2 數據清洗示例&#xff1a;處理缺失值示例&#xff1a;處理異常值2.3 數據轉換3. 數據分析3.1 描述性統計3.2 分組分析示例&#xff1a;按年齡分組計算工資的平均值3.3 時間序列分析4. 數據可視化4.1 基本繪圖示例&#xff1a;…

【AI】AIService(基本使用與指令定制)

【AI】AIService(基本使用與指令定制) 文章目錄【AI】AIService(基本使用與指令定制)1. 簡介2. AIService2.1 引入依賴2.2 編寫AIService接口2.3 測試代碼3. 指令定制3.1 系統提示詞3.2 用戶提示詞1. 簡介 AIService可以被視為應用程序服務層的一個組件&#xff0c;提供對應的…