用卷積神經網絡 (CNN) 實現 MNIST 手寫數字識別

在深度學習領域,MNIST 手寫數字識別是經典的入門級項目,就像編程世界里的 “Hello, World”。卷積神經網絡(Convolutional Neural Network,CNN)作為處理圖像數據的強大工具,在該任務中展現出卓越的性能。本文將結合具體的 PyTorch 代碼,詳細解析如何利用 CNN 實現 MNIST 手寫數字識別,帶大家從代碼實踐深入理解背后的技術原理。

一、數據準備:加載與預處理 MNIST 數據集

MNIST 數據集包含 6 萬張訓練圖像和 1 萬張測試圖像,涵蓋 0 - 9 這十個數字的手寫體。我們借助torchvision庫中的datasets.MNIST函數來加載數據,具體代碼如下:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortraining_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)

上述代碼中,root="data"指定數據集的存儲路徑;train=True表示加載訓練集,train=False用于加載測試集;download=True確保本地無數據集時自動下載;transform=ToTensor()將圖像數據轉換為 PyTorch 張量格式,并把像素值從 0 - 255 歸一化到 0 - 1 區間,便于后續處理。

為直觀感受數據,我們用matplotlib庫繪制 9 張訓練圖像及其標簽:

from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i + 59000]figure.add_subplot(3, 3, i + 1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")a = img.squeeze()
plt.show()

完成數據加載后,使用DataLoader將數據封裝成批次,方便模型訓練和測試:

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

batch_size=64意味著每次訓練或測試,模型會同時處理 64 個樣本,能提高計算效率和訓練穩定性。

二、模型構建:搭建卷積神經網絡架構

我們定義一個名為CNN的類,繼承自nn.Module,用于構建卷積神經網絡:

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,stride=1,padding=1,),nn.ReLU(),nn.MaxPool2d(2))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 3, 1, 1),nn.ReLU(),nn.MaxPool2d(2),)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1),nn.ReLU(),)self.out = nn.Linear(64 * 7 * 7, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)output = self.out(x)return output

  • 卷積層(nn.Conv2d:在conv1conv2conv3中,通過卷積層提取圖像特征。例如conv1中的nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)in_channels=1表示輸入圖像為單通道灰度圖,out_channels=16表示輸出 16 個特征圖,kernel_size=3指定 3×3 的卷積核,stride=1是步長,padding=1用于保持圖像尺寸不變。
  • 激活函數(nn.ReLU:緊跟在卷積層之后,為模型引入非線性,幫助模型學習復雜的模式。
  • 池化層(nn.MaxPool2d:通過下采樣操作,如nn.MaxPool2d(2)將圖像尺寸減半,減少數據量和模型參數,同時保留重要特征,防止過擬合。
  • 全連接層(nn.Linearself.out = nn.Linear(64 * 7 * 7, 10)將卷積層輸出的特征圖展平后連接到全連接層,輸出 10 個神經元對應 0 - 9 十個數字類別,完成最終分類。

最后,將模型移動到合適的計算設備(GPU、MPS 或 CPU)上:

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model = CNN().to(device)
print(model)

三、模型訓練與測試:優化與評估

3.1 訓練函數

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 == 0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1

在訓練函數中,model.train()將模型設為訓練模式。遍歷數據加載器,將每一批數據和標簽移至指定設備,前向傳播計算預測值,通過交叉熵損失函數nn.CrossEntropyLoss()計算損失,optimizer.zero_grad()清空梯度,loss.backward()反向傳播計算梯度,optimizer.step()更新模型參數,每 100 個批次打印一次損失值。

3.2 測試函數

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")return test_loss, correct

測試函數中,model.eval()將模型設為評估模式,關閉如 Dropout 等訓練時的操作。在with torch.no_grad()下遍歷測試數據,計算測試損失和正確預測的樣本數,最后計算平均損失和準確率并輸出。

3.3 執行訓練與測試

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 10
for t in range(epochs):print(f"Epoch {t + 1}\n--------------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

我們選用交叉熵損失函數和 Adam 優化器,學習率設為 0.01,通過 10 個訓練周期不斷優化模型,訓練完成后在測試集上評估模型性能,得到最終的準確率和平均損失。

四、總結與展望

通過上述代碼實踐,我們成功利用卷積神經網絡實現了 MNIST 手寫數字識別。從數據加載、模型構建到訓練測試,每個環節都緊密相連,展示了 CNN 在圖像識別任務中的強大能力。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/81910.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/81910.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/81910.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

從 MDM 到 Data Fabric:下一代數據架構如何釋放 AI 潛能

從 MDM 到 Data Fabric:下一代數據架構如何釋放 AI 潛能 —— 傳統治理與新興架構的范式變革與協同進化 引言:AI 規模化落地的數據困境 在人工智能技術快速發展的今天,企業對 AI 的期望已從 “單點實驗” 轉向 “規模化落地”。然而&#…

蒼穹外賣部署到云服務器使用Docker

部署前端 1.創建nginx鏡像 docker pull nginx 2.宿主機(云服務器)創建掛載目錄和文件 最好手動創建 而不是通過docker run創建,否則nginx.conf 默認會被創建為文件夾 nginx.conf 和html可以直接從黑馬給的資料里導入 3.運行nginx容器&am…

C++ 滲透 數據結構中的二叉搜索樹

歡迎來到干貨小倉庫 "沙漠盡頭必是綠洲。" --面對技術難題時,堅持終會看到希望。 1.二叉搜索樹的概念 二叉搜索樹又稱二叉排序樹,它或者是一顆空樹,或者是具有以下性質的二叉樹: a、若它的左子樹不為空,則…

實現滑動選擇器從離散型的數組中選擇

1.使用原生的input 詳細代碼如下&#xff1a; <template><div class"slider-container"><!-- 滑動條 --><inputtype"range"v-model.number"sliderIndex":min"0":max"customValues.length - 1"step&qu…

ARM尋址方式

尋址方式指的是確定操作數位置的方式。 尋址方式&#xff1a; 立即數尋址 直接尋址&#xff08;絕對尋址&#xff09;&#xff0c;ARM不支持這種尋址方式&#xff0c;但所有CISC處理器都支持 寄存器間接尋址 3種尋址方式總結如下&#xff1a; 助記符 RTL格式 描述 ADD r0,r1…

學苑教育雜志學苑教育雜志社學苑教育編輯部2025年第9期目錄

專題研究 核心素養下合作學習在初中數學中的應用 鄭鐵洪; 4-6 教育管理 小學班級管理應用賞識教育的策略研究 芮望; 7-9 課堂教學 小學數學概念教學的實踐策略 劉淑萍; 10-12 “減負提質”下小學五年級語文課堂情境教學 王利;梁巖; 13-15 小練筆的美麗轉身…

關于類型轉換的細節(隱式類型轉換的臨時變量和理解const權限)

文章目錄 前言類型轉換的細節1. 類型轉換的臨時變量細節二&#xff1a;const與指針 前言 關于類型轉換的細節&#xff0c;這里小編和大家探討兩個方面&#xff1a; 關于類型轉化的臨時變量的問題const關鍵字的權限問題 — 即修改權限。小編或通過一道例題&#xff08;配圖&am…

技術對暴力的削弱

信息時代的大政治分析&#xff1a;效率對暴力的顛覆 一、工業時代勒索邏輯的終結 工廠罷工的消亡 1930年代通用汽車罷工依賴工廠的物理集中、高資本投入和流水線脆弱性&#xff0c;通過暴力癱瘓生產實現勒索。 信息時代企業分散化、資產虛擬化&#xff08;如軟件公司可攜帶代碼…

深入理解分布式鎖——以Redis為例

一、分布式鎖簡介 1、什么是分布式鎖 分布式鎖是一種在分布式系統環境下&#xff0c;通過多個節點對共享資源進行訪問控制的一種同步機制。它的主要目的是防止多個節點同時操作同一份數據&#xff0c;從而避免數據的不一致性。 線程鎖&#xff1a; 也被稱為互斥鎖&#xff08…

yolo訓練用的數據集的數據結構

Football Players Detection using YOLOV11 可以在roboflow上標注 Sign in to Roboflow 訓練數據集只看這個data.yaml 里面是train的image地址和classnames 每個image一一對應一個label 第一個位是分類&#xff0c;0是classnames[0]對應的物體&#xff0c;現在是cuboid &…

Redis 使用及命令操作

文章目錄 一、基本命令二、redis 設置鍵的生存時間或過期時間三、SortSet 排序集合類型操作四、查看中文五、密碼設置和查看密碼的方法六、關于 Redis 的 database 相關基礎七、查看內存占用 一、基本命令 # 查看版本 redis-cli --version 結果&#xff1a;redis-cli 8.0.0red…

Java大師成長計劃之第13天:Java中的響應式編程

&#x1f4e2; 友情提示&#xff1a; 本文由銀河易創AI&#xff08;https://ai.eaigx.com&#xff09;平臺gpt-4o-mini模型輔助創作完成&#xff0c;旨在提供靈感參考與技術分享&#xff0c;文中關鍵數據、代碼與結論建議通過官方渠道驗證。 隨著現代應用程序的復雜性增加&…

華為私有協議Hybrid

實驗top圖 理論環節 1. 基本概念 Hybrid接口&#xff1a; 支持同時處理多個VLAN流量&#xff0c;且能針對不同VLAN配置是否攜帶標簽&#xff08;Tagged/Untagged&#xff09;。 核心特性&#xff1a; 靈活控制數據幀的標簽處理方式&#xff0c;適用于復雜網絡場景。 2. 工作…

K8s 常用命令、對象名稱縮寫匯總

K8s 常用命令、對象名稱縮寫匯總 前言 在之前的文章中已經陸續介紹過 Kubernetes 的部分命令&#xff0c;本文將專題介紹 Kubernetes 的常用命令&#xff0c;處理日常工作基本夠用了。 集群相關 1、查看集群信息 kubectl cluster-info # 輸出信息Kubernetes master is run…

【HDLBits刷題】Verilog Language——1.Basics

目錄 一、題目與題解 1.Simple wire&#xff08;簡單導線&#xff09; 2.Four wires&#xff08;4線&#xff09; 3.Inverter&#xff08;逆變器&#xff08;非門&#xff09;&#xff09; 4.AND gate &#xff08;與門&#xff09; 5. NOR gate &#xff08;或非門&am…

C語言|遞歸求n!

C語言| 函數的遞歸調用 【遞歸求n!】 0!1; 1!1 n! n*(n-1)*(n-2)*(n-3)*...*3*2*1; 【分析過程】 定義一個求n&#xff01;的函數&#xff0c;主函數直接調用 [ Factorial()函數 ] 1 用if語句去實現&#xff0c;把求n!的情況列舉出來 2 if條件有3個&#xff0c;n<0; n0||n…

Android第四次面試總結之Java基礎篇(補充)

一、設計原則高頻面試題&#xff08;附大廠真題解析&#xff09; 1. 單一職責原則&#xff08;SRP&#xff09;在 Android 開發中的應用&#xff08;字節跳動真題&#xff09; 真題&#xff1a;“你在項目中如何體現單一職責原則&#xff1f;舉例說明。”考點&#xff1a;結合…

OpenHarmony GPIO應用開發-LED

學習于&#xff1a; https://docs.openharmony.cn/pages/v5.0/zh-cn/device-dev/driver/driver-platform-gpio-develop.md https://docs.openharmony.cn/pages/v5.0/zh-cn/device-dev/driver/driver-platform-gpio-des.md 通過OpenHarmony官方文檔指導可獲知&#xff1a;芯片廠…

XILINX原語之——xpm_fifo_async(異步FIFO靈活設置位寬、深度)

目錄 一、"fwft"模式&#xff08;First-Word-Fall-Through read mode&#xff09; 1、寫FIFO 2、讀FIFO 二、"std"模式&#xff08;standard read mode&#xff09; 1、寫FIFO 2、讀FIFO 調用方式和xpm_fifo_sync基本一致&#xff1a; XILINX原語之…

系統學習算法:動態規劃(斐波那契+路徑問題)

題目一&#xff1a; 思路&#xff1a; 作為動態規劃的第一道題&#xff0c;這個題很有代表性且很簡單&#xff0c;適合入門 先理解題意&#xff0c;很簡單&#xff0c;就是斐波那契數列的加強版&#xff0c;從前兩個數變為前三個數 算法原理&#xff1a; 這五步可以說是所有…