在遇到大規模數據集時,顯存常常無法一次性存儲所有數據,所以需要使用分批訓練的方法。為此,PyTorch提供了DataLoader類,該類可以自動將數據集切分為多個批次batch,并支持多線程加載數據。此外,還存在Dataset類,該類可以定義數據集的讀取方式和預處理方式。
1. DataLoader類:決定數據如何加載
2. Dataset類:告訴程序去哪里找數據,如何讀取單個樣本,以及如何預處理。
為了引入這些概念,我們現在接觸一個新的而且非常經典的數據集:MNIST手寫數字數據集。該數據集包含60000張訓練圖片和10000張測試圖片,每張圖片大小為28*28像素,共包含10個類別。因為每個數據的維度比較小,所以既可以視為結構化數據,用機器學習、MLP訓練,也可以視為圖像數據,用卷積神經網絡訓練。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加載數據的工具
from torchvision import datasets, transforms # torchvision 是一個用于計算機視覺的庫,datasets 和 transforms 是其中的模塊
import matplotlib.pyplot as plt# 設置隨機種子,確保結果可復現
torch.manual_seed(42)
?## 一、Dataset類
現在我們想要取出來一個圖片,看看長啥樣,因為datasets.MNIST本質上集成了torch.utils.data.Dataset,所以自然需要有對應的方法。
import matplotlib.pyplot as plt# 隨機選擇一張圖片,可以重復運行,每次都會隨機選擇
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 隨機選擇一張圖片的索引
# len(train_dataset) 表示訓練集的圖片數量;size=(1,)表示返回一個索引;torch.randint() 函數用于生成一個指定范圍內的隨機數,item() 方法將張量轉換為 Python 數字
image, label = train_dataset[sample_idx] # 獲取圖片和標簽
### __getitem__方法
__getitem__方法用于讓對象支持索引操作,當使用[]語法訪問對象元素時,Python 會自動調用該方法。
# 示例代碼
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 創建類的實例
my_list_obj = MyList()
# 此時可以使用索引訪問元素,這會自動調用__getitem__方法
print(my_list_obj[2]) # 輸出:30
?### __len__方法
__len__方法用于返回對象中元素的數量,當使用內置函數len()作用于對象時,Python 會自動調用該方法。
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 創建類的實例
my_list_obj = MyList()
# 使用len()函數獲取元素數量,這會自動調用__len__方法
print(len(my_list_obj)) # 輸出:5
?## 二、Dataloader類
# 3. 創建數據加載器
train_loader = DataLoader(train_dataset,batch_size=64, # 每個批次64張圖片,一般是2的冪次方,這與GPU的計算效率有關shuffle=True # 隨機打亂數據
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每個批次1000張圖片# shuffle=False # 測試時不需要打亂數據
)
?@浙大疏錦行