Python訓練打卡Day38

Dataset和Dataloader類

知識點回顧:

  1. Dataset類的__getitem__和__len__方法(本質是python的特殊方法)
  2. Dataloader類
  3. minist手寫數據集的了解

????????在遇到大規模數據集時,顯存常常無法一次性存儲所有數據,所以需要使用分批訓練的方法。為此,PyTorch提供了DataLoader類,該類可以自動將數據集切分為多個批次batch,并支持多線程加載數據。此外,還存在Dataset類,該類可以定義數據集的讀取方式和預處理方式。

1. DataLoader類:決定數據如何加載

2. Dataset類:告訴程序去哪里找數據,如何讀取單個樣本,以及如何預處理。

使用的數據集為MNIST手寫數字數據集。該數據集包含60000張訓練圖片和10000張測試圖片,每張圖片大小為28*28像素,共包含10個類別。因為每個數據的維度比較小,所以既可以視為結構化數據,用機器學習、MLP訓練,也可以視為圖像數據,用卷積神經網絡訓練。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加載數據的工具
from torchvision import datasets, transforms # torchvision 是一個用于計算機視覺的庫,datasets 和 transforms 是其中的模塊
import matplotlib.pyplot as plt# 設置隨機種子,確保結果可復現
torch.manual_seed(42)
# 1. 數據預處理,該寫法非常類似于管道pipeline
# transforms 模塊提供了一系列常用的圖像預處理操作# 先歸一化,再標準化
transform = transforms.Compose([transforms.ToTensor(),  # 轉換為張量并歸一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST數據集的均值和標準差,這個值很出名,所以直接使用
])# 2. 加載MNIST數據集,如果沒有會自動下載
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)

Dataset類

import matplotlib.pyplot as plt# 隨機選擇一張圖片,可以重復運行,每次都會隨機選擇
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 隨機選擇一張圖片的索引
# len(train_dataset) 表示訓練集的圖片數量;size=(1,)表示返回一個索引;torch.randint() 函數用于生成一個指定范圍內的隨機數,item() 方法將張量轉換為 Python 數字
image, label = train_dataset[sample_idx] # 獲取圖片和標簽

為什么train_dataset[sample_idx]可以獲取到圖片和標簽,是因為 datasets.MNIST這個類繼承了torch.utils.data.Dataset類,這個類中有一個方法__getitem__,這個方法會返回一個tuple,tuple中第一個元素是圖片,第二個元素是標簽。

torch.utils.data.Dataset類

一個抽象基類,所有自定義數據集都需要繼承它并實現兩個核心方法:

- __len__():返回數據集的樣本總數。

- __getitem__(idx):根據索引idx返回對應樣本的數據和標簽。

PyTorch 要求所有數據集必須實現__getitem__和__len__,這樣才能被DataLoader等工具兼容。這是一種接口約定,類似函數參數的規范。這意味著,如果你要創建一個自定義數據集,你需要實現這兩個方法,否則PyTorch將無法識別你的數據集。

?__getitem__和__len__ 是類的特殊方法(也叫魔術方法 ),它們不是像普通函數那樣直接使用,而是需要在自定義類中進行定義,來賦予類特定的行為。

__getitem__方法

用于讓對象支持索引操作,當使用[]語法訪問對象元素時,Python 會自動調用該方法。

# 示例代碼
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 創建類的實例
my_list_obj = MyList()
# 此時可以使用索引訪問元素,這會自動調用__getitem__方法
print(my_list_obj[2])  # 輸出:30

__len__方法

用于返回對象中元素的數量,當使用內置函數len()作用于對象時,Python 會自動調用該方法。

class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 創建類的實例
my_list_obj = MyList()
# 使用len()函數獲取元素數量,這會自動調用__len__方法
print(len(my_list_obj))  # 輸出:5
# minist數據集的簡化版本
class MNIST(Dataset):def __init__(self, root, train=True, transform=None):# 初始化:加載圖片路徑和標簽self.data, self.targets = fetch_mnist_data(root, train) # 這里假設 fetch_mnist_data 是一個函數,用于加載 MNIST 數據集的圖片路徑和標簽self.transform = transform # 預處理操作def __len__(self): return len(self.data)  # 返回樣本總數def __getitem__(self, idx): # 獲取指定索引的樣本# 獲取指定索引的圖像和標簽img, target = self.data[idx], self.targets[idx]# 應用圖像預處理(如ToTensor、Normalize)if self.transform is not None: # 如果有預處理操作img = self.transform(img) # 轉換圖像格式# 這里假設 img 是一個 PIL 圖像對象,transform 會將其轉換為張量并進行歸一化return img, target  # 返回處理后的圖像和標簽

?通俗地類比:

Dataset = 廚師(準備單個菜品)

DataLoader = 服務員(將菜品按訂單組合并上桌)

預處理(如切菜、調味)屬于廚師的工作,而非服務員。所以在dataset就需要添加預處理步驟。

