PyTorch 數據加載全攻略:從自定義數據集到模型訓練

目錄

一、為什么需要數據加載器?

二、自定義 Dataset 類

1. 核心方法解析

2. 代碼實現

三、快速上手:TensorDataset

1. 代碼示例

2. 適用場景

四、DataLoader:批量加載數據的利器

1. 核心參數說明

2. 代碼示例

五、實戰:用數據加載器訓練線性回歸模型

1. 完整代碼

2. 代碼解析

六、總結與拓展


在深度學習實踐中,數據加載是模型訓練的第一步,也是至關重要的一環。高效的數據加載不僅能提高訓練效率,還能讓代碼更具可維護性。本文將結合 PyTorch 的核心 API,通過實例詳解數據加載的全過程,從自定義數據集到批量訓練,帶你快速掌握 PyTorch 數據處理的精髓。

一、為什么需要數據加載器?

在處理大規模數據時,我們不可能一次性將所有數據加載到內存中。PyTorch 提供了DatasetDataLoader兩個核心類來解決這個問題:

  • Dataset:負責數據的存儲和索引
  • DataLoader:負責批量加載、打亂數據和多線程處理

簡單來說,Dataset就像一個 "倉庫",而DataLoader是 "搬運工",負責把數據按批次運送到模型中進行訓練。

二、自定義 Dataset 類

當我們需要處理特殊格式的數據(如自定義標注文件、特殊預處理)時,就需要自定義數據集。自定義數據集需繼承torch.utils.data.Dataset,并實現三個核心方法:

1. 核心方法解析

  • __init__:初始化數據集,加載數據路徑或原始數據
  • __len__:返回數據集的樣本數量
  • __getitem__:根據索引返回單個樣本(特征 + 標簽)

2. 代碼實現

import torch
from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self, data, labels):# 初始化數據和標簽self.data = dataself.labels = labelsdef __len__(self):# 返回樣本總數return len(self.data)def __getitem__(self, index):# 根據索引返回單個樣本sample = self.data[index]label = self.labels[index]return sample, label# 使用示例
if __name__ == "__main__":# 生成隨機數據x = torch.randn(1000, 100, dtype=torch.float32)  # 1000個樣本,每個100個特征y = torch.randn(1000, 1, dtype=torch.float32)   # 對應的標簽# 創建自定義數據集dataset = MyDataset(x, y)print(f"數據集大小:{len(dataset)}")print(f"第一個樣本:{dataset[0]}")  # 查看第一個樣本

三、快速上手:TensorDataset

如果你的數據已經是 PyTorch 張量(Tensor),且不需要復雜的預處理,那么TensorDataset會是更好的選擇。它是 PyTorch 內置的數據集類,能快速將特征和標簽綁定在一起。

1. 代碼示例

from torch.utils.data import TensorDataset, DataLoader# 生成張量數據
x = torch.randn(1000, 100, dtype=torch.float32)
y = torch.randn(1000, 1, dtype=torch.float32)# 使用TensorDataset包裝數據
dataset = TensorDataset(x, y)  # 特征和標簽按索引對應# 查看樣本
print(f"樣本數量:{len(dataset)}")
print(f"第一個樣本特征:{dataset[0][0].shape}")
print(f"第一個樣本標簽:{dataset[0][1]}")

2. 適用場景

  • 數據已轉換為 Tensor 格式
  • 不需要復雜的預處理邏輯
  • 快速搭建訓練流程(如驗證代碼可行性)

四、DataLoader:批量加載數據的利器

