Pytorch的Dataloader使用詳解

PyTorch 的 DataLoader 是數據加載的核心組件,它能高效地批量加載數據并進行預處理。

Pytorch DataLoader基礎概念

DataLoader基礎概念
DataLoader是PyTorch基礎概念
DataLoader是PyTorch中用于加載數據的工具,它可以:批量加載數據(batch loading)打亂數據(shuffling)并行加載數據(多線程)
自定義數據加載方式Dataloader的基本使用from torch.utils.data import Dataset, DataLoader

自定義數據集類

class MyDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __getitem__(self, index):return self.data[index], self.labels[index]def __len__(self):return len(self.data)

創建數據集實例

dataset = MyDataset(data, labels)

創建DataLoader

dataloader = DataLoader(dataset=dataset,      # 數據集batch_size=32,        # 批次大小shuffle=True,         # 是否打亂數據num_workers=4,        # 多進程加載數據的線程數drop_last=False       # 當樣本數不能被batch_size整除時,是否丟棄最后一個不完整的batch
)
# 使用DataLoader迭代數據
for batch_data, batch_labels in dataloader:# 訓練或推理代碼pass

DataLoader重要參數詳解

  1. dataset: 要加載的數據集,必須是Dataset類的實例 batch_size: 每個批次的樣本數
  2. shuffle:是否在每個epoch重新打亂數據
  3. sampler:自定義從數據集中抽取樣本的策略,如果指定了sampler,則shuffle必須為False
  4. num_workers:使用多少個子進程加載數據,0表示在主進程中加載。
  5. collate_fn:將一批數據整合成一個批次的函數,特別使用于處理不同長度的序列數據
  6. Pin_memory:如果為True,數據加載器會將張量復制到CUDA固定內存中,加速CPU到GPU的數據傳輸
  7. drop_last: 如果數據集大小不能被batch_size整除,是否丟棄最后一個不完整的批次。
  8. timeout:收集一個批次的超時值
  9. worker_init_fn:每個worker初始化時被調用的函數
  10. weight_sampler:參數決定是都使用加權采樣器來平衡類別分布
if infinite_data_loader:data_loader = InfiniteDataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)
else:data_loader = DataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)n_class = len(data.classes)
return data_loader, n_class
這段代碼決定了如何創建數據加載器,根據infinite_data_loader參數選擇不同的加載器類型:
if infinite_data_loader:data_loader = InfiniteDataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)
else:data_loader = DataLoader(dataset=data,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,sampler=sampler,**kwargs)n_class = len(data.classes)
return data_loader, n_class

代碼解析

這段代碼基于infinite_data_loader參數創建不同類型的數據加載器:
當infinite_data_loader為True時:
創建InfiniteDataLoader實例
自定義的無限循環數據加載器,會持續提供數據而不會在一個epoch結束時停止
當infinite_data_loader為False時:
創建標準的PyTorch DataLoader實例
這是普通的數據加載器,一個epoch結束后會停止

共同參數:

dataset=data:要加載的數據集
batch_size=batch_size:每批數據的大小
shuffle=shuffle:是否打亂數據(之前代碼中已設置)
num_workers=num_workers:用于并行加載數據的線程數
sampler=sampler:用于采樣的策略(之前代碼中已設置,可能是加權采樣器)
**kwargs:其他可能的參數,如pin_memory、drop_last等

返回值:

data_loader:創建好的數據加載器
n_class = len(data.classes):數據集中的類別數量
InfiniteDataLoader的作用
在您的代碼中定義了兩種InfiniteDataLoader實現:一種作為DataLoader的子類,另一種是完全自定義的類。它們的共同目的是:
持續提供數據:當一個epoch結束后,自動重新開始,不會引發StopIteration異常
支持長時間訓練:在需要長時間訓練的場景中特別有用,如半監督學習或者領域適應
避免手動重置:不需要在每個epoch結束后手動重置數據加載器

使用場景

無限數據加載器特別適用于:
持續訓練:模型需要無限期地訓練,如自監督學習或強化學習
不均勻更新:源域和目標域數據需要不同頻率的更新
流式訓練:數據以流的形式到達,不需要明確的epoch邊界
基于迭代而非epoch的訓練:訓練基于迭代次數而非數據epoch
最后的返回值n_class提供了數據集的類別數量,這對模型構建和評估都很重要,比如設置分類層的輸出維度或計算平均類別準確率。
高級用法

1.自定義collate_fn處理變長序列

