pytorch中dataloader自定義數據集

前言

在深度學習中我們需要使用自己的數據集做訓練,因此需要將自定義的數據和標簽加載到pytorch里面的dataloader里,也就是自實現一個dataloader。

數據集處理

以花卉識別項目為例,我們分別做出圖片的訓練集和測試集,訓練集的標簽和測試集的標簽

flower_data/
├── train_filelist/
│   ├── image_0001.jpg
│   └── ...
├── val_filelist/
│   ├── image_1001.jpg
│   └── ...
├── train.txt  # 格式:文件名 標簽
└── val.txt

?數據目錄的組織方式如上所示。

首先看圖片的處理。圖片只要做好編號放在同一個文件夾里就好了。

再看標簽的處理。標簽處理我們自己規定了一種形式,就是圖像文件的名稱+空格+分類標簽。

可以看到前面第一列數據是圖像名稱,第二列數據是圖像的分組,同樣的數字為一組。比如分組為0的圖像就是同一種花朵。

自定義dataset

源碼

import os.path
import numpy as np
import torch
from PIL import Image  # 從PIL庫導入Image類
from torch.utils.data import Datasetclass FlowerDataSet(Dataset):"""花朵分類任務數據集類,繼承自torch的Dataset類"""def __init__(self, root_dir, ann_file, transform=None):"""初始化數據集實例Args:root_dir (str): 數據集根目錄路徑ann_file (str): 標注文件路徑transform (callable, optional): 數據預處理變換函數"""self.ann_file = ann_fileself.root_dir = root_dir# 加載圖片路徑與標簽的映射字典 {文件名: 標簽}self.image_label = self.load_annotations()# 構建完整圖片路徑列表 [root_dir/文件名1, ...]self.image = [os.path.join(self.root_dir, img) for img in list(self.image_label.keys())]# 構建標簽列表 [標簽1, 標簽2, ...]self.label = [lbl for lbl in list(self.image_label.values())]  # 重命名為lbl避免與導入的label沖突self.transform = transformdef __len__(self):"""返回數據集樣本數量"""return len(self.image)def __getitem__(self, index):"""獲取單個樣本數據Args:index (int): 樣本索引Returns:tuple: (預處理后的圖像數據, 對應的標簽)"""# 打開圖片文件image = Image.open(self.image[index])# 獲取對應標簽label = self.label[index]# 應用數據預處理if self.transform:image = self.transform(image)# 將標簽轉換為torch張量label = torch.from_numpy(np.array(label))return image, labeldef load_annotations(self):"""加載標注文件,解析圖片文件名和標簽的映射關系Returns:dict: {圖片文件名: 對應標簽} 的字典"""data_infos = {}with open(self.ann_file) as f:# 讀取所有行并分割,每行格式應為 "文件名 標簽"samples = [x.strip().split(' ') for x in f.readlines()]for filename, label in samples:# 將標簽轉換為int64類型的numpy數組data_infos[filename] = np.array(label, dtype=np.int64)return data_infos

解析

1、將標簽數據進行讀取,組成一個哈希表,哈希表的鍵是圖像的文件名稱,哈希表的值是分組標簽。

    def load_annotations(self):"""加載標注文件,解析圖片文件名和標簽的映射關系Returns:dict: {圖片文件名: 對應標簽} 的字典"""data_infos = {}with open(self.ann_file) as f:# 讀取所有行并分割,每行格式應為 "文件名 標簽"samples = [x.strip().split(' ') for x in f.readlines()]for filename, label in samples:# 將標簽轉換為int64類型的numpy數組data_infos[filename] = np.array(label, dtype=np.int64)return data_infos

上面的代碼里,在錄入標簽的時候使用數組進行記錄,這是為了兼容多標簽的場景。如果不考慮兼容問題,僅考慮在單標簽場景下的簡單實現,可以用下面的代碼:

def load_annotations(self):data_infos = {}with open(self.ann_file) as f:for line in f:filename, label = line.strip().split()  # 直接解包data_infos[filename] = int(label)        # 存為 Python 整數return data_infos# 在 __getitem__ 中直接轉為張量
label = torch.tensor(self.labels[index], dtype=torch.long)

