【動手學深度學習】LeNet:卷積神經網絡的開山之作

【動手學深度學習】LeNet:卷積神經網絡的開山之作

  • 1,LeNet卷積神經網絡簡介
  • 2,Fashion-MNIST圖像分類數據集
  • 3,LeNet總體架構
  • 4,LeNet代碼實現
    • 4.1,定義LeNet模型
    • 4.2,定義模型評估函數
    • 4.3,定義訓練函數進行訓練


1,LeNet卷積神經網絡簡介

LeNet 是一種經典的卷積神經網絡,是現代卷積神經網絡的起源之一。它是早期成功的神經網絡;LeNet先使用卷積層來學習圖片空間信息,使用池化層降低圖片敏感度,然后使用全連接層來轉換到類別空間。 其思想被廣泛應用于圖像分類、目標檢測、圖像分割等多個計算機視覺領域,為這些領域的研究和發展提供了新的思路和方法。例如,在安防領域用于面部識別和監控系統,在自動駕駛領域用于實時視頻分析和對象跟蹤等。

1989年,Yann LeCun等人在貝爾實驗室工作期間提出了LeNet-1。這個網絡主要用于手寫數字識別,引入了卷積操作和權值共享的概念,簡化了網絡結構,減少了參數數量,提高了模型的泛化能力和訓練速度。此后經過多年的迭代改進,1998年,LeCun等人正式發表了LeNet-5。LeNet-5在LeNet-1的基礎上進一步優化了網絡結構,增加了網絡的深度和復雜度,使其在手寫數字識別任務上取得了更好的性能。LeNet-5的成功應用證明了CNN在圖像識別領域的巨大潛力,為后續CNN的發展奠定了堅實的基礎。


2,Fashion-MNIST圖像分類數據集

Fashion-MNIST數據集是一個廣泛使用的圖像分類數據集。Fashion-MNIST中包含的10個類別,分別為t-shirt(T恤)、trouser(褲子)、pullover(套衫)、dress(連衣裙)、coat(外套)、sandal(涼鞋)、shirt(襯衫)、sneaker(運動鞋)、bag(包)和ankle boot(短靴)。

之前,已經學習過Fashion-MNIST數據集。 【動手學深度學習】Fashion-MNIST圖片分類數據集,其基本情況如下:

  • 訓練集:包含60,000張圖像,用于模型訓練;
  • 測試集:包含10,000張圖像,用于評估模型性能;
  • 數據集由灰度圖像組成,其通道數為1;
  • 每個圖像的高度和寬度均為28像素;
  • 調用load_data_fashion_mnist()函數加載數據集;

具體定義如下:

"""
下載Fashion-MNIST數據集,然后將其加載到內存中
參數resize表示調整圖片大小
"""
def load_data_fashion_mnist(batch_size, resize=None): # trans是一個用于轉換的 *列表*trans = [transforms.ToTensor()]if resize:    # resize不為空,表示需要調整圖片大小trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

3,LeNet總體架構

總體來看,LeNet(LeNet-5)由兩個部分組成:

  • 卷積編碼器:由兩個卷積層組成;
  • 全連接層密集塊:由三個全連接層組成;

在這里插入圖片描述

每個卷積塊中的基本單元是一個卷積層、一個sigmoid激活函數和平均匯聚層。(實際上使用ReLU激活函數和最大匯聚層更有效,但當時還沒有發現):

  • Fashion-MNIST數據集的圖像通道為1,大小為28×28,內部經過卷積層填充之后得到的實際輸入數據是32×32的圖像數據

  • 第一卷積層有6個輸出通道,而第二個卷積層有16個輸出通道;

  • 對應輸出通道的數量,第一個卷積層有6個5×5的卷積核,第二個卷積層有16個5×5的卷積核;

  • 每個卷積核應用于輸入數據時會產生一個特征圖(feature map),也就是一個輸出通道;

  • 每個卷積層都使用不同數量的5×5的卷積核和一個sigmoid激活函數。這些層將輸入映射到多個二維特征輸出,通常同時增加通道的數量;

  • 卷積操作后,通過2×2的池化操作默認步幅為2和池化窗口大小保持一致)將原特征圖的各維度減半。比如原來是28×28,池化后變為14×14;


4,LeNet代碼實現

接下來使用深度學習框架實現LeNet模型,并進行訓練和測試。


4.1,定義LeNet模型

LeNet模型總共七層: 兩層卷積層、兩層池化層、三層全連接層; 其中每層都使用sigmod作為激活函數,它將卷積層的輸出壓縮到0和1之間,有助于非線性變換。

import torch
from torch import nn
from d2l import torch as d2l
""" 默認情況下,深度學習框架中的步幅與匯聚窗口的大小相同(窗口沒有重疊)"""# nn.Sequential 是一個容器,可按順序包裝一系列子模塊(如層、激活函數)。使得模型的構建變得更加簡潔
net = nn.Sequential(# 第一個二維卷積層,輸入通道是1(灰度圖像),輸出通道是6,卷積核大小5×5,圖像周圍加入兩層0填充# 使用sigmod激活函數nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),# 第一個平均池化層:用2x2的池化窗口,步長為2。經此池化操作后得6個14×14的特征圖nn.AvgPool2d(kernel_size=2, stride=2),# 這是第二個二維卷積層,輸入通道數為6(與第一個卷積層的輸出通道數相匹配),輸出通道數為16。卷積核的大小為5x5,沒有使用padding填充# 使用sigmod激活函數nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),# 第二個平均池化層:配置與第一層平均池化層相同。nn.AvgPool2d(kernel_size=2, stride=2),# 在將數據傳遞給全連接層之前,需要將多維的卷積和池化輸出展平為一維向量。以便傳給全連接層nn.Flatten(),# 經過前面的卷積和池化操作后,輸出16個5×5的特征圖# 全連接層,輸入特征的數量是16 * 5 * 5。輸出特征的數量是120。nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),# 全連接層,輸入特征數量120,輸出84nn.Linear(120, 84), nn.Sigmoid(),# 全連接層,輸入特征數量84,輸出10,對應Fashion-MNIST數據集的10個類別nn.Linear(84, 10))

下面,我們將一個大小為 28 × 28 28 \times 28 28×28的單通道(黑白)圖像通過LeNet。通過在每一層打印輸出的形狀,我們可以檢查模型,以確保其操作與我們期望的一致。

# 打印調試信息,檢查模型
# size=(1, 1, 28, 28):批次大小1,通道數1,形狀28*28 
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)# 遍歷了神經網絡 net 中的每一層
for layer in net:X = layer(X)# 打印該層的類型(Conv2d、AvgPool2d、Flatten、Linear)以及輸出張量的形狀print(layer.__class__.__name__,'output shape: \t',X.shape)# torch.Size([1, 6, 28, 28])中的1代表批次大小,6表示通道數

運行結果如下:

在這里插入圖片描述


4.2,定義模型評估函數

我們已經實現了LeNet,接下來讓我們看看LeNet在Fashion-MNIST數據集上的表現。

加載Fashion-MNIST圖片分類數據集

batch_size = 256  # 批量大小
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

定義評估函數計算預測準確率

def evaluate_accuracy_gpu(net, data_iter, device=None):"""使用GPU計算模型在數據集上的精度"""if isinstance(net, nn.Module):net.eval()  # 設置為評估模式if not device: # 若沒有指定device,則通過獲取模型參數的第一個元素的設備來確定應該使用的設備# net.parameters()返回模型的所有可學習參數(如權重和偏置)# next() 函數從迭代器中獲取第一個元素。通常是第一個層的權重或偏置# .device 是 PyTorch 張量(torch.Tensor)的一個屬性,表示該張量所在的備(如 GPU 或 CPU)# 例如,模型在 GPU 上運行,.device 的值可能是 device(type='cuda', index=0)device = next(iter(net.parameters())).device# 累加器記錄正確預測的數量和總預測的數量metric = d2l.Accumulator(2)with torch.no_grad():  # 評估模型時,不需要計算梯度for X, y in data_iter: # 每次迭代獲取一個數據批次X和對應的標簽yif isinstance(X, list):  # x為list,每個元素都挪到對應的設備X = [x.to(device) for x in X]else:   # x是tensor,只需要挪一次X = X.to(device)y = y.to(device)# accuracy可以計算出預測正確的樣本數量# y.numel()計算出樣本總數metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

4.3,定義訓練函數進行訓練

