Dataset和DataLoader構建數據通道

重點在第二部分的構建數據通道和第三部分的加載數據集

Pytorch通常使用Dataset和DataLoader這兩個工具類來構建數據管道。

Dataset定義了數據集的內容,它相當于一個類似列表的數據結構,具有確定的長度,能夠用索引獲取數據集中的元素。

而DataLoader定義了按batch加載數據集的方法,它是一個實現了__iter__方法的可迭代對象,每次迭代輸出一個batch的數據。

DataLoader能夠控制batch的大小,batch中元素的采樣方法,以及將batch結果整理成模型所需輸入形式的方法,并且能夠使用多進程讀取數據。

在絕大部分情況下,用戶只需實現Dataset的__len__方法和__getitem__方法,就可以輕松構建自己的數據集,并用默認數據管道進行加載。

一,Dataset和DataLoader概述

1,獲取一個batch數據的步驟

讓我們考慮一下從一個數據集中獲取一個batch的數據需要哪些步驟。

(假定數據集的特征和標簽分別表示為張量X和Y,數據集可以表示為(X,Y), 假定batch大小為m)

1,首先我們要確定數據集的長度n。

結果類似:n = 1000。

2,然后我們從0到n-1的范圍中抽樣出m個數(batch大小)。

假定m=4, 拿到的結果是一個列表,類似:indices = [1,4,8,9]

3,接著我們從數據集中去取這m個數對應下標的元素。

拿到的結果是一個元組列表,類似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]

4,最后我們將結果整理成兩個張量作為輸出。

拿到的結果是兩個張量,類似batch = (features,labels),

其中 features = torch.stack([X[1],X[4],X[8],X[9]])

labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])

2,Dataset和DataLoader的功能分工

上述第1個步驟確定數據集的長度是由 Dataset的__len__ 方法實現的。

第2個步驟從0到n-1的范圍中抽樣出m個數的方法是由 DataLoader的 sampler和 batch_sampler參數指定的。

sampler參數指定單個元素抽樣方法,一般無需用戶設置,程序默認在DataLoader的參數shuffle=True時采用隨機抽樣,shuffle=False時采用順序抽樣。

batch_sampler參數將多個抽樣的元素整理成一個列表,一般無需用戶設置,默認方法在DataLoader的參數drop_last=True時會丟棄數據集最后一個長度不能被batch大小整除的批次,在drop_last=False時保留最后一個批次。

第3個步驟的核心邏輯根據下標取數據集中的元素 是由 Dataset的 __getitem__方法實現的。

第4個步驟的邏輯由DataLoader的參數collate_fn指定。一般情況下也無需用戶設置。

3,Dataset和DataLoader的主要接口

偽代碼,實際應用意義不大

import torch 
class Dataset(object):def __init__(self):passdef __len__(self):raise NotImplementedErrordef __getitem__(self,index):raise NotImplementedErrorclass DataLoader(object):def __init__(self,dataset,batch_size,collate_fn,shuffle = True,drop_last = False):self.dataset = datasetself.sampler =torch.utils.data.RandomSampler if shuffle else \torch.utils.data.SequentialSamplerself.batch_sampler = torch.utils.data.BatchSamplerself.sample_iter = self.batch_sampler(self.sampler(range(len(dataset))),batch_size = batch_size,drop_last = drop_last)def __next__(self):indices = next(self.sample_iter)batch = self.collate_fn([self.dataset[i] for i in indices])return batch

二,使用Dataset創建數據集

Dataset創建數據集常用的方法有:

使用 torch.utils.data.TensorDataset 根據Tensor創建數據集(numpy的array,Pandas的DataFrame需要先轉換成Tensor)。

使用 torchvision.datasets.ImageFolder 根據圖片目錄創建圖片數據集。

繼承 torch.utils.data.Dataset 創建自定義數據集。

此外,還可以通過

torch.utils.data.random_split 將一個數據集分割成多份,常用于分割訓練集,驗證集和測試集。

調用Dataset的加法運算符(+)將多個數據集合并成一個數據集。

1,根據Tensor創建數據集

  1. 頭文件:
import numpy as np 
import torch 
from torch.utils.data import TensorDataset,Dataset,DataLoader,random_split 
  1. 根據Tensor創建數據集
from sklearn import datasets 
iris = datasets.load_iris()
ds_iris = TensorDataset(torch.tensor(iris.data),torch.tensor(iris.target))
  1. 分割成訓練集和預測集
n_train = int(len(ds_iris)*0.8)
n_valid = len(ds_iris) - n_train
ds_train,ds_valid = random_split(ds_iris,[n_train,n_valid])
  1. 使用DataLoader加載數據集
dl_train,dl_valid = DataLoader(ds_train,batch_size = 8),DataLoader(ds_valid,batch_size = 8)#查看數據集
for features,labels in dl_train:print(features,labels)break
  1. 演示加法運算符(+)的合并作用
ds_data = ds_train + ds_validprint('len(ds_train) = ',len(ds_train))
print('len(ds_valid) = ',len(ds_valid))
print('len(ds_train+ds_valid) = ',len(ds_data))print(type(ds_data))

2,根據圖片目錄創建圖片數據集

  1. 頭文件:
import numpy as np 
import torch 
from torch.utils.data import DataLoader
from torchvision import transforms,datasets 
  1. 圖片加載:
from PIL import Image
img = Image.open('./data/cat.jpeg')
  1. 隨機數值翻轉
transforms.RandomVerticalFlip()(img)
  1. 隨機旋轉
transforms.RandomRotation(45)(img)
  1. 定義圖片增強操作
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(), #隨機水平翻轉transforms.RandomVerticalFlip(), #隨機垂直翻轉transforms.RandomRotation(45),  #隨機在45度角度內旋轉transforms.ToTensor() #轉換成張量]
) transform_valid = transforms.Compose([transforms.ToTensor()]
)
  1. 根據圖片目錄創建數據集
ds_train = datasets.ImageFolder("./data/cifar2/train/",transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("./data/cifar2/test/",transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())print(ds_train.class_to_idx)
  1. 使用DataLoader加載數據集
#注意:windows用戶要把num_workers去掉,容易報錯
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True,num_workers=3)
dl_valid = DataLoader(ds_valid,batch_size = 50,shuffle = True,num_workers=3)for features,labels in dl_train:print(features.shape)print(labels.shape)break

三,使用DataLoader加載數據集

DataLoader能夠控制batch的大小,batch中元素的采樣方法,以及將batch結果整理成模型所需輸入形式的方法,并且能夠使用多進程讀取數據。

DataLoader的函數簽名

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None,
)

一般情況下,我們僅僅會配置 dataset, batch_size, shuffle, num_workers, drop_last這五個參數,其他參數使用默認值即可。
dataset : 數據集
batch_size: 批次大小
shuffle: 是否亂序
sampler: 樣本采樣函數,一般無需設置。
batch_sampler: 批次采樣函數,一般無需設置。
num_workers: 使用多進程讀取數據,設置的進程數。
collate_fn: 整理一個批次數據的函數。
pin_memory: 是否設置為鎖業內存。默認為False,鎖業內存不會使用虛擬內存(硬盤),從鎖業內存拷貝到GPU上速度會更快。
drop_last: 是否丟棄最后一個樣本數量不足batch_size批次數據。
timeout: 加載一個數據批次的最長等待時間,一般無需設置。
worker_init_fn: 每個worker中dataset的初始化函數,常用于 IterableDataset。一般不使用。

#構建輸入數據管道
ds = TensorDataset(torch.arange(1,50))
dl = DataLoader(ds,batch_size = 10,shuffle= True,num_workers=2,drop_last = True)
#迭代數據
for batch, in dl:print(batch)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/389359.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/389359.shtml
英文地址,請注明出處:http://en.pswp.cn/news/389359.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

鐵拳nat映射_鐵拳如何重塑我的數據可視化設計流程

鐵拳nat映射It’s been a full year since I’ve become an independent data visualization designer. When I first started, projects that came to me didn’t relate to my interests or skills. Over the past eight months, it’s become very clear to me that when cl…

Django2 Web 實戰03-文件上傳

作者:Hubery 時間:2018.10.31 接上文:接上文:Django2 Web 實戰02-用戶注冊登錄退出 視頻是一種可視化媒介,因此視頻數據庫至少應該存儲圖像。讓用戶上傳文件是個很大的隱患,因此接下來會討論這倆話題&#…

BZOJ.2738.矩陣乘法(整體二分 二維樹狀數組)

題目鏈接 BZOJ洛谷 整體二分。把求序列第K小的樹狀數組改成二維樹狀數組就行了。 初始答案區間有點大,離散化一下。 因為這題是一開始給點,之后詢問,so可以先處理該區間值在l~mid的修改,再處理詢問。即二分標準可以直接用點的標號…

從數據庫里讀值往TEXT文本里寫

/// <summary> /// 把預定內容導入到Text文檔 /// </summary> private void ChangeDbToText() { this.RecNum.Visibletrue; //建立文件&#xff0c;并打開 string oneLine ""; string filename "Storage Card/YD" DateTime.Now.…

DengAI —如何應對數據科學競賽? (EDA)

了解機器學習 (Understanding ML) This article is based on my entry into DengAI competition on the DrivenData platform. I’ve managed to score within 0.2% (14/9069 as on 02 Jun 2020). Some of the ideas presented here are strictly designed for competitions li…

Pytorch模型層簡單介紹

模型層layers 深度學習模型一般由各種模型層組合而成。 torch.nn中內置了非常豐富的各種模型層。它們都屬于nn.Module的子類&#xff0c;具備參數管理功能。 例如&#xff1a; nn.Linear, nn.Flatten, nn.Dropout, nn.BatchNorm2d nn.Conv2d,nn.AvgPool2d,nn.Conv1d,nn.Co…

