參考
5.8 網絡中的網絡(NiN)
前幾節介紹的LeNet、AlexNet和VGG在設計上的共同之處是:先以由卷積層構成的模塊充分抽取空間特征,再以由全連接層構成的模塊來輸出分類結果。其中,AlexNet和VGG對LeNet的改進主要在于如何對這兩個模塊加寬(增加通道數)和加深。本節我們介紹網絡中的網絡(NiN)。它提出了另外一個思路,即串聯多個由卷積層和“全連接”層構成的小網絡來構建一個深層網絡。
5.8.1 NiN塊
我們知道,卷積層的輸入和輸出通常是四維數組(樣本, 通道, 高, 寬),而全連接層的輸入和輸出則通常是二維數組(樣本、特征)。如果想在全連接層后再接上卷積層,則需要將全連接層的輸出變成四維。
NiN塊是NiN中的基礎塊。它由一個卷積層加兩個充當全連接層的 1 * 1 卷積層串聯而成。其中第一個卷積層的超參數可以自行設置,而第二和第三個卷積層的超參數一般是固定的。
import time
import torch
from torch import nn, optimimport sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def nin_block(in_channels, out_channels, kernel_size, stride, padding):blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU())return blk
5.8.2 NiN模型
NiN是在AlexNet問世不久后提出的。它們的卷積層設定有類似之處。NiN使用卷積窗口形狀分別為 11×11、 5×5和3×3的卷積層,相應的輸出通道也與AlexNet中的一致。每個NiN塊后接一個步幅為2、窗口形狀為3×3的最大池化層。
除使用NiN塊以外,NiN還有一個設計與AlexNet顯著不同:NiN去掉了AlexNet最后的3個全連接層,取而代之地,NiN使用了輸出通道數等于標簽類別數的NiN塊,然后使用全局平均池化層對每個通道中所有元素求平均并直接用于分類。這里的全局平均池化層即窗口形狀等于輸入空間維形狀的平均池化層。NiN的這個設計的好處是可以顯著減小模型參數尺寸,從而緩解過擬合。然而,該設計有時會造成獲得有效模型的訓練時間的增加。
import torch.nn.functional as Fclass GlobalAvgPool2d(nn.Module):# 全局平均池化層可通過將池化窗口形狀設置成輸入的高和寬實現def __init__(self):super(GlobalAvgPool2d, self).__init__()def forward(self, x):return F.avg_pool2d(x, kernel_size=x.size()[2:])net = nn.Sequential(nin_block(1, 96, kernel_size = 11, stride = 4, padding = 0),nn.MaxPool2d(kernel_size = 3, stride = 2),nin_block(96, 256, kernel_size = 5, stride = 1, padding = 2),nn.MaxPool2d(kernel_size = 3, stride = 2),nin_block(256, 384, kernel_size = 3, stride = 1, padding = 1),nn.MaxPool2d(kernel_size=3, stride =2),nn.Dropout(0.5),# 標簽類別數是10nin_block(384, 10, kernel_size = 3, stride=1, padding = 1),GlobalAvgPool2d(),# 將四維的輸出轉成二維的輸出,其形狀為(批量, 10)d2l.FlattenLayer()
)print(net)
構建數據觀察每一層的結構
X = torch.rand(1, 1, 224, 224)
for name, blk in net.named_children():X = blk(X)print(name, 'output shape: ', X.shape)
5.8.3 獲取數據和訓練模型
batch_size = 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)lr, num_epochs = 0.002, 5
optimizer = torch.optim.Adam(net.parameters(), lr = lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)