一、概念
Pytorch的標準數據集包括很多種類型,如CIFAR,COCO,KITTI,MNIST等,我們可以在官網查看。當然我們也可以做數據集,但需要自己標注。
二、如何調用數據集
一、調用torchvision
在程序中調用torchvision.datasets,下面用程序示例如何下載CIFAR10數據集。
import torchvisiontrain_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
也可以復制路徑通過其他方式下載,然后將下載文件放入py文件路徑下,可運行程序可自動解壓。
如果想顯示數據集的圖片,可以直接調用imshow方法。
img, target = test_set[0]
img.show()
如果想通過tensorboard顯示圖片,需要先將圖片格式轉化為tensor,然后調用SummaryWriter類。
二、調用dataset類
dataset類屬于抽象類,需要通過創建子類來繼承,從而創建數據集。
from torch.utils.data import Dataset
from PIL import Image
import osclass MyData(Dataset):def __init__(self, root_dir, label_dir):self.root_dir = root_dirself.label_dir = label_dirself.path = os.path.join(self.root_dir, self.label_dir)self.img_path = os.listdir(self.path)def __getitem__(self, idx):img_name = self.img_path[idx]img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)img = Image.open(img_item_path)label = self.label_dirreturn img, labeldef __len__(self):return len(self.img_path)
_ init _可以初始化子類的基礎參數,可以自定義,相當于構造函數。
_ getitem _根據索引返回數據和標簽。
_ len _返回數據大小
三、加載數據集
一般用DataLoader類來加載數據集。常見的參數包括:batch_size, shuffle num_workers。
這些參數的意義如下:
batch_size:指批大小,在訓練時每次在訓練集中取batchsize個樣本。
epoch:指使用所有訓練集的樣本訓練一次。
shuffle :指將訓練集進行打亂的操作,一般生成數據集的時候要shuffle一下圖片順序,防止過擬合。
num_workers:設定DataLoader要使用多少個子進程進行加載。
drop_last:指訓練集經過批處理后剩余的部分數據的處理模式。ture代表丟棄,false代表繼續執行,只是batch_size會相對變小。
簡單例子:
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWritertest_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())test_load = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=False)