重點在第二部分的構建數據通道和第三部分的加載數據集
Pytorch通常使用Dataset和DataLoader這兩個工具類來構建數據管道。
Dataset定義了數據集的內容,它相當于一個類似列表的數據結構,具有確定的長度,能夠用索引獲取數據集中的元素。
而DataLoader定義了按batch加載數據集的方法,它是一個實現了__iter__方法的可迭代對象,每次迭代輸出一個batch的數據。
DataLoader能夠控制batch的大小,batch中元素的采樣方法,以及將batch結果整理成模型所需輸入形式的方法,并且能夠使用多進程讀取數據。
在絕大部分情況下,用戶只需實現Dataset的__len__方法和__getitem__方法,就可以輕松構建自己的數據集,并用默認數據管道進行加載。
一,Dataset和DataLoader概述
1,獲取一個batch數據的步驟
讓我們考慮一下從一個數據集中獲取一個batch的數據需要哪些步驟。
(假定數據集的特征和標簽分別表示為張量X和Y,數據集可以表示為(X,Y), 假定batch大小為m)
1,首先我們要確定數據集的長度n。
結果類似:n = 1000。
2,然后我們從0到n-1的范圍中抽樣出m個數(batch大小)。
假定m=4, 拿到的結果是一個列表,類似:indices = [1,4,8,9]
3,接著我們從數據集中去取這m個數對應下標的元素。
拿到的結果是一個元組列表,類似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]
4,最后我們將結果整理成兩個張量作為輸出。
拿到的結果是兩個張量,類似batch = (features,labels),
其中 features = torch.stack([X[1],X[4],X[8],X[9]])
labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])
2,Dataset和DataLoader的功能分工
上述第1個步驟確定數據集的長度是由 Dataset的__len__ 方法實現的。
第2個步驟從0到n-1的范圍中抽樣出m個數的方法是由 DataLoader的 sampler和 batch_sampler參數指定的。
sampler參數指定單個元素抽樣方法,一般無需用戶設置,程序默認在DataLoader的參數shuffle=True時采用隨機抽樣,shuffle=False時采用順序抽樣。
batch_sampler參數將多個抽樣的元素整理成一個列表,一般無需用戶設置,默認方法在DataLoader的參數drop_last=True時會丟棄數據集最后一個長度不能被batch大小整除的批次,在drop_last=False時保留最后一個批次。
第3個步驟的核心邏輯根據下標取數據集中的元素 是由 Dataset的 __getitem__方法實現的。
第4個步驟的邏輯由DataLoader的參數collate_fn指定。一般情況下也無需用戶設置。
3,Dataset和DataLoader的主要接口
偽代碼,實際應用意義不大
import torch
class Dataset(object):def __init__(self):passdef __len__(self):raise NotImplementedErrordef __getitem__(self,index):raise NotImplementedErrorclass DataLoader(object):def __init__(self,dataset,batch_size,collate_fn,shuffle = True,drop_last = False):self.dataset = datasetself.sampler =torch.utils.data.RandomSampler if shuffle else \torch.utils.data.SequentialSamplerself.batch_sampler = torch.utils.data.BatchSamplerself.sample_iter = self.batch_sampler(self.sampler(range(len(dataset))),batch_size = batch_size,drop_last = drop_last)def __next__(self):indices = next(self.sample_iter)batch = self.collate_fn([self.dataset[i] for i in indices])return batch
二,使用Dataset創建數據集
Dataset創建數據集常用的方法有:
使用 torch.utils.data.TensorDataset 根據Tensor創建數據集(numpy的array,Pandas的DataFrame需要先轉換成Tensor)。
使用 torchvision.datasets.ImageFolder 根據圖片目錄創建圖片數據集。
繼承 torch.utils.data.Dataset 創建自定義數據集。
此外,還可以通過
torch.utils.data.random_split 將一個數據集分割成多份,常用于分割訓練集,驗證集和測試集。
調用Dataset的加法運算符(+)將多個數據集合并成一個數據集。
1,根據Tensor創建數據集
- 頭文件:
import numpy as np
import torch
from torch.utils.data import TensorDataset,Dataset,DataLoader,random_split
- 根據Tensor創建數據集
from sklearn import datasets
iris = datasets.load_iris()
ds_iris = TensorDataset(torch.tensor(iris.data),torch.tensor(iris.target))
- 分割成訓練集和預測集
n_train = int(len(ds_iris)*0.8)
n_valid = len(ds_iris) - n_train
ds_train,ds_valid = random_split(ds_iris,[n_train,n_valid])
- 使用DataLoader加載數據集
dl_train,dl_valid = DataLoader(ds_train,batch_size = 8),DataLoader(ds_valid,batch_size = 8)#查看數據集
for features,labels in dl_train:print(features,labels)break
- 演示加法運算符(
+
)的合并作用
ds_data = ds_train + ds_validprint('len(ds_train) = ',len(ds_train))
print('len(ds_valid) = ',len(ds_valid))
print('len(ds_train+ds_valid) = ',len(ds_data))print(type(ds_data))
2,根據圖片目錄創建圖片數據集
- 頭文件:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
- 圖片加載:
from PIL import Image
img = Image.open('./data/cat.jpeg')
- 隨機數值翻轉
transforms.RandomVerticalFlip()(img)
- 隨機旋轉
transforms.RandomRotation(45)(img)
- 定義圖片增強操作
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(), #隨機水平翻轉transforms.RandomVerticalFlip(), #隨機垂直翻轉transforms.RandomRotation(45), #隨機在45度角度內旋轉transforms.ToTensor() #轉換成張量]
) transform_valid = transforms.Compose([transforms.ToTensor()]
)
- 根據圖片目錄創建數據集
ds_train = datasets.ImageFolder("./data/cifar2/train/",transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("./data/cifar2/test/",transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())print(ds_train.class_to_idx)
- 使用DataLoader加載數據集
#注意:windows用戶要把num_workers去掉,容易報錯
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True,num_workers=3)
dl_valid = DataLoader(ds_valid,batch_size = 50,shuffle = True,num_workers=3)for features,labels in dl_train:print(features.shape)print(labels.shape)break
三,使用DataLoader加載數據集
DataLoader能夠控制batch的大小,batch中元素的采樣方法,以及將batch結果整理成模型所需輸入形式的方法,并且能夠使用多進程讀取數據。
DataLoader的函數簽名
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None,
)
一般情況下,我們僅僅會配置 dataset, batch_size, shuffle, num_workers, drop_last這五個參數,其他參數使用默認值即可。
dataset : 數據集
batch_size: 批次大小
shuffle: 是否亂序
sampler: 樣本采樣函數,一般無需設置。
batch_sampler: 批次采樣函數,一般無需設置。
num_workers: 使用多進程讀取數據,設置的進程數。
collate_fn: 整理一個批次數據的函數。
pin_memory: 是否設置為鎖業內存。默認為False,鎖業內存不會使用虛擬內存(硬盤),從鎖業內存拷貝到GPU上速度會更快。
drop_last: 是否丟棄最后一個樣本數量不足batch_size批次數據。
timeout: 加載一個數據批次的最長等待時間,一般無需設置。
worker_init_fn: 每個worker中dataset的初始化函數,常用于 IterableDataset。一般不使用。
#構建輸入數據管道
ds = TensorDataset(torch.arange(1,50))
dl = DataLoader(ds,batch_size = 10,shuffle= True,num_workers=2,drop_last = True)
#迭代數據
for batch, in dl:print(batch)