【DL學習筆記】Dataset類功能以及自定義

文章目錄

  • 一、Dataset 與 DataLoader 功能介紹
    • 抽象類Dataset的作用
    • DataLoader 作用
    • 兩者關系
  • 二、自定義Dataset類
    • Dataset的三個重要方法
      • `__len__()`方法
      • `_getitem__()`方法
      • `__init__` 方法
  • 三、現成的torchvision.datasets模塊
    • MNIST舉例
    • COCODetection舉例
    • `torchvision.datasets.MNIST`使用舉例
    • `torchvision.datasets.CocoDetection`舉例

一、Dataset 與 DataLoader 功能介紹

抽象類Dataset的作用

簡單來說,就是將原始數據(可能是圖片、文本、音頻等各種格式)整理成模型可以處理的格式,為后續的數據加載和處理做準備。功能是定義數據集的基本屬性數據獲取方式

  • 初始化數據路徑:在Dataset類的__init__方法中,通常會初始化數據存放的路徑,以及一些數據預處理的操作,比如指定圖片數據集圖片所在文件夾路徑,文本數據集文本文件路徑等 。包含 加載數據/讀取數據、預處理數據、圖像增強 等一系列操作
  • 獲取單個樣本及其標簽:通過實現__getitem__方法,根據給定的索引(dataloader返回的),返回相應的數據樣本和對應的標簽。例如在圖片分類任務中,給定索引后,返回該索引對應的圖片數據(經過預處理,如調整尺寸、歸一化等)以及圖片的類別標簽。
  • 統計樣本數量:通過實現__len__方法,返回數據集中樣本的總數,方便在訓練和評估過程中知道數據規模 。

DataLoader 作用

DataLoader是在Dataset的基礎上,提供了一種更加高效、便捷地加載數據的方式,它可以將Dataset返回的單個樣本,按照指定的方式進行打包(如組成batch)、打亂順序等操作,從而滿足模型訓練和評估的需求。

  • 創建數據批次,指定數據打包輸出規則:通過batch_size參數,將Dataset中的單個樣本打包成一個個批次(batch)的數據。

    • collate_fn指定如何從NNN張訓練集選出一個batch的Nbatch_size\frac{N}{batch\_size}batch_sizeN?張圖片。
    • 例如batch_size=32,那么DataLoader每次會從Dataset中取出32個樣本組成一個batch。每次迭代,返回的是 一個batch 的數據
  • 自定義數據采樣,指定數據迭代讀取規則:

    • 一般使用自定義的采樣器(Sampler),實現對數據的特殊采樣方式,比如分層采樣(在類別不均衡的數據集中,保證每個batch中各類別的樣本比例與原始數據集相似)等。
    • dataset對象是dataloader的一個參數,通過dataset讓dataloader知道訓練集一共多少圖片,從而知道共跌代多少次。
  • 數據打亂:通過shuffle參數設置是否在每個epoch開始時打亂數據順序,這樣可以避免模型在訓練時對數據產生特定的依賴,有助于模型學習到更通用的特征,提高模型的泛化能力 。

  • 多進程加載:通過num_workers參數設置多進程加載數據,從而加快數據加載速度,尤其是在數據量較大、數據預處理較為復雜的情況下,多進程可以充分利用CPU資源,減少數據加載時間,避免數據加載成為訓練過程中的瓶頸 。

兩者關系

  • Dataset是數據的基礎容器,定義了如何獲取數據集中的單個樣本;

  • DataLoader則是Dataset的上層應用,負責按照特定規則(如批量處理、打亂順序等)從Dataset中高效地加載數據,供模型進行訓練、驗證和測試等操作。

  • 可以說,Dataset是數據的來源和基本操作接口,DataLoader則是為了更好地適配模型訓練需求,對Dataset的數據進行進一步處理和組織的工具。

二、自定義Dataset類

所謂的 自定義 dataset ,即自己去寫一個 Dataset 類,要滿足兩個要求:

  • 一般需要繼承自 torch.utils.data.Dataset
    • 繼承 torch.utils.data.Dataset 主要目的是為了與 DataLoader 保持兼容,確保數據集遵循 DataLoader 的接口標準,方便后續使用 PyTorch 提供的工具,比如 :批量加載、打亂數據、并行處理等功能
  • 并且滿足和DataLoader進行交互的規范 :
    • 因為DataLoader會調用 Datasetlen()getitem() 方法,所以自定義 Dataset 類必須實現這兩個方法,如此才能保證 DataLoader 可以正確地加載和操作你的數據集
  • 兼容訓練和推理階段

