1. 為什么要有數據集類和數據加載器類?
一萬個人會有一萬種獲取并處理原始數據樣本的代碼,這會導致對數據的操作代碼標準不一,并且很難復用。
為了解決這個問題,Pytorch提供了兩種最基本的數據相關類:
torch.utils.data.Dataset
: 一個數據集對象,包含每個數據樣本路徑以及對應標簽torch.utils.data.DataLoader
:持有一個對Dataloader
的迭代器,通過調用Dataset
的__getitem__
函數方便地獲取實際的樣本-標簽對
。
PyTorch 為不同的任務類型提供了方便的預加載數據集,例如 torchvision.datasets、torchaudio.datasets 等。這些數據集都是 torch.utils.data.Dataset 的子類,可以直接通過
dataset.數據集名稱
的方式來方便的下載經典的數據集,在下面你會看到它的使用例。
2. Dataset類的使用方法
2.1 加載一個Fashion-MNIST數據集
Fashion-MNIST 是一個來自 Zalando 的文章圖像數據集,包含 60,000 個訓練樣本和 10,000 個測試樣本。每個樣本由一張 28×28 的灰度圖像和其對應的 10 個類別中的一個標簽組成。
這是一個使用TorchVision
的預加載數據集類加載Fashion-MNIST 數據集的例子,如下是每個參數代表的意思:
- root:是存儲訓練/測試數據的路徑。
- train:指定是訓練數據集還是測試數據集。
- download=True:如果數據在 root 路徑下不可用,則從互聯網下載。
- transform 和 target_transform:分別指定特征和標簽的轉換。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data", # 指定數據集實際存放的路徑(相對于本代碼文件)train=True, # 指定這是訓練集還是測試集download=True, # 如果在root下沒有數據,從網絡上自動下載transform=ToTensor() # 給每一張圖片轉換為Tensor的數據類型
)test_data = datasets.FashionMNIST(root="data", # 指定數據集實際存放的路徑(相對于本代碼文件)train=False, # 指定這是訓練集還是測試集download=True, # 如果在root下沒有數據,從網絡上自動下載transform=ToTensor() # 給每一張圖片轉換為Tensor的數據類型
)
2.2 遍歷并可視化數據集
我們可以簡單的使用training_data[index]
來獲取Datasets
類中對應index的樣本。通常可以用matplotlib來可視化我們的一些訓練數據集:
labels_map = { # 定義一個標簽映射字典0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}figure = plt.figure(figsize=(8, 8)) # 創建一個新的畫布,大小為8x8英寸
cols, rows = 3, 3 # 定義展示網格尺寸 3x3的展示網格,每個網格展示i一個圖片for i in range(1, cols * rows + 1): # plt的索引從1開始,配合一下sample_idx = torch.randint(len(training_data), size=(1,)).item() # 生成一個包含1個元素的張量,item()回python數據類型之后為0到數據集大小-1的隨機整數img, label = training_data[sample_idx] # 本質上是在調用__getitem__函數figure.add_subplot(rows, cols, i) # 在之前創建的圖形窗口中,添加一個子圖(subplot),并將當前的畫筆操作對象設置為當前子圖plt.title(labels_map[label]) # 子圖的標題設置為對應的標簽字符串plt.axis("off") # 不顯示坐標軸plt.imshow(img.squeeze(), cmap="gray") # 把當前網格畫好
plt.show() # 展示畫布
這里我并不知道為啥要使用img.squeeze()這個方法, 直到我把img的shape的打印出來:
現在img是一個3維的tensor,但是plt.imshow需要輸入二維的tensor,所以使用squeeze的目的是把所有的尺寸為1的維度給擠壓掉,將img維度降維到2維,然后就可以用plt可視化了。
2.3 進階:如何制作一個自己的數據集類
自定義的 Dataset
類必須實現三個函數:__init__
、__len__
和 __getitem__
。請看下面的實現示例:FashionMNIST 圖像存儲在 img_dir
目錄中,而它們的標簽則單獨保存在 annotations_file
的 CSV 文件里。
import os
import pandas as pd
from torchvision.io import decode_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitemm__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) # iloc全寫為“integer location”, 表明你要通過數據的行和列的整數索引來選擇數據image = decode_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label
在接下來的部分將詳細解釋每個方法的作用。
__init__
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform
這個方法會在初始化數據集的時候調用。其主要完成如下工作:
- 讀取標簽文件
- 指定圖片文件夾路徑
- 指定樣本和標簽的transform(這個下面細講)
一個Fashion-MNIST是一個分類任務,其標簽文件annotations
大概長這樣:
tshirt1.jpg, 0 # 樣本-標簽對
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
__len__
這個方法是簡單返回數據集的樣本數量:
def __len__(self):return len(self.img_labels)
__getitem__
這個方法是Dataset
類的核心,當此方法被Dataloader
調用,請求特定idx的數據時,Dataset
會根據idx,讀取對應的圖片和標簽,并對它們做出各自的transform之后,返回給Dataloader
,讓它把圖片和標簽搬運到內存.
def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label
3. Dataloader類的使用方法
3.1 對數據集對象配置Dataloader
當Dataset
類的__getitem__
方法被調用的時候,他會返回一個樣本-標簽對。
但是在實際的模型訓練中,我們還有一些別的要求,例如:
- 以“小批量(minibatches)”的方式傳遞樣本。(減少單樣本噪聲帶來的震蕩,讓梯度更新的方向更加穩定)
- 在每個周期(epoch)對數據進行重新洗牌(reshuffle),以減少模型過擬合。
- 使用 Python 的多進程(multiprocessing)來加快數據檢索速度。
以上的要求可以通過如下的參數設定來滿足:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True, num_workers=5)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=5)
- batch_size=64 設定批量大小為64
- shuffle=True 指定一個epoch之后dataloader持有的索引要重新洗牌
- num_workers=5 指定dataloader會同時開啟5個進程去調用dataset的
__getitem__
方法
以上是
Dataloader
最基本的用法,不過,當你有GPU的時候,我推薦你也把下面兩個參數打開:
pin_memory=True 開啟鎖頁內存,減少CPU到GPU的數據傳遞延遲
persistent_workers=True 每個epoch結束后不銷毀dataloader所開啟的worker進程,而是接著用,這樣剩下了worker的初始化時間
3.2 使用Dataloader遍歷數據集
給Dataset配置好對應的Dataloader后,就可以開始用dataloader遍歷它了。每次遍歷都會返回一個batch_size的訓練圖片和訓練標簽對(這里就是64個)。
# Display image and label.
train_features, train_labels = next(iter(train_dataloader)) # 先從train_dataloader中獲得一個迭代器,然后調用next獲取其下一個元素
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
由于開啟了shuffle=True,所以每次遍歷完整個數據集后train_dataloader持有的索引會被打亂。