文章目錄
- 一、Dataset 與 DataLoader 功能介紹
- 抽象類Dataset的作用
- DataLoader 作用
- 兩者關系
- 二、自定義Dataset類
- Dataset的三個重要方法
- `__len__()`方法
- `_getitem__()`方法
- `__init__` 方法
- 三、現成的torchvision.datasets模塊
- MNIST舉例
- COCODetection舉例
- `torchvision.datasets.MNIST`使用舉例
- `torchvision.datasets.CocoDetection`舉例
一、Dataset 與 DataLoader 功能介紹
抽象類Dataset的作用
簡單來說,就是將原始數據(可能是圖片、文本、音頻等各種格式)整理成模型可以處理的格式,為后續的數據加載和處理做準備。功能是定義數據集的基本屬性和數據獲取方式。
- 初始化數據路徑:在
Dataset
類的__init__
方法中,通常會初始化數據存放的路徑,以及一些數據預處理的操作,比如指定圖片數據集圖片所在文件夾路徑,文本數據集文本文件路徑等 。包含 加載數據/讀取數據、預處理數據、圖像增強 等一系列操作 - 獲取單個樣本及其標簽:通過實現
__getitem__
方法,根據給定的索引(dataloader返回的),返回相應的數據樣本和對應的標簽。例如在圖片分類任務中,給定索引后,返回該索引對應的圖片數據(經過預處理,如調整尺寸、歸一化等)以及圖片的類別標簽。 - 統計樣本數量:通過實現
__len__
方法,返回數據集中樣本的總數,方便在訓練和評估過程中知道數據規模 。
DataLoader 作用
DataLoader
是在Dataset
的基礎上,提供了一種更加高效、便捷地加載數據的方式,它可以將Dataset
返回的單個樣本,按照指定的方式進行打包(如組成batch)、打亂順序等操作,從而滿足模型訓練和評估的需求。
-
創建數據批次,指定數據打包輸出規則:通過
batch_size
參數,將Dataset
中的單個樣本打包成一個個批次(batch)的數據。collate_fn
指定如何從NNN張訓練集選出一個batch的Nbatch_size\frac{N}{batch\_size}batch_sizeN?張圖片。- 例如
batch_size=32
,那么DataLoader
每次會從Dataset
中取出32個樣本組成一個batch。每次迭代,返回的是 一個batch 的數據
-
自定義數據采樣,指定數據迭代讀取規則:
- 一般使用自定義的采樣器(
Sampler
),實現對數據的特殊采樣方式,比如分層采樣(在類別不均衡的數據集中,保證每個batch中各類別的樣本比例與原始數據集相似)等。 - dataset對象是dataloader的一個參數,通過dataset讓dataloader知道訓練集一共多少圖片,從而知道共跌代多少次。
- 一般使用自定義的采樣器(
-
數據打亂:通過
shuffle
參數設置是否在每個epoch開始時打亂數據順序,這樣可以避免模型在訓練時對數據產生特定的依賴,有助于模型學習到更通用的特征,提高模型的泛化能力 。 -
多進程加載:通過
num_workers
參數設置多進程加載數據,從而加快數據加載速度,尤其是在數據量較大、數據預處理較為復雜的情況下,多進程可以充分利用CPU資源,減少數據加載時間,避免數據加載成為訓練過程中的瓶頸 。
兩者關系
-
Dataset
是數據的基礎容器,定義了如何獲取數據集中的單個樣本; -
而
DataLoader
則是Dataset
的上層應用,負責按照特定規則(如批量處理、打亂順序等)從Dataset
中高效地加載數據,供模型進行訓練、驗證和測試等操作。 -
可以說,
Dataset
是數據的來源和基本操作接口,DataLoader
則是為了更好地適配模型訓練需求,對Dataset
的數據進行進一步處理和組織的工具。
二、自定義Dataset類
所謂的 自定義 dataset ,即自己去寫一個 Dataset 類,要滿足兩個要求:
- 一般需要繼承自
torch.utils.data.Dataset
類- 繼承
torch.utils.data.Dataset
主要目的是為了與DataLoader
保持兼容,確保數據集遵循DataLoader
的接口標準,方便后續使用 PyTorch 提供的工具,比如 :批量加載、打亂數據、并行處理等功能
- 繼承
- 并且滿足和
DataLoader
進行交互的規范 :- 因為
DataLoader
會調用Dataset
的len()
和getitem()
方法,所以自定義Dataset
類必須實現這兩個方法,如此才能保證DataLoader
可以正確地加載和操作你的數據集
- 因為
- 兼容訓練和推理階段
Dataset的三個重要方法
創建自定義 Dataset
類時,必須實現的3個方法 :__init__()
、__len__()
、 __getitem__()
。
這些方法定義了數據集的基本結構和行為,也是 DataLoader
可以正確的從 Dataset
中讀取數據的基礎。
__len__()
方法
DataLoader是通過Dataset的 __len__()
,得知訓練集一共多少數據樣本的。
def __len__(self):return len(self.file_list)
- 返回值:數據集中的樣本的總數。
- 作用:
- 方便通過調用
len(dataset)
來獲取數據量,其中 dataset 為 Dataset 對象 - Dataloader 會用它和 batch_size 一起來計算一個 epoch 要迭代多少個 steps:
steps=len(dataset)batch_sizesteps = \frac{len(dataset)}{batch\_size}steps=batch_sizelen(dataset)? - DataLoader調用len方法的代碼封裝在源碼了,所以看不到顯式調用。DataLoader得到一共NNN個數據樣本后,生成000 ~ N?1N-1N?1的索引。再根據batch_size和是否打亂,生成一個batch的索引列表,再將每個索引
idx
傳入到Dataset的_getitem__()
方法中返回得到圖片和索引return image, label
- 方便通過調用
_getitem__()
方法
作用: 根據給定的索引返回數據集中的一個樣本。這是用于獲取數據集中單個樣本的方法。
def __getitem__(self, idx):# 通過索引idx,獲取圖片地址img_nameimg_name = os.path.join(self.data_folder, self.file_list[idx])# 根據圖片地址img_name讀取對應圖像original_imageoriginal_image = Image.open(img_name)# 通過索引idx獲取圖片對應的標簽(這里舉的例子的標簽含在圖片名中)label = img_name.split('_')[-1].split('.')[0]# 圖像預處理和數據增強(僅訓練階段)if self.train:image = self.transform(original_image)else:image = self.transform(original_image)# 返回處理好的一張圖像和標簽return image, label
- 接收參數: index(idx)是單個數據樣本的索引,由DataLoader傳來的
- 返回值: 返回數據集中索引指定的樣本。通常是一個包含輸入數據和對應標簽的元組。這里可以根據自己的需求,進行自定義。
DataLoader返回的是一個batch的數據,具體是:
- DataLoader的采樣器
sampler
根據數據總量和batch_size=2
,和采樣方法(舉例為順序采樣)得到第一次迭代結果為索引列表[0, 1]
- DataLoader分別把索引0和1給Dataset,
__getitem__()
方法返回出對應單個索引的圖片和標簽。 - 把得到的一個batch的兩組圖片和標簽給
collate_fn
函數進行打包并以一種數據結構儲存,由DataLoader返回
__init__
方法
- 參數: 根據需要傳遞一些參數,例如文件路徑、數據轉換等。
- 作用: 構造方法,配好len和getitem方法做一些初始化工作,需要什么數據,就傳入進來賦值到成員屬性。
def __init__(self, data_folder, train, transform=None):self.data_folder = data_folderself.transform = transformself.file_list = os.listdir(data_folder)# 把文件名讀取出來,存入到file_list,方便len方法獲取數據量self.train = train
例如:設置文件路徑selfl.data_folder
、定義數據轉換的transforms
、當前是訓練階段還是驗證階段的布爾值train
等。
三、現成的torchvision.datasets模塊
對于一些公開的數據集,可以直接用torchvision.datasets模塊的現成的Dataset類。
Pytorch官方文檔的torchvision的Dataset列出了可使用的數據集的Dataset,實現了getitem和len方法
MNIST舉例
這里以Image classification任務的MNIST(mixed national institute of standards and technology database)數據集舉例,點入詳情頁課查看:
train_dataset = torchvision.datasets.MNIST(root, train=True, transform=None, target_transform= None download=True)
參數:
root
:數據集存放的路徑download
:是否下載數據集,默認為False
。配合root
參數:- 若設置
download=True
root
目錄下沒有該數據集,數據集將會被下載到root
指定的位置。root
目錄下已經存在該數據集,則不會重新下載,而是會直接使用已存在的數據,以節省時間
- 若設置
download=False
,程序將會在root
指定的位置查找數據集,如果數據集不存在,則會拋出錯誤。
- 若設置
train
:- 如果是
True
,下載訓練集trainin.pt
; - 如果是
False
,下載測試集test.pt
。默認是True
- 如果是
transform
:接收torchvision.transforms
的對象,一系列作用在PIL
圖片上的轉換操作,用于對數據集的圖像預處理和數據增強。target_transform
:對target處理,一般不用。因為出來target出來一般用自定義的Dataset,因為圖像處理和target處理要放一個transform里寫
COCODetection舉例
Image detection任務的COCO數據集
注意:對于一部分數據集比如torchvision.datasets.CocoDetection
,Pytorch不提供下載功能 (具體情況取決于數據集的來源和許可協議),就沒有download
參數。
所以在使用 torchvision.datasets.CocoDetection
這個現成的Dataset
類之前,需要確保已經下載并淮備好COCO數
據集的圖像和標注文件。然后使用torchvision.datasets.CocoDetection
類來加載 COCO數據集。
torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None, transforms=None)
root
:指定圖片地址(本地已經下載下來的圖像地址)annFile
:指定標注文件地址(本地已經下載下來的標注文件地址)transform
:圖像處理 (用于PIL
)target_transform
:標注處理transforms
:圖像和標注的處理
torchvision.datasets.MNIST
使用舉例
訓練集和驗證集分別實例化一個Dataset類(torchvision.datasets.MNIST
)的對象,傳入的transforms參數都為實例化的transforms.Compose
對象my_transform
。數據集下載到當下文件所在目錄下。
import torchvision
from torchvision.transforms import transforms
import torch.utils.data as data
import matplotlib.pyplot as pltbatch_size = 5# transforms.Compose的對象,傳入到transforms參數
my_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], # mean=[0.485, 0.456, 0.406]std=[0.5])]) # std=[0.229, 0.224, 0.225]train_dataset = torchvision.datasets.MNIST(root="./",train=True,transform=my_transform,download=True)val_dataset = torchvision.datasets.MNIST(root="./",train=False,transform=my_transform,download=True)
- 可以看的在當下目錄下出現了一個MNIST文件夾,
.gz
后綴的是下載的壓縮文件,程序自動解壓為同名的二進制文件- Dataset會自動處理好二進制文件,最終從DataLoader跌代出來的是正常的單通道灰度圖。
將定義出的訓練集和驗證集的Dataset對象,分別作為參數傳入到兩個DataLoader,得到兩個DataLoader對象
train_loader = data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)val_loader = data.DataLoader(val_dataset,batch_size=batch_size,shuffle=True)
分別調用量Dataset的len方法,輸出數據量。再將train_loader
轉換為迭代器iter(train_loader)
,通過next
方法得到一個batch的image和label。
打印出一個batch的image的shape。[5, 1, 28, 28]分別指batch_size,圖片通道數,圖像長寬。
打印出標簽label列表。
最后可視化一個batch的圖和標簽。
print(len(train_dataset))
print(len(val_dataset))image, label = next(iter(train_loader))
print(image.shape)
print(label)for i in range(batch_size):plt.subplot(1, batch_size, i + 1)plt.title(label[i].item())plt.axis("off")plt.imshow(image[i].permute(1, 2, 0))plt.show()
torchvision.datasets.CocoDetection
舉例
需要把數據集的下載地址換掉,換成你的 COCO數據集地址
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms import functional as F
import randomdef collate_fn_coco(batch):return tuple(zip(*batch))coco_det = datasets.CocoDetection(root="./COCO2017/train2017",annFile="./COCO2017/annotations/instances_train 2017.json")sampler = torch.utils.data.SequentialSampler(coco_det) # RandomSampler
batch_sampler = torch.utils.data.BatchSampler(sampler, 1, drop_last=True)
data_loader = torch.utils.data.DataLoader(coco_det,batch_sampler=batch_sampler,collate_fn=collate_fn_coco)# 可視化
iterator = iter(data_loader)
imgs, gts = next(iterator)
img, gts_one_img = imgs[0], gts[0]bboxes = []
ids = []
for gt in gts_one_img:bboxes.append([gt['bbox'][0],gt['bbox'][1],gt['bbox'][2],gt['bbox'][3]])ids.append(gt['category_id'])fig, ax = plt.subplots()
for box, id in zip(bboxes, ids):x = int(box[0])y = int(box[1])w = int(box[2])h = int(box[3])rect = plt.Rectangle((x, y), w, h, edgecolor='r', linewidth=2, facecolor='none')ax.add_patch(rect)ax.text(x, y, id, backgroundcolor="r")plt.axis("off")
plt.imshow(img)
plt.show()