👋 你好!這里有實用干貨與深度分享?? 若有幫助,歡迎:?
👍 點贊 | ? 收藏 | 💬 評論 | ? 關注 ,解鎖更多精彩!?
📁 收藏專欄即可第一時間獲取最新推送🔔。?
📖后續我將持續帶來更多優質內容,期待與你一同探索知識,攜手前行,共同進步🚀。?
?
數據集讀取
本文使用PyTorch框架,介紹PyTorch中數據讀取的相關知識。
本文目標:
- 了解PyTorch中數據讀取的基本概念
- 了解PyTorch中集成的開源數據集的讀取方法
- 了解PyTorch中自定義數據集的讀取方法
- 了解PyTorch中數據讀取的流程
一、數據的準備
使用開源數據集或者自己采集數據后進行數據標注。
PyTorch中數據讀取的基本概念
PyTorch中數據讀取的基本概念是Dataset
和DataLoader
。
Dataset
是一個抽象類,用于表示數據集。它包含了數據集的長度、索引、數據獲取等方法。
DataLoader
是一個類,用于將數據集按批次加載到模型中。它包含了數據讀取、數據轉換、數據打亂等方法。
實現數據集讀取的步驟:
- 繼承
Dataset
類,實現__len__
和__getitem__
方法 - 使用
DataLoader
類,將數據集按批次加載到模型中
示例代碼:
import torch
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index], self.labels[index]data = torch.randn(100, 3, 224, 224)
labels = torch.randint(0, 10, (100,))dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)for batch_data, batch_labels in dataloader:print(batch_data.shape, batch_labels.shape)
PyTorch中集成的開源數據集的讀取方法
使用開源數據MNIST作為示范。
數據集鏈接:MNIST數據集
PyTorch中以及集成了很多開源數據集,我們可以直接使用。MNIST也包括在其中。
只需要使用PyTorch中的torchvision.datasets
模塊即可。
示例代碼:
- 引入必要的庫:
import torch
from torchvision import datasets
import matplotlib.pyplot as plt
- 加載數據集:
train_dataset = datasets.MNIST(root='./data', train=True, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True)
參數說明:
root
:數據集保存的路徑train
:是否為訓練集download
:是否下載數據集
- 查看數據集信息:
print(len(train_dataset), len(test_dataset))
print(train_dataset[0][0].size, train_dataset[0][1])
- 可視化數據集:
plt.imshow(train_dataset[0][0], cmap='gray')
plt.show()
- 數據加載:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:print(batch_data.shape, batch_labels.shape)break
參數說明:
batch_size
:批次大小shuffle
:是否打亂數據,訓練集一般需要打亂數據,測試集一般不需要打亂數據
其實,真實的訓練過程只需要步驟1、2、5即可,3、4步驟是為了驗證數據集是否正確。
二、PyTorch中自定義數據集的讀取方法
自定義數據集的讀取方法是指,我們自己定義一個數據集,然后使用PyTorch中的Dataset
和DataLoader
類來讀取數據集。因為不是所有的數據集都在PyTorch中集成了,當我們有擁有(自己標注或下載)一個新的數據集時,就需要自己定義數據集的讀取方法。
這時候需要將數據集以一定的規則保存起來,然后使用PyTorch中的Dataset
和DataLoader
類來讀取數據集。
示例代碼:
- 引入必要的庫:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import matplotlib.pyplot as plt
- 定義數據集類:
class MyDataset(Dataset):def __init__(self, data_dir, transform=None):self.data_dir = data_dirself.transform = transformself.data_list = os.listdir(data_dir)def __len__(self):return len(self.data_list)def __getitem__(self, index):data_path = os.path.join(self.data_dir, self.data_list[index])data = np.load(data_path)label = data['label']if self.transform is not None:data = self.transform(data)return data, label
參數說明:
data_dir
:數據集保存的路徑transform
:數據轉換函數,可選。1. 用于數據增強,一般的數據增強方法有:隨機裁剪、隨機旋轉、隨機翻轉、隨機縮放等。2. 也可以用于數據預處理,如歸一化、標準化等。
- 定義數據轉換函數:
def transform(data):data = data['data']data = data.astype(np.float32)data = data / 255.0data = torch.from_numpy(data)return data
- 加載數據集:
train_dataset = MyDataset(data_dir='./data/train', transform=transform)
test_dataset = MyDataset(data_dir='./data/test', transform=transform)
- 查看數據集信息:
print(len(train_dataset), len(test_dataset))
print(train_dataset[0][0].size, train_dataset[0][1])
- 可視化數據集:
plt.imshow(train_dataset[0][0], cmap='gray')
plt.show()
- 數據加載:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:print(batch_data.shape, batch_labels.shape)break
- 數據增強:
from torchvision import transformstransform = transforms.Compose([transforms.RandomCrop(28), # 隨機裁剪,裁剪大小為28x28transforms.RandomHorizontalFlip(), # 隨機水平翻轉transforms.RandomVerticalFlip(), # 隨機垂直翻轉transforms.RandomRotation(10), # 隨機旋轉transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)), # 隨機仿射變換transforms.ToTensor() # 轉換為張量
])
train_dataset = MyDataset(data_dir='./data/train', transform=transform)
test_dataset = MyDataset(data_dir='./data/test', transform=transform)train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:print(batch_data.shape, batch_labels.shape)break
DataLoader核心參數詳解
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None,num_workers=0, collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None,
)
關鍵參數解析:
num_workers
:數據預加載進程數(建議設為CPU核心數的70-80%)pin_memory
:啟用CUDA鎖頁內存加速GPU傳輸prefetch_factor
:每個worker預加載的batch數(PyTorch 1.7+)
數據加載性能優化公式
理論最大吞吐量:
T h r o u g h p u t = min ? ( B a t c h S i z e × n u m _ w o r k e r s D a t a L o a d T i m e , G P U C o m p u t e T i m e ? 1 ) Throughput = \min\left(\frac{BatchSize \times num\_workers}{DataLoadTime}, GPUComputeTime^{-1}\right) Throughput=min(DataLoadTimeBatchSize×num_workers?,GPUComputeTime?1)
三、拓展:多模態數據加載示例
class MultiModalDataset(Dataset):def __init__(self, img_dir, text_path):self.img_dir = img_dirself.text_data = pd.read_csv(text_path)self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')def __getitem__(self, idx):# 圖像處理img_path = os.path.join(self.img_dir, self.text_data.iloc[idx]['image_id'])image = Image.open(img_path).convert('RGB')image = transforms.ToTensor()(image)# 文本處理text = self.text_data.iloc[idx]['description']inputs = self.tokenizer(text, padding='max_length', truncation=True, max_length=128)return {'image': image,'input_ids': torch.tensor(inputs['input_ids']),'attention_mask': torch.tensor(inputs['attention_mask'])}
四、總結
本文介紹了PyTorch中數據讀取的基本概念、集成的開源數據集的讀取方法、自定義數據集的讀取方法和數據讀取的流程。
數據讀取是深度學習訓練的重要環節,數據讀取的流程是:
- 定義數據集類
- 定義數據轉換函數、數據增強函數
- 加載數據集
?
?
📌 感謝閱讀!若文章對你有用,別吝嗇互動~?
👍 點個贊 | ? 收藏備用 | 💬 留下你的想法 ,關注我,更多干貨持續更新!