1.torch.utils.data.TensorDataset
功能定位
torch.utils.data.TensorDataset
是一個將多個張量(Tensor)數據進行簡單包裝整合的數據集類,它主要的作用是將相關聯的數據(比如特征數據和對應的標簽數據等)組合在一起,形成一個方便后續用于訓練等操作的數據集對象。
例如,如果你有輸入特征數據 x
(形狀為 [n_samples, feature_dim]
)和對應的標簽數據 y
(形狀為 [n_samples]
),且它們都是 torch.Tensor
類型,可以這樣創建 TensorDataset
:
import torch
from torch.utils.data import TensorDatasetx = torch.randn(100, 10) # 模擬100個樣本,每個樣本特征維度為10
y = torch.randint(0, 2, (100,)) # 模擬二分類標簽dataset = TensorDataset(x, y)
特點
-
簡單包裝:只是把傳入的張量按照樣本維度進行了對應組合,并沒有對數據做復雜的預處理、采樣等額外操作。
-
索引支持:支持像普通列表那樣通過索引訪問其中的數據元素,例如
dataset[0]
會返回由對應索引的特征和標簽組成的元組(按照傳入構造函數的張量順序)。 -
適用于小型數據集直接使用:當數據量不大且數據格式已經整理為張量形式時,可以直接基于它來進行簡單的模型訓練循環等操作,不過對于批量處理等更復雜的情況支持有限,需要進一步配合其他工具。
2.torch.utils.data.DataLoader
功能定位
torch.utils.data.DataLoader
是一個用于加載數據的工具類,它圍繞著給定的數據集(比如 TensorDataset
或者自定義的繼承自 Dataset
的類實例等),實現了諸如批量加載數據、打亂數據順序、并行加載數據等功能,旨在讓數據能夠以合適的方式、合適的批量大小等被送入到模型中進行訓練、驗證或測試等操作。
示例:
from torch.utils.data import DataLoaderbatch_size = 10
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)for batch_x, batch_y in dataloader:# 這里的batch_x和batch_y就是每次迭代取出的一個批量的特征和標簽數據pass
特點
-
批量處理:可以按照設定的
batch_size
參數,將數據集中的數據劃分為一個個的小批量(mini-batch),方便模型以批量的方式進行梯度計算更新,有助于優化訓練過程和提升效率,尤其在大數據集場景下優勢明顯。 -
數據打亂:通過設置
shuffle=True
可以在每個訓練輪次(epoch)開始時對數據集里面的數據順序進行隨機打亂,使得數據的輸入順序具有隨機性,這有助于提升模型訓練的泛化能力,避免模型因數據順序固定而產生過擬合等問題。 -
并行加載:支持多進程加載數據(通過設置
num_workers
參數大于 0),能夠利用多核 CPU 的優勢加快數據讀取和預處理的速度,特別是在處理大規模數據集或者數據加載比較耗時的情況下,能顯著提升整體訓練效率。 -
靈活性和通用性:它可以適配各種不同類型的數據集,只要這些數據集繼承自
torch.utils.data.Dataset
抽象類并實現了必要的__len__
和__getitem__
等方法,因此無論是簡單的TensorDataset
還是復雜的自定義數據集都可以用它來加載數據。
總的來說,TensorDataset 側重于對已有張量數據進行簡單的整合包裝形成數據集;而 DataLoader 側重于圍繞數據集實現數據的批量加載、打亂順序、并行化等復雜的數據加載相關功能,它們通常配合使用,先使用 TensorDataset 組織好數據,再通過 DataLoader 按照訓練需求來加載和處理這些數據并送入模型中。
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs)