知識點回顧:
- Dataset類的__getitem__和__len__方法(本質是python的特殊方法)
- Dataloader類
- minist手寫數據集的了解
作業:了解下cifar數據集,嘗試獲取其中一張圖片
一、首先加載CIFAR數據集
import torch
import torchvision
import torchvision.transforms as transforms
from matplotlib import pyplot as plt# 定義數據轉換
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加載訓練集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform
)
二、創建DataLoader并獲取單張圖片
# 創建DataLoader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True
)# 獲取一個batch的數據
dataiter = iter(trainloader)
images, labels = next(dataiter)# 顯示第一張圖片
def imshow(img):img = img / 2 + 0.5 # 反歸一化npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()imshow(images[0])
print('Label:', trainset.classes[labels[0]])
三、直接通過Dataset獲取單張圖片
# 直接通過Dataset獲取第100張圖片
image, label = trainset[100]# 顯示圖片
imshow(image)
print('Label:', trainset.classes[label])
說明:
1. Dataset 類的兩個核心方法:
? ?
? ?- __len__ : 返回數據集大小
? ?- __getitem__ : 根據索引返回單個樣本
2. DataLoader 主要參數:
? ?
? ?- batch_size : 每次加載的樣本數
? ?- shuffle : 是否打亂數據順序
3. CIFAR-10數據集包含10個類別:
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']