[數據處理] 3. 數據集讀取

👋 你好!這里有實用干貨與深度分享?? 若有幫助,歡迎:?
👍 點贊 | ? 收藏 | 💬 評論 | ? 關注 ,解鎖更多精彩!?
📁 收藏專欄即可第一時間獲取最新推送🔔。?
📖后續我將持續帶來更多優質內容,期待與你一同探索知識,攜手前行,共同進步🚀。?



?人工智能

數據集讀取

本文使用PyTorch框架,介紹PyTorch中數據讀取的相關知識。

本文目標:

  1. 了解PyTorch中數據讀取的基本概念
  2. 了解PyTorch中集成的開源數據集的讀取方法
  3. 了解PyTorch中自定義數據集的讀取方法
  4. 了解PyTorch中數據讀取的流程

一、數據的準備

使用開源數據集或者自己采集數據后進行數據標注。

PyTorch中數據讀取的基本概念

PyTorch中數據讀取的基本概念是DatasetDataLoader

Dataset是一個抽象類,用于表示數據集。它包含了數據集的長度、索引、數據獲取等方法。

DataLoader是一個類,用于將數據集按批次加載到模型中。它包含了數據讀取、數據轉換、數據打亂等方法。

實現數據集讀取的步驟:

  1. 繼承Dataset類,實現__len____getitem__方法
  2. 使用DataLoader類,將數據集按批次加載到模型中

示例代碼:

import torch
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index], self.labels[index]data = torch.randn(100, 3, 224, 224)
labels = torch.randint(0, 10, (100,))dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)for batch_data, batch_labels in dataloader:print(batch_data.shape, batch_labels.shape)

PyTorch中集成的開源數據集的讀取方法

使用開源數據MNIST作為示范。

數據集鏈接:MNIST數據集

PyTorch中以及集成了很多開源數據集,我們可以直接使用。MNIST也包括在其中。

只需要使用PyTorch中的torchvision.datasets模塊即可。

示例代碼:

  1. 引入必要的庫:
import torch
from torchvision import datasets
import matplotlib.pyplot as plt
  1. 加載數據集:
train_dataset = datasets.MNIST(root='./data', train=True, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True)

參數說明:

  • root:數據集保存的路徑
  • train:是否為訓練集
  • download:是否下載數據集
  1. 查看數據集信息:
print(len(train_dataset), len(test_dataset))
print(train_dataset[0][0].size, train_dataset[0][1])
  1. 可視化數據集:
plt.imshow(train_dataset[0][0], cmap='gray')
plt.show()
  1. 數據加載:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:print(batch_data.shape, batch_labels.shape)break

參數說明:

  • batch_size:批次大小
  • shuffle:是否打亂數據,訓練集一般需要打亂數據,測試集一般不需要打亂數據

其實,真實的訓練過程只需要步驟1、2、5即可,3、4步驟是為了驗證數據集是否正確。

二、PyTorch中自定義數據集的讀取方法

自定義數據集的讀取方法是指,我們自己定義一個數據集,然后使用PyTorch中的DatasetDataLoader類來讀取數據集。因為不是所有的數據集都在PyTorch中集成了,當我們有擁有(自己標注或下載)一個新的數據集時,就需要自己定義數據集的讀取方法。

這時候需要將數據集以一定的規則保存起來,然后使用PyTorch中的DatasetDataLoader類來讀取數據集。

示例代碼:

  1. 引入必要的庫:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import matplotlib.pyplot as plt
  1. 定義數據集類:
class MyDataset(Dataset):def __init__(self, data_dir, transform=None):self.data_dir = data_dirself.transform = transformself.data_list = os.listdir(data_dir)def __len__(self):return len(self.data_list)def __getitem__(self, index):data_path = os.path.join(self.data_dir, self.data_list[index])data = np.load(data_path)label = data['label']if self.transform is not None:data = self.transform(data)return data, label

參數說明:

  • data_dir:數據集保存的路徑
  • transform:數據轉換函數,可選。1. 用于數據增強,一般的數據增強方法有:隨機裁剪、隨機旋轉、隨機翻轉、隨機縮放等。2. 也可以用于數據預處理,如歸一化、標準化等。
  1. 定義數據轉換函數:
def transform(data):data = data['data']data = data.astype(np.float32)data = data / 255.0data = torch.from_numpy(data)return data
  1. 加載數據集:
train_dataset = MyDataset(data_dir='./data/train', transform=transform)
test_dataset = MyDataset(data_dir='./data/test', transform=transform)
  1. 查看數據集信息:
print(len(train_dataset), len(test_dataset))
print(train_dataset[0][0].size, train_dataset[0][1])
  1. 可視化數據集:
plt.imshow(train_dataset[0][0], cmap='gray')
plt.show()
  1. 數據加載:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:print(batch_data.shape, batch_labels.shape)break
  1. 數據增強:
