文章目錄
- 一、前置知識
- 如何查看torchvision的數據集
- 二、代碼(附注釋)及運行結果
一、前置知識
如何查看torchvision的數據集
(1)打開官網 https://pytorch.org/
pytorch官網
(2)打開torchvision
在Docs下拉后選擇torchvision
(3)左側點擊Datasets
本次用的數據集是CIFAR10:
可以看到,要輸入的參數有:
root(字符串):數據集的根目錄,其中存在 cifar-10-batches-py 目錄,如果設置 download 為 True,則數據集將保存在此目錄中。
train(bool,可選):如果為 True,則從訓練集創建數據集,否則從測試集創建數據集。
transform(callable,可選):接受 PIL 圖像并返回轉換后版本的函數/轉換。例如,transforms.RandomCrop。
target_transform(callable,可選):接受目標并對其進行轉換的函數/轉換。
download(bool,可選):如果為 True,則從互聯網下載數據集并將其放在根目錄中。如果數據集已經下載,則不會重新下載。
二、代碼(附注釋)及運行結果
import torchvision
from torch.utils.tensorboard import SummaryWriter# 定義導入數據時進行的變換
data_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])# 創建訓練集和測試集
train_set = torchvision.datasets.CIFAR10("./dataset1", train=True, transform=data_transform, download=True)
test_set = torchvision.datasets.CIFAR10("./dataset1", train=False, transform=data_transform, download=True)# 打印test_set第一個數據
# 結果為:(<PIL.Image.Image image mode=RGB size=32x32 at 0x10C7177C190>, 3)
print(test_set[0])
# 打印test_set數據的類別
# 結果為:['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
print(test_set.classes)# 將test_set的第一個數據拆分為img和target
img, target = test_set[0]
# 打印test_set第一個數據的img
# 結果為<PIL.Image.Image image mode=RGB size=32x32 at 0x10C7177C190>
print(img)
# 打印test_set第一個數據的target,結果為3
print(target)
# 打印test_set第target個類別
print(test_set.classes[target])# 創建一個 TensorBoard 的 SummaryWriter 對象,用于記錄測試集中的圖像
writer = SummaryWriter("logs")
for i in range(10):img, target = test_set[i]writer.add_image("test_set", img, i)
運行結果: