LeNet卷積神經網絡
- 一、理論部分
- 1.1 核心理論
- 1.2 LeNet-5 網絡結構
- 1.3 關鍵細節
- 1.4 后期改進
- 1.6 意義與局限性
- 二、代碼實現
- 2.1 導包
- 2.1 數據加載和處理
- 2.3 網絡構建
- 2.4 訓練和測試函數
- 2.4.1 訓練函數
- 2.4.2 測試函數
- 2.5 訓練和保存模型
- 2.6 模型加載和預測
一、理論部分
LeNet是一種經典的
卷積神經網絡
(CNN),由Yann LeCun等人于1998年提出,最初用于手寫數字識別(如MNIST數據集)。它是CNN的奠基性工作之一,其核心思想
是通過局部感受野、共享權重和空間下采樣來提取有效特征
1.1 核心理論
-
局部感受野(Local Receptive Fields):
卷積層通過小尺寸的濾波器(如5×5)掃描輸入圖像,每個神經元僅連接輸入圖像的局部區域,從而捕捉局部特征(如邊緣、紋理) -
共享權重(Weight Sharing):
同一卷積層的濾波器在整張圖像上共享參數,顯著減少參數量,增強平移不變性 -
空間下采樣(Subsampling):
池化層(如平均池化)降低特征圖的分辨率,減少計算量并增強對微小平移的魯棒性 -
多層特征組合:
通過交替的卷積和池化層,逐步組合低層特征(邊緣)為高層特征(數字形狀)
1.2 LeNet-5 網絡結構
LeNet-5是LeNet系列中最著名的版本,其結構如下(輸入為32×32灰度圖像):
層類型 | 參數說明 | 輸出尺寸 |
---|---|---|
輸入層 | 灰度圖像 | 32×32×1 |
C1層 | 卷積層:6個5×5濾波器,步長1,無填充 | 28×28×6 |
S2層 | 平均池化:2×2窗口,步長2 | 14×14×6 |
C3層 | 卷積層:16個5×5濾波器,步長1 | 10×10×16 |
S4層 | 平均池化:2×2窗口,步長2 | 5×5×16 |
C5層 | 卷積層:120個5×5濾波器 | 1×1×120 |
F6層 | 全連接層:84個神經元 | 84 |
輸出層 | 全連接 + Softmax(10類) | 10 |
1.3 關鍵細節
-
激活函數
:
原始LeNet使用Tanh或Sigmoid,現代實現常用ReLU -
池化方式
:
原始版本使用平均池化,后續改進可能用最大池化 -
參數量優化
:
C3層并非全連接至S2的所有通道,而是采用部分連接(如論文中的連接表),減少計算量 -
輸出處理
:
最后通過全連接層(F6)和Softmax輸出分類概率(如0-9數字)
1.4 后期改進
- ReLU替代Tanh:解決梯度消失問題,加速訓練
- 最大池化:更關注顯著特征,抑制噪聲
- Batch Normalization:穩定訓練過程
- Dropout:防止過擬合(原LeNet未使用)
1.6 意義與局限性
-
意義:
證明了CNN在視覺任務中的有效性,啟發了現代深度學習模型(如AlexNet、ResNet) -
局限性:
參數量小、層數淺,對復雜數據(如ImageNet)表現不足,需更深的網絡結構
LeNet的設計思想至今仍是CNN的基礎,理解它有助于掌握現代卷積神經網絡的演變邏輯
二、代碼實現
- LeNet 是一個經典的卷積神經網絡(CNN),由 Yann LeCun 等人于 1998 年提出,主要用于手寫數字識別(如 MNIST 數據集)
- MNIST數據集是機器學習領域中非常經典的一個數據集,由60000個訓練樣本和10000個測試樣本組成,每個樣本都是一張28 * 28像素的灰度手寫數字圖片
- 總體來看,LeNet(LeNet-5)由兩個部分組成:(1)
卷積編碼器
:由兩個卷積層組成(2)全連接層密集塊
:由三個全連接層組成
2.1 導包
import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
from torchsummary import summary
2.1 數據加載和處理
# 加載 MNIST 數據集
def load_data(batch_size=64):transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), # 將圖像轉換為張量torchvision.transforms.Normalize((0.5,), (0.5,)) # 歸一化])# 下載訓練集和測試集train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 創建 DataLoadertrain_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)return train_loader, test_loader
2.3 網絡構建
- LeNet 的網絡結構如下:
- 卷積層 1:輸入通道 1,輸出通道 6,卷積核大小 5x5
- 池化層 1:2x2 的最大池化
- 卷積層 2:輸入通道 6,輸出通道 16,卷積核大小 5x5。
- 池化層 2:2x2 的最大池化。
- 全連接層 1:輸入 16x5x5,輸出 120
- 全連接層 2:輸入 120,輸出 84
- 全連接層 3:輸入 84,輸出 10(對應 10 個類別)
#定義LeNet網絡架構
class LeNet(nn.Module):def __init__(self):super(LeNet,self).__init__()self.net=nn.Sequential(#卷積層1nn.Conv2d