Dataset的三個重要方法

創建自定義 Dataset類時,必須實現的3個方法 :__init__()__len__()__getitem__()
這些方法定義了數據集的基本結構和行為,也是 DataLoader 可以正確的從 Dataset 中讀取數據的基礎。

__len__()方法

DataLoader是通過Dataset的 __len__(),得知訓練集一共多少數據樣本的。

def __len__(self):return len(self.file_list)
  • 返回值:數據集中的樣本的總數。
  • 作用:
    • 方便通過調用 len(dataset) 來獲取數據量,其中 dataset 為 Dataset 對象
    • Dataloader 會用它和 batch_size 一起來計算一個 epoch 要迭代多少個 steps:
      steps=len(dataset)batch_sizesteps = \frac{len(dataset)}{batch\_size}steps=batch_sizelen(dataset)?
    • DataLoader調用len方法的代碼封裝在源碼了,所以看不到顯式調用。DataLoader得到一共NNN個數據樣本后,生成000 ~ N?1N-1N?1的索引。再根據batch_size和是否打亂,生成一個batch的索引列表,再將每個索引idx傳入到Dataset的_getitem__()方法中返回得到圖片和索引return image, label

_getitem__()方法

作用: 根據給定的索引返回數據集中的一個樣本。這是用于獲取數據集中單個樣本的方法。

def __getitem__(self, idx):# 通過索引idx,獲取圖片地址img_nameimg_name = os.path.join(self.data_folder, self.file_list[idx])# 根據圖片地址img_name讀取對應圖像original_imageoriginal_image = Image.open(img_name)# 通過索引idx獲取圖片對應的標簽(這里舉的例子的標簽含在圖片名中)label = img_name.split('_')[-1].split('.')[0]# 圖像預處理和數據增強(僅訓練階段)if self.train:image = self.transform(original_image)else:image = self.transform(original_image)# 返回處理好的一張圖像和標簽return image, label
  • 接收參數: index(idx)是單個數據樣本的索引,由DataLoader傳來的
  • 返回值: 返回數據集中索引指定的樣本。通常是一個包含輸入數據和對應標簽的元組。這里可以根據自己的需求,進行自定義。

DataLoader返回的是一個batch的數據,具體是:

  • DataLoader的采樣器sampler根據數據總量和batch_size=2,和采樣方法(舉例為順序采樣)得到第一次迭代結果為索引列表[0, 1]
  • DataLoader分別把索引0和1給Dataset,__getitem__()方法返回出對應單個索引的圖片和標簽。
  • 把得到的一個batch的兩組圖片和標簽給collate_fn函數進行打包并以一種數據結構儲存,由DataLoader返回

__init__ 方法

  • 參數: 根據需要傳遞一些參數,例如文件路徑、數據轉換等。
  • 作用: 構造方法,配好len和getitem方法做一些初始化工作,需要什么數據,就傳入進來賦值到成員屬性。
def __init__(self, data_folder, train, transform=None):self.data_folder = data_folderself.transform = transformself.file_list = os.listdir(data_folder)# 把文件名讀取出來,存入到file_list,方便len方法獲取數據量self.train = train

例如:設置文件路徑selfl.data_folder、定義數據轉換的transforms、當前是訓練階段還是驗證階段的布爾值train等。

三、現成的torchvision.datasets模塊

對于一些公開的數據集,可以直接用torchvision.datasets模塊的現成的Dataset類。

Pytorch官方文檔的torchvision的Dataset列出了可使用的數據集的Dataset,實現了getitem和len方法

在這里插入圖片描述

MNIST舉例

這里以Image classification任務的MNIST(mixed national institute of standards and technology database)數據集舉例,點入詳情頁課查看:
在這里插入圖片描述

在這里插入圖片描述

train_dataset = torchvision.datasets.MNIST(root,    train=True,               transform=None,  target_transform= None  download=True)

