PyTorch 數據加載實戰:從 CSV 到圖像的全流程解析

目錄

一、PyTorch 數據加載的核心組件

1.1 Dataset 類的核心方法

1.2 DataLoader 的作用

二、加載 CSV 數據實戰

2.1 自定義 CSV 數據集

2.2 使用 TensorDataset 快速加載

三、加載圖像數據實戰

3.1 自定義圖像數據集

3.2 使用 ImageFolder 快速加載

四、加載官方數據集

五、總結


在深度學習項目中,數據加載是模型訓練的第一步,也是至關重要的一步。PyTorch 提供了靈活的數據加載工具,讓我們能夠輕松處理各種類型的數據。本文將結合實際代碼,詳細講解如何使用 PyTorch 加載 CSV 數據和圖像數據,幫助初學者快速掌握數據加載的核心技巧。

一、PyTorch 數據加載的核心組件

PyTorch 的數據加載主要依賴兩個核心類:DatasetDataLoader

  • Dataset:負責數據的讀取和預處理,是所有自定義數據集的基類
  • DataLoader:負責批量加載數據,支持打亂順序、多線程加載等功能

1.1 Dataset 類的核心方法

自定義數據集需要繼承Dataset類,并實現以下三個方法:

class CustomDataset(Dataset):def __init__(self, ...):  # 初始化數據集,加載文件路徑等passdef __len__(self):  # 返回數據集大小return len(self.data)def __getitem__(self, index):  # 根據索引返回樣本return sample, label

1.2 DataLoader 的作用

DataLoader像是一個 "搬運工",將Dataset中的數據按批次搬運給模型:

dataloader = DataLoader(dataset=dataset,  # 要加載的數據集batch_size=32,    # 批次大小shuffle=True,     # 是否打亂數據num_workers=2     # 多線程加載
)

二、加載 CSV 數據實戰

CSV 文件是存儲表格數據的常用格式,比如學生成績表、特征數據表等。下面我們通過實際代碼講解如何加載 CSV 數據。

2.1 自定義 CSV 數據集

import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pdclass CsvDataset(Dataset):def __init__(self, filepath):# 讀取CSV文件df = pd.read_csv(filepath)# 刪除不需要的列(學號、姓名)df.drop(['學號', '姓名'], axis=1, inplace=True)# 提取特征和標簽x = df.iloc[1:, :-1]  # 從第二行開始,取除最后一列外的所有列作為特征y = df.iloc[1:, -1]   # 從第二行開始,取最后一列作為標簽# 轉換為Tensorself.data = torch.tensor(x.values, dtype=torch.float)self.labels = torch.tensor(y.values, dtype=torch.float)def __len__(self):return len(self.data)def __getitem__(self, index):sample = self.data[index]label = self.labels[index]return sample, label# 測試代碼
def test_csv_dataset():filepath = '大數據答辯成績表.csv'dataset = CsvDataset(filepath)print(f"數據集大小: {len(dataset)}")print(f"第一個樣本: {dataset[0]}")test_csv_dataset()

2.2 使用 TensorDataset 快速加載

如果數據已經是 Tensor 格式,可以使用TensorDataset快速創建數據集,無需自定義類:

def test_tensor_dataset():filepath = '大數據答辯成績表.csv'df = pd.read_csv(filepath)df.drop(['學號', '姓名'], axis=1, inplace=True)x = df.iloc[1:, :-1]y = df.iloc[1:, -1]# 轉換為Tensordata = torch.tensor(x.values, dtype=torch.float)labels = torch.tensor(y.values, dtype=torch.float)# 使用TensorDatasetdataset = TensorDataset(data, labels)print(f"第一個樣本: {dataset[0]}")

三、加載圖像數據實戰

處理圖像數據時,我們需要考慮圖像的讀取、大小調整、格式轉換等問題。下面介紹兩種加載圖像數據的方法。

3.1 自定義圖像數據集

import os
import cv2
from torch.utils.data import Datasetclass PicDataset(Dataset):def __init__(self, filepath):self.filepaths = []  # 存儲圖像路徑self.labels = []     # 存儲標簽dirnames = []        # 存儲類別名稱# 遍歷文件夾for root, dirs, files in os.walk(filepath):if len(dirs) > 0:dirnames = dirs  # 獲取類別文件夾名稱for file in files:f_path = os.path.join(root, file)self.filepaths.append(f_path)# 根據文件夾名稱確定標簽classname = root.split('\\')[-1]self.labels.append(dirnames.index(classname))def __len__(self):return len(self.filepaths)def __getitem__(self, index):filepath = self.filepaths[index]# 讀取圖像img = cv2.imread(filepath)# 調整圖像大小為112x112img = cv2.resize(img, (112, 112))# 轉換為Tensor并調整維度 (HWC -> CHW)t_img = torch.tensor(img)t_img = t_img.permute(2, 0, 1)label = self.labels[index]return t_img, label# 測試代碼
def test_pic_dataset():filepath = r'E:\人工智能\深度學習\dataset\butterfly'dataset = PicDataset(filepath)print(f"數據集大小: {len(dataset)}")img, label = dataset[0]print(f"圖像形狀: {img.shape}, 標簽: {label}")

