? ? ? ?在 PyTorch 中,Dataset
?和?DataLoader
?是處理數據的核心工具。它們的作用是將數據高效地加載到模型中,支持批量處理、多線程加速和數據增強等功能。
一、Dataset:數據集的抽象?
Dataset
?是一個抽象類,用于表示數據集的接口。你需要繼承?torch.utils.data.Dataset
?并實現以下兩個方法:
__len__()
: 返回數據集的總樣本數。__getitem__(idx)
: 根據索引?idx
?返回一個樣本(數據和標簽)。
?示例:自定義 Dataset
import torch
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data, labels, transform=None):self.data = dataself.labels = labelsself.transform = transform # 數據預處理/增強函數def __len__(self):return len(self.data)def __getitem__(self, idx):sample = {"data": self.data[idx], "label": self.labels[idx]}if self.transform:sample = self.transform(sample)return sample
?使用場景?
- 加載圖像、文本、表格數據等。
- 支持數據預處理(如歸一化、裁剪)和數據增強(如隨機翻轉)。
二、?DataLoader:高效加載數據?
DataLoader
?負責將?Dataset
?包裝成一個可迭代對象,支持批量加載、多線程加速和數據打亂。
?基本用法
from torch.utils.data import DataLoader# 假設 dataset 是你的 CustomDataset 實例
data_loader = DataLoader(dataset,batch_size=32, # 批量大小shuffle=True, # 是否打亂數據(訓練時建議開啟)num_workers=4, # 多線程加載數據的進程數drop_last=False # 是否丟棄最后不足一個 batch 的數據
)
??遍歷 DataLoader
for batch in data_loader:data = batch["data"] # 形狀:[batch_size, ...]labels = batch["label"] # 形狀:[batch_size]# 將數據送入模型訓練...
三、pytorch內置數據集
PyTorch 提供了一系列內置數據集,這些數據集可以直接用于訓練模型。這些數據集涵蓋了多種領域,如圖像、文本、音頻等。以下是一些常用的PyTorch內置數據集:
圖像數據集
-
MNIST: 手寫數字數據集,包含0到9的手寫數字圖片。
from torchvision import datasets mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
-
CIFAR10/CIFAR100: 包含彩色圖片的數據集,CIFAR10有60000張32x32的彩色圖片,分為10個類別;CIFAR100類似但有100個類別。
cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
-
ImageNet: 包含超過1400萬張圖片的非常龐大的數據集,常用于圖像識別和分類任務。
import torchvision.datasets as datasets imagenet_train = datasets.ImageNet(root='./data', split='train', download=True)
-
STL10: 一個用于計算機視覺研究的小型圖像數據集,包含96x96的彩色圖片。
stl10_train = datasets.STL10(root='./data', split='train', download=True)
-
SVHN: 包含數字圖片的數據集,與MNIST類似但包含更多實際場景的圖片。
svhn_train = datasets.SVHN(root='./data', split='train', download=True, transform=transform)
文本數據集
? ? 1.Text8: 一個用于自然語言處理的小型文本數據集。
from torchtext.datasets import Text8
text8_train = Text8(split=('train',))
? ? 2.?AG_NEWS: 包含新聞文章的文本數據集,分為4個類別。
from torchtext.datasets import AG_NEWS
ag_news_train = AG_NEWS(split=('train',))
音頻數據集??
? 1. Speech Commands: 一個用于語音識別的數據集,包含約65,000個單詞發音的音頻文件。?
from torchaudio.datasets import SPEECHCOMMANDS
speech_commands = SPEECHCOMMANDS(root="./data", download=True)
?使用方法
要使用這些數據集,首先需要導入torchvision
(對于圖像數據集)、torchtext
(對于文本數據集)或torchaudio
(對于音頻數據集),然后使用其提供的類來加載數據。通常還包括一些數據預處理步驟,例如轉換(transforms)。
import torchvision.transforms as transforms
from torchvision import datasetstransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
四、完整代碼示例
步驟 1:創建數據集
import numpy as np
from torch.utils.data import Dataset, DataLoader# 生成示例數據(假設是 10 個樣本,每個樣本是長度為 5 的向量)
data = np.random.randn(10, 5)
labels = np.random.randint(0, 2, size=(10,)) # 二分類標簽class MyDataset(Dataset):def __init__(self, data, labels):self.data = torch.tensor(data, dtype=torch.float32)self.labels = torch.tensor(labels, dtype=torch.long)def __len__(self):return len(self.data)def __getitem__(self, idx):return {"data": self.data[idx],"label": self.labels[idx]}dataset = MyDataset(data, labels)
?步驟 2:創建 DataLoader
data_loader = DataLoader(dataset,batch_size=2,shuffle=True,num_workers=2
)
??步驟 3:使用 DataLoader 訓練模型
model = ... # 你的模型
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()for epoch in range(10):for batch in data_loader:x = batch["data"]y = batch["label"]# 前向傳播outputs = model(x)loss = loss_fn(outputs, y)# 反向傳播optimizer.zero_grad()loss.backward()optimizer.step()
五、常見問題解決
?(1)數據格式不匹配?
- ?問題?:
DataLoader
?返回的數據形狀與模型輸入不匹配。 - ?解決?:檢查?
Dataset
?的?__getitem__
?返回的數據類型和形狀,確保與模型輸入一致。
?(2)多線程加載卡頓?
- ?問題?:設置?
num_workers>0
?時程序卡死或報錯。 - ?解決?:在 Windows 系統中,多線程可能需要將代碼放在?
if __name__ == "__main__":
?塊中運行。
?(3)數據增強?
- 使用?
torchvision.transforms
?中的工具(如?RandomCrop
、RandomHorizontalFlip
)對圖像數據進行增強:from torchvision import transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5]), ])
?(4)內存不足?
- 對于大型數據集,使用?
torch.utils.data.DataLoader
?的?persistent_workers=True
(PyTorch 1.7+)或優化數據加載邏輯。
六、高級功能
- 分布式訓練?:使用?
torch.utils.data.distributed.DistributedSampler
?配合多 GPU。 - ?預加載數據?:使用?
torch.utils.data.TensorDataset
?直接加載 Tensor 數據。 - ?自定義采樣器?:通過?
sampler
?參數控制數據采樣順序(如平衡類別采樣)。