通過代碼認識 CNN:用 PyTorch 實現卷積神經網絡識別手寫數字

目錄

一、從代碼看 CNN 的核心組件

二、準備工作:庫導入與數據加載

三、核心:用代碼實現 CNN 并理解各層作用

1.網絡層結構

2.重點理解:卷積層參數與輸出尺寸計算

四、訓練 CNN

五、結果分析


卷積神經網絡(CNN)是計算機視覺領域的核心模型,相比全連接網絡,它能更高效地提取圖像特征。本文不空談理論,而是通過 PyTorch 代碼實現一個完整的 CNN 模型,帶你在實戰中理解卷積、池化等核心概念,掌握 CNN 的工作原理。


一、從代碼看 CNN 的核心組件

在實現模型前,先明確 CNN 的三個核心層 —— 這些是區別于全連接網絡的關鍵,后續代碼會逐一對應:

  1. 卷積層(Conv2d):通過滑動窗口提取局部特征(如邊緣、紋理);
  2. 激活層(ReLU):引入非線性,讓模型學習復雜模式;
  3. 池化層(MaxPool2d):降低特征圖尺寸,減少計算量,增強魯棒性。

我們將用這些組件構建一個識別 MNIST 手寫數字的 CNN 模型,邊寫代碼邊解釋原理。


二、準備工作:庫導入與數據加載

首先導入必要的庫,加載 MNIST 數據集(28×28 的手寫數字圖片):

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# 加載MNIST數據集
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor()  # 轉為張量,形狀為[1,28,28]
)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)# 按批次加載數據(每批64張圖)
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)# 查看數據形狀([批次, 通道, 高度, 寬度])
for X, y in test_dataloader:print(f"數據形狀: {X.shape}")  # 輸出:torch.Size([64, 1, 28, 28])break

關鍵說明:MNIST 圖片是單通道(灰度圖),所以輸入形狀為[N,1,28,28](N 為批次大小),這會影響后續卷積層的參數設置。


三、核心:用代碼實現 CNN 并理解各層作用

1.網絡層結構

我們構建一個包含 4 個卷積塊的 CNN 模型,每個卷積塊由 “卷積層 + 激活層” 組成,部分塊后添加池化層。通過代碼注釋詳細說明每層的作用和參數含義。

# 自動選擇設備(優先GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用設備: {device}")class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 第一個卷積塊:卷積層+激活層+池化層self.conv1 = nn.Sequential(# 卷積層:輸入1通道,輸出16通道,卷積核5×5,步長1,填充2nn.Conv2d(in_channels=1,    # 輸入通道數(灰度圖為1)out_channels=16,  # 輸出通道數(16個不同的卷積核)kernel_size=5,    # 卷積核大小5×5stride=1,         # 步長1(每次滑動1個像素)padding=2         # 填充2(保持輸出尺寸與輸入一致:28→28)),nn.ReLU(),  # 激活層:引入非線性,過濾負值# 池化層:2×2窗口,步長2,輸出尺寸變為14×14(28/2)nn.MaxPool2d(kernel_size=2, stride=2))# 第二個卷積塊:卷積層+激活層(無池化)self.conv2 = nn.Sequential(# 輸入16通道(上一層輸出),輸出32通道,卷積核3×3nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),nn.ReLU()  # 輸出尺寸保持14×14)# 第三個卷積塊:卷積層+激活層+池化層self.conv3 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2)  # 輸出尺寸變為7×7(14/2))# 第四個卷積塊:卷積層+激活層(無池化)self.conv4 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU()  # 輸出尺寸保持7×7)# 全連接層:將特征圖轉為10個類別(0-9)self.fc = nn.Linear(128 * 7 * 7, 10)  # 128通道×7×7尺寸def forward(self, x):# 前向傳播:數據依次經過各層x = self.conv1(x)  # 輸出形狀:[N,16,14,14]x = self.conv2(x)  # 輸出形狀:[N,32,14,14]x = self.conv3(x)  # 輸出形狀:[N,64,7,7]x = self.conv4(x)  # 輸出形狀:[N,128,7,7]x = x.view(x.size(0), -1)  # 展平:[N,128×7×7]x = self.fc(x)     # 輸出形狀:[N,10](10個類別分數)return x# 創建模型并移動到設備
model = CNN().to(device)
print("CNN模型結構:")
print(model)