def collate_fn(batch):# 排序批次數據,按序列長度降序batch.sort(key=lambda x: len(x[0]), reverse=True)# 分離數據和標簽sequences, labels = zip(*batch)# 計算每個序列的長度lengths = [len(seq) for seq in sequences]# 填充序列到相同長度padded_seqs = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)return padded_seqs, torch.tensor(labels), lengths

使用自定義的collate_fn

dataloader = DataLoader(dataset=text_dataset,batch_size=16,shuffle=True,collate_fn=collate_fn
)

2.使用Sampler進行不均衡數據采樣
from torch.utils.data import WeightedRandomSampler

假設我們有類別不平衡問題,計算采樣權重

class_count = [100, 1000, 500]  # 每個類別的樣本數量
weights = 1.0 / torch.tensor(class_count, dtype=torch.float)
sample_weights = weights[target_list]  # target_list是每個樣本的類別索引

創建WeightedRandomSampler

sampler = WeightedRandomSampler(weights=sample_weights,num_samples=len(sample_weights),replacement=True
)

使用sampler

dataloader = DataLoader(dataset=dataset,batch_size=32,sampler=sampler,  # 使用sampler時,shuffle必須為Falsenum_workers=4
)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/80768.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/80768.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/80768.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

HTML、CSS 和 JavaScript 基礎知識點

HTML、CSS 和 JavaScript 基礎知識點 一、HTML 基礎 1. HTML 文檔結構 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.…

亞遠景-對ASPICE評估體系的深入研究與分析

一、ASPICE評估體系的定義與背景 ASPICE&#xff08;Automotive Software Process Improvement and Capability Determination&#xff09;即汽車軟件過程改進及能力測定模型&#xff0c;是由歐洲20多家主要汽車制造商共同制定的&#xff0c;專門針對汽車行業的軟件開發過程評…

灰度圖像和RGB圖像在數據大小和編碼處理方式差別

技術背景 好多開發者對灰度圖像和RGB圖像有些認知差異&#xff0c;今天我們大概介紹下二者差別。灰度圖像&#xff08;Grayscale Image&#xff09;和RGB圖像在編碼處理時&#xff0c;數據大小和處理方式的差別主要體現在以下幾個方面&#xff1a; 1. 通道數差異 圖像類型通道…

從爬蟲到網絡---<基石9> 在VPS上沒搞好Docker項目,把他卸載干凈

1.停止并刪除所有正在運行的容器 docker ps -a # 查看所有容器 docker stop $(docker ps -aq) # 停止所有容器 docker rm $(docker ps -aq) # 刪除所有容器如果提示沒有找到容器&#xff0c;可以忽略這些提示。 2.刪除所有鏡像 docker images # 查看所有鏡像 dock…

Centos 上安裝Klish(clish)的編譯和測試總結

1&#xff0c;介紹 clish是一個類思科命令行補全與執行程序&#xff0c;它可以幫助程序員在nix操作系統上實現功能導引、命令補全、命令執行的程序。支持&#xff1f;&#xff0c;help, Tab按鍵。本文基于klish-2.2.0介紹編譯和測試。 2&#xff0c;klish的編譯 需要安裝的庫&…

理解計算機系統_并發編程(3)_基于I/O復用的并發(二):基于I/O多路復用的并發事件驅動服務器

前言 以<深入理解計算機系統>(以下稱“本書”)內容為基礎&#xff0c;對程序的整個過程進行梳理。本書內容對整個計算機系統做了系統性導引,每部分內容都是單獨的一門課.學習深度根據自己需要來定 引入 接續上一帖理解計算機系統_并發編程(2)_基于I/O復用的并發…

系統可靠性分析:指標解析與模型應用全覽

以下是關于系統可靠性分析中可靠性指標、串聯系統與并聯系統、混合系統、系統可靠性模型的相關內容&#xff1a; 一、可靠性指標 可靠度&#xff1a;是系統、設備或元件在規定條件和規定時間內完成規定功能的概率。假設一個系統由多個部件組成&#xff0c;每個部件都有其自身…

數字高程模型(DEM)公開數據集介紹與下載指南

數字高程模型&#xff08;DEM&#xff09;公開數據集介紹與下載指南 數字高程模型&#xff08;Digital Elevation Model, DEM&#xff09;廣泛應用于地理信息系統&#xff08;GIS&#xff09;、水文模擬、城市規劃、環境分析、災害評估等領域。本文系統梳理了主流的DEM公開數據…

Python+大模型 day01

Python基礎 計算機系統組成 基礎語法 如:student_num 4.標識符要做到見名知意,增強代碼的可讀性 關鍵字 系統或者Python定義的,有特殊功能的字符組合 在學習過程中,文件名沒有遵循標識符命名規則,是為了按序號編寫文件方便查找復習 但是,在開發中,所有的Python文件名稱必須…