定義可以使用GPU訓練的訓練函數。

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU訓練模型"""def init_weights(m): # 初始化權重# 如果是全連接層或卷積層使用Xavier均勻初始化方法if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)# 模型移動到設備net.to(device)# 使用隨機梯度下降(SGD)優化器optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 使用交叉熵損失函數(nn.CrossEntropyLoss),適用于分類任務loss = nn.CrossEntropyLoss()# 實現動畫效果打印輸出animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 累加器記錄訓練損失之和,訓練準確率之和,樣本數metric = d2l.Accumulator(3)# 將模型設置為訓練模式,這會啟用Dropout等訓練時特有的操作net.train()for i, (X, y) in enumerate(train_iter): # 遍歷訓練數據集timer.start()optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 將輸入數據X和標簽y移動到指定的設備# 前向傳播,得到預測結果 y_haty_hat = net(X) """在 PyTorch 中,nn.CrossEntropyLoss 默認會對每個樣本的損失值進行平均,返回的是批次中所有樣本損失的平均值。"""l = loss(y_hat, y) # 計算損失# 進行反向傳播,計算梯度。l.backward()# 使用優化器更新模型參數。optimizer.step()with torch.no_grad():  # 禁用梯度計算# l * X.shape[0]是當前批次的總損失。樣本平均損失乘當前批次樣本數# d2l.accuracy(y_hat, y) 計算當前批次正確預測的樣本數# X.shape[0]代表當前批次的樣本數# 最終累加器累積了整個訓練集的總損失,預測正確的樣本總數和總樣本數metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()# 計算整個訓練集上每個訓練樣本的平均損失train_l = metric[0] / metric[2]# 計算訓練準確率train_acc = metric[1] / metric[2]# 更新動畫if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))# 在每個epoch結束時,計算測試集上的準確率test_acc = evaluate_accuracy_gpu(net, test_iter)# 更新動畫animator.add(epoch + 1, (None, None, test_acc))# 打印訓練損失、訓練準確率和測試準確率print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')# 打印訓練過程中每秒處理的樣本數(即訓練效率),以及訓練所使用的設備。print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

調用函數進行訓練

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

運行結果如下:

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/76516.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/76516.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/76516.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

代碼隨想錄第15天:(二叉樹)

一、二叉搜索樹的最小絕對差(Leetcode 530) 思路1 :中序遍歷將二叉樹轉化為有序數組,然后暴力求解。 class Solution:def __init__(self):# 初始化一個空的列表,用于保存樹的節點值self.vec []def traversal(self, r…

計算機操作系統-【死鎖】

文章目錄 一、什么是死鎖?死鎖產生的原因?死鎖產生的必要條件?互斥條件請求并保持不可剝奪環路等待 二、處理死鎖的基本方法死鎖的預防摒棄請求和保持條件摒棄不可剝奪條件摒棄環路等待條件 死鎖的避免銀行家算法案例 提示:以下是…

vue拓撲圖組件

vue拓撲圖組件 介紹技術棧功能特性快速開始安裝依賴開發調試構建部署 使用示例演示截圖組件源碼 介紹 一個基于 Vue3 的拓撲圖組件,具有以下特點: 1.基于 vue-flow 實現,提供流暢的拓撲圖展示體驗 2.支持傳入 JSON 對象自動生成拓撲結構 3.自…

go 通過匯編分析函數傳參與返回值機制

文章目錄 概要一、前置知識二、匯編分析2.1、示例2.2、匯編2.2.1、 寄存器傳值的匯編2.2.2、 棧內存傳值的匯編 三、拓展3.1 了解go中的Duff’s Device3.2 go tool compile3.2 call 0x46dc70 & call 0x46dfda 概要 在上一篇文章中,我們研究了go函數調用時的棧布…

python-1. 找單獨的數

問題描述 在一個班級中,每位同學都拿到了一張卡片,上面有一個整數。有趣的是,除了一個數字之外,所有的數字都恰好出現了兩次。現在需要你幫助班長小C快速找到那個拿了獨特數字卡片的同學手上的數字是什么。 要求: 設…

算法學習C++需注意的基本知識

文章目錄 01_算法中C需注意的基本知識cmath頭文件一些計算符ASCII碼表數據類型長度運算符cout固定輸出格式浮點數的比較max排序自定義類型字符的大小寫轉換與判斷判斷字符是數字還是字母 02_數據結構需要注意的內容1.stringgetline函數的使用string::findsubstr截取字符串strin…

從零開始寫android 的智能指針

Android中定義了兩種智能指針類型,一種是強指針sp(strong pointer),源碼中的位置在system/core/include/utils/StrongPointer.h。另外一種是弱指針(weak pointer)。其實稱之為強引用和弱引用更合適一些。強…

【leetcode hot 100 152】乘積最大子數組

錯誤解法:db[i]表示以i結尾的最大的非空連續,動態規劃:dp[i] Math.max(nums[i], nums[i] * dp[i - 1]); class Solution {public int maxProduct(int[] nums) {int n nums.length;int[] dp new int[n]; // db[i]表示以i結尾的最大的非空連…

圖論整理復習

回溯: 模板: void backtracking(參數) {if (終止條件) {存放結果;return;}for (選擇:本層集合中元素(樹中節點孩子的數量就是集合的大小)) {處理節點;backtracking(路徑,選擇列表); // 遞歸回溯&#xff…

uniapp離線打包提示未添加videoplayer模塊

uniapp中使用到video標簽,但是離線打包放到安卓工程中,運行到真機中時提示如下: 解決方案: 1、把media-release.aar、weex_videoplayer-release.aar放到工程的libs目錄下; 文檔:https://nativesupport.dcloud.net.cn/…

打包構建替換App名稱

方案適用背景 一套代碼出多個安裝包,且安裝包的應用名稱、圖標都不一樣考慮三語名稱問題 通過 Gradle 腳本實現 gradle.properties 里面定義標識來區分應用,如下文里的 APP_TYPEAAA 、APP_TYPEBBB// 定義 groovy 替換方法 def replaceAppName(String …

DrissionPage移動端自動化:從H5到原生App的跨界測試

一、移動端自動化測試的挑戰與機遇 移動端測試面臨多維度挑戰: 設備碎片化:Android/iOS版本、屏幕分辨率差異 混合應用架構:H5頁面與原生組件的深度耦合 交互復雜性:多點觸控、手勢操作、傳感器模擬 性能監控:內存…

達夢數據庫用函數實現身份證合法校驗

達夢數據庫用函數實現身份證合法校驗 拿走不謝~ CREATE OR REPLACE FUNCTION CHECK_IDCARD(A_SFZ IN VARCHAR2) RETURN VARCHAR2 IS TYPE WEIGHT_TAB IS VARRAY(17) OF NUMBER; TYPE CHECK_TAB IS VARRAY(11) OF CHAR; WEIGHT_FACTOR WEIGHT_TAB : WEIGHT_TAB(7,9,10,5,8,4,…

3dmax的python通過普通的攝像頭動捕表情

1、安裝python 進入cdm,打python要能顯示版本號 >>>(進入python提示符模式) import sys sys.path顯示python的安裝路徑, 進入到python.exe的路徑 在python目錄中安裝(ctrlz退出python交互模式) 2、pip install mediapipe…

國產Linux統信安裝mysql8教程步驟

系統環境 uname -a Linux FlencherHU-PC 6.12.9-amd64-desktop-rolling #23.01.01.18 SMP PREEMPT_DYNAMIC Fri Jan 10 18:29:31 CST 2025 x86_64 GNU/Linux下載離線安裝包 瀏覽器下載https://downloads.mysql.com/archives/get/p/23/file/mysql-test-8.0.33-linux-glibc2.28…

Vite 權限繞過導致任意文件讀取(CVE-2025-32395)(附腳本)

免責申明: 本文所描述的漏洞及其復現步驟僅供網絡安全研究與教育目的使用。任何人不得將本文提供的信息用于非法目的或未經授權的系統測試。作者不對任何由于使用本文信息而導致的直接或間接損害承擔責任。如涉及侵權,請及時與我們聯系,我們將盡快處理并刪除相關內容。 前言…

poi-tl

官網地址 Poi-tl Documentationword模板引擎https://deepoove.com/poi-tl github 地址 https://github.com/Sayi/poi-tl/tree/master gitcode 加速地址 GitCode - 全球開發者的開源社區,開源代碼托管平臺GitCode是面向全球開發者的開源社區,包括原創博客,開源代碼托管,代碼…

操作系統 4.1-I/O與顯示器

外設工作起來 操作系統讓外設工作的基本原理和過程,具體來說,它概括了以下幾個關鍵步驟: 發出指令:操作系統通過向控制器中的寄存器發送指令來啟動外設的工作。這些指令通常是通過I/O指令(如out指令)來實現…

琥珀掃描 2.0.5.0 | 文檔處理全能助手,支持掃描、文字提取及表格識別

琥珀掃描是一款功能強大的文檔處理應用程序。它不僅僅支持基本的文檔掃描功能,還涵蓋了文字提取、證件掃描、表格識別等多種實用功能。無論是學生、職員還是教師,都能從中找到適合自己的功能。該應用支持拍照生成電子件,并能自動矯正文檔邊緣…

jQuery UI 小部件方法調用詳解

jQuery UI 小部件方法調用詳解 引言 jQuery UI 是一個基于 jQuery 的用戶界面和交互庫,它提供了一系列小部件,如按鈕、對話框、進度條等,這些小部件極大地豐富了網頁的交互性和用戶體驗。本文將詳細介紹 jQuery UI 中小部件的方法調用,幫助開發者更好地理解和應用這些小部…