參數:

  • root:數據集存放的路徑
  • download:是否下載數據集,默認為False 。配合root參數:
    • 若設置download=True
      • root目錄下沒有該數據集,數據集將會被下載到root指定的位置。
      • root目錄下已經存在該數據集,則不會重新下載,而是會直接使用已存在的數據,以節省時間
    • 若設置download=False,程序將會在root指定的位置查找數據集,如果數據集不存在,則會拋出錯誤。
  • train
    • 如果是True,下載訓練集trainin.pt
    • 如果是False,下載測試集test.pt。默認是True
  • transform:接收torchvision.transforms的對象,一系列作用在PIL圖片上的轉換操作,用于對數據集的圖像預處理和數據增強。
  • target_transform:對target處理,一般不用。因為出來target出來一般用自定義的Dataset,因為圖像處理和target處理要放一個transform里寫

COCODetection舉例

Image detection任務的COCO數據集
在這里插入圖片描述

注意:對于一部分數據集比如torchvision.datasets.CocoDetection,Pytorch不提供下載功能 (具體情況取決于數據集的來源和許可協議),就沒有download參數。
所以在使用 torchvision.datasets.CocoDetection 這個現成的Dataset類之前,需要確保已經下載并淮備好COCO數
據集的圖像和標注文件。然后使用torchvision.datasets.CocoDetection 類來加載 COCO數據集。

torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None, transforms=None)
  • root:指定圖片地址(本地已經下載下來的圖像地址)
  • annFile:指定標注文件地址(本地已經下載下來的標注文件地址)
  • transform:圖像處理 (用于PIL)
  • target_transform:標注處理
  • transforms:圖像和標注的處理

torchvision.datasets.MNIST使用舉例

訓練集和驗證集分別實例化一個Dataset類(torchvision.datasets.MNIST)的對象,傳入的transforms參數都為實例化的transforms.Compose對象my_transform。數據集下載到當下文件所在目錄下。

import torchvision
from torchvision.transforms import transforms
import torch.utils.data as data
import matplotlib.pyplot as pltbatch_size = 5# transforms.Compose的對象,傳入到transforms參數
my_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],  # mean=[0.485, 0.456, 0.406]std=[0.5])])  # std=[0.229, 0.224, 0.225]train_dataset = torchvision.datasets.MNIST(root="./",train=True,transform=my_transform,download=True)val_dataset = torchvision.datasets.MNIST(root="./",train=False,transform=my_transform,download=True)

在這里插入圖片描述

  • 可以看的在當下目錄下出現了一個MNIST文件夾,
  • .gz后綴的是下載的壓縮文件,程序自動解壓為同名的二進制文件
  • Dataset會自動處理好二進制文件,最終從DataLoader跌代出來的是正常的單通道灰度圖。

將定義出的訓練集和驗證集的Dataset對象,分別作為參數傳入到兩個DataLoader,得到兩個DataLoader對象

train_loader = data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)val_loader = data.DataLoader(val_dataset,batch_size=batch_size,shuffle=True)

分別調用量Dataset的len方法,輸出數據量。再將train_loader轉換為迭代器iter(train_loader),通過next方法得到一個batch的image和label。
打印出一個batch的image的shape。[5, 1, 28, 28]分別指batch_size,圖片通道數,圖像長寬。
打印出標簽label列表。
最后可視化一個batch的圖和標簽。

print(len(train_dataset))
print(len(val_dataset))image, label = next(iter(train_loader))
print(image.shape)
print(label)for i in range(batch_size):plt.subplot(1, batch_size, i + 1)plt.title(label[i].item())plt.axis("off")plt.imshow(image[i].permute(1, 2, 0))plt.show()

在這里插入圖片描述

torchvision.datasets.CocoDetection舉例

需要把數據集的下載地址換掉,換成你的 COCO數據集地址

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms import functional as F
import randomdef collate_fn_coco(batch):return tuple(zip(*batch))coco_det = datasets.CocoDetection(root="./COCO2017/train2017",annFile="./COCO2017/annotations/instances_train 2017.json")sampler = torch.utils.data.SequentialSampler(coco_det)  # RandomSampler
batch_sampler = torch.utils.data.BatchSampler(sampler, 1, drop_last=True)
data_loader = torch.utils.data.DataLoader(coco_det,batch_sampler=batch_sampler,collate_fn=collate_fn_coco)# 可視化
iterator = iter(data_loader)
imgs, gts = next(iterator)
img,  gts_one_img = imgs[0], gts[0]bboxes = []
ids = []
for gt in gts_one_img:bboxes.append([gt['bbox'][0],gt['bbox'][1],gt['bbox'][2],gt['bbox'][3]])ids.append(gt['category_id'])fig, ax = plt.subplots()
for box, id in zip(bboxes, ids):x = int(box[0])y = int(box[1])w = int(box[2])h = int(box[3])rect = plt.Rectangle((x, y), w, h, edgecolor='r', linewidth=2, facecolor='none')ax.add_patch(rect)ax.text(x, y, id, backgroundcolor="r")plt.axis("off")
plt.imshow(img)
plt.show()

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/94476.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/94476.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/94476.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Python爬蟲實戰:研究python_reference庫,構建技術研究數據系統