2、遍歷哈希表,將文件名和標簽分別存在兩個數組里。這里注意,為了方便后面dataloader按照batch去讀取圖片,這里要將圖片的全路徑加到文件名里。

        # 構建完整圖片路徑列表 [root_dir/文件名1, ...]self.image = [os.path.join(self.root_dir, img) for img in list(self.image_label.keys())]# 構建標簽列表 [標簽1, 標簽2, ...]self.label = [lbl for lbl in list(self.image_label.values())]  # 重命名為lbl避免與導入的label沖突

3、在dataloader向顯卡/cpu加載數據的時候會調用getitem方法。比如一個batch里有64個數據,dataloader就會調用64次該方法,將64組圖片和標簽全部獲取后交給運算單元去處理。

    def __getitem__(self, index):"""獲取單個樣本數據Args:index (int): 樣本索引Returns:tuple: (預處理后的圖像數據, 對應的標簽)"""# 打開圖片文件image = Image.open(self.image[index])# 獲取對應標簽label = self.label[index]# 應用數據預處理if self.transform:image = self.transform(image)# 將標簽轉換為torch張量label = torch.from_numpy(np.array(label))return image, label

測試dataloader

import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from dataloader import FlowerDataSet  # 假設你的數據集類在dataloader.py中def denormalize(image_tensor):"""將歸一化的圖像張量轉換為可顯示的格式"""mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image = image_tensor.numpy().transpose((1, 2, 0))  # 轉換維度順序image = std * image + mean  # 反歸一化image = np.clip(image, 0, 1)  # 限制像素值范圍return imagedef test_dataloader():# 定義數據預處理data_transforms = {'train': transforms.Compose([transforms.Resize(64),transforms.RandomRotation(45),transforms.CenterCrop(64),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 檢查文件路徑是否存在print("[1/5] 檢查文件路徑...")required_files = {'train_txt': './flower_data/train.txt','val_txt': './flower_data/val.txt','train_dir': './flower_data/train_filelist','val_dir': './flower_data/val_filelist'}for name, path in required_files.items():if not os.path.exists(path):print(f"? 文件/目錄不存在: {path}")returnprint(f"? {name}: {path} 存在")# 初始化數據集print("\n[2/5] 加載數據集...")try:train_dataset = FlowerDataSet(root_dir=required_files['train_dir'],ann_file=required_files['train_txt'],transform=data_transforms['train'])val_dataset = FlowerDataSet(root_dir=required_files['val_dir'],ann_file=required_files['val_txt'],transform=data_transforms['valid'])print("? 數據集加載成功")except Exception as e:print(f"? 數據集加載失敗: {str(e)}")return# 打印數據集信息print("\n[3/5] 數據集統計:")print(f"訓練集樣本數: {len(train_dataset)}")print(f"驗證集樣本數: {len(val_dataset)}")# 檢查單個樣本print("\n[4/5] 檢查單個樣本:")sample_idx = 0try:img, label = train_dataset[sample_idx]print(f"圖像張量形狀: {img.shape} (應接近 torch.Size([3, 64, 64]))")print(f"標簽類型: {type(label)} (應為 torch.Tensor)")print(f"標簽值: {label.item()} (應為整數)")except Exception as e:print(f"? 樣本檢查失敗: {str(e)}")# 可視化樣本print("\n[5/5] 可視化訓練集樣本...")try:plt.figure(figsize=(8, 8))img_show = denormalize(img)plt.imshow(img_show)plt.title(f"Label: {label.item()}")plt.axis('off')plt.show()except Exception as e:print(f"? 可視化失敗: {str(e)}")# 檢查DataLoaderprint("\n[附加] 檢查DataLoader:")train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)for loader, name in [(train_loader, '訓練集'), (val_loader, '驗證集')]:print(f"\n{name} DataLoader測試:")try:batch = next(iter(loader))images, labels = batchprint(f"批次圖像形狀: {images.shape} (應接近 [batch, 3, 64, 64])")print(f"批次標簽示例: {labels[:5].numpy()}")print(f"像素值范圍: [{images.min():.3f}, {images.max():.3f}]")except Exception as e:print(f"? {name} DataLoader錯誤: {str(e)}")if __name__ == '__main__':test_dataloader()

在測試代碼中,分別測試了文件路徑,dataset是否正常創建,dataset樣本數量,dataset樣本格式,dataset數據可視化,dataloader數據樣式。