from torchvision import transformstransform = transforms.Compose([transforms.RandomCrop(28),  # 隨機裁剪,裁剪大小為28x28transforms.RandomHorizontalFlip(),  # 隨機水平翻轉transforms.RandomVerticalFlip(),  # 隨機垂直翻轉transforms.RandomRotation(10),  # 隨機旋轉transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),  # 隨機仿射變換transforms.ToTensor()  # 轉換為張量
])
train_dataset = MyDataset(data_dir='./data/train', transform=transform)
test_dataset = MyDataset(data_dir='./data/test', transform=transform)train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:print(batch_data.shape, batch_labels.shape)break

DataLoader核心參數詳解

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None,num_workers=0, collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None,
)

關鍵參數解析

  • num_workers:數據預加載進程數(建議設為CPU核心數的70-80%)
  • pin_memory:啟用CUDA鎖頁內存加速GPU傳輸
  • prefetch_factor:每個worker預加載的batch數(PyTorch 1.7+)

數據加載性能優化公式

理論最大吞吐量
T h r o u g h p u t = min ? ( B a t c h S i z e × n u m _ w o r k e r s D a t a L o a d T i m e , G P U C o m p u t e T i m e ? 1 ) Throughput = \min\left(\frac{BatchSize \times num\_workers}{DataLoadTime}, GPUComputeTime^{-1}\right) Throughput=min(DataLoadTimeBatchSize×num_workers?,GPUComputeTime?1)

三、拓展:多模態數據加載示例

class MultiModalDataset(Dataset):def __init__(self, img_dir, text_path):self.img_dir = img_dirself.text_data = pd.read_csv(text_path)self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')def __getitem__(self, idx):# 圖像處理img_path = os.path.join(self.img_dir, self.text_data.iloc[idx]['image_id'])image = Image.open(img_path).convert('RGB')image = transforms.ToTensor()(image)# 文本處理text = self.text_data.iloc[idx]['description']inputs = self.tokenizer(text, padding='max_length', truncation=True, max_length=128)return {'image': image,'input_ids': torch.tensor(inputs['input_ids']),'attention_mask': torch.tensor(inputs['attention_mask'])}

四、總結

本文介紹了PyTorch中數據讀取的基本概念、集成的開源數據集的讀取方法、自定義數據集的讀取方法和數據讀取的流程。

數據讀取是深度學習訓練的重要環節,數據讀取的流程是:

  1. 定義數據集類
  2. 定義數據轉換函數、數據增強函數
  3. 加載數據集
    ?
    ?


📌 感謝閱讀!若文章對你有用,別吝嗇互動~?
👍 點個贊 | ? 收藏備用 | 💬 留下你的想法 ,關注我,更多干貨持續更新!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/904529.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/904529.shtml
英文地址,請注明出處:http://en.pswp.cn/news/904529.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

IIS配置SSL

打開iis 如果搜不到iis,要先開 再搜就打得開了 cmd中找到本機ip 用http訪問本機ip 把原本的http綁定刪了 再用http訪問本機ip就不行了 只能用https訪問了

RabbitMQ的交換機

一、三種交換機模式 核心區別對比?? ??特性????廣播模式(Fanout)????路由模式(Direct)????主題模式(Topic)????路由規則??無條件復制到所有綁定隊列精確匹配 Routing Key通配符匹配…

(2025,AR,NAR,GAN,Diffusion,模型對比,數據集,評估指標,性能對比)文本到圖像的生成和編輯:綜述

【本文為我在去年完成的綜述,因某些原因未能及時投稿,但本文仍能為想要全面了解文本到圖像的生成和編輯的學習者提供可靠的參考。目前本文已投稿 ACM Computing Surveys。 完整內容可在如下鏈接獲取,或在 Q 群群文件獲取。 中文版為論文初稿&…

MCU怎么運行深度學習模型

Gitee倉庫 git clone https://gitee.com/banana-peel-x/freedom-learn.git項目場景: 解決面試時遺留的問題,面試官提了兩個問題:1.單片機能跑深度學習的模型嗎? 2.為什么FreeRTOS要采用SVC去觸發第一個任務,只用Pend…

多模態學習(一)——從 Image-Text Pair 到 Instruction-Following 格式

前言 在多模態任務中(例如圖像問答、圖像描述等),為了使用指令微調(Instruction Tuning)提升多模態大模型的能力,我們需要構建成千上萬條**指令跟隨(instruction-following)**格式的…

MySQL基礎關鍵_011_視圖

目 錄 一、說明 二、操作 1.創建視圖 2.創建可替換視圖 3.修改視圖 4.刪除視圖 5.對視圖內容的增、刪、改 (1)增 (2)改 (3)刪 一、說明 只能將 DQL 語句創建為視圖;作用: …

『深夜_MySQL』數據庫操作 字符集與檢驗規則

2.庫的操作 2.1 創建數據庫 語法: CREATE DATABASE [IF NOT EXISTS] db_name [create_specification [,create_specification]….]create_spcification:[DEFAULT] CHARACTER SET charset_nam[DEFAULT] COLLATE collation_name說明: 大寫的表示關鍵字 …

