知識點回顧:
- Dataset類的__getitem__和__len__方法(本質是python的特殊方法)
- Dataloader類
- minist手寫數據集的了解
作業:了解下cifar數據集,嘗試獲取其中一張圖片
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加載數據的工具
from torchvision import datasets, transforms # torchvision 是一個用于計算機視覺的庫,datasets 和 transforms 是其中的模塊
import matplotlib.pyplot as plt# 設置隨機種子,確保結果可復現
torch.manual_seed(42)
# 1. 數據預處理,該寫法非常類似于管道pipeline
# transforms 模塊提供了一系列常用的圖像預處理操作# 先歸一化,再標準化
transform = transforms.Compose([transforms.ToTensor(), # 轉換為張量并歸一化到[0,1]transforms.Normalize((0.1307,), (0.3081,)) # MNIST數據集的均值和標準差,這個值很出名,所以直接使用
])
# 2. 加載MNIST數據集,如果沒有會自動下載
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)
一、Dataset類
現在我們想要取出來一個圖片,看看長啥樣,因為datasets.MNIST本質上集成了torch.utils.data.Dataset,所以自然需要有對應的方法。
import matplotlib.pyplot as plt# 隨機選擇一張圖片,可以重復運行,每次都會隨機選擇
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 隨機選擇一張圖片的索引
# len(train_dataset) 表示訓練集的圖片數量;size=(1,)表示返回一個索引;torch.randint() 函數用于生成一個指定范圍內的隨機數,item() 方法將張量轉換為 Python 數字
image, label = train_dataset[sample_idx] # 獲取圖片和標簽
__getitem__方法
# 示例代碼
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 創建類的實例
my_list_obj = MyList()
# 此時可以使用索引訪問元素,這會自動調用__getitem__方法
print(my_list_obj[2]) # 輸出:30
__len__方法
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 創建類的實例
my_list_obj = MyList()
# 使用len()函數獲取元素數量,這會自動調用__len__方法
print(len(my_list_obj)) # 輸出:5
# minist數據集的簡化版本
class MNIST(Dataset):def __init__(self, root, train=True, transform=None):# 初始化:加載圖片路徑和標簽self.data, self.targets = fetch_mnist_data(root, train) # 這里假設 fetch_mnist_data 是一個函數,用于加載 MNIST 數據集的圖片路徑和標簽self.transform = transform # 預處理操作def __len__(self): return len(self.data) # 返回樣本總數def __getitem__(self, idx): # 獲取指定索引的樣本# 獲取指定索引的圖像和標簽img, target = self.data[idx], self.targets[idx]# 應用圖像預處理(如ToTensor、Normalize)if self.transform is not None: # 如果有預處理操作img = self.transform(img) # 轉換圖像格式# 這里假設 img 是一個 PIL 圖像對象,transform 會將其轉換為張量并進行歸一化return img, target # 返回處理后的圖像和標簽
# 可視化原始圖像(需要反歸一化)
def imshow(img):img = img * 0.3081 + 0.1307 # 反標準化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 顯示灰度圖像plt.show()print(f"Label: {label}")
imshow(image)
二、Dataloader類
# 3. 創建數據加載器
train_loader = DataLoader(train_dataset,batch_size=64, # 每個批次64張圖片,一般是2的冪次方,這與GPU的計算效率有關shuffle=True # 隨機打亂數據
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每個批次1000張圖片# shuffle=False # 測試時不需要打亂數據
)
@浙大疏錦行