《從卷積核到數字解碼:CNN 手寫數字識別實戰解析》

文章目錄

  • 一、手寫數字識別的本質與挑戰
  • 二、使用步驟
    • 1.導入torch庫以及與視覺相關的torchvision庫
    • 2.下載datasets自帶的手寫數字的數據集到本地
  • 三、完整代碼展示


一、手寫數字識別的本質與挑戰

手寫數字識別的核心是:從二維像素矩陣中提取具有判別性的特征,區分 0-9 這 10 個類別。其難點包括:
手寫風格多樣性:不同人書寫的數字(如 “3” 可能有開口或閉口)、筆畫粗細、傾斜角度差異大。
位置與尺度變化:數字在圖像中的位置(偏上 / 偏下)、大小可能不一致(如 MNIST 數據集中數字存在輕微平移)。
噪聲與形變:實際場景中可能存在筆畫斷裂、污漬等噪聲,或掃描時的圖像模糊。
傳統方法(如 SVM、KNN)依賴人工設計特征(如 HOG、SIFT、幾何矩),需專家經驗且泛化能力有限;而 CNN 通過自動化特征學習 + 結構化歸納偏置,天然適配這些挑戰。

二、使用步驟

1.導入torch庫以及與視覺相關的torchvision庫

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

2.下載datasets自帶的手寫數字的數據集到本地

"""下載測試數據集(包含圖片和標簽)"""training_data=datasets.MNIST(root='../data',train=True,download=True,transform=ToTensor()
)"""下載測試數據集(包含訓練圖片+標簽)"""test_data=datasets.MNIST(root='../data',train=False,download=True,transform=ToTensor()
)

3、將下載的數據集打包

train_dataloder=DataLoader(training_data,batch_size=64)
test_dataloder=DataLoader(test_data,batch_size=64)

4、指定數據訓練的設備

device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"print(f"{device}device")

5、定義神經網絡框架和前向傳播

class NeurakNetwork(nn.Module):     #通過調用類的形式來使用神經網絡,神經網絡的模型nn.moudledef __init__(self):super().__init__()  #繼承父類的初始化self.flatten=nn.Flatten()   #將二位數據展成一維數據self.hidden1=nn.Linear(28*28,128)   #第一個參數時有多少個神經元傳進來,第二個參數是有多少個數據傳出去self.hidden2=nn.Linear(128,256)self.out=nn.Linear(256,10)      #輸出必須與標簽類型相同,輸入必須是上一層神經元的個數def forward(self,x):    #前向傳播,指明數據的流向,使神經網絡連接起來,函數名稱不能修改x=self.flatten(x)x=self.hidden1(x)x=torch.relu(x)     #激活函數,torch使用relu或者tanh函數作為激活函數x=self.hidden2(x)x=torch.relu(x)x=self.out(x)return x

6、初始化神經網絡并將模型加載到設備中

model = NeurakNetwork().to(device)      #將剛剛定義的模型傳入到GPU中

7、定義模型訓練的函數

def train(dataloader,model,loss_fn,optimizer):model.train()       #告訴模型,即將開始訓練,其中的w進行隨機化操作,已經更新w,在訓練過程中,w會被修改"""pytorch提供兩種方式來切換訓練和測試的模式,分別是model.train()和model.eval()一般用法是,在訓練開始之前寫上model.train(),在測試時寫上model.eval()"""batch_size_num=1for X,y in dataloader:      #其中batch為每一個數據的編號X,y=X.to(device),y.to(device)   #將訓練數據集和標簽傳入cpu和gpupred=model.forward(X)loss=loss_fn(pred,y)    #通過交叉熵損失函數計算loss#Backpropagation  進來一個batch的數據,計算一次梯度,更新一次網絡optimizer.zero_grad()   #梯度值清零loss.backward()     #反向傳播計算得到的每個參數的梯度值woptimizer.step()    #根據梯度更新網絡w參數loss_value=loss.item()  #從tensor數據中提取數據出來,tensor獲取損失值if batch_size_num%100==0:print(f"loss:{loss_value:>7f}[number:{batch_size_num}]")batch_size_num+=1

8、定義測試的函數

def test(dataloader,model,loss_fn):size=len(dataloader.dataset)num_batches=len(dataloader)model.eval()test_loss,correct=0,0with torch.no_grad():for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)test_loss+=loss_fn(pred,y).item()correct +=(pred.argmax(1)==y).type(torch.float).sum().item()a=(pred.argmax(1)==y)b=(pred.argmax(1)==y).type(torch.float)test_loss/=num_batchescorrect/=sizeprint(f"Test result:\n Accurracy:{(100*correct)}%,AVG loss:{test_loss}")