2.重點理解:卷積層參數與輸出尺寸計算

以第一個卷積層為例,輸入是[64,1,28,28](64 張圖,1 通道,28×28),經過kernel_size=5, padding=2, stride=1的卷積后,輸出尺寸計算公式:

輸出尺寸 = (輸入尺寸 - 卷積核大小 + 2×填充) / 步長 + 1
即:(28 - 5 + 2×2)/1 + 1 = 28

所以輸出仍為 28×28,再經 2×2 池化后變為 14×14—— 這就是卷積層如何在保留特征的同時控制尺寸的核心邏輯。


四、訓練 CNN

CNN 的訓練流程和全連接網絡類似,我們將訓練輪次調整為 10 輪,既能保證模型收斂,又能節省訓練時間。定義訓練和測試函數如下:

# 損失函數(多分類用交叉熵)
loss_fn = nn.CrossEntropyLoss()
# 優化器(Adam,學習率0.0001)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)# 訓練函數
def train(dataloader, model, loss_fn, optimizer):model.train()  # 開啟訓練模式batch_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)# 前向傳播:計算預測pred = model(X)loss = loss_fn(pred, y)# 反向傳播:更新參數optimizer.zero_grad()  # 梯度清零loss.backward()        # 計算梯度optimizer.step()       # 更新參數# 每100批次打印一次損失if batch_num % 100 == 1:print(f"批次 {batch_num} | 損失: {loss.item():.4f}")batch_num += 1# 測試函數
def test(dataloader, model, loss_fn):model.eval()  # 開啟測試模式size = len(dataloader.dataset)num_batches = len(dataloader)correct = 0test_loss = 0with torch.no_grad():  # 禁用梯度計算for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 計算準確率和平均損失test_loss /= num_batchescorrect /= sizeprint(f"\n測試集:準確率 {100*correct:.2f}% | 平均損失 {test_loss:.4f}\n")# 開始訓練(10輪)
print("="*50)
print("開始訓練CNN模型(10輪)")
print("="*50)
for epoch in range(10):print(f"輪次 {epoch+1}/10")print("-"*30)train(train_dataloader, model, loss_fn, optimizer)# 每2輪測試一次if (epoch+1) % 2 == 0:test(test_dataloader, model, loss_fn)
print("="*50)
print("訓練結束")

五、結果分析

輪次 10/10
------------------------------
批次 1 | 損失: 0.0002
批次 101 | 損失: 0.0000
批次 201 | 損失: 0.0015
批次 301 | 損失: 0.0190
批次 401 | 損失: 0.0003
批次 501 | 損失: 0.0008
批次 601 | 損失: 0.0001
批次 701 | 損失: 0.0065
批次 801 | 損失: 0.0019
批次 901 | 損失: 0.0310測試集:準確率 99.17% | 平均損失 0.0355==================================================
訓練結束

即使只訓練 10 輪,CNN 在測試集上的準確率通常也能達到99% 以上,明顯高于同輪次的全連接網絡。這體現了 CNN 的高效性,原因在于:

  1. 局部感受野:卷積層通過滑動窗口只關注局部像素,更符合圖像的局部相關性;
  2. 權值共享:同一通道的卷積核參數共享,大幅減少參數數量(全連接層 784→128 需要近 10 萬個參數,而 5×5 的卷積層 1→16 僅需 400 個參數);
  3. 池化層:通過下采樣保留關鍵特征,增強模型對圖像位移、縮放的魯棒性。

這些特性讓 CNN 在較少的訓練輪次下就能達到較好的性能。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/94892.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/94892.shtml
英文地址,請注明出處:http://en.pswp.cn/web/94892.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

基于SpringBoot和Thymeleaf開發的英語學習網站