有效溝通的技能有哪些_如何有效地展示您的數據科學或軟件工程技能

有效溝通的技能有哪些What is the most important thing to do after you got your skills to be a data scientist? It has to be to show off your skills. Otherwise, there is no use of your skills. If you want to get a job or freelance or start a start-up, you ha…

java.net.SocketException: Software caused connection abort: socket write erro

場景&#xff1a;接口測試 編輯器&#xff1a;eclipse 版本&#xff1a;Version: 2018-09 (4.9.0) testng版本&#xff1a;TestNG version 6.14.0 執行testng.xml時報錯信息&#xff1a; 出現此報錯原因之一&#xff1a;網上有人說是testng版本與eclipse版本不一致造成的&#…

[博客..配置?]博客園美化

博客園搞定時間 -> 18年6月27日 [讓我歇會兒 搞這個費腦子 代碼一個都看不懂] 轉載于:https://www.cnblogs.com/Steinway/p/9235437.html

使用K-Means對美因河畔法蘭克福的社區進行聚類

介紹 (Introduction) This blog post summarizes the results of the Capstone Project in the IBM Data Science Specialization on Coursera. Within the project, the districts of Frankfurt am Main in Germany shall be clustered according to their venue data using t…

Pytorch損失函數losses簡介

一般來說&#xff0c;監督學習的目標函數由損失函數和正則化項組成。(Objective Loss Regularization) Pytorch中的損失函數一般在訓練模型時候指定。 注意Pytorch中內置的損失函數的參數和tensorflow不同&#xff0c;是y_pred在前&#xff0c;y_true在后&#xff0c;而Ten…

讀取Mc1000的 唯一 ID 機器號

先引用Symbol.ResourceCoordination 然后引用命名空間 using System;using System.Security.Cryptography;using System.IO; 以下為類程序 /// <summary> /// 獲取設備id /// </summary> /// <returns></returns> public static string GetDevi…

樣本均值的抽樣分布_抽樣分布樣本均值

樣本均值的抽樣分布One of the most important concepts discussed in the context of inferential data analysis is the idea of sampling distributions. Understanding sampling distributions helps us better comprehend and interpret results from our descriptive as …

玩轉ceph性能測試---對象存儲(一)

筆者最近在工作中需要測試ceph的rgw&#xff0c;于是邊測試邊學習。首先工具采用的intel的一個開源工具cosbench&#xff0c;這也是業界主流的對象存儲測試工具。 1、cosbench的安裝&#xff0c;啟動下載最新的cosbench包wget https://github.com/intel-cloud/cosbench/release…

[BZOJ 4300]絕世好題

Description 題庫鏈接 給定一個長度為 \(n\) 的數列 \(a_i\) &#xff0c;求 \(a_i\) 的子序列 \(b_i\) 的最長長度&#xff0c;滿足 \(b_i\wedge b_{i-1}\neq 0\) &#xff08; \(\wedge\) 表示按位與&#xff09; \(1\leq n\leq 100000\) Solution 令 \(f_i\) 為二進制第 \(i…

因果關系和相關關系 大數據_數據科學中的相關性與因果關系

因果關系和相關關系 大數據Let’s jump into it right away.讓我們馬上進入。 相關性 (Correlation) Correlation means relationship and association to another variable. For example, a movement in one variable associates with the movement in another variable. For…

Pytorch構建模型的3種方法

這個地方一直是我思考的地方&#xff01;因為學的代碼太多了&#xff0c;構建的模型各有不同&#xff0c;這里記錄一下&#xff01; 可以使用以下3種方式構建模型&#xff1a; 1&#xff0c;繼承nn.Module基類構建自定義模型。 2&#xff0c;使用nn.Sequential按層順序構建模…

vue取數據第一個數據_我作為數據科學家的第一個月

vue取數據第一個數據A lot.很多。 I landed my first job as a Data Scientist at the beginning of August, and like any new job, there’s a lot of information to take in at once.我于8月初找到了數據科學家的第一份工作&#xff0c;并且像任何新工作一樣&#xff0c;一…

Flask-SocketIO 簡單使用指南

Flask-SocketIO 使 Flask 應用程序能夠訪問客戶端和服務器之間的低延遲雙向通信。客戶端應用程序可以使用 Javascript&#xff0c;C &#xff0c;Java 和 Swift 中的任何 SocketIO 官方客戶端庫或任何兼容的客戶端來建立與服務器的永久連接。 安裝 直接使用 pip 來安裝&#xf…

STL-開篇

基本概念 STL&#xff1a; Standard Template Library&#xff0c;標準模板庫 定義&#xff1a; c引入的一個標準類庫 特點&#xff1a;1&#xff09;數據結構和算法的 c實現&#xff08; 采用模板類和模板函數&#xff09;2&#xff09;數據的存儲和算法的分離3&#xff09;高…