# 可視化原始圖像(需要反歸一化)
def imshow(img):img = img * 0.3081 + 0.1307  # 反標準化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 顯示灰度圖像plt.show()print(f"Label: {label}")
imshow(image)

Dataloader類

# 3. 創建數據加載器
train_loader = DataLoader(train_dataset,batch_size=64, # 每個批次64張圖片,一般是2的冪次方,這與GPU的計算效率有關shuffle=True # 隨機打亂數據
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每個批次1000張圖片# shuffle=False # 測試時不需要打亂數據
)

總結:

Dataset 類:定義數據的內容和格式(即 “如何獲取單個樣本”),包括:
? 數據存儲路徑 / 來源(如文件路徑、數據庫查詢)。
? 原始數據的讀取方式(如圖像解碼為 PIL 對象、文本讀取為字符串)。
? 樣本的預處理邏輯(如裁剪、翻轉、歸一化等,通常通過 transform 參數實現)。
? 返回值格式(如 (image_tensor, label))。
? DataLoader 類:定義數據的加載方式和批量處理邏輯(即 “如何高效批量獲取數據”),包括:
? 批量大小(batch_size)。
? 是否打亂數據順序(shuffle)。

@浙大疏錦行

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/83691.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/83691.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/83691.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

web3-區塊鏈基礎:從區塊添加機制到哈希加密與默克爾樹結構

區塊鏈基礎:從區塊添加機制到哈希加密與默克爾樹結構 什么是區塊鏈 抽象的回答: 區塊鏈提供了一種讓多個參與方在沒有一個唯一可信方的情況下達成合作 若有可信第三方 > 不需要區塊鏈 [金融系統中常常沒有可信的參與方] 像股票市場,或者一個國家的…

MySQL 索引:為使用 B+樹作為索引數據結構,而非 B樹、哈希表或二叉樹?

在數據庫的世界里,性能是永恒的追求。而索引,作為提升查詢速度的利器,其底層數據結構的選擇至關重要。如果你深入了解過 MySQL(尤其是其主流存儲引擎 InnoDB),你會發現它不約而同地選擇了 B樹 作為索引的主…

Kafka broker 寫消息的過程

Producer → Kafka Broker → Replication → Consumer|Partition chosen (by key or round-robin)|Message appended to end of log (commit log)上面的流程是kafka 寫操作的大體流程。 kafka 不會特意保留message 在內存中,而是直接寫入了disk。 那么消費的時候&…

leetcode hot100(兩數之和、字母異位詞分組、最長連續序列)

兩數之和 題目鏈接 參考鏈接&#xff1a; 題目描述&#xff1a; 暴力法 雙重循環查找目標值 class Solution {public int[] twoSum(int[] nums, int target) {int[] res new int[2];for(int i 0 ; i < nums.length ; i){boolean isFind false;for(int j i 1 ; j …

SkyWalking架構深度解析:分布式系統監控的利器

一、SkyWalking概述 SkyWalking是一款開源的APM(應用性能監控)系統&#xff0c;專門為微服務、云原生和容器化架構設計。它由Apache軟件基金會孵化并畢業&#xff0c;已成為分布式系統監控領域的明星項目。 核心特性 ?分布式追蹤?&#xff1a;跨服務調用鏈路的完整追蹤?服務…

Matlab程序設計基礎

matlab程序設計基礎 程序設計函數文件1.函數文件的基本結構2.創建并使用函數文件的示例3.帶多個輸出的函數示例4.包含子函數的函數文件 流程控制1. if 條件語句2. switch 多分支選擇語句3. try-catch 異常處理語句ME與lasterr 4. while 循環語句5. for 循環語句break和continue…

Client-Side Path Traversal 漏洞學習筆記

近年來,隨著Web前端技術的飛速發展,越來越多的數據請求和處理邏輯被轉移到客戶端(瀏覽器)執行。這大大提升了用戶體驗,但也帶來了新的安全威脅。其中,Client-Side Path Traversal(客戶端路徑穿越,CSPT)作為一種新興的漏洞類型,逐漸受到安全研究者和攻擊者的關注。本文…

基于Socketserver+ThreadPoolExecutor+Thread構造的TCP網絡實時通信程序

目錄 介紹&#xff1a; 源代碼&#xff1a; Socketserver-服務端代碼 Socketserver客戶端代碼&#xff1a; 介紹&#xff1a; socketserver是一種傳統的傳輸層網絡編程接口&#xff0c;相比WebSocket這種應用層的協議來說&#xff0c;socketserver比較底層&#xff0c;soc…

【無標題】平面圖四色問題P類歸屬的嚴格論證——基于拓撲收縮與動態調色算法框架

平面圖四色問題P類歸屬的嚴格論證——基于拓撲收縮與動態調色算法框架 --- #### **核心定理** 任意平面圖 \(G (V, E)\) 的四色著色問題可在多項式時間 \(O(|V|^2)\) 內求解&#xff0c;且算法正確性由以下三重保證&#xff1a; 1. **拓撲不變性**&#xff08;Kuratowsk…

HALCON 深度學習訓練 3D 圖像的幾種方式優缺點

HALCON 深度學習訓練 3D 圖像的幾種方式優缺點 ** 在計算機視覺和工業檢測等領域&#xff0c;3D 圖像數據的處理和分析變得越來越重要&#xff0c;HALCON 作為一款強大的機器視覺軟件&#xff0c;提供了多種深度學習訓練 3D 圖像的方式。每種方式都有其獨特的設計思路和應用場…

pytest中的元類思想與實戰應用

在Python編程世界里&#xff0c;元類是一種強大而高級的特性&#xff0c;它能在類定義階段深度定制類的創建與行為。而pytest作為熱門的測試框架&#xff0c;雖然沒有直接使用元類&#xff0c;但在設計機制上&#xff0c;卻暗含了許多與元類思想相通的地方。接下來&#xff0c;…

以太網幀結構和封裝【三】-- TCP/UDP頭部信息

TCP頭部用于建立可靠連接、流量控制及數據完整性校驗。 Ipv4封裝tcp報&#xff1a; Ipv6封裝tcp報&#xff1a; UDP頭部信息 UDP關鍵協議特性&#xff1a; 1&#xff09;無連接&#xff1a;無需握手&#xff0c;直接發送數據。 2&#xff09;不可靠性&#xff1a;不保證數據…

MySQL補充知識點學習

書接上文&#xff1a;MySQL關系型數據庫學習&#xff0c;繼續看書補充MySQL知識點學習。 1. 基本概念學習 1.1 游標&#xff08;Cursor&#xff09; MySQL 游標是一種數據庫對象&#xff0c;它允許應用程序逐行處理查詢結果集&#xff0c;而不是一次性獲取所有結果。游標在需…

基于InternLM的情感調節大師FunGPT

基于書生系列大模型&#xff0c;社區用戶不斷創造出令人耳目一新的項目&#xff0c;從靈感萌發到落地實踐&#xff0c;每一個都充滿智慧與價值。“與書生共創”將陸續推出一系列文章&#xff0c;分享這些項目背后的故事與經驗。歡迎訂閱并積極投稿&#xff0c;一起分享經驗與成…

【拓撲】1639.拓撲排序

題目描述 這是 2018 2018 2018 年研究生入學考試中給出的一個問題&#xff1a; 以下哪個選項不是從給定的有向圖中獲得的拓撲序列&#xff1f; 現在&#xff0c;請你編寫一個程序來測試每個選項。 輸入格式 第一行包含兩個整數 N N N 和 M M M&#xff0c;分別表示有向圖…

macOS 上使用 Homebrew 安裝redis-cli

在 macOS 上使用 Homebrew 安裝 redis-cli&#xff08;Redis 命令行工具&#xff09;非常簡單&#xff0c;以下是詳細步驟&#xff1a; 1. 安裝 Redis&#xff08;包含 redis-cli&#xff09; 運行以下命令安裝 Redis&#xff1a; brew install redis這會安裝完整的 Redis 服…

Scratch節日 | 六一兒童節射擊游戲

六一兒童節快樂&#xff01;這款超有趣的 六一兒童節射擊游戲&#xff0c;讓你變身小貓弓箭手&#xff0c;守護節日的快樂時光&#xff01; &#x1f3ae; 游戲玩法 上下方向鍵&#xff1a;控制小貓的位置&#xff0c;自由移動&#xff0c;瞄準目標&#xff01; 空格鍵&#…

[AI Claude] 軟件測試2

好的&#xff0c;我現在為你準備一份預填充好大部分內容的測試報告和PPT內容。這里面的數據是我根據項目結構和常見的測試場景推理和編造的&#xff0c;你需要根據你的實際操作結果&#xff08;包括截圖、實際數據、發現的缺陷等&#xff09;進行替換和修改。 我將按照之前定義…

程序代碼篇---face_recognition庫實現的人臉檢測系統

以下是一個基于face_recognition庫的人臉管理系統,支持從文件夾加載人臉數據、實時識別并顯示姓名,以及動態添加新人臉。系統采用模塊化設計,代碼結構清晰,易于擴展。 一、系統架構 face_recognition_system/ ├── faces/ # 人臉數據庫(按姓名命名子…

Cursor 工具項目構建指南:Java 21 環境下的 Spring Boot Prompt Rules 約束

簡簡單單 Online zuozuo: 簡簡單單 Online zuozuo 簡簡單單 Online zuozuo 簡簡單單 Online zuozuo 簡簡單單 Online zuozuo :本心、輸入輸出、結果 簡簡單單 Online zuozuo : 文章目錄 Cursor 工具項目構建指南:Java 21 環境下的 Spring Boot Prompt Rules 約束前言項目簡…