有了數據集,還需要高效的批量加載工具。DataLoader可以實現:

  • 批量讀取數據(batch_size
  • 打亂數據順序(shuffle
  • 多線程加載(num_workers

1. 核心參數說明

參數作用
dataset要加載的數據集
batch_size每批樣本數量(常用 32/64/128)
shuffle每個 epoch 是否打亂數據(訓練時設為 True)
num_workers加載數據的線程數(加速數據讀取)

2. 代碼示例

# 創建DataLoader
dataloader = DataLoader(dataset=dataset,batch_size=32,      # 每批32個樣本shuffle=True,       # 訓練時打亂數據num_workers=2       # 2個線程加載
)# 遍歷數據
for batch_idx, (batch_x, batch_y) in enumerate(dataloader):print(f"第{batch_idx}批:")print(f"特征形狀:{batch_x.shape}")  # (32, 100)print(f"標簽形狀:{batch_y.shape}")  # (32, 1)if batch_idx == 2:  # 只看前3批break

五、實戰:用數據加載器訓練線性回歸模型

下面結合一個完整案例,展示如何使用TensorDatasetDataLoader訓練模型。我們將實現一個線性回歸任務,預測生成的隨機數據。

1. 完整代碼

from sklearn.datasets import make_regression
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn, optim# 生成回歸數據
def build_data():bias = 14.5# 生成1000個樣本,100個特征x, y, coef = make_regression(n_samples=1000,n_features=100,n_targets=1,bias=bias,coef=True,random_state=0  # 固定隨機種子,保證結果可復現)# 轉換為Tensor并調整形狀x = torch.tensor(x, dtype=torch.float32)y = torch.tensor(y, dtype=torch.float32).view(-1, 1)  # 轉為列向量bias = torch.tensor(bias, dtype=torch.float32)coef = torch.tensor(coef, dtype=torch.float32)return x, y, coef, bias# 訓練函數
def train():x, y, true_coef, true_bias = build_data()# 構建數據集和數據加載器dataset = TensorDataset(x, y)dataloader = DataLoader(dataset=dataset,batch_size=100,  # 每批100個樣本shuffle=True     # 訓練時打亂數據)# 定義模型、損失函數和優化器model = nn.Linear(in_features=x.size(1), out_features=y.size(1))  # 線性層criterion = nn.MSELoss()  # 均方誤差損失optimizer = optim.SGD(model.parameters(), lr=0.01)  # 隨機梯度下降# 訓練50個epochepochs = 50for epoch in range(epochs):for batch_x, batch_y in dataloader:# 前向傳播y_pred = model(batch_x)loss = criterion(batch_y, y_pred)# 反向傳播和參數更新optimizer.zero_grad()  # 清空梯度loss.backward()        # 計算梯度optimizer.step()       # 更新參數# 打印結果print(f"真實權重:{true_coef[:5]}...")  # 只顯示前5個print(f"預測權重:{model.weight.detach().numpy()[0][:5]}...")print(f"真實偏置:{true_bias}")print(f"預測偏置:{model.bias.item()}")if __name__ == "__main__":train()

2. 代碼解析

  1. 數據生成:用make_regression生成帶噪聲的回歸數據,并轉換為 PyTorch 張量。
  2. 數據集構建:用TensorDataset將特征和標簽綁定,方便后續加載。
  3. 批量加載DataLoader按批次讀取數據,每次訓練用 100 個樣本。
  4. 模型訓練:線性回歸模型通過梯度下降優化,最終輸出預測的權重和偏置,與真實值對比。

六、總結與拓展

本文介紹了 PyTorch 中數據加載的核心工具:

  • 自定義 Dataset:靈活處理特殊數據格式
  • TensorDataset:快速包裝張量數據
  • DataLoader:高效批量加載,支持多線程和數據打亂

在實際項目中,你可以根據數據類型選擇合適的工具:

  • 處理圖片:用ImageFolder(PyTorch 內置,支持按文件夾分類)
  • 處理文本:自定義 Dataset 讀取文本文件并轉換為張量
  • 大規模數據:結合num_workerspin_memory(針對 GPU 加速)

掌握數據加載是深度學習的基礎,用好這些工具能讓你的訓練流程更高效、更易維護。快去試試用它們處理你的數據吧!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/89341.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/89341.shtml
英文地址,請注明出處:http://en.pswp.cn/web/89341.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Python--plist文件的讀取

Python練習:讀取Apple Plist文件 Plist文件簡介 ??定義??:Apple公司創建的基于XML結構的文件格式??特點??:采用XML語法組織數據,可存儲鍵值對、數組等結構化信息文件擴展名??:.plist應用場景: ??iOS系統:?…

JAVA幾個注解記錄

在Java中,Data、AllArgsConstructor和NoArgsConstructor是Lombok庫提供的注解,用于自動生成Java類中的樣板代碼(如getter、setter、構造函數等),從而減少冗余代碼,提高開發效率。以下是它們的詳細功能和使用…

js對象簡介、內置對象

對象、內置對象 jarringslee 對象 對象(object)是js的一種引用數據類型,是一種無序的數據集合“ul”(類比于數組,有序的數據集合“ol”)。 基本上等于結構體。 對象的聲明 //基本方法 let 對象名 {聲…

【工程篇】07:如何打包conda環境并拷貝到另一臺服務器上

這是一份以名為 qwen2.5-vl 的 Conda 環境為例的詳細操作手冊,指導您如何將其打包并遷移至另一臺服務器。操作手冊:遷移 Conda 環境 qwen2.5-vl 至新服務器 本文檔將提供兩種有效的方法來遷移您的 qwen2.5-vl 環境。請根據您的具體需求和服務器條件選擇最…

rustdesk遠控電腦替代todesk,平替向日葵等軟件

rustdesk網頁端遠控電腦docker run --restart always \ --privileged \ -p 9000:9000 \ -p 21114:21114 \ -p 21115:21115 \ -p 21116:21116 \ -p 21116:21116/udp \ -p 21117:21117 \ -p 21118:21118 \ -p 21119:21119 \ -e KEYj8muHpzr2HK00zm9D94b1UFkaJ1bEiWsyA1qxb1nOA \ …

板凳-------Mysql cookbook學習 (十二--------1)

第9章 存儲例程,觸發器和計劃事件 326 9.0 概述 326 9.1 創建復合語句對象 329 mysql> -- 恢復默認分隔符 mysql> DELIMITER ; mysql>mysql> DROP FUNCTION IF EXISTS avg_mail_size; Query OK, 0 rows affected (0.02 sec)mysql> DELIMITER $$ mysq…

密碼學系列文(3)--分組密碼

一、分組密碼概述分組密碼是許多系統安全的一個重要組成部分,可用于構造:擬隨機數生成器流密碼消息認證碼(MAC)和雜湊函數消息認證技術、數據完整性機構、實體認證協議以及單鑰數字簽字體制的核心組成部分應用中對于分組密碼的要求:安全性運行…

WCDB soci 查詢語句

測試代碼 #pragma once #include <string> #include <vector>// Assume OperationLog is a struct representing a row in the table struct OperationLog {int id;std::string op_type;std::string op_subtype;std::string details;std::string timestamp; };clas…

lesson16:Python函數的認識

目錄 一、為什么需要函數&#xff1f; 1. 拒絕重復造輪子 2. 讓代碼像句子一樣可讀 3. 隔離變化&#xff0c;降低維護成本 二、函數的定義&#xff1a;編寫高質量函數的5個要素 基本語法框架 1. 函數命名的黃金法則&#xff08;PEP8規范&#xff09; 2. 不可或缺的文檔…

通過輪詢方式使用LoRa DTU有什么缺點?

在物聯網系統中&#xff0c;DTU&#xff08;Data Transfer Unit&#xff09;通常用于通過485或M-Bus等接口抄讀子設備的數據&#xff0c;并將這些數據傳輸到平臺側。然而&#xff0c;如果DTU采用輪詢方式與平臺通信&#xff0c;會帶來一系列問題&#xff0c;尤其是在功耗和系統…

Syntax Error: Error: PostCSS received undefined instead of CSS string

報錯&#xff1a;Syntax Error: Error: PostCSS received undefined instead of CSS string npm rebuild node-sass報錯&#xff1a;npm i canvas 報錯 canvas2.11.2 run install node-pre-gyp install --fallback-to-build --update-binary npm install canvas --canvas_binar…

人工智能之數學基礎:概率論和數理統計在機器學習的地位

概率和統計的概念概率統計是各類學科中唯一一門專門研究隨機現象的規律性的學科&#xff0c;隨機現象的廣泛性決定了這一學科的重要性。概率論是數學的分支&#xff0c;它研究的是如何定量描述隨機現象及其規律。我們之前經常在天氣軟件上看到&#xff1a;“今天下雨的概率是95…

第十四章 Stream API

JAVA語言引入了一個流式Stream API,這個API對集合數據進行操作&#xff0c;類似于使用SQL執行的數據庫查詢&#xff0c;同樣可以使用Stream API并行執行操作。Stream和Collection的區別Collection:靜態的內存數據結構&#xff0c;強調的是數據。Stream API:和集合相關的計算操作…

Oracle數據庫各版本間的技術迭代詳解

今天我想和大家聊聊一個我們可能每天都在用&#xff0c;但未必真正了解的技術——Oracle數據庫的版本。如果你是企業的IT工程師&#xff0c;可能經歷過“升級數據庫”的頭疼&#xff1b;如果你是業務負責人&#xff0c;可能疑惑過“為什么一定要換新版本”&#xff1b;甚至如果…

論文reading學習記錄3 - weekly - 模塊化視覺端到端ST-P3

文章目錄前言一、摘要與引言二、Related Word2.1 可解釋的端到端架構2.2 鳥瞰圖2.3 未來預測2.4 規劃三、方法3.1 感知bev特征積累3.1.1 空間融合&#xff08;幀的對齊&#xff09;3.1.2 時間融合3.2 預測&#xff1a;雙路徑未來建模3.3 規劃&#xff1a;先驗知識的整合與提煉4…

crawl4ai--bitcointalk爬蟲實戰項目

&#x1f4cc; 項目目標本項目旨在自動化抓取 Bitcointalk 論壇中指定板塊的帖子數據&#xff08;包括主貼和所有回復&#xff09;&#xff0c;并提取出結構化信息如標題、作者、發帖時間、用戶等級、活躍度、Merit 等&#xff0c;以便進一步分析或使用。本項目只供科研學習使用…

調用 System.gc() 的弊端及修復方式

弊端分析不可控的執行時機System.gc() 僅是 建議 JVM 執行垃圾回收&#xff0c;但 JVM 可自由忽略該請求&#xff08;尤其是高負載時&#xff09;。實際回收時機不確定&#xff0c;無法保證內存及時釋放。嚴重的性能問題Stop-The-World 停頓&#xff1a;觸發 Full GC 時會暫停所…

git merge 和 git rebase 的區別

主要靠一張圖&#xff1a;區別 git merge git checkout feature git merge master此時在feature上git會自動產生一個新的commit 修改的是當前分支 feature。 git rebase git checkout feature git rebase master&#xff08;在feature分支上執行&#xff0c;修改的是master分支…

Java學習--JVM(2)

JVM提供垃圾回收機制&#xff0c;其也是JVM的核心機制&#xff0c;其主要是實現自動回收不再被引用的對象所占用的內存&#xff1b;對內存進行整理&#xff0c;防止內存碎片化&#xff1b;以及對內存分配配進行管理。JVM 通過兩種主要算法判斷對象是否可回收&#xff1a;引用計…

用大模型(qwen)提取知識三元組并構建可視化知識圖譜:從文本到圖譜的完整實現

引言 知識圖譜作為一種結構化的知識表示方式&#xff0c;在智能問答、推薦系統、數據分析等領域有著廣泛應用。在信息爆炸的時代&#xff0c;如何從非結構化文本中提取有價值的知識并進行結構化展示&#xff0c;是NLP領域的重要任務。知識三元組&#xff08;Subject-Relation-O…