CIFAR-10是一個更接近普適物體的彩色圖像數據集。CIFAR-10 是由Hinton 的學生Alex Krizhevsky 和Ilya Sutskever 整理的一個用于識別普適物體的小型數據集。一共包含10 個類別的RGB 彩色圖片:飛機( airplane )、汽車( automobile )、鳥類( bird )、貓( cat )、鹿( deer )、狗( dog )、蛙類( frog )、馬( horse )、船( ship )和卡車( truck )。
每個圖片的尺寸為32 × 32 ,每個類別有6000個圖像,數據集中一共有50000 張訓練圖片和10000 張測試圖片。
import torchvision
import torchvision.transforms as transforms# 定義數據預處理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 自動下載訓練集
trainset = torchvision.datasets.CIFAR10(root='./data', # 數據保存路徑train=True,download=True, # 設置為True自動下載transform=transform
)# 自動下載測試集
testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform
)
1. 導入必要的庫
import torchvision import torchvision.transforms as transforms |
- torchvision:PyTorch 的視覺庫,提供常用數據集、模型架構和圖像轉換工具。
- transforms:用于圖像預處理的模塊,如縮放、歸一化等。
2. 定義數據預處理流程
transform = transforms.Compose([ ??? transforms.ToTensor(), ??? transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) |
- transforms.Compose:將多個預處理操作按順序組合。
- transforms.ToTensor():
- 將 PIL 圖像或 NumPy 數組(H×W×C,范圍 0-255)轉換為 PyTorch 張量(C×H×W,范圍 0.0-0)。
- transforms.Normalize(mean, std):
- 對每個通道進行歸一化:output = (input - mean) / std。
- 這里mean=(0.5, 0.5, 0.5)和std=(0.5, 0.5, 0.5)將像素值從[0.0, 1.0]映射到[-1.0, 1.0](例如,0.0→-1.0,1.0→1.0)。
3. 下載并加載訓練集
trainset = torchvision.datasets.CIFAR10( ??? root='./data',? # 數據保存路徑 ??? train=True,???? # True表示訓練集(50,000張) ??? download=True,? # 自動下載(如果數據不存在) ??? transform=transform? # 應用預處理 ) |
- torchvision.datasets.CIFAR10:CIFAR-10 數據集類,包含 10 個類別(如飛機、汽車、鳥類等)的 60,000 張 32×32 彩色圖像。
- 參數說明:
- root='./data':數據將下載到當前目錄的data文件夾中。
- train=True:加載訓練集(50,000 張);若為False則加載測試集(10,000 張)。
- download=True:若數據不存在,自動從互聯網下載(約 170MB)。
- transform=transform:對圖像應用之前定義的預處理(轉為張量并歸一化)。
4. 下載并加載測試集
testset = torchvision.datasets.CIFAR10( ??? root='./data',? # 與訓練集路徑一致 ??? train=False,??? # 加載測試集 ??? download=True,? # 自動下載 ??? transform=transform? # 應用相同的預處理 ) |
- 測試集與訓練集結構相同,但用于模型評估,不參與訓練。
5. 數據驗證與使用
下載完成后,數據將存儲在./data/cifar-10-batches-py目錄中。你可以:
- 查看數據集大小:
print(len(trainset))? # 輸出: 50000 print(len(testset))?? # 輸出: 10000 |
- 訪問單個樣本:
image, label = trainset[0]? # 獲取第一張圖像及其標簽 print(image.shape)? # 輸出: (3, 32, 32) print(label)??????? # 輸出: 6(對應類別索引) |
- 使用數據加載器批量處理數據:
from torch.utils.data import DataLoader trainloader = DataLoader(trainset, batch_size=32, shuffle=True) testloader = DataLoader(testset, batch_size=32, shuffle=False) |
注意事項
- 下載路徑:
- 若指定路徑(如./data)已存在 CIFAR-10 數據,download=True不會重復下載。
- 若路徑錯誤或無寫入權限,會拋出異常(如PermissionError)。
- 網絡問題:
- 首次下載需聯網,可能需要幾分鐘。若下載中斷,可刪除./data目錄后重新運行。
- 數據預處理:
- 歸一化參數mean和std通常根據數據集的統計特性設定。對于 CIFAR-10,常用(0.5, 0.5, 0.5)進行簡單歸一化。
- 若需要更精確的歸一化,可計算數據集的真實均值和標準差(如mean=[0.4914, 0.4822, 0.4465],std=[0.2470, 0.2435, 0.2616])。
擴展應用
加載數據后,可用于訓練 CNN 模型(如之前創建的SimpleCNN):
# 假設model已定義 from torch import nn, optim criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 訓練循環 for epoch in range(5):? # 訓練5個輪次 ??? for inputs, labels in trainloader: ??????? optimizer.zero_grad() ??????? outputs = model(inputs) ??????? loss = criterion(outputs, labels) ??????? loss.backward() ??????? optimizer.step() ??? print(f"Epoch {epoch+1} completed") |