9、初始化損失函數創建優化器

loss_fn=nn.CrossEntropyLoss()   #創建交叉熵損失函數對象,適合做多分類optimizer=torch.optim.SGD(model.parameters(),lr=0.01)   #創建優化器,使用SGD隨機梯度下降

10、調用訓練和測試的函數,完成訓練一次測試一次

train(train_dataloder,model,loss_fn,optimizer)  #訓練一次完整的數據,多輪訓練
test(test_dataloder,model,loss_fn)

11、訓練20輪,測試一次

epochs=20
for epoch in range(epochs):train(train_dataloder,model,loss_fn,optimizer)print(f"epoch{epoch}")
test(test_dataloder,model,loss_fn)

三、完整代碼展示


"""手寫數字識別"""
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor"""下載測試數據集(包含圖片和標簽)"""training_data=datasets.MNIST(root='../data',train=True,download=True,transform=ToTensor()
)"""下載測試數據集(包含訓練圖片+標簽)"""test_data=datasets.MNIST(root='../data',train=False,download=True,transform=ToTensor()
)
print(len(training_data))"""展示手寫圖片,把訓練集中的前59000張圖片展示一下"""
from matplotlib import pyplot as plt
figure=plt.figure()
for i in range(9):img,label=training_data[i+59000]figure.add_subplot(3,3,i+1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(),cmap='gray')a=img.squeeze()
plt.show()train_dataloder=DataLoader(training_data,batch_size=64)
test_dataloder=DataLoader(test_data,batch_size=64)device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"print(f"{device}device")"""self參數理解:在類內部開辟出了一個共享空間,所有被定義在這片空間的參數都能夠使用self.參數名來調用"""class NeurakNetwork(nn.Module):     #通過調用類的形式來使用神經網絡,神經網絡的模型nn.moudledef __init__(self):super().__init__()  #繼承父類的初始化self.flatten=nn.Flatten()   #將二位數據展成一維數據self.hidden1=nn.Linear(28*28,128)   #第一個參數時有多少個神經元傳進來,第二個參數是有多少個數據傳出去self.hidden2=nn.Linear(128,256)self.out=nn.Linear(256,10)      #輸出必須與標簽類型相同,輸入必須是上一層神經元的個數def forward(self,x):    #前向傳播,指明數據的流向,使神經網絡連接起來,函數名稱不能修改x=self.flatten(x)x=self.hidden1(x)x=torch.relu(x)     #激活函數,torch使用relu或者tanh函數作為激活函數x=self.hidden2(x)x=torch.relu(x)x=self.out(x)return xmodel = NeurakNetwork().to(device)      #將剛剛定義的模型傳入到GPU中
print(model)def train(dataloader,model,loss_fn,optimizer):model.train()       #告訴模型,即將開始訓練,其中的w進行隨機化操作,已經更新w,在訓練過程中,w會被修改"""pytorch提供兩種方式來切換訓練和測試的模式,分別是model.train()和model.eval()一般用法是,在訓練開始之前寫上model.train(),在測試時寫上model.eval()"""batch_size_num=1for X,y in dataloader:      #其中batch為每一個數據的編號X,y=X.to(device),y.to(device)   #將訓練數據集和標簽傳入cpu和gpupred=model.forward(X)loss=loss_fn(pred,y)    #通過交叉熵損失函數計算loss#Backpropagation  進來一個batch的數據,計算一次梯度,更新一次網絡optimizer.zero_grad()   #梯度值清零loss.backward()     #反向傳播計算得到的每個參數的梯度值woptimizer.step()    #根據梯度更新網絡w參數loss_value=loss.item()  #從tensor數據中提取數據出來,tensor獲取損失值if batch_size_num%100==0:print(f"loss:{loss_value:>7f}[number:{batch_size_num}]")batch_size_num+=1def test(dataloader,model,loss_fn):size=len(dataloader.dataset)num_batches=len(dataloader)model.eval()test_loss,correct=0,0with torch.no_grad():for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)test_loss+=loss_fn(pred,y).item()correct +=(pred.argmax(1)==y).type(torch.float).sum().item()a=(pred.argmax(1)==y)b=(pred.argmax(1)==y).type(torch.float)test_loss/=num_batchescorrect/=sizeprint(f"Test result:\n Accurracy:{(100*correct)}%,AVG loss:{test_loss}")loss_fn=nn.CrossEntropyLoss()   #創建交叉熵損失函數對象,適合做多分類optimizer=torch.optim.Adam(model.parameters(),lr=0.01)   #創建優化器,使用Adam優化器#params:要訓練的參數,一般傳入的都是model.parameters()
#lr是指學習率,也就是步長#loss表示模型訓練后的輸出結果與樣本標簽的差距,如果差距越小,就表示模型訓練越好,越逼近于真實的模型
train(train_dataloder,model,loss_fn,optimizer)  #訓練一次完整的數據,多輪訓練
test(test_dataloder,model,loss_fn)epochs=20
for epoch in range(epochs):train(train_dataloder,model,loss_fn,optimizer)print(f"epoch{epoch}")
test(test_dataloder,model,loss_fn)

在這里插入圖片描述
可以看到經過20輪的訓練模型的正確率為96.91%。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/77877.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/77877.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/77877.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

UniOcc:自動駕駛占用預測和預報的統一基準

25年3月來自 UC Riverside、U Wisconsin 和 TAMU 的論文"UniOcc: A Unified Benchmark for Occupancy Forecasting and Prediction in Autonomous Driving"。 UniOcc 是一個全面統一的占用預測基準(即基于歷史信息預測未來占用)和基于攝像頭圖…

模型量化核心技術解析:從算法原理到工業級實踐

一、模型量化為何成為大模型落地剛需&#xff1f; 算力困境&#xff1a;175B參數模型FP32推理需0.5TB內存&#xff0c;超出主流顯卡容量 速度瓶頸&#xff1a;FP16推理延遲難以滿足實時對話需求&#xff08;如客服場景<200ms&#xff09; 能效挑戰&#xff1a;邊緣設備運行…

AD9253鏈路訓練

傳統方式 參考Xilinx官方文檔xapp524。對于AD9253器件 - 125M采樣率 - DDR模式&#xff0c;ADC器件的DCO采樣時鐘(500M Hz)和FCO幀時鐘是中心對齊的&#xff0c;適合直接采樣。但是DCO時鐘不能直接被FPGA內部邏輯使用&#xff0c;需要經過BUFIO和BUFR緩沖后&#xff0c;得到s_b…

解決方案:遠程shell連不上Ubuntu服務器

服務器是可以通過VNC登錄&#xff0c;排除了是服務器本身故障 檢查服務是否在全網卡監聽 sudo ss -tlnp | grep sshd確保有一行類似 LISTEN 0 128 0.0.0.0:22 0.0.0.0:* users:(("sshd",pid...,fd3))返回無結果&#xff0c;表明系統里并沒有任…

關于大數據的基礎知識(四)——大數據的意義與趨勢

成長路上不孤單&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a; 【14后&#x1f60a;///計算機愛好者&#x1f60a;///持續分享所學&#x1f60a;///如有需要歡迎收藏轉發///&#x1f60a;】 今日分享關于大數據的基礎知識&#xff08;四&a…

智能指針(weak_ptr )之三

1. std::weak_ptr 1.1 定義與用法 std::weak_ptr 是一種不擁有對象所有權的智能指針&#xff0c;用于觀察但不影響對象的生命周期。主要用于解決 shared_ptr 之間的循環引用問題。 主要特性&#xff1a; 非擁有所有權&#xff1a;不增加引用計數。可從 shared_ptr 生成&…

學習海康VisionMaster之卡尺工具

一&#xff1a;進一步學習了 今天學習下VisionMaster中的卡尺工具&#xff1a;主要用于測量物體的寬度、邊緣的特征的位置以及圖像中邊緣對的位置和間距 二&#xff1a;開始學習 1&#xff1a;什么是卡尺工具&#xff1f; 如果我需要檢測芯片的每一個PIN的寬度和坐標&#xff…

Java面試實戰:從Spring Boot到微服務的深入探討

Java面試實戰&#xff1a;從Spring Boot到微服務的深入探討 場景&#xff1a;電商場景的面試之旅 在某互聯網大廠的面試間&#xff0c;面試官李老師正襟危坐&#xff0c;而對面坐著的是傳說中的“水貨程序員”趙大寶。 第一輪&#xff1a;核心Java與構建工具 面試官&#x…

深入理解 Spring @Configuration 注解

在 Spring 框架中,@Configuration 注解是一個非常重要的工具,它用于定義配置類,這些類可以包含 Bean 定義方法。通過使用 @Configuration 和 @Bean 注解,開發者能夠以編程方式創建和管理應用程序上下文中的 Bean。本文將詳細介紹 @Configuration 注解的作用、如何使用它以及…

密碼學中的鹽值是什么?

目錄 1. 鹽值的基本概念 2. 鹽值的作用 (1) 防止彩虹表攻擊 (2) 防止相同的密碼生成相同的哈希值 (3) 增加暴力破解的難度 3. 如何使用鹽值&#xff1f; (1) 生成鹽值 (2) 將鹽值附加到密碼 (3) 存儲鹽值和哈希值 (4) 驗證密碼 4. 鹽值如何增加暴力破解的難度 在線暴…

基于瑞芯微RK3576國產ARM八核2.2GHz A72 工業評估板——Docker容器部署方法說明

前 言 本文適用開發環境: Windows開發環境:Windows 7 64bit、Windows 10 64bit Linux開發環境:VMware16.2.5、Ubuntu22.04.5 64bit U-Boot:U-Boot-2017.09 Kernel:Linux-6.1.115 LinuxSDK:LinuxSDK-[版本號](基于rk3576_linux6.1_release_v1.1.0) Docker是一個開…

大數據技術全解析

目錄 前言1. Kafka&#xff1a;流數據的傳輸平臺1.1 Kafka概述1.2 Kafka的應用場景1.3 Kafka的特點 2. HBase&#xff1a;分布式列式數據庫2.1 HBase概述2.2 HBase的應用場景2.3 HBase的特點 3. Hadoop&#xff1a;大數據處理的基石3.1 Hadoop概述3.2 Hadoop的應用場景3.3 Hado…

mcpo的簡單使用

1.安裝依賴 conda create -n mcpo python3.11 conda activate mcpo pip install mcpo pip install uv2.隨便從https://github.com/modelcontextprotocol/servers?tabreadme-ov-file 找一個mcp服務使用就行&#xff0c;我這里選的是爬蟲 然后安裝 pip install mcp-server-f…

uniapp-商城-32-shop 我的訂單-訂單詳情和組件goods-list

上面完成了我的訂單&#xff0c;通過點擊我的訂單中每一條數據&#xff0c;可以跳轉到訂單詳情中。 這里就需要展示訂單的狀態&#xff0c;支付狀態&#xff0c;物流狀態&#xff0c;取貨狀態&#xff0c;用戶信息&#xff0c;訂單中的貨物詳情等。 1、創建一個訂單詳情文件 …

XCVU13P-2FHGA2104I Xilinx Virtex UltraScale+ FPGA

XCVU13P-2FHGA2104I 是 Xilinx&#xff08;現為 AMD&#xff09;Virtex UltraScale? FPGA 系列中的高端 Premium 器件&#xff0c;基于 16nm FinFET 工藝并采用 3D IC 堆疊硅互連&#xff08;SSI&#xff09;技術&#xff0c;提供業內頂級的計算密度和帶寬?。該芯片集成約 3,…

【Python3】Django 學習之路

第一章&#xff1a;Django 簡介 1.1 什么是 Django&#xff1f; Django 是一個高級的 Python Web 框架&#xff0c;旨在讓 Web 開發變得更加快速和簡便。它鼓勵遵循“不要重復自己”&#xff08;DRY&#xff0c;Don’t Repeat Yourself&#xff09;的原則&#xff0c;并提供了…

Python 設計模式:模板模式

1. 什么是模板模式&#xff1f; 模板模式是一種行為設計模式&#xff0c;它定義了一個操作的算法的骨架&#xff0c;而將一些步驟延遲到子類中。模板模式允許子類在不改變算法結構的情況下&#xff0c;重新定義算法的某些特定步驟。 模板模式的核心思想是將算法的固定部分提取…

【后端】構建簡潔的音頻轉寫系統:基于火山引擎ASR實現

在當今數字化時代&#xff0c;語音識別技術已經成為許多應用不可或缺的一部分。無論是會議記錄、語音助手還是內容字幕&#xff0c;將語音轉化為文本的能力對提升用戶體驗和工作效率至關重要。本文將介紹如何構建一個簡潔的音頻轉寫系統&#xff0c;專注于文件上傳、云存儲以及…

音頻base64

音頻 Base64 是一種將二進制音頻數據&#xff08;如 MP3、WAV 等格式&#xff09;編碼為 ASCII 字符串的方法。通過 Base64 編碼&#xff0c;音頻文件可以轉換為純文本形式&#xff0c;便于在文本協議&#xff08;如 JSON、XML、HTML 或電子郵件&#xff09;中傳輸或存儲&#…

240422 leetcode exercises

240422 leetcode exercises jarringslee 文章目錄 240422 leetcode exercises[237. 刪除鏈表中的節點](https://leetcode.cn/problems/delete-node-in-a-linked-list/)&#x1f501;節點覆蓋法 [392. 判斷子序列](https://leetcode.cn/problems/is-subsequence/)&#x1f501;…