3.2 使用 ImageFolder 快速加載

PyTorch 的ImageFolder是加載圖像數據集的便捷工具,特別適合以下結構的數據集:

root/class1/img1.jpgimg2.jpgclass2/img1.jpgimg2.jpg

使用方法如下:

from torchvision.datasets import ImageFolder
from torchvision import transformsdef test_image_folder():filepath = r'E:\人工智能\深度學習\dataset\butterfly'# 定義圖像轉換transform = transforms.Compose([transforms.Resize((112, 112)),  # 調整大小transforms.ToTensor(),          # 轉換為Tensor])# 使用ImageFolder加載數據dataset = ImageFolder(root=filepath, transform=transform)print(f"類別: {dataset.classes}")print(f"數據集大小: {len(dataset)}")# 創建DataLoaderdataloader = DataLoader(dataset=dataset,batch_size=1,shuffle=True)# 顯示一張圖像for img, label in dataloader:print(f"圖像形狀: {img.shape}")print(f"標簽: {label}")breaktest_image_folder()

四、加載官方數據集

PyTorch 提供了許多常用的公開數據集(如 MNIST、CIFAR 等),可以直接下載使用:

from torchvision import datasets, transformsdef test_mnist_dataset():# 定義轉換transform = transforms.Compose([transforms.ToTensor()])# 加載MNIST訓練集dataset = datasets.MNIST(root='../dataset',  # 數據保存路徑train=True,         # 訓練集download=True,      # 如果沒有數據則下載transform=transform)# 創建DataLoaderdataloader = DataLoader(dataset=dataset,batch_size=1,shuffle=True)# 顯示一張圖像for img, label in dataloader:print(f"圖像形狀: {img.shape}")print(f"標簽: {label}")breaktest_mnist_dataset()

五、總結

本文介紹了 PyTorch 加載不同類型數據的方法,包括:

  1. 加載 CSV 數據:可以自定義CsvDataset,或使用TensorDataset快速加載
  2. 加載圖像數據:可以自定義PicDataset,或使用ImageFolder加載按類別組織的圖像
  3. 加載官方數據集:直接使用torchvision.datasets中的類

掌握數據加載的技巧,可以為后續的模型訓練打下堅實基礎。在實際項目中,需要根據數據的具體格式和特點,選擇合適的加載方式,并進行必要的預處理。

希望本文能幫助大家快速上手 PyTorch 的數據加載,如果你有任何問題或建議,歡迎在評論區留言討論!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/89716.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/89716.shtml
英文地址,請注明出處:http://en.pswp.cn/web/89716.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

程序人生,開啟2025下半年

時光匆匆,2025年已然過去一半。轉眼來到了7月份。 回望過去上半年,可能你也經歷了職場的浮沉、生活的跌宕、家庭的變故。 而下半年,生活依舊充滿了各種變數。 大環境的起起伏伏、生活節奏的加快,都讓未來的不確定性愈發凸顯。 在這…

在 .NET Core 中創建 Web Socket API

要在 ASP.NET Core 中創建 WebSocket API,您可以按照以下步驟操作:設置新的 ASP.NET Core 項目打開 Visual Studio 或您喜歡的 IDE。 創建一個新的 ASP.NET Core Web 應用程序項目。 選擇API模板,因為這將成為您的 WebSocket API 的基礎。在啟…

Python 之地址編碼識別