角色: 管理員、用戶 技術: SpringBoot、Thymeleaf、MySQL、MyBatis、jQuery、Bootstrap 核心功能: 這是一個基于SpringBoot的英語學習平臺,旨在為用戶提供英語學習資料(如書籍、聽力、單詞)的管理和學習功能…

把 AI 塞進「智能跳繩」——基于 MEMS 傳感器的零樣本卡路里估算器

標簽:MEMS、卡路里估算、零樣本、智能跳繩、TinyML、RISC-V、低功耗、邊緣 AI ---- 1. 背景:為什么跳繩要「算卡路里」? 全球 1.5 億人把跳繩當日常運動,卻苦于: ? 機械計數器誤差大; ? 手機 App 需聯網…

礦用隨鉆測量現場應用中,最新的MEMS陀螺定向短節的優勢是什么?

在當代礦業開發向深部復雜地層進軍的過程中,隨鉆測量技術是控制鉆井定向打孔質量和提升長距離鉆探中靶精度的核心手段,煤礦井下定向鉆孔、瓦斯抽放孔、探放水孔等關鍵工程面臨著一系列特殊挑戰:強磁干擾、劇烈振動、空間受限等惡劣條件。最新…

Spring Boot 使用 RestTemplate 調用 HTTPS 接口時報錯:PKIX path building failed 解決方案

在使用 Spring Boot RestTemplate 調用 HTTPS 接口時,很多同學會遇到類似下面的報錯:javax.net.ssl.SSLHandshakeException: PKIX path building failed: sun.security.provider.certpath.SunCertPathBuilderException: unable to find valid certif…

【C語言入門級教學】sizeof和strlen的對?

1.sizeof和strlen的對? 1.1 sizeof sizeof 計算變量所占內存空間??的,單位是字節,如果操作數是類型的話,計算的是使?類型創建的變量所占內存空間的??。 sizeof 只關注占?內存空間的??,不在乎內存中存放什么數據。 ?如&a…

線程安全及死鎖問題

系列文章目錄 初步了解多線程-CSDN博客 目錄 系列文章目錄 前言 一、線程安全 1. 線程安全問題 2. 問題原因分析 3. 問題解決辦法 4. synchronized 的優勢 1. 自動解鎖 2. 是可重入鎖 二、死鎖 1. 一個線程一把鎖 2. 兩個線程兩把鎖 3. N 個線程 M 把鎖 4. 死鎖…

2025年8月無人駕駛技術現有技術報告

第1章 引言 無人駕駛技術作為21世紀交通運輸領域最具革命性的技術創新之一,正在深刻地改變著人類的出行方式和生活模式。進入2025年,隨著人工智能、5G通信、高精度傳感器等關鍵技術的快速發展與成熟,無人駕駛技術已從實驗室的概念驗證階段逐…

CETOL 6σ 助力康美醫療(CONMED Corporation)顯著提升一次性穿刺器產品合格率

概述 康美醫療 (CONMED Corporation)將 Sigmetrix 的 CETOL 6σ 公差分析軟件應用于一次性穿刺器的結構優化。該裝置是微創外科技術的一次早期突破。在設計階段,團隊發現“測量臨界間隙”存在尺寸偏差、超出預期范圍,可能在手術中造成患者皮膚損傷&…

LaunchScreen是啥?AppDelegate是啥?SceneDelegate是啥?ContentView又是啥?Main.storyboard是啥?

雖然我很想挑戰一下swiftui,但是精力真的是有限,把精力分散開不是一個很好的選擇,so swiftui淺嘗則止了,目前的精力在html上面。 AppDelegate todo SceneDelegate todo ContentView 最明顯的就是這個,當編輯的時候,頁面…

垃圾回收機制(GC)

目錄 垃圾回收機制 引用計數法 可達性分析算法 垃圾回收算法 標記清除算法 復制算法 標記壓縮算法 JVM中一次完整的GC(分代收集算法) 在新生代中 在老年代中 空間分配擔保原則 對象從新生代進入老年代的幾種情況? Young GC 和 Full GC 垃…