Spark jdbc寫入崖山等國產數據庫失敗問題

隨著互聯網、信息產業的大發展、以及地緣政治的變化,網絡安全風險日益增長,網絡安全關乎國家安全。因此很多的企業,開始了國產替代的腳步,從服務器芯片,操作系統,到數據庫,中間件,逐步實現信息技術自主可控,規避外部技術制裁和風險。 就數據庫而言,目前很多的國產數據…

數字化轉型-4A架構之應用架構

系列文章 數字化轉型-4A架構(業務架構、應用架構、數據架構、技術架構)數字化轉型-4A架構之業務架構 前言 應用架構AA(Application Architecture)是規劃支撐業務的核心系統與功能模塊,實現端到端協同。 一、什么是應…

格雷狼優化算法`GWO 通過模擬和優化一個信號處理問題來最大化特定頻率下的功率

這段代碼是一個Python程序,它使用了多個科學計算庫,包括`random`、`numpy`、`matplotlib.pyplot`、`scipy.signal`和`scipy.signal.windows`。程序的主要目的是通過模擬和優化一個信號處理問題來最大化特定頻率下的功率。 4. **定義類`class_model`**: - 這個類包含了信號…

中級網絡工程師知識點1

1.1000BASE-CX:銅纜,最大傳輸距離為25米 1000BASE-LX:傳輸距離可達3000米 1000BASE-ZX:超過10km 2.RSA加密算法的安全性依賴于大整數分解問題的困難性 3.網絡信息系統的可靠性測度包括有效性,康毀性,生存性 4.VLAN技術所依據的協議是IEEE802.1q IEEE802.15標準是針…

2025年五一數學建模A題【支路車流量推測】原創論文講解

大家好呀,從發布賽題一直到現在,總算完成了2025年五一數學建模A題【支路車流量推測】完整的成品論文。 給大家看一下目錄吧: 摘 要: 一、問題重述 二.問題分析 2.1問題一 2.2問題二 2.3問題三 2.4問題四 2.5 …

性能優化實踐:渲染性能優化

性能優化實踐:渲染性能優化 在Flutter應用開發中,渲染性能直接影響用戶體驗。本文將從渲染流程分析入手,深入探討Flutter渲染性能優化的關鍵技術和最佳實踐。 一、Flutter渲染流程解析 1.1 渲染流水線 Flutter的渲染流水線主要包含以下幾…

linux基礎學習--linux磁盤與文件管理系統

linux磁盤與文件管理系統 1.認識linux系統 1.1 磁盤組成與分區的復習 首先了解磁盤的物理組成,主要有: 圓形的碟片(主要記錄數據的部分)。機械手臂,與在機械手臂上的磁頭(可擦寫碟片上的內容)。主軸馬達,可以轉動碟片,讓機械手臂的磁頭在碟片上讀寫數據。 數據存儲…

DIFY教程第五彈:科研論文翻譯與SEO翻譯應用

科研論文翻譯 我可以在工作流案例中結合聊天大模型來實現翻譯工具的功能,具體的設計如下 在開始節點中接收一個輸入信息 content 然后在 LLM 模型中我們需要配置一個 CHAT 模型,這里選擇了 DeepSeek-R1 64K 的聊天模型,注意需要在這里設置下…

【Redis】哨兵機制和集群

🔥個人主頁: 中草藥 🔥專欄:【中間件】企業級中間件剖析 一、哨兵機制 Redis的主從復制模式下,一旦主節點由于故障不能提供服務,需要人工的進行主從切換,同時需要大量的客戶端需要被通知切換到…

注意力機制(Attention)

1. 注意力認知和應用 AM: Attention Mechanism,注意力機制。 根據眼球注視的方向,采集顯著特征部位數據: 注意力示意圖: 注意力機制是一種讓模型根據任務需求動態地關注輸入數據中重要部分的機制。通過注意力機制&…

解鎖 AI 生產力:Google 四大免費工具全面解析20250507

🚀 解鎖 AI 生產力:Google 四大免費工具全面解析 在人工智能迅猛發展的今天,Google 推出的多款免費工具正在悄然改變我們的學習、工作和創作方式。本文將深入解析四款代表性產品:NotebookLM、Google AI Studio、Google Colab 和 …

知識圖譜:AI大腦中的“超級地圖”如何煉成?

人類看到“蘋果”一詞,會瞬間聯想到“iPhone”“喬布斯”“牛頓”,甚至“維生素C”——這種思維跳躍的背后,是大腦將概念連結成網的能力。而AI要模仿這種能力,需要一張動態的“數字地圖”來存儲和鏈接知識,這就是?知識…

Win11 24H2首個熱補丁下周推送!更新無需重啟

快科技5月7 日消息,微軟宣布,Windows 11 24H2的首個熱補丁更新將于下周通過Patch Tuesday發布,將為管理員帶來更高效的安全更新部署方式,同時減少設備停機時間。 為幫助IT管理員順利過渡到熱補丁模式,微軟還提供了豐富…