1. 引言 1.1 研究背景與意義 在大數據時代,數據已成為重要的生產要素。互聯網作為全球最大的信息庫,蘊含著海量有價值的數據。如何從紛繁復雜的網絡信息中快速、準確地提取所需數據,成為各行各業面臨的重要課題。網絡爬蟲技術作為數據獲取的關鍵手段,能夠模擬人類瀏覽網頁…

Web開發系列-第15章 項目部署-Docker

第15章 項目部署-Docker Docker技術能夠避免部署對服務器環境的依賴,減少復雜的部署流程。 輕松部署各種常見軟件、Java項目 參考文檔:???????????????????第十五章:…

微軟無界鼠標(Mouse without Borders)安裝及使用:多臺電腦共用鼠標鍵盤

文章目錄一、寫在前面二、下載安裝1、兩臺電腦都下載安裝2、被控端3、控制端主機三、使用一、寫在前面 在辦公中,我們經常會遇到這種場景,自己帶著筆記本電腦外加公司配置的臺式機。由于兩臺電腦,所以就需要搭配兩套鍵盤鼠標。對于有限的辦公…

nodejs 編程基礎01-NPM包管理

1:npm 包管理介紹 npm 是nodejs 的包管理工具,類似于java 的maven 和 gradle 等,用來解決nodejs 的依賴包問題 使用場景:1. 從NPM 服務騎上下載或拉去別人編寫好的第三方包到本地進行使用2. 將自己編寫代碼或軟件包發布到npm 服務器供他人使用…

基于Mediapipe_Unity_Plugin實現手勢識別

GitHub - homuler/MediaPipeUnityPlugin: Unity plugin to run MediaPipehttps://github.com/homuler/MediaPipeUnityPlugin 實現了以下: public enum HandGesture { None, Stop, ThumbsUp, Victory, OK, OpenHand } 核心腳本&#xff1a…

Android 項目構建編譯概述

主要內容是Android AOSP源碼的管理方式,項目源碼的構建和編譯,用到比如git、repo、gerrit一些命令工具,以及使用Soong編譯系統,編寫Android.bp文件的格式樣式。 1. Android操作系統堆棧概述 Android 是一個針對多種不同設備類型打…

Python爬蟲08_Requests聚焦批量爬取圖片

一、Requests聚焦批量爬取圖片 import re import requests import os import timeurl https://www.douban.com/ userAgent {User-Agent:Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:122.0) Gecko/20100101 Firefox/122.0}#獲取整個瀏覽頁面 page_text requests.get(urlur…

Spring Cloud系列—簡介

目錄 1 單體架構 2 集群與分布式 3 微服務架構 4 Spring Cloud 5 Spring Cloud環境和工程搭建 5.1 服務拆分 5.2 示例 5.2.1 數據庫配置 5.2.2 父子項目創建 5.2.3 order_service子項目結構配置 5.2.4 product_service子項目結構配置 5.2.5 服務之間的遠程調用 5.…

【普中STM32精靈開發攻略】--第 1 章 如何使用本攻略

學習本開發攻略主要參考的文檔有《STM32F1xx 中文參考手冊》和《Cortex M3權威指南(中文)》,這兩本都是 ST 官方手冊,尤其是《STM32F1xx 中文參考手冊》,里面包含了 STM32F1 內部所有外設介紹,非常詳細。大家在學習 STM32F103的時…

【Docker】RK3576-Debian上使用Docker安裝Ubuntu22.04+ROS2

1、簡述 RK3576自帶Debian12系統,如果要使用ROS2,可以在Debian上直接安裝ROS2,缺點是有的ROS包需要源碼編譯;當然最好是使用Ubuntu系統,可以使用Docker安裝,或者構建Ubuntu系統,替換Debian系統。 推薦使用Docker來安裝Ubuntu22.04,這里會有個疑問,是否可以直接使用Do…