DNS域名系統

DNS域名系統一、什么是DNS?二、DNS的域名層級1. 根域2. 頂級域3. 二級域4. 三級域(子域)5. 主機名三、DNS服務器的分類四、DNS的解析過程五、DNS的記錄類型六、FQDN(完全限定域名)一、什么是DNS? DNS(Domain Name S…

虛擬內存和虛擬頁面

虛擬內存虛擬內存是現代操作系統提供的一種內存管理機制,它允許程序訪問比實際物理內存更大的地址空間。虛擬內存通過將程序的地址空間劃分為多個固定大小的塊(稱為頁面),并將這些頁面映射到物理內存或磁盤上的頁面文件中&#xf…

【2025年電賽E題】基于k230的矩形框識別鎖定1

文章目錄 概要 整體架構流程 技術名詞解釋 技術細節 1. 多閾值適配與目標識別邏輯 2. 動態ROI與狀態管理機制 3. 數據平滑與偏差計算 4. 硬件適配與UART通信 小結 靜態矩形框識別 動態矩形框追蹤 概要 本文分析的代碼是基于立創廬山派K230CanMV開發板的目標追蹤系統實現,主要…

c語言中的數組可以用int a[3]來創建。寫一次int就可以了,而java中要聲明兩次int類型像這樣:int[] arr = new int[3];

C 語言數組只需寫一次int,而 Java 需兩次int相關聲明,核心原因是兩種語言的數組本質定義、類型系統設計和內存管理邏輯完全不同,具體可拆解為兩點核心差異:一、C 語言:數組是 “內存塊的類型綁定”,一次聲明…

深度學習——詳細教學:神經元、神經網絡、感知機、激活函數、損失函數、優化算法(梯度下降)

神經網絡實戰: 深度學習——神經網絡簡單實踐(在乳腺癌數據集上的小型二分類示例)-CSDN博客https://blog.csdn.net/2302_78022640/article/details/150779819?spm1001.2014.3001.5502 深度學習——神經網絡(PyTorch 實現 MNIST…

Ubuntu 軟件安裝的五種方法

1、App Store 安裝 Ubuntu 里面有 一個App叫 “Ubuntu軟件” 2、Sudo apt-get install 安裝法 注意 使用apt工具安裝軟件,需要sudo,也就是root權限 例子 apt -get install git 會提示查看是否以root用戶運行,install-安裝sudo a…

Day15 (前端:JavaScript基礎階段)

接續上文:Day14——JavaScript 核心知識全解析:變量、類型與操作符深度探秘-CSDN博客 點關注不迷路喲。你的點贊、收藏,一鍵三連,是我持續更新的動力喲!!! 主頁:一位搞嵌入式的 genius-CSDN博…

在線旅游及旅行管理系統項目SQL注入

1.前言 之前在網上隨便逛逛的時候,發現一個有各種各樣的PHP項目的管理系統,隨便點進一個查看,發現還把mysql版本都寫出來了,而且還是PHP語言。 https://itsourcecode.com/free-projects/php-project/online-tours-and-travels-m…

Java網絡編程(UDP, TCP, HTTP)

1. OSI 七層網絡模型層級名稱核心功能協議示例數據單元7應用層提供用戶接口和網絡服務HTTP, FTP, SMTP, DNS報文6表示層數據格式轉換、加密/解密、壓縮/解壓SSL, JPEG, MPEG數據流5會話層建立、管理和終止會話連接NetBIOS, RPC會話數據4傳輸層端到端可靠傳輸、流量控制、差錯校…

【P2P】P2P主要技術及RELAY服務1:python實現

P2P 技術 P2P(點對點)網絡的核心是去中心化的網絡拓撲和通信協議。DP的應用相對較少,但可能出現在: 路由優化:在一些復雜的P2P網絡中,一個節點需要向另一個節點發送消息。為了找到一條延遲最低或跳數最少的路徑,可能會用到類似最短路徑的算法,而這類算法(如Bellman-F…