%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
代碼執行流程圖
3.5.1 讀取數據集
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
下載并加載FashionMNIST數據集
?關鍵參數?:
transform=trans:將圖像轉換為張量(形狀 [1, 28, 28],值域 [0,1])。
download=True:若本地無數據則自動下載。
?數據集結構?:
訓練集:60,000 張 28x28 灰度圖像。
測試集:10,000 張 28x28 灰度圖像。
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]
標簽映射
將數字標簽(0-9)轉換為可讀的文本標簽(如 0 → ‘t-shirt’)。
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
輸入 imgs 可以是張量或PIL圖像。
squeeze():移除單通道維度(1x28x28 → 28x28),否則 imshow 可能報錯。
cmap=‘gray’:確保灰度圖正確顯示(默認可能為彩色)。
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
?輸出?:顯示 2行x9列 的圖像網格,標題為對應的文本標簽。
X.reshape(18, 28, 28):調整形狀以匹配 imshow 的輸入要求(原始形狀為 18x1x28x28)。
3.5.2 讀取小批量
batch_size = 256def get_dataloader_workers():return 4 # 根據CPU核心數調整(通常設為4-8)train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())
shuffle=True:打亂訓練數據順序,避免模型記憶批次。
num_workers=4:啟用4個進程并行加載數據,加速數據讀取。
timer = d2l.Timer()
for X, y in train_iter:continue
print(f'加載時間:{timer.stop():.2f} sec')
‘2.30 sec’
3.5.3 整合所有組件
def load_data_fashion_mnist(batch_size, resize=None):trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize)) # Resize必須在ToTensor前trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))
?功能擴展?:支持調整圖像尺寸(如 resize=64 將圖像縮放為 64x64)。
?預處理順序?:
Resize(若指定)
ToTensor(轉為張量并歸一化)
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(f'X形狀: {X.shape}, 數據類型: {X.dtype}') # 輸出如 torch.Size([32,1,64,64])print(f'y形狀: {y.shape}, 數據類型: {y.dtype}') # 輸出如 torch.int64break
X形狀: torch.Size([32, 1, 64, 64]), 數據類型: torch.float32
y形狀: torch.Size([32]), 數據類型: torch.int64
X.shape = [batch_size, channels, height, width]
y 為標簽張量,形狀 [batch_size]