深度學習——基于卷積神經網絡的MNIST手寫數字識別詳解

文章目錄

    • 引言
    • 1. 環境準備和數據加載
      • 1.1 下載MNIST數據集
      • 1.2 數據可視化
    • 2. 數據預處理
    • 3. 設備配置
    • 4. 構建卷積神經網絡模型
    • 5. 訓練和測試函數
      • 5.1 訓練函數
      • 5.2 測試函數
    • 6. 模型訓練和評估
      • 6.1 初始化損失函數和優化器
      • 6.2 訓練過程
    • 7. 關鍵點解析
    • 8. 完整代碼
    • 9. 總結

引言

手寫數字識別是計算機視覺和深度學習領域的經典入門項目。本文將詳細介紹如何使用PyTorch框架構建一個卷積神經網絡(CNN)來實現MNIST手寫數字識別任務。我們將從數據加載、模型構建到訓練和測試,一步步解析整個過程。

1. 環境準備和數據加載

首先,我們需要導入必要的PyTorch模塊:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

1.1 下載MNIST數據集

MNIST數據集包含60,000個訓練樣本和10,000個測試樣本,每個樣本都是一個28x28像素的灰度手寫數字圖像。

# 下載訓練數據集
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)# 下載測試數據集
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)

1.2 數據可視化

我們可以使用matplotlib庫來查看數據集中的一些樣本:

from matplotlib import pyplot as pltfigure = plt.figure()
for i in range(9):img, label = training_data[i+59000]  # 提取后幾張圖片figure.add_subplot(3,3,i+1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

2. 數據預處理

為了高效訓練模型,我們需要使用DataLoader將數據集分批次加載:

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

3. 設備配置

PyTorch支持在CPU、NVIDIA GPU和蘋果M系列芯片上運行,我們可以自動檢測最佳可用設備:

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

4. 構建卷積神經網絡模型

我們定義一個CNN類來實現手寫數字識別:

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 8, 3, 1, 1),  # (8,28,28)nn.ReLU(),nn.MaxPool2d(2),           # (8,14,14))self.conv2 = nn.Sequential(nn.Conv2d(8, 16, 3, 1, 1), # (16,14,14)nn.ReLU(),nn.MaxPool2d(2),           # (16,7,7))self.out = nn.Linear(16*7*7, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)      # flatten操作output = self.out(x)return outputmodel = CNN().to(device)

這個CNN模型包含:

  • 兩個卷積層,每個卷積層后接ReLU激活函數和最大池化層
  • 一個全連接輸出層
  • 輸入大小:(1,28,28)
  • 輸出大小:10(對應0-9的數字類別)

5. 訓練和測試函數

5.1 訓練函數

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()if batch_size_num % 100 == 0:print(f"loss: {loss.item():>7f} [number:{batch_size_num}]")batch_size_num += 1

5.2 測試函數

def Test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test result: \n Accuracy:{(100*correct)}%, Avg loss:{test_loss}")

6. 模型訓練和評估

6.1 初始化損失函數和優化器

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

6.2 訓練過程

# 初始訓練和測試
train(train_dataloader, model, loss_fn, optimizer)
Test(test_dataloader, model, loss_fn)# 多輪訓練
epochs = 10
for t in range(epochs):print(f"epoch {t+1}\n---------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")# 最終測試
Test(test_dataloader, model, loss_fn)

7. 關鍵點解析

  1. 數據轉換:使用ToTensor()將圖像數據轉換為PyTorch張量,并自動歸一化到[0,1]范圍。

  2. 批處理:DataLoader的batch_size參數控制每次訓練使用的樣本數量,影響內存使用和訓練速度。

  3. 模型結構

    • 卷積層提取空間特征
    • ReLU激活函數引入非線性
    • 最大池化層降低特征圖尺寸
    • 全連接層輸出分類結果
  4. 訓練模式切換model.train()model.eval()分別用于訓練和測試階段,影響某些層(如Dropout和BatchNorm)的行為。

  5. 優化過程:Adam優化器結合了動量法和自適應學習率的優點,通常能獲得較好的訓練效果。