根據輸入地址,利用已有的地址編碼文件,構造處理規則策略識別地址的編碼。 lib/address.json 地址編碼文件(這個文件太大,博客里放不下,需要的話可以到 gitcode 倉庫獲取:https://gitcode.com/TomorrowAndT…

kafka的部署

目錄 一、kafka簡介 1.1、概述 1.2、消息系統介紹 1.3、點對點消息傳遞模式 1.4、發布-訂閱消息傳遞模式 二、kafka術語解釋 2.1、結構概述 2.2、broker 2.3、topic 2.4、producer 2.5、consumer 2.6、consumer group 2.7、leader 2.8、follower 2.9、partition…

小語種OCR識別技術實現原理

小語種OCR(光學字符識別)技術的實現原理涉及計算機視覺、自然語言處理(NLP)和深度學習等多個領域的融合,其核心目標是讓計算機能夠準確識別并理解不同語言的印刷或手寫文本。以下是其關鍵技術實現原理的詳細解析&#…

GPT:讓機器擁有“創造力”的語言引擎

當ChatGPT寫出莎士比亞風格的十四行詩,當GitHub Copilot自動生成編程代碼,背后都源于同一項革命性技術——**GPT(Generative Pre-trained Transformer)**。今天,我們將揭開這項“語言魔術”背后的科學原理!…

LeetCode|Day19|14. 最長公共前綴|Python刷題筆記

LeetCode|Day19|14. 最長公共前綴|Python刷題筆記 🗓? 本文屬于【LeetCode 簡單題百日計劃】系列 👉 點擊查看系列總目錄 >> 📌 題目簡介 題號:14. 最長公共前綴 難度:簡單…

安全事件響應分析--基礎命令

----萬能密碼oror1 or # 1or11 1 or 11安全事件響應分析------***windoes***------方法開機啟動有無異常文件 【開始】?【運行】?【msconfig】文件排查 各個盤下的temp(tmp)相關目錄下查看有無異常文件 :Windows產生的 臨時文件 可以通過查看日志且通過篩…

基于C#+SQL Server實現(Web)學生選課管理系統

學生選課管理系統的設計與開發一、項目背景學生選課管理系統是一個學校不可缺少的部分,傳統的人工管理檔案的方式存在著很多的缺點,如:效率低、保密性差等,所以開發一套綜合教務系統管理軟件很有必要,它應該具有傳統的…

垃圾回收(GC)

內存管理策略,在業務進程運行的過程中,由垃圾收集器以類似守護協程的方式在后臺運行,按照指定策略回收不再被使用的對象,釋放內存空間進行回收 優勢: 屏蔽內存回收的細節:屏蔽復雜的內存管理工作&#xff0…

Datawhale AI夏令營-機器學習

比賽簡介 「用戶新增預測挑戰賽」是由科大訊飛主辦的一項數據科學競賽,旨在通過機器學習方法預測用戶是否為新增用戶 比賽屬于二分類任務,評價指標采用F1分數,分數越高表示模型性能越好。 如果你有一份帶標簽的表格型數據,只要…

Spring IOC容器在Web環境中是如何啟動的(源碼級剖析)?

文章目錄一、Web 環境中的 Spring MVC 框架二、Web 應用部署描述配置傳統配置(web.xml):Java配置類(Servlet 3.0):三、核心啟動流程詳解1. 啟動流程圖2. ★容器初始化入口:ContextLoaderListene…

18個優質Qt開源項目匯總

1,Clementine Music Player Clementine Music Player 是一個功能完善、跨平臺的開源音樂播放器,非常適合用于學習如何開發媒體類應用,尤其是跨平臺桌面應用。它基于 Qt 框架開發,支持多種操作系統,包括 Windows、macO…

計算機視覺:AI 的 “眼睛” 如何看懂世界?

1. 什么是計算機視覺:讓機器 “看見” 并 “理解” 的技術1.1 計算機視覺的核心目標計算機視覺(CV)是人工智能的一個重要分支,它讓計算機能夠 “看懂” 圖像和視頻 —— 不僅能捕捉像素信息,還能分析內容、提取語義&am…

華為OD刷題記錄

華為OD刷題記錄 刷過的題 入門 1、進制 2、NC61 doing 訂閱專欄

QT學習教程(二十五)

雙緩沖技術&#xff08;Double Buffering&#xff09;&#xff08; 2、公有函數實現&#xff09;#include <QtGui> #include <cmath> using namespace std; #include "plotter.h"以上代碼為文件的開頭&#xff0c;在這里把std 的名空間加入到當前的全…

設計模式筆記_結構型_裝飾器模式

1.裝飾器模式介紹裝飾器模式是一種結構型設計模式&#xff0c;允許你動態地給對象添加行為&#xff0c;而無需修改其代碼。它的核心思想是將對象放入一個“包裝器”中&#xff0c;這個包裝器提供了額外的功能&#xff0c;同時保持原有對象的接口不變。想象一下&#xff0c;你有…

day25 力扣90.子集II 力扣46.全排列 力扣47.全排列 II

子集II給你一個整數數組 nums &#xff0c;找出并返回所有該數組中不同的遞增子序列&#xff0c;遞增子序列中 至少有兩個元素 。你可以按 任意順序 返回答案。數組中可能含有重復元素&#xff0c;如出現兩個整數相等&#xff0c;也可以視作遞增序列的一種特殊情況。示例 1&…

Solidity 中的`bytes`

在 Solidity 中&#xff0c;bytes 和 bytes32 都是用來保存二進制數據的類型&#xff0c;但它們的長度、使用場景、Gas 成本完全不同。? 一句話區分類型一句話總結bytes32定長 32 字節&#xff0c;適合做哈希、地址、標識符等固定長度數據。bytes動態長度字節數組&#xff0c;…

初學者STM32—PWM驅動電機與舵機

一、簡介 上一節課主要學習了輸出比較和PWM的基本原理和結構&#xff0c;本節課就主要以實踐為主通過STM32最小系統板和驅動器控制舵機和直流電機。 上一節課的坐標 初學者STM32—輸出比較與PWM-CSDN博客 二、舵機 舵機是一種根據輸入PWM信號占空比來控制輸出角度的裝置 輸…