解決docker load加載tar鏡像報json no such file or directory的錯誤

在使用docker加載離線鏡像文件時,出現了json no such file or directory的錯誤,剛開始以為是壓縮包拷貝壞了,重新拷貝了以后還是出現了問題。經過網上查找方案,并且自己實踐,采用下面的簡單方法就可以搞定。 歸結為一句…

《協作畫布的深層架構:React與TypeScript構建多人實時繪圖應用的核心邏輯》

多人在線協作繪圖應用的構建不僅是技術棧的簡單組合,更是對實時性、一致性與用戶體驗的多維挑戰。基于React與TypeScript開發這類應用,需要在圖形繪制的基礎功能之外,解決多用戶并發操作的同步難題、狀態回溯的邏輯沖突以及大規模協作的性能瓶頸。每一層架構的設計,都需兼顧…

智慧社區(八)——社區人臉識別出入管理系統設計與實現

在社區安全管理日益智能化的背景下,傳統的人工登記方式已難以滿足高效、精準的管理需求。本文將詳細介紹一套基于人臉識別技術的社區出入管理系統,該系統通過整合騰訊云 AI 接口、數據庫設計與業務邏輯,實現了居民出入自動識別、記錄追蹤與訪…

嵌入式開發學習———Linux環境下IO進程線程學習(四)

進程相關函數fork創建一個子進程,子進程復制父進程的地址空間。父進程返回子進程PID,子進程返回0。pid_t pid fork(); if (pid 0) { /* 子進程代碼 */ } else { /* 父進程代碼 */ }getpid獲取當前進程的PID。pid_t pid getpid();getppid獲取父進程的P…

標記-清除算法中的可達性判定與Chrome DevTools內存分析實踐

引言 在現代前端開發中,內存管理是保證應用性能與用戶體驗的核心技術之一。作為JavaScript運行時的基礎機制,標記-清除算法(Mark-and-Sweep) 通過可達性判定決定哪些內存需要回收,而Chrome DevTools提供的Memory工具則為開發者提供了深度的內…

微算法科技(NASDAQ:MLGO)基于量子重加密技術構建區塊鏈數據共享解決方案

隨著信息技術的飛速發展,數據已成為數字經濟時代的核心生產要素。數據的共享和安全往往是一對難以調和的矛盾。傳統的加密方法在面對日益強大的計算能力和復雜的網絡攻擊時,安全性受到了挑戰。微算法科技(NASDAQ:MLGO)通過引入量子重加密技術…

FastAPI快速入門P2:與SpringBoot比較

歡迎來到啾啾的博客🐱。 記錄學習點滴。分享工作思考和實用技巧,偶爾也分享一些雜談💬。 有很多很多不足的地方,歡迎評論交流,感謝您的閱讀和評論😄。 目錄引言1 FastAPI事件管理2 類的使用2.1 初始化方法對…

SAP-ABAP: Open SQL集合函數COUNT(統計行數)、SUM(數值求和)、AVG(平均值)、MAX/MIN(極值)深度指南

SAP Open SQL集合函數深度指南 1. 核心價值與特性函數作用關鍵特性COUNT統計行數用COUNT(*)包含NULL值行,COUNT(字段)排除NULLSUM數值求和自動過濾NULL值,結果類型與源字段相同AVG平均值必須用TYPE f接收,否則四舍五入導致精度丟失MAX/MIN極值…

【docker】UnionFS聯合操作系統

Linux 的 Namespace、CGroups 和 UnionFS 三大技術支撐了 Docker 的實現。 一、為什么需要聯合文件系統?在傳統操作系統中,每個文件系統都是獨立的孤島。但當我們需要:合并多個目錄的內容保持基礎系統不變的同時進行修改高效共享重復文件內容…

CTF-XXE 漏洞解題思路總結

一、XXE 漏洞簡介XXE (XML External Entity) 漏洞允許攻擊者通過構造惡意的 XML 輸入,強迫服務器的 XML 解析器執行非預期的操作。在 CTF 場景中,最常見的利用方式是讓解析器讀取服務器上的敏感文件,并將其內容返回給攻擊者。二、核心攻擊載荷…