1. 定義最簡單的Dataset
import torch
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self, data):self.data = data # 假設data是一個列表,如[10, 20, 30, 40]def __len__(self):return len(self.data) # 返回數據總量def __getitem__(self, idx):return self.data[idx] # 返回單個數據樣本# 示例數據
my_data = [10, 20, 30, 40]
dataset = MyDataset(my_data)
2. 創建DataLoader
loader = DataLoader(dataset, batch_size=2, # 每批2個樣本shuffle=True) # 打亂數據順序
3. 遍歷DataLoader時的內部操作
當執行以下代碼時:
for batch in loader:print(batch)
實際發生的步驟:
- DataLoader自動調用
dataset.__len__()
獲取數據總量(這里是4) - 根據
batch_size=2
生成索引序列(如[1,3]
、[0,2]
,因shuffle=True而隨機)- 索引生成邏輯:
- PyTorch通過以下設計保證索引不重復:
- 采樣器隔離:每個epoch生成獨立的隨機排列。
- 批次切割:按固定步長切分排列,避免交叉。
- 全局控制:
Sampler
嚴格管理索引分配。
- 對每個索引調用
dataset.__getitem__(idx)
:- 第一次取
idx=1
和idx=3
→ 返回20
和40
- 自動堆疊為張量
tensor([20, 40])
- 第一次取
- 輸出結果示例:
tensor([20, 40]) # 第一批 tensor([10, 30]) # 第二批
4. 關鍵點圖解
數據集: [10, 20, 30, 40]│ │ │ │
索引: 0 1 2 3DataLoader操作:
1. 隨機選索引(如[1,3]) → 取數據20和40 → 堆疊為tensor([20, 40])
2. 隨機選索引(如[0,2]) → 取數據10和30 → 堆疊為tensor([10, 30])
5. 如果數據是元組
假設每個樣本是(用戶ID, 物品ID)
:
class PairDataset(Dataset):def __init__(self):self.pairs = [(1,101), (2,102), (3,103)] # (用戶, 物品)def __len__(self):return len(self.pairs) # 必須實現:返回數據總量def __getitem__(self, idx):return self.pairs[idx] # 返回一個元組loader = DataLoader(PairDataset(), batch_size=2)
for batch in loader:print(batch)
輸出:
# 每個元組字段自動堆疊
[tensor([1, 2]), tensor([101, 102])] # 第一批
[tensor([3]), tensor([103])] # 第二批(最后不足batch_size)
總結
- Dataset:定義數據存儲和單個樣本獲取方式(必須實現
__len__
和__getitem__
) - DataLoader:
- 根據
batch_size
生成索引 - 自動調用
__getitem__
獲取數據 - 將樣本堆疊成批次張量
- 根據
- 核心特性:
- 支持多進程加速(
num_workers
參數) - 自動打亂數據(
shuffle=True
) - 靈活處理各種數據結構(標量、元組、字典等)
- 支持多進程加速(
這就是PyTorch數據加載的核心機制!其他復雜功能都是基于這個簡單流程的擴展。