文章目錄
- 前言
- 一、VGG網絡簡介
- 1.1 VGG的核心特點
- 1.2 VGG的典型結構
- 1.3 優點與局限性
- 1.4 本文的實現目標
- 二、搭建VGG網絡
- 2.1 數據準備
- 2.2 定義VGG塊
- 2.3 構建VGG網絡
- 2.4 輔助工具
- 2.4.1 計時器和累加器
- 2.4.2 準確率計算
- 2.4.3 可視化工具
- 2.5 訓練模型
- 2.6 運行實驗
- 總結
前言
深度學習是近年來人工智能領域的重要突破,而卷積神經網絡(CNN)作為其核心技術之一,在圖像分類、目標檢測等領域展現了強大的能力。VGG(Visual Geometry Group)網絡是CNN中的經典模型之一,以其模塊化的“塊”設計和深層結構而聞名。本篇博客將通過PyTorch實現一個簡化的VGG網絡,并結合代碼逐步解析其構建、訓練和可視化過程,幫助讀者從代碼層面理解深度學習的基本原理和實踐方法。我們將使用Fashion-MNIST數據集進行實驗,展示如何從零開始搭建并訓練一個VGG模型。
本文的目標讀者是對深度學習有基本了解、希望通過代碼實踐加深理解的初學者或中級開發者。以下是博客的完整內容,包括代碼實現和詳細說明。
一、VGG網絡簡介
VGG網絡(Visual Geometry Group Network)是由牛津大學視覺幾何組在2014年提出的深度卷積神經網絡(CNN)模型,因其在ImageNet圖像分類競賽中的優異表現而廣為人知。VGG的設計理念是通過堆疊多個小卷積核(通常為3×3)和池化層,構建一個深層網絡,從而提取圖像中的復雜特征。與之前的模型(如AlexNet)相比,VGG顯著增加了網絡深度(常見版本包括VGG-16和VGG-19,分別有16層和19層),并采用統一的模塊化結構,使其易于理解和實現。
1.1 VGG的核心特點
- 小卷積核:VGG使用3×3的小卷積核替代傳統的大卷積核(如5×5或7×7)。兩個3×3卷積核的堆疊可以達到5×5的感受野,而參數量更少,計算效率更高,同時增加了非線性(通過更多ReLU激活)。
- 模塊化設計:網絡由多個“塊”(block)組成,每個塊包含若干卷積層和一個最大池化層。這種設計使得網絡結構清晰,便于擴展或調整。
- 深度增加:VGG通過加深網絡層數(從11層到19層不等)提升性能,證明了深度對特征提取的重要性。
- 全連接層:在卷積層之后,VGG使用多個全連接層(通常為4096、4096和1000神經元)進行分類,輸出對應ImageNet的1000個類別。
1.2 VGG的典型結構
以下是VGG-16的結構示意圖,展示了其卷積塊和全連接層的組織方式:
上圖中:
- 綠色方框表示卷積層(3×3卷積核,步幅1,padding=1),對應圖中的“convolution+ReLU”部分(以立方體表示)。這些卷積層負責提取圖像特征,padding=1確保特征圖尺寸在卷積后保持不變。
- 紅色方框表示最大池化層(2×2,步幅2),對應圖中的“max pooling”部分(以紅色立方體表示)。池化層將特征圖尺寸減半(例如從224×224到112×112),同時保留重要特征。
- 藍色部分為全連接層,最終輸出分類結果,對應圖中的“fully connected+ReLU”和“softmax”部分(以藍色線條表示)。全連接層將卷積特征展平后進行分類,輸出對應ImageNet的1000個類別。
VGG-16包含13個卷積層和3個全連接層,總計16層(池化層不計入層數)。每個卷積塊的通道數逐漸增加(從64到512),而池化層將特征圖尺寸逐步減半(從224×224到7×7)。
1.3 優點與局限性
優點:
- 結構簡單,易于實現和理解。
- 小卷積核和深層設計提高了特征提取能力。
- 在多種視覺任務中表現出色,可作為預訓練模型遷移學習。
局限性:
- 參數量巨大(VGG-16約有1.38億個參數),訓練和推理耗時。
- 深層網絡可能導致梯度消失問題(盡管ReLU和適當初始化緩解了部分問題)。
- 對內存和計算資源要求較高,不適合資源受限的設備。
1.4 本文的實現目標
在本文中,我們將基于PyTorch實現一個簡化的VGG網絡,針對Fashion-MNIST數據集(28×28灰度圖像,10個類別)進行調整。我們保留VGG的模塊化思想,但適當減少層數和參數量,以適應較小規模的數據和計算資源。通過代碼實踐,讀者可以深入理解VGG的設計原理及其在實際任務中的應用。
下一節將進入具體的代碼實現部分,逐步搭建VGG網絡并完成訓練。
二、搭建VGG網絡
2.1 數據準備
在開始構建VGG網絡之前,我們需要準備訓練和測試數據。這里使用Fashion-MNIST數據集,這是一個包含10類服裝圖像的灰度圖像數據集,每個圖像大小為28×28像素。以下是數據加載的代碼:
import torchvision
import torchvision.transforms as transforms
from torch.utils import data
import multiprocessingdef get_dataloader_workers():"""使用電腦支持的最大進程數來讀取數據"""return multiprocessing.cpu_count()def load_data_fashion_mnist(batch_size, resize=None):"""下載Fashion-MNIST數據集,然后將其加載到內存中。參數:batch_size (int): 每個數據批次的大小。resize (int, 可選): 圖像的目標尺寸。如果為 None,則不調整大小。返回:tuple: 包含訓練 DataLoader 和測試 DataLoader 的元組。"""# 定義變換管道trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)# 加載 Fashion-MNIST 訓練和測試數據集mnist_train = torchvision.datasets.FashionMNIST(root="./data",train=True,transform=trans,download=True)mnist_test = torchvision.datasets.FashionMNIST(root="./data",train=False,transform=trans,download=True)# 返回 DataLoader 對象return (data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=get_dataloader_workers()))
這段代碼定義了load_data_fashion_mnist
函數,用于加載Fashion-MNIST數據集并將其封裝成PyTorch的DataLoader
對象。transforms.ToTensor()
將圖像轉換為張量格式,batch_size
控制每個批次的數據量,shuffle=True
確保訓練數據隨機打亂以提高模型泛化能力。num_workers
通過多進程加速數據加載。
2.2 定義VGG塊
VGG網絡的核心思想是將網絡分解為多個“塊”(block),每個塊包含若干卷積層和一個池化層。以下是VGG塊的實現:
import torch
from torch import nndef vgg_block(num_convs, in_channels, out_channels):layers = [] # 初始化一個空列表,用于存儲網絡層for _ in range(num_convs): # 循環 num_convs 次,構建卷積層layers.append(nn.Conv2d( # 添加一個二維卷積層in_channels, # 輸入通道數out_channels, # 輸出通道數kernel_size=3, # 卷積核大小為 3x3padding=1)) # 填充大小為 1,保持特征圖尺寸layers.append(nn.ReLU()) # 添加 ReLU 激活函數in_channels = out_channels # 更新輸入通道數為輸出通道數,用于下一次卷積layers.append(nn.MaxPool2d( # 添加一個最大池化層kernel_size=2, # 池化核大小為 2x2stride=2)) # 步幅為 2,縮小特征圖尺寸return nn.Sequential(*layers)