Dataset和Dataloader類
知識點回顧:
- Dataset類的__getitem__和__len__方法(本質是python的特殊方法)
- Dataloader類
- minist手寫數據集的了解
????????在遇到大規模數據集時,顯存常常無法一次性存儲所有數據,所以需要使用分批訓練的方法。為此,PyTorch提供了DataLoader類,該類可以自動將數據集切分為多個批次batch,并支持多線程加載數據。此外,還存在Dataset類,該類可以定義數據集的讀取方式和預處理方式。
1. DataLoader類:決定數據如何加載
2. Dataset類:告訴程序去哪里找數據,如何讀取單個樣本,以及如何預處理。
使用的數據集為MNIST手寫數字數據集。該數據集包含60000張訓練圖片和10000張測試圖片,每張圖片大小為28*28像素,共包含10個類別。因為每個數據的維度比較小,所以既可以視為結構化數據,用機器學習、MLP訓練,也可以視為圖像數據,用卷積神經網絡訓練。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加載數據的工具
from torchvision import datasets, transforms # torchvision 是一個用于計算機視覺的庫,datasets 和 transforms 是其中的模塊
import matplotlib.pyplot as plt# 設置隨機種子,確保結果可復現
torch.manual_seed(42)
# 1. 數據預處理,該寫法非常類似于管道pipeline
# transforms 模塊提供了一系列常用的圖像預處理操作# 先歸一化,再標準化
transform = transforms.Compose([transforms.ToTensor(), # 轉換為張量并歸一化到[0,1]transforms.Normalize((0.1307,), (0.3081,)) # MNIST數據集的均值和標準差,這個值很出名,所以直接使用
])# 2. 加載MNIST數據集,如果沒有會自動下載
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)
Dataset類
import matplotlib.pyplot as plt# 隨機選擇一張圖片,可以重復運行,每次都會隨機選擇
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 隨機選擇一張圖片的索引
# len(train_dataset) 表示訓練集的圖片數量;size=(1,)表示返回一個索引;torch.randint() 函數用于生成一個指定范圍內的隨機數,item() 方法將張量轉換為 Python 數字
image, label = train_dataset[sample_idx] # 獲取圖片和標簽
為什么train_dataset[sample_idx]可以獲取到圖片和標簽,是因為 datasets.MNIST這個類繼承了torch.utils.data.Dataset類,這個類中有一個方法__getitem__,這個方法會返回一個tuple,tuple中第一個元素是圖片,第二個元素是標簽。
torch.utils.data.Dataset類
一個抽象基類,所有自定義數據集都需要繼承它并實現兩個核心方法:
- __len__():返回數據集的樣本總數。
- __getitem__(idx):根據索引idx返回對應樣本的數據和標簽。
PyTorch 要求所有數據集必須實現__getitem__和__len__,這樣才能被DataLoader等工具兼容。這是一種接口約定,類似函數參數的規范。這意味著,如果你要創建一個自定義數據集,你需要實現這兩個方法,否則PyTorch將無法識別你的數據集。
?__getitem__和__len__ 是類的特殊方法(也叫魔術方法 ),它們不是像普通函數那樣直接使用,而是需要在自定義類中進行定義,來賦予類特定的行為。
__getitem__方法
用于讓對象支持索引操作,當使用[]語法訪問對象元素時,Python 會自動調用該方法。
# 示例代碼
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 創建類的實例
my_list_obj = MyList()
# 此時可以使用索引訪問元素,這會自動調用__getitem__方法
print(my_list_obj[2]) # 輸出:30
__len__方法
用于返回對象中元素的數量,當使用內置函數len()作用于對象時,Python 會自動調用該方法。
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 創建類的實例
my_list_obj = MyList()
# 使用len()函數獲取元素數量,這會自動調用__len__方法
print(len(my_list_obj)) # 輸出:5
# minist數據集的簡化版本
class MNIST(Dataset):def __init__(self, root, train=True, transform=None):# 初始化:加載圖片路徑和標簽self.data, self.targets = fetch_mnist_data(root, train) # 這里假設 fetch_mnist_data 是一個函數,用于加載 MNIST 數據集的圖片路徑和標簽self.transform = transform # 預處理操作def __len__(self): return len(self.data) # 返回樣本總數def __getitem__(self, idx): # 獲取指定索引的樣本# 獲取指定索引的圖像和標簽img, target = self.data[idx], self.targets[idx]# 應用圖像預處理(如ToTensor、Normalize)if self.transform is not None: # 如果有預處理操作img = self.transform(img) # 轉換圖像格式# 這里假設 img 是一個 PIL 圖像對象,transform 會將其轉換為張量并進行歸一化return img, target # 返回處理后的圖像和標簽
?通俗地類比:
Dataset = 廚師(準備單個菜品)
DataLoader = 服務員(將菜品按訂單組合并上桌)
預處理(如切菜、調味)屬于廚師的工作,而非服務員。所以在dataset就需要添加預處理步驟。
# 可視化原始圖像(需要反歸一化)
def imshow(img):img = img * 0.3081 + 0.1307 # 反標準化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 顯示灰度圖像plt.show()print(f"Label: {label}")
imshow(image)
Dataloader類
# 3. 創建數據加載器
train_loader = DataLoader(train_dataset,batch_size=64, # 每個批次64張圖片,一般是2的冪次方,這與GPU的計算效率有關shuffle=True # 隨機打亂數據
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每個批次1000張圖片# shuffle=False # 測試時不需要打亂數據
)
總結:
Dataset 類:定義數據的內容和格式(即 “如何獲取單個樣本”),包括:
? 數據存儲路徑 / 來源(如文件路徑、數據庫查詢)。
? 原始數據的讀取方式(如圖像解碼為 PIL 對象、文本讀取為字符串)。
? 樣本的預處理邏輯(如裁剪、翻轉、歸一化等,通常通過 transform 參數實現)。
? 返回值格式(如 (image_tensor, label))。
? DataLoader 類:定義數據的加載方式和批量處理邏輯(即 “如何高效批量獲取數據”),包括:
? 批量大小(batch_size)。
? 是否打亂數據順序(shuffle)。
@浙大疏錦行