《動手學深度學習》-3.5-學習筆記
# 通過ToTensor實例將圖像數據從PIL類型變換成32位浮點數格式,
# 并除以255使得所有像素的數值均在0~1之間
trans = transforms.ToTensor()#用于將圖像數據從 PIL 圖像格式(Python Imaging Library,Python 的圖像處理庫)轉換為 PyTorch 張量(Tensor)。
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)#加載訓練數據集
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)#加載測試數據集
-
torchvision.datasets.FashionMNIST
是 PyTorch 提供的用于加載 FashionMNIST 數據集的類。 -
參數解釋:
-
root="../data"
:指定數據集的存儲路徑。如果數據集不存在,PyTorch 會自動下載到這個路徑。 -
train=True
:表示加載訓練數據集。 -
transform=trans
:指定對圖像數據應用的預處理操作,這里是transforms.ToTensor()
,即將圖像轉換為歸一化的張量。 -
download=True
:如果指定路徑下沒有數據集,會自動從網絡下載。 - ?了解基礎情況:在 PyTorch 中,
mnist_train
是一個torchvision.datasets.FashionMNIST
數據集對象,它是一個可迭代的集合,包含了所有訓練樣本的圖像和標簽。mnist_train[3]
表示獲取數據集中的第四個樣本(索引從 0 開始),包括第四個樣本的圖像和標簽。 -
image.shape
輸出torch.Size([1, 28, 28])
,表示圖像是一個張量(Tensor),形狀為:-
1:表示圖像有 1 個通道(灰度圖)。
-
28:圖像的寬度為 28 像素。
-
28:圖像的高度為 28 像素。
-
-
label
輸出的是一個整數,表示圖像的類別標簽。FashionMNIST 數據集有 10 個類別,每個類別對應一個整數標簽(從 0 到 9)。
-
-
打印出來看了一下
?def get_fashion_mnist_labels(labels): """返回Fashion-MNIST數據集的文本標簽"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]
-
這是一個列表推導式,用于將輸入的整數標簽列表
labels
轉換為對應的文本標簽列表。 -
對于
labels
中的每個元素i
:-
int(i)
確保i
是整數(雖然通常labels
已經是整數,但這里加了保險)。 -
text_labels[int(i)]
從text_labels
列表中獲取對應的文本標簽。
對text_labels -
列表的索引(從 0 到 9)對應于數據集中的整數標簽。例如:
-
0
對應't-shirt'
-
1
對應'trouser'
-
9
對應'ankle boot'
下面這段 僅僅是 使用這個函數,應用場景 -
-
-
-
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): """繪制圖像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 圖片張量ax.imshow(img.numpy())else:# PIL圖片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
?show_images
是一個用于批量顯示圖像的工具函數,
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
從 FashionMNIST 數據集中加載一批圖像,使用 show_images
函數將圖像以 2 行 9 列的網格形式顯示,并為每張圖像添加文本標簽。
?
創建Dataloader
batch_size = 256def get_dataloader_workers(): """使用4個進程來讀取數據"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())
def load_data_fashion_mnist(batch_size, resize=None): """下載Fashion-MNIST數據集"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))
用于下載并加載 FashionMNIST 數據集,并將其轉換為適合訓練和測試的 DataLoader
對象。