本文將介紹以下內容:
數據集與數據加載器
數據遷移
如何建立神經網絡
數據集與數據加載器
處理數據樣本的代碼可能會變得混亂且難以維護;理想情況下,我們希望我們的數據集代碼與模型訓練代碼解耦,以獲得更好的可讀性和模塊化。PyTorch提供了兩個數據原語:torch.utils.data.DataLoader和torch.utils.data.Dataset,它們允許你使用預加載的數據集和你自己的數據。Dataset存儲了樣本及其相應的標簽,DataLoader在Dataset周圍包裝了一個可迭代對象,以便于訪問樣本。
PyTorch域庫提供了許多預加載的數據集(如FashionMNIST),這些數據集是torch.utils.data.Dataset的子類,并實現了特定數據的特定函數。它們可用于原型化和基準化模型。你可以在這里找到它們:圖像數據集,文本數據集和音頻數據集
加載數據集
下面是一個如何從TorchVision加載Fashion-MNIST數據集的示例。Fashion-MNIST是Zalando文章圖像的數據集,由60,000個訓練樣例和10,000個測試樣例組成。每個示例都包含一個28×28灰度圖像和來自10個類之一的關聯標簽。
我們用以下參數加載FashionMNIST數據集:
-
root是存儲訓練/測試數據的路徑,
-
Train指定訓練或測試數據集,
-
download=True從互聯網上下載數據,如果它在根不可用。
-
Transform和target_transform指定特征和標簽轉換
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)
迭代和可視化數據集
我們可以像列表一樣手動索引數據集:training_data[index]。我們使用matplotlib來可視化訓練數據中的一些樣本。
labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx]figure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()
輸出為:
為文件創建自定義數據集
自定義Dataset類必須實現三個函數:init, len__和__getitem。看看這個實現;FashionMNIST圖像存儲在目錄img_dir中,它們的標簽單獨存儲在CSV文件annotations_file中。
在接下來的部分中,我們將分解這些函數中發生的事情。
class CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label
init
__init__函數在實例化Dataset對象時運行一次。我們初始化包含圖像、注釋文件和兩個轉換(下一節將詳細介紹)的目錄。
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file