C++引用編程練習

#include <iostream> using namespace std; double vals[] {10.1, 12.6, 33.1, 24.1, 50.0}; double& setValues(int i) { double& ref vals[i]; return ref; // 返回第 i 個元素的引用&#xff0c;ref 是一個引用變量&#xff0c;ref 引用 vals[i] } // 要調用…

機密虛擬機的威脅模型

本文將介紹近年興起的機密虛擬機&#xff08;Confidential Virtual Machine&#xff09;技術所旨在抵御的威脅模型&#xff0c;主要關注內存機密性&#xff08;confidentiality&#xff09;和內存完整性&#xff08;integrity&#xff09;兩個方面。在解釋該威脅可能造成的問題…

【Rust trait特質】如何在Rust中使用trait特質,全面解析與應用實戰

?? 歡迎大家來到景天科技苑?? &#x1f388;&#x1f388; 養成好習慣&#xff0c;先贊后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者簡介&#xff1a;景天科技苑 &#x1f3c6;《頭銜》&#xff1a;大廠架構師&#xff0c;華為云開發者社區專家博主&#xff0c;…

Simulink模型回調

Simulink 模型回調函數是一種特殊的 MATLAB 函數&#xff0c;可在模型生命周期的特定階段自動執行。它們允許用戶自定義模型行為、執行初始化任務、驗證參數或記錄數據。以下是各回調函數的詳細說明&#xff1a; 1. PreLoadFcn 觸發時機&#xff1a;Simulink 模型加載到內存之…

FPGA:Xilinx Kintex 7實現DDR3 SDRAM讀寫

在Xilinx Kintex 7系列FPGA上實現對DDR3 SDRAM的讀寫&#xff0c;主要依賴Xilinx提供的Memory Interface Generator (MIG) IP核&#xff0c;結合Vivado設計流程。以下是詳細步驟和關鍵點&#xff1a; 1. 準備工作 硬件需求&#xff1a; Kintex-7 FPGA&#xff08;如XC7K325T&…

Python爬蟲實戰:研究進制流數據,實現逆向解密

1. 引言 1.1 研究背景與意義 在現代網絡環境中,數據加密已成為保護信息安全的重要手段。許多網站和應用通過二進制流數據傳輸敏感信息,如視頻、金融交易數據等。這些數據通常經過復雜的加密算法處理,直接分析難度較大。逆向工程進制流數據不僅有助于合法的數據獲取與分析,…

Java Spring Boot項目目錄規范示例

以下是一個典型的 Java Spring Boot 項目目錄結構規范示例&#xff0c;結合了分層架構和模塊化設計的最佳實踐&#xff1a; text 復制 下載 src/ ├── main/ │ ├── java/ │ │ └── com/ │ │ └── example/ │ │ └── myapp/ │…

圖像顏色理論與數據挖掘應用的全景解析

文章目錄 一、圖像顏色系統的理論基礎1.1 圖像數字化的本質邏輯1.2 顏色空間的數學框架1.3 量化過程的技術原理 二、主要顏色空間的深度解析2.1 RGB顏色空間的加法原理2.2 HSV顏色空間的感知模型2.3 CMYK顏色空間的減色原理 三、圖像幾何屬性與高級特征3.1 分辨率與像素密度的關…

mysql兩張關聯表批量更新一張表存在數據,而另一張表不存在數據的sql

一、mysql兩張關聯表批量更新一張表存在、另一張表不存在的數據 創建user和user_order表 CREATE TABLE user (id varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL,id_card varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NU…

PNG轉ico圖標(支持圓角矩形/方形+透明背景)Python腳本 - 隨筆

摘要 在網站開發或應用程序設計中&#xff0c;常需將高品質PNG圖像轉換為ICO格式圖標。本文提供一份基于Pillow庫實現的&#xff0c;能夠完美保留透明背景且支持導出圓角矩形/方形圖標的格式轉換腳本。 源碼示例 圓角方形 from PIL import Image, ImageDraw, ImageOpsdef c…

在線SQL轉ER圖工具

在線SQL轉ER圖網站 在數據庫設計、軟件開發或學術研究中&#xff0c;ER圖&#xff08;實體-關系圖&#xff09; 是展示數據庫結構的重要工具。然而&#xff0c;手動繪制ER圖不僅耗時費力&#xff0c;還容易出錯。今天&#xff0c;我將為大家推薦一款非常實用的在線工具——SQL…