目錄
一、PyTorch 數據加載的核心組件
1.1 Dataset 類的核心方法
1.2 DataLoader 的作用
二、加載 CSV 數據實戰
2.1 自定義 CSV 數據集
2.2 使用 TensorDataset 快速加載
三、加載圖像數據實戰
3.1 自定義圖像數據集
3.2 使用 ImageFolder 快速加載
四、加載官方數據集
五、總結
在深度學習項目中,數據加載是模型訓練的第一步,也是至關重要的一步。PyTorch 提供了靈活的數據加載工具,讓我們能夠輕松處理各種類型的數據。本文將結合實際代碼,詳細講解如何使用 PyTorch 加載 CSV 數據和圖像數據,幫助初學者快速掌握數據加載的核心技巧。
一、PyTorch 數據加載的核心組件
PyTorch 的數據加載主要依賴兩個核心類:Dataset
和DataLoader
。
Dataset
:負責數據的讀取和預處理,是所有自定義數據集的基類DataLoader
:負責批量加載數據,支持打亂順序、多線程加載等功能
1.1 Dataset 類的核心方法
自定義數據集需要繼承Dataset
類,并實現以下三個方法:
class CustomDataset(Dataset):def __init__(self, ...): # 初始化數據集,加載文件路徑等passdef __len__(self): # 返回數據集大小return len(self.data)def __getitem__(self, index): # 根據索引返回樣本return sample, label
1.2 DataLoader 的作用
DataLoader
像是一個 "搬運工",將Dataset
中的數據按批次搬運給模型:
dataloader = DataLoader(dataset=dataset, # 要加載的數據集batch_size=32, # 批次大小shuffle=True, # 是否打亂數據num_workers=2 # 多線程加載
)
二、加載 CSV 數據實戰
CSV 文件是存儲表格數據的常用格式,比如學生成績表、特征數據表等。下面我們通過實際代碼講解如何加載 CSV 數據。
2.1 自定義 CSV 數據集
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pdclass CsvDataset(Dataset):def __init__(self, filepath):# 讀取CSV文件df = pd.read_csv(filepath)# 刪除不需要的列(學號、姓名)df.drop(['學號', '姓名'], axis=1, inplace=True)# 提取特征和標簽x = df.iloc[1:, :-1] # 從第二行開始,取除最后一列外的所有列作為特征y = df.iloc[1:, -1] # 從第二行開始,取最后一列作為標簽# 轉換為Tensorself.data = torch.tensor(x.values, dtype=torch.float)self.labels = torch.tensor(y.values, dtype=torch.float)def __len__(self):return len(self.data)def __getitem__(self, index):sample = self.data[index]label = self.labels[index]return sample, label# 測試代碼
def test_csv_dataset():filepath = '大數據答辯成績表.csv'dataset = CsvDataset(filepath)print(f"數據集大小: {len(dataset)}")print(f"第一個樣本: {dataset[0]}")test_csv_dataset()
2.2 使用 TensorDataset 快速加載
如果數據已經是 Tensor 格式,可以使用TensorDataset
快速創建數據集,無需自定義類:
def test_tensor_dataset():filepath = '大數據答辯成績表.csv'df = pd.read_csv(filepath)df.drop(['學號', '姓名'], axis=1, inplace=True)x = df.iloc[1:, :-1]y = df.iloc[1:, -1]# 轉換為Tensordata = torch.tensor(x.values, dtype=torch.float)labels = torch.tensor(y.values, dtype=torch.float)# 使用TensorDatasetdataset = TensorDataset(data, labels)print(f"第一個樣本: {dataset[0]}")
三、加載圖像數據實戰
處理圖像數據時,我們需要考慮圖像的讀取、大小調整、格式轉換等問題。下面介紹兩種加載圖像數據的方法。
3.1 自定義圖像數據集
import os
import cv2
from torch.utils.data import Datasetclass PicDataset(Dataset):def __init__(self, filepath):self.filepaths = [] # 存儲圖像路徑self.labels = [] # 存儲標簽dirnames = [] # 存儲類別名稱# 遍歷文件夾for root, dirs, files in os.walk(filepath):if len(dirs) > 0:dirnames = dirs # 獲取類別文件夾名稱for file in files:f_path = os.path.join(root, file)self.filepaths.append(f_path)# 根據文件夾名稱確定標簽classname = root.split('\\')[-1]self.labels.append(dirnames.index(classname))def __len__(self):return len(self.filepaths)def __getitem__(self, index):filepath = self.filepaths[index]# 讀取圖像img = cv2.imread(filepath)# 調整圖像大小為112x112img = cv2.resize(img, (112, 112))# 轉換為Tensor并調整維度 (HWC -> CHW)t_img = torch.tensor(img)t_img = t_img.permute(2, 0, 1)label = self.labels[index]return t_img, label# 測試代碼
def test_pic_dataset():filepath = r'E:\人工智能\深度學習\dataset\butterfly'dataset = PicDataset(filepath)print(f"數據集大小: {len(dataset)}")img, label = dataset[0]print(f"圖像形狀: {img.shape}, 標簽: {label}")
3.2 使用 ImageFolder 快速加載
PyTorch 的ImageFolder
是加載圖像數據集的便捷工具,特別適合以下結構的數據集:
root/class1/img1.jpgimg2.jpgclass2/img1.jpgimg2.jpg
使用方法如下:
from torchvision.datasets import ImageFolder
from torchvision import transformsdef test_image_folder():filepath = r'E:\人工智能\深度學習\dataset\butterfly'# 定義圖像轉換transform = transforms.Compose([transforms.Resize((112, 112)), # 調整大小transforms.ToTensor(), # 轉換為Tensor])# 使用ImageFolder加載數據dataset = ImageFolder(root=filepath, transform=transform)print(f"類別: {dataset.classes}")print(f"數據集大小: {len(dataset)}")# 創建DataLoaderdataloader = DataLoader(dataset=dataset,batch_size=1,shuffle=True)# 顯示一張圖像for img, label in dataloader:print(f"圖像形狀: {img.shape}")print(f"標簽: {label}")breaktest_image_folder()
四、加載官方數據集
PyTorch 提供了許多常用的公開數據集(如 MNIST、CIFAR 等),可以直接下載使用:
from torchvision import datasets, transformsdef test_mnist_dataset():# 定義轉換transform = transforms.Compose([transforms.ToTensor()])# 加載MNIST訓練集dataset = datasets.MNIST(root='../dataset', # 數據保存路徑train=True, # 訓練集download=True, # 如果沒有數據則下載transform=transform)# 創建DataLoaderdataloader = DataLoader(dataset=dataset,batch_size=1,shuffle=True)# 顯示一張圖像for img, label in dataloader:print(f"圖像形狀: {img.shape}")print(f"標簽: {label}")breaktest_mnist_dataset()
五、總結
本文介紹了 PyTorch 加載不同類型數據的方法,包括:
- 加載 CSV 數據:可以自定義
CsvDataset
,或使用TensorDataset
快速加載 - 加載圖像數據:可以自定義
PicDataset
,或使用ImageFolder
加載按類別組織的圖像 - 加載官方數據集:直接使用
torchvision.datasets
中的類
掌握數據加載的技巧,可以為后續的模型訓練打下堅實基礎。在實際項目中,需要根據數據的具體格式和特點,選擇合適的加載方式,并進行必要的預處理。
希望本文能幫助大家快速上手 PyTorch 的數據加載,如果你有任何問題或建議,歡迎在評論區留言討論!