在打印日志的時候需要注意,dataset和dataloader里面的變量都是張量形式的,所以需要轉換成python標量再打印。比如從dataset里取出的標簽label是一個一維張量,需要通過label.item()進行轉換。

?在遍歷的時候為了簡化代碼,將兩個dataloader放在同一個循環語句中處理,并且通過增加name變量來區分兩個dataloader。

for loader, name in [(train_loader, '訓練集'), (val_loader, '驗證集')]:

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/899825.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/899825.shtml
英文地址,請注明出處:http://en.pswp.cn/news/899825.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Blender模型導入虛幻引擎設置

單位系統不一致 Blender默認單位是米(Meters),而虛幻引擎默認使用**厘米(Centimeters)**作為單位。 當模型從Blender導出為FBX或其他格式時,如果沒有調整單位,虛幻引擎會將1米(Blen…

Docker基礎詳解

Docker 技術詳解 一、概述 Docker官網:https://docs.docker.com/ 菜鳥教程:https://www.runoob.com/docker/docker-tutorial.html 1.1 什么是Docker? Docker 是一個開源的容器化平臺,它允許開發者將應用程序和其依賴項打包到…

FastPillars:一種易于部署的基于支柱的 3D 探測器

FastPillars:一種易于部署的基于支柱的 3D 探測器Report issue for preceding element Sifan Zhou 1 , Zhi Tian 2 , Xiangxiang Chu 2 , Xinyu Zhang 2 , Bo Zhang 2 , Xiaobo Lu11{}^{1}start_FLOATSUPERSCRIPT 1 end_FLOATSUPERSCRIPT11footnotemark: 1 Chengji…

NLP語言模型訓練里的特殊向量

1. CLS 向量和 DEC 向量的區別及訓練方式 (1) CLS 向量與 DEC 向量是否都是特殊 token? CLS 向量([CLS] token)和 DEC 向量(Decoder Input token)都是特殊的 token,但它們出現在不同類型的 NLP 模型中&am…

字節跳動 UI-TARS 匯總整理報告

1. 摘要 UI-TARS 是字節跳動開發的一種原生圖形用戶界面(GUI)代理模型 。它將感知、行動、推理和記憶整合到一個統一的視覺語言模型(VLM)中 。UI-TARS 旨在跨桌面、移動和 Web 平臺實現與 GUI 的無縫交互 。實驗結果表明&#xf…

基于Python深度學習的鯊魚識別分類系統

摘要:鯊魚是海洋環境健康的指標,但受到過度捕撈和數據缺乏的挑戰。傳統的觀察方法成本高昂且難以收集數據,特別是對于具有較大活動范圍的物種。論文討論了如何利用基于媒體的遠程監測方法,結合機器學習和自動化技術,來…

【漫話機器學習系列】168.最大最小值縮放(Min-Max Scaling)

在機器學習和數據預處理中,特征縮放(Feature Scaling) 是一個至關重要的步驟,它可以使模型更穩定,提高訓練速度,并優化收斂效果。最大最小值縮放(Min-Max Scaling) 是其中最常見的方…

開源測試用例管理平臺

不可錯過的10個開源測試用例管理平臺: PingCode、TestLink、Kiwi TCMS、Squash TM、FitNesse、Tuleap、Robot Framework、SpecFlow、TestMaster、Nitrate。 開源測試用例管理工具提供了一種透明、靈活的解決方案,使團隊能夠在不受限的情況下適應具體的測…

鴻蒙闊折疊Pura X外屏開發適配

首先看下鴻蒙中斷點分類 內外屏開合規則 Pura X開合連續規則: 外屏切換到內屏,界面可以直接接續。內屏(鎖屏或非鎖屏狀態)切換到外屏,默認都顯示為鎖屏的亮屏狀態。用戶解鎖后:對于應用已適配外屏的情況下,應用界面可以接續到外屏。折疊外屏顯示展開內屏顯示折疊狀態…

DRM_CLIENT_CAP_UNIVERSAL_PLANES和DRM_CLIENT_CAP_ATOMIC

drmSetClientCap(fd, DRM_CLIENT_CAP_UNIVERSAL_PLANES, 1); drmSetClientCap(fd, DRM_CLIENT_CAP_ATOMIC, 1); 這兩行代碼用于啟用 Linux DRM(Direct Rendering Manager)客戶端的兩個關鍵特性,具體作用如下: 1. drmSetClientCap…

敏捷開發10:精益軟件開發和看板kanban開發方法的區別是什么

簡介 精益生產起源于豐田生產系統,核心是消除浪費,而看板最初是由豐田用于物料管理的信號卡片,后來被引入軟件開發。 Kanban 后來引入到敏捷開發中,強調持續交付和流程可視化。 精益軟件開發原則是基于精益生產的原則&#xff0…

用matlab探索卷積神經網絡(Convolutional Neural Networks)-3

5.GoogLeNet中的Filters 這里我們探索GoogLeNet中的Filters,首先你需要安裝GoogLeNet.在Matlab的APPS里找到Deep Network Designer,然后找到GoogLeNet,安裝后的網絡是沒有右下角的黃色感嘆號的,沒有安裝的神經網絡都有黃色感嘆號。 一個層&a…

Verilog中X態的危險:仿真漏掉的bug

由于Verilog中X態的微妙語義,RTL仿真可能PASS,而網表仿真卻會fail。 目前進行的網表仿真越來越少,這個問題尤其嚴重,主要是網表仿真比RTL仿真慢得多,因此對整個回歸測試而言成本效益不高。 上面的例子中,用Verilog RTL中的case語句描述了一個簡單的AND函數,它被綜合成AN…

PyTorch中知識蒸餾淺講

知識蒸餾 在 PyTorch 中,使用 teacher_model.eval() 和凍結教師模型參數是知識蒸餾(Knowledge Distillation)中的關鍵步驟。 ?1. teacher_model.eval() 的作用 目的: 將教師模型切換到評估模式,影響某些特定層(如 Dropout、BatchNorm)的行為。 ?具體影響: ?Dropo…

Odoo/OpenERP 和 psql 命令行的快速參考總結

Odoo/OpenERP 和 psql 命令行的快速參考總結 psql 命令行選項 選項意義-a從腳本中響應所有輸入-A取消表數據輸出的對齊模式-c <查詢>僅運行一個簡單的查詢&#xff0c;然后退出-d <數據庫名>指定連接的數據庫名&#xff08;默認為當前登錄用戶名&#xff09;-e回顯…

ChatGPT 迎來 4o模型:更強大的圖像生成能力與潛在風險

OpenAI 對 ChatGPT 進行重大升級&#xff0c;圖像生成功能即將迎來新的 4o 模型&#xff0c;并取代原本的 DALLE。此次更新不僅提升了圖像生成質量&#xff0c;還增強了對話內容和上傳文件的融合能力&#xff0c;使 AI 生成的圖像更加智能化和精準化。 4o 模型帶來的革新 Ope…

Python 實現的運籌優化系統代碼詳解(整數規劃問題)

一、引言 在數學建模的廣袤領域里&#xff0c;整數規劃問題占據著極為重要的地位。它廣泛應用于工業生產、資源分配、項目管理等諸多實際場景&#xff0c;旨在尋求在一系列約束條件下&#xff0c;使目標函數達到最優&#xff08;最大或最小&#xff09;且決策變量取整數值的解決…

Visual Studio Code配置自動規范代碼格式

目錄 前言1. 插件安裝2. 配置個性化設置2.1 在左下角點擊設置按鈕 &#xff0c;點擊命令面板&#xff08;或者也可以之間按快捷鍵CtrlShiftP&#xff09;2.2 在彈出的搜索框輸入 settings.json&#xff0c;打開首選項&#xff1a;打開工作區設置&#xff1b;2.3 在settings.jso…

【分布式】Hystrix 的核心概念與工作原理?

熔斷機制? Hystrix 的熔斷機制就像是電路中的保險絲。當某個服務的失敗請求達到一定比例&#xff08;例如 50%&#xff09;或者在一定時間內&#xff08;如 20 秒&#xff09;失敗請求數量超過一定閾值&#xff08;如 20 個&#xff09;時&#xff0c;熔斷開關就會打開。此時…

TypeScript 中 await 的詳解

TypeScript 中 await 的詳解 1. 基本概念2. 語法要求3. 工作原理4. 與 Promise 的比較5. 實踐中的注意事項總結 本文詳細介紹了 TypeScript 中 await 的工作原理、語法要求、與 Promise 的關系以及實踐中需要注意的問題&#xff0c;同時針對代碼示例進行了優化和補充說明。 1.…