8. 完整代碼

import torch
from torch import nn    #導入神經網絡模塊
from torch.utils.data import DataLoader  #數據包管理工具,打包數據
from torchvision import  datasets  #封裝了很多與圖像相關的模型,數據集
from torchvision.transforms import ToTensor  #數據轉換,張量,將其他類型的數據轉換為tensor張量,numpy array'''下載訓練數據集(包含訓練圖片+標簽)'''
training_data = datasets.MNIST( #跳轉到函數的內部源代碼,pycharm按下ctrl + 鼠標點擊root="data", #表示下載的手寫數字  到哪個路徑。60000train=True, #讀取下載后的數據中的訓練集download=True, #如果你之前已經下載過了,就不用下載transform=ToTensor(), #張量,圖片是不能直接傳入神經網絡模型)   #對于pytorch庫能夠識別的數據一般是tensor張量'''下載測試數據集(包含訓練圖片+標簽)'''
test_data = datasets.MNIST( #跳轉到函數的內部源代碼,pycharm按下ctrl + 鼠標點擊root="data", #表示下載的手寫數字  到哪個路徑。60000train=False, #讀取下載后的數據中的訓練集download=True, #如果你之前已經下載過了,就不用下載transform=ToTensor(), #Tensor是在深度學習中提出并廣泛應用的數據類型)   #Numpy數組只能在CPU上運行。Tensor可以在GPU上運行。這在深度學習應用中可以顯著提高計算速度。
print(len(training_data))'''展示手寫數字圖片,把訓練集中的59000張圖片展示'''
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img,label = training_data[i+59000] #提取第59000張圖片figure.add_subplot(3,3,i+1) #圖像窗口中創建多個小窗口,小窗口用于顯示圖片plt.title(label)plt.axis("off")  #plt.show(I) 顯示矢量plt.imshow(img.squeeze(),cmap="gray") #plt.imshow()將Numpy數組data中的數據顯示為圖像,并在圖形窗口中顯示a = img.squeeze()  #img.squeeze()從張量img中去掉維度為1的,如果該維度的大小不為1,則張量不會改變
plt.show()'''創建數據DataLoader(數據加載器)'''
# batch_size:將數據集分為多份,每一份為batch_size個數據
#       優點:可以減少內存的使用,提高訓練速度train_dataloader = DataLoader(training_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)'''判斷當前設備是否支持GPU,其中mps是蘋果m系列芯片的GPU'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")   #字符串的格式化,CUDA驅動軟件的功能:pytorch能夠去執行cuda的命令
# 神經網絡的模型也需要傳入到GPU,1個batch_size的數據集也需要傳入到GPU,才可以進行訓練''' 定義神經網絡  類的繼承這種方式'''
class CNN(nn.Module): #通過調用類的形式來使用神經網絡,神經網絡的模型,nn.mdouledef __init__(self): #輸入大小:(1,28,28)super(CNN,self).__init__()  #初始化父類self.conv1 = nn.Sequential(      #將多個層組合成一起,創建了一個容器,將多個網絡組合在一起nn.Conv2d(              # 2d一般用于圖像,3d用于視頻數據(多一個時間維度),1d一般用于結構化的序列數據in_channels=1,      # 圖像通道個數,1表示灰度圖(確定了卷積核 組中的個數)out_channels=8,     # 要得到多少個特征圖,卷積核的個數kernel_size=3,      # 卷積核大小 3×3stride=1,           # 步長padding=1,          # 一般希望卷積核處理后的結果大小與處理前的數據大小相同,效果會比較好),                      # 輸出的特征圖為(8,28,28)nn.ReLU(),  # Relu層,不會改變特征圖的大小nn.MaxPool2d(kernel_size=2),    # 進行池化操作(2×2操作),輸出結果為(8,14,14))self.conv2 = nn.Sequential(nn.Conv2d(8,16,3,1,1),  #輸出(16,14,14)nn.ReLU(),  #Relu層  (16,14,14)nn.MaxPool2d(kernel_size=2),    #池化層,輸出結果為(16,7,7))self.out = nn.Linear(16*7*7,10)  # 全連接層得到的結果def forward(self,x):   #前向傳播,你得告訴它 數據的流向 是神經網絡層連接起來,函數名稱不能改x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0),-1)    # flatten操作,結果為:(batch_size,64 * 7 * 7)output = self.out(x)return output
model = CNN().to(device) #把剛剛創建的模型傳入到GPU
print(model)def train(dataloader,model,loss_fn,optimizer):model.train() #告訴模型,我要開始訓練,模型中w進行隨機化操作,已經更新w,在訓練過程中,w會被修改的
# pytorch提供2種方式來切換訓練和測試的模式,分別是:model.train() 和 mdoel.eval()
# 一般用法是:在訓練開始之前寫上model.train(),在測試時寫上model.eval()batch_size_num = 1for X,y in dataloader:              #其中batch為每一個數據的編號X,y = X.to(device),y.to(device) #把訓練數據集和標簽傳入cpu或GPUpred = model.forward(X)         # .forward可以被省略,父類種已經對此功能進行了設置loss = loss_fn(pred,y)          # 通過交叉熵損失函數計算損失值loss# Backpropagation 進來一個batch的數據,計算一次梯度,更新一次網絡optimizer.zero_grad()           # 梯度值清零loss.backward()                 # 反向傳播計算得到每個參數的梯度值woptimizer.step()                # 根據梯度更新網絡w參數loss_value = loss.item()        # 從tensor數據種提取數據出來,tensor獲取損失值if batch_size_num %100 ==0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1def Test(dataloader,model,loss_fn):size = len(dataloader.dataset)  #10000num_batches = len(dataloader)  # 打包的數量model.eval()        #測試,w就不能再更新test_loss,correct =0,0with torch.no_grad():       #一個上下文管理器,關閉梯度計算。當你確認不會調用Tensor.backward()的時候for X,y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X)test_loss += loss_fn(pred,y).item() #test_loss是會自動累加每一個批次的損失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y) #dim=1表示每一行中的最大值對應的索引號,dim=0表示每一列中的最大值對應的索引號b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches #能來衡量模型測試的好壞correct /= size  #平均的正確率print(f"Test result: \n Accuracy:{(100*correct)}%, Avg loss:{test_loss}")loss_fn = nn.CrossEntropyLoss()  #創建交叉熵損失函數對象,因為手寫字識別一共有十種數字,輸出會有10個結果
#
optimizer = torch.optim.Adam(model.parameters(),lr=0.01) #創建一個優化器,SGD為隨機梯度下降算法
# # params:要訓練的參數,一般我們傳入的都是model.parameters()
# # lr:learning_rate學習率,也就是步長
#
# # loss表示模型訓練后的輸出結果與樣本標簽的差距。如果差距越小,就表示模型訓練越好,越逼近真實的模型
train(train_dataloader,model,loss_fn,optimizer) #訓練1次完整的數據。多輪訓練
Test(test_dataloader,model,loss_fn)epochs = 10
for t in range(epochs):print(f"epoch {t+1}\n---------------")train(train_dataloader,model,loss_fn,optimizer)
print("Done!")
Test(test_dataloader,model,loss_fn)

9. 總結

通過本文,我們學習了如何使用PyTorch實現一個完整的手寫數字識別項目。從數據加載、模型構建到訓練和評估,每個步驟都展示了PyTorch框架的簡潔和強大。這個簡單的CNN模型在MNIST數據集上可以達到很高的準確率,為進一步學習更復雜的計算機視覺任務打下了良好基礎。

未來可以嘗試:

  • 調整網絡結構(增加層數、改變通道數)
  • 嘗試不同的優化器和學習率
  • 添加數據增強技術
  • 在更復雜的數據集上應用類似方法

希望這篇教程能幫助你入門PyTorch和計算機視覺領域!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/84835.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/84835.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/84835.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Activiti初識

文章目錄 1 工作流介紹1_工作流概念介紹2 工作流系統3 適用行業4 具體應用5 實現方式 2 Activiti介紹1_BPM2 BPM 軟件3 BPMN 3 使用步驟1_部署 activiti2 流程定義3 流程定義部署4 啟動一個流程實例5 用戶查詢待辦任務(Task)6 用戶辦理任務7 流程結束 4 Activiti應用1_Activiti…

CyclicBarrier入門代碼解析

文章目錄 核心思想:組隊出游,人到齊了才出發 🚌最簡單易懂的代碼示例代碼解析運行效果分析CyclicBarrier vs CountDownLatch 的關鍵區別CyclicBarrier在業務系統里面通常有什么常用的應用場景核心應用模式1. 數據并行處理與ETL(最…

Maven 配置中繞過 HTTP 阻斷機制的完整解決方案

Maven 配置中繞過 HTTP 阻斷機制的完整解決方案 一、背景與問題分析 自 Maven 3.8.1 版本起&#xff0c;出于安全考慮&#xff0c;默認禁止了對 HTTP 倉庫的訪問。這一機制通過 <mirror> 配置中的 maven-default-http-blocker 實現&#xff0c;其作用是攔截所有使用 HT…

【大廠機試題解法筆記】恢復數字序列

題目 對于一個連續正整數組成的序列&#xff0c;可以將其拼接成一個字符串&#xff0c;再將字符串里的部分字符打亂順序。如序列8 9 10 11 12,拼接成的字符串為89101112,打亂一部分字符后得到90811211,原來的正整數10就被拆成了0和1。 現給定一個按如上規則得到的打亂字符的字…

MongoDB 事務有哪些限制和注意事項?

MongoDB 的多文檔 ACID 事務雖然強大&#xff0c;但在使用時確實有一些限制和需要特別注意的事項。 以下是主要的限制和注意事項&#xff1a; 1. 性能開銷 (Performance Overhead) 額外協調: 事務需要額外的協調工作&#xff0c;包括跟蹤事務狀態、管理鎖&#xff08;即使是樂…

CTF實戰技巧:獲取初始權限后如何高效查找Flag

CTF實戰技巧&#xff1a;獲取初始權限后如何高效查找Flag 在CTF比賽中&#xff0c;獲得初始訪問權限只是開始&#xff0c;真正的挑戰在于如何在系統中高效定位Flag。本文將分享我在滲透測試中總結的系統化Flag搜索方法&#xff0c;涵蓋Linux和Windows雙平臺。 引言&#xff1a;…

kafka Tool (Offset Explorer)使用SASL Plaintext進行身份驗證

一、前面和不需要認證的情況相同&#xff1a; 1、填寫Properties中的cluster name和版本&#xff0c;以及zk的ip和port 2、Advanced中填寫bootstrap servers 二、和不需要認證時不同的點&#xff1a; 1、Security的Type&#xff0c;不需要認證時選plaintext&#xff0c;需要認…

最小費用最大流算法

最小費用最大流算法 原理 問題:網絡中有源點(起點)和匯點(終點),每條邊有流量上限和單位流量費用。求: 從源點到匯點的最大流量在流量最大的前提下,總費用最小核心思想:在找增廣路時,選擇單位費用之和最小的路徑(使用SPFA找最短路) 實現步驟 建圖:使用鏈式前向…

從匯編的角度揭開C++ this指針的神秘面紗(上)

C中的this指針一直比較神秘。任何類的對象&#xff0c;都有一個this指針&#xff0c;無處不在。那么this指針的本質究竟是什么&#xff1f;this指針什么時候會被用到&#xff1f;今天通過幾段簡單的代碼&#xff0c;來揭秘一下。 要先揭秘this指針&#xff0c;先來說一下函數調…

18 - GCNet

論文《GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond》 1、作用 GCNet通過聚合每個查詢位置的全局上下文信息來捕獲長距離依賴關系&#xff0c;從而改善了圖像/視頻分類、對象檢測和分割等一系列識別任務的性能。非局部網絡&#xff08;NLNet&…

人工智能學習17-Pandas-查看數據

人工智能學習概述—快手視頻 人工智能學習17-Pandas-查看數據—快手視頻

RV1126+OPENCV在視頻中添加LOGO圖像

一.RV1126OPENCV在視頻中添加LOGO圖像大體流程圖 主要是利用RV1126的視頻流結合OPENCV的API在視頻流里面添加LOGO圖像&#xff0c;換言之就是在RV1126的視頻流里面疊加圖片。大體流程我們來看上圖&#xff0c;要完成這個功能我們需要創建兩個線程(實際上還有初始化過程&#xf…

汽車制造通信革新:網關模塊讓EtherCAT成功對接CCLINK

?在現代工業自動化生產領域&#xff0c;不同品牌和類型的設備往往采用不同的通信協議&#xff0c;這給設備之間的互聯互通帶來了挑戰。某汽車制造企業的生產線上&#xff0c;采用了三菱FX5U PLC作為主站進行整體生產流程的控制和調度&#xff0c;同時配備了庫卡機器人作為從站…

vue父類跳轉到子類帶參數,跳轉完成后去掉參數

當通過路由導航的時候&#xff0c;由于父類頁面帶參數到子類&#xff0c;導致路徑上面有參數 這樣不僅不美觀&#xff0c;而且在點擊導航菜單按鈕時還會有各種問題&#xff0c;這時我們只需要將路由后面的參數去掉就好了&#xff0c;在子頁面mounted()函數里面獲取到父類的參數…

純 CSS 實現的的3種掃光效果

介紹一個比較常見的動畫效果。 在日常開發中&#xff0c;為了強調凸顯某些文本或者元素&#xff0c;會加一些掃光動效&#xff0c;起到吸引眼球的效果&#xff0c;比如文本的 或者是一個卡片容器&#xff0c;里面可能是圖片或者文本或者任意元素 除此之外&#xff0c;還有那…

如何在FastAPI中構建一個既安全又靈活的多層級權限系統?

title: 如何在FastAPI中構建一個既安全又靈活的多層級權限系統? date: 2025/06/14 12:43:05 updated: 2025/06/14 12:43:05 author: cmdragon excerpt: FastAPI通過依賴注入系統和OAuth2、JWT等安全方案,支持構建多層級權限系統。系統設計包括基于角色的訪問控制、細粒度權…

大模型_Ubuntu24.04安裝RagFlow_使用hyper-v虛擬機_超級詳細--人工智能工作筆記0251

因為之前使用dify搭建了一個知識庫&#xff0c;但是dify的效果&#xff0c;尤其是在文檔解析方面是非常不友好的&#xff0c;雖然測試了&#xff0c;納米的效果非常好&#xff0c;但是納米只能容納2000個文件&#xff0c;如果 你的知識庫中有代碼&#xff0c;sql文件等等&…

LeetCode - LCR 173. 點名

題目 LCR 173. 點名 - 力扣&#xff08;LeetCode&#xff09; 思路 首先對數組進行排序&#xff0c;使學號按順序排列 在排序后的數組中&#xff0c;如果沒有缺失的學號&#xff0c;那么每個元素應該等于其索引值 使用二分查找找到第一個不等于其索引的元素位置&#xff1…

VSCode如何優雅的debug python文件,包括外部命令uv run main.py等等

debug程序的方式有很多種。每一種方式都各有缺點:有的方式雖然優雅,但是局限性很大;有的方式麻煩,但是局限性小。 常規方式: 優點:然后可以觀察所有線程。一勞永逸。缺點:就是寫參數很麻煩,但是你可以讓chatgpt等大模型幫你寫。最最最優雅的方式: 優點:就是需要在代碼…

[調試技巧]VS Code如何在代理模式下使用 MCP 工具?

在開發環境調試MCP&#xff0c;通過agent模式與大模型對話&#xff0c;并不能保證每次均正確調用tool。在閱讀官方文檔之后&#xff0c;得知以下小技巧。 添加 MCP 服務器后&#xff0c;您可以在代理模式下使用它提供的工具。要在代理模式下使用 MCP 工具 打開聊天視圖 (CtrlAl…