用PyTorch搭建卷積神經網絡實現MNIST手寫數字識別

用PyTorch搭建卷積神經網絡實現MNIST手寫數字識別

在深度學習領域,卷積神經網絡(Convolutional Neural Network,簡稱CNN)是處理圖像數據的強大工具。它通過卷積層、池化層和全連接層等組件,自動提取圖像特征,在圖像分類、目標檢測等任務中表現卓越。本文將使用PyTorch框架,搭建一個CNN模型來實現MNIST手寫數字識別,并詳細解析每一步代碼。

一、MNIST數據集介紹

MNIST數據集是深度學習領域經典的入門數據集,包含70,000張手寫數字圖像,其中60,000張用于訓練,10,000張用于測試。這些圖像均為灰度圖,尺寸是28x28像素,并且已經做了居中處理,這在一定程度上減少了預處理的工作量,能夠加快模型的訓練和運行速度。

二、環境準備與數據加載

2.1 導入必要的庫

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

上述代碼導入了PyTorch的核心庫、神經網絡模塊、數據加載工具以及用于圖像數據處理和數據集管理的庫。

2.2 下載并加載數據集

training_data = datasets.MNIST(root='data',train=True,download=True,transform=ToTensor()
)test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor()
)

通過datasets.MNIST函數分別下載訓練集和測試集。root參數指定數據下載的路徑;train=True表示下載訓練集數據,train=False則表示下載測試集數據;download=True確保如果數據尚未下載,會自動進行下載;transform=ToTensor()將圖像數據轉換為PyTorch能夠處理的張量格式。

2.3 數據可視化

from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i + 59000]figure.add_subplot(3, 3, i + 1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

這段代碼使用matplotlib庫展示了訓練數據集中的部分手寫數字圖像,通過plt.imshow函數將張量格式的圖像數據可視化,直觀感受MNIST數據集的內容。

2.4 創建數據加載器

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

DataLoader用于將數據集打包成批次,batch_size參數指定每個批次包含的數據樣本數量。將數據集分成批次進行訓練,能夠有效減少內存使用,并提高訓練速度。

三、設備配置

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

這段代碼檢測當前設備是否支持GPU(CUDA)或蘋果M系列芯片的GPU(MPS),如果都不支持,則使用CPU進行計算。后續模型和數據都會被移動到選定的設備上運行,以充分利用硬件資源加速訓練。

四、定義卷積神經網絡模型

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=1,padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU())self.out = nn.Linear(64 * 7 * 7, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)output = self.out(x)return output

在這個自定義的CNN類中,繼承自nn.Module__init__方法中定義了網絡的結構:

  • 卷積層(nn.Conv2d:用于提取圖像特征,通過設置in_channels(輸入通道數)、out_channels(輸出通道數,即卷積核個數)、kernel_size(卷積核大小)、stride(步長)和padding(填充)等參數,控制卷積操作。
  • 激活函數層(nn.ReLU:引入非線性,增強網絡的表達能力。
  • 池化層(nn.MaxPool2d:對特征圖進行下采樣,減少數據量和計算量,同時保留主要特征。
  • 全連接層(nn.Linear:將卷積層和池化層提取的特征映射到輸出類別(MNIST數據集中有10個數字類別)。

forward方法定義了數據在網絡中的前向傳播路徑,確保數據按照網絡結構依次經過各層處理,最終輸出預測結果。

五、訓練與測試模型

5.1 定義損失函數和優化器

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

nn.CrossEntropyLoss是適用于多分類任務的交叉熵損失函數,用于計算模型預測結果與真實標簽之間的差距。torch.optim.Adam是一種常用的優化器,通過調整模型的參數(model.parameters())來最小化損失函數,lr參數設置學習率,控制參數更新的步長。

5.2 訓練函數

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 == 0:print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')batch_size_num += 1

在訓練函數中:

  • model.train()將模型設置為訓練模式,此時模型中的一些層(如Dropout層)會按照訓練規則工作。
  • 遍歷數據加載器中的每一個批次數據,將數據和標簽移動到指定設備上。
  • 通過模型進行預測,計算損失值。
  • 使用optimizer.zero_grad()清零梯度,loss.backward()進行反向傳播計算梯度,optimizer.step()根據梯度更新模型參數。
  • 每隔100個批次,打印當前的損失值,以便觀察訓練過程中的損失變化。

5.3 測試函數

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f'Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}')

測試函數中:

  • model.eval()將模型設置為測試模式,關閉一些在訓練過程中起作用但在測試時不需要的操作(如Dropout)。
  • 使用with torch.no_grad()上下文管理器,關閉梯度計算,因為在測試階段不需要更新模型參數,這樣可以節省計算資源。
  • 遍歷測試數據,計算每個批次的損失值并累加,同時統計預測正確的樣本數量。
  • 最后計算并打印測試集上的平均損失和準確率,評估模型的性能。

5.4 執行訓練和測試

epoch = 9
for i in range(epoch):print(i + 1)train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)

通過設置訓練輪數(epoch),循環調用訓練函數進行模型訓練,每一輪訓練結束后,調用測試函數評估模型在測試集上的性能。

六、總結

本文通過詳細的代碼解析,展示了如何使用PyTorch搭建一個簡單的卷積神經網絡來實現MNIST手寫數字識別任務。從數據加載、模型定義,到訓練和測試,每一個步驟都體現了CNN在圖像分類任務中的核心思想和實現方法。通過不斷調整模型結構、超參數等,還可以進一步提升模型的性能。卷積神經網絡在圖像領域的應用遠不止于此,它在更復雜的圖像任務和其他領域也有著廣泛的應用前景,希望本文能為大家深入學習深度學習提供一個良好的開端。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/79857.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/79857.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/79857.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Tensorrt 基礎入門

什么是tensorrt? 其他廠商: Qualcomm, Hailo, google TPU tensorrt的優劣勢 使用tensorrt的pipeline tensorrt使用中存在的問題以及解決方案 tensorrt的應用場景 自動駕駛模型部署需要關注的問題: 邊端硬件資源有限 散熱(不能水冷) 實時性&…

Qt 顯示QRegExp 和 QtXml 不存在問題

QRegExp 和 QtXml 問題 在Qt6 中 已被棄用; 1)QRegExp 已被棄用,改用 QRegularExpression Qt5 → Qt6 重大變更:QRegExp 被移到了 Qt5Compat 模塊,默認不在 Qt6 核心模塊中。 錯誤類型解決方法QRegExp 找不到改用 Q…

玩玩OCR

一、Tesseract: 1.下載windows版: tesseract 2. 安裝并記下路徑,等會要填 3.保存.py文件 import pytesseract from PIL import Image def ocr_local_image(image_path):try:pytesseract.pytesseract.tesseract_cmd rD:\Programs\Tesseract-OCR\tesse…

Dify 完全指南(一):從零搭建開源大模型應用平臺(Ollama/VLLM本地模型接入實戰)》

文章目錄 1. 相關資源2. 核心特性3. 安裝與使用(Docker Compose 部署)3.1 部署Dify3.2 更新Dify3.3 重啟Dify3.4 訪問Dify 4. 接入本地模型4.1 接入 Ollama 本地模型4.1.1 步驟4.1.2 常見問題 4.2 接入 Vllm 本地模型 5. 進階應用場景6. 總結 1. 相關資源…

C++ Windows 打包exe運行方案(cmake)

文章目錄 背景動態庫梳理打包方案一、使用 Vcpkg 安裝靜態庫(關鍵基礎配置)1. 初始化 Vcpkg2. 安裝靜態庫(注意 x64-windows-static 后綴) 二、CMakeLists.txt 關鍵配置三、編譯四、驗證 不同平臺代碼兼容\_\_attribute\_\_((pack…

Java學習手冊:Hibernate/JPA 使用指南

一、Hibernate 和 JPA 的核心概念 實體(Entity) :實體是 JPA 中用于表示數據庫表的 Java 對象。通過在實體類上添加 Entity 注解,JPA 可以將實體類映射到數據庫表。例如,定義一個 User 實體類: import ja…

字符串匹配 之 拓展 KMP算法(Z算法)

文章目錄 習題2223.構造字符串的總得分和3031.將單詞恢復初始狀態所需的最短時間 II 靈神代碼模版 區別與KMP算法 KMP算法可用于求解在線性時間復雜度0(n)內求解模式串p在主串s中匹配的未知當然,由于在KMP算法中,預處理求解出了next數組,也就…

安全為上,在系統威脅建模中使用量化分析

*注:Open FAIR? 知識體系是一種開放和獨立的信息風險分析方法。它為理解、分析和度量信息風險提供了分類和方法。Open FAIR作為領先的風險分析方法論,已得到越來越多的大型組織認可。 在數字化風險與日俱增的今天,企業安全決策正面臨雙重挑戰…

游戲引擎學習第259天:OpenGL和軟件渲染器清理

回顧并為今天的內容做好鋪墊 今天,我們將對游戲的分析器進行升級。在之前的修復中,我們解決了分析器的一些敏感問題,例如它無法跨代碼重新加載進行分析,以及一些復雜的小問題。現在,我們的分析器看起來已經很穩定了。…

訊睿CMS模版常用標簽參數匯總

一、模板調用標簽 1、首頁 網站名稱:{SITE_NAME} 標題:{$meta_title}(列表頁通用) Keywords:{$meta_keywords} Description:{$meta_description}2、列表頁 迅睿cms調用本欄目基礎信息標簽代碼 當前欄目…

【C#】Buffer.BlockCopy的使用

Buffer.BlockCopy 是 C# 中的一個方法,用于在數組之間高效地復制字節塊。它主要用于操作字節數組(byte[]),但也可以用于其他類型的數組,因為它直接基于內存操作。 以下是關于 Buffer.BlockCopy 的詳細說明和使用示例&…

記一次pdf轉Word的技術經歷

一、發現問題 前幾天在打開一個pdf文件時,遇到了一些問題,在Win10下使用WPS PDF、萬興PDF、Adobe Acrobat、Chrome瀏覽器打開都是正常顯示的;但是在macOS 10.13中使用系統自帶的預覽程序和Chrome瀏覽器(由于macOS版本比較老了&am…

在Laravel 12中實現4A日志審計

以下是在Laravel 12中實現4A(認證、授權、賬戶管理、審計)日志審計并將日志存儲到MongoDB的完整方案(包含性能優化和安全增強措施): 一、環境配置 安裝MongoDB擴展包 composer require jenssegers/mongodb配置.env …

鏈表高級操作與算法

鏈表是數據結構中的基礎,但也是面試和實際開發中的重點考察對象。今天我們將深入探討鏈表的高級操作和常見算法,讓你能夠輕松應對各種鏈表問題。 1. 鏈表翻轉 - 最經典的鏈表問題 鏈表翻轉是面試中的常見題目,也是理解鏈表指針操作的絕佳練…

架構思維:構建高并發讀服務_使用懶加載架構實現高性能讀服務

文章目錄 一、引言二、讀服務的功能性需求三、兩大基本設計原則1. 架構盡量不要分層2. 代碼盡可能簡單 四、實戰方案:懶加載架構及其四大挑戰五、改進思路六、總結與思考題 一、引言 在任何后臺系統設計中,「讀多寫少」的業務場景占據主流:瀏…

在運行 Hadoop 作業時,遇到“No such file or directory”,如何在windows里打包在虛擬機里運行

最近在學習Hadoop集群map reduce分布運算過程中,經多方面排查可能是電腦本身配置的原因導致每次運行都會報“No such file or directory”的錯誤,最后我是通過打包文件到虛擬機里運行得到結果,具體步驟如下: 前提是要保證maven已經…

軟考-軟件設計師中級備考 11、計算機網絡

1、計算機網絡的分類 按分布范圍分類 局域網(LAN):覆蓋范圍通常在幾百米到幾千米以內,一般用于連接一個建筑物內或一個園區內的計算機設備,如學校的校園網、企業的辦公樓網絡等。其特點是傳輸速率高、延遲低、誤碼率低…

【C#】.net core6.0無法訪問到控制器方法,直接404。由于自己的不仔細,出現個低級錯誤,這讓DeepSeek看出來了,是什么錯誤呢,來瞧瞧

🌹歡迎來到《小5講堂》🌹 🌹這是《C#》系列文章,每篇文章將以博主理解的角度展開講解。🌹 🌹溫馨提示:博主能力有限,理解水平有限,若有不對之處望指正!&#…

當LLM遇上Agent:AI三大流派的“復仇者聯盟”

你一定聽說過ChatGPT和DeepSeek,也知道它們背后的LLM(大語言模型)有多牛——能寫詩、寫代碼、甚至假裝人類。但如果你以為這就是AI的極限,那你就too young too simple了! 最近,**Agent(智能體&a…

Spring Boot多模塊劃分設計

在Spring Boot多模塊項目中,模塊劃分主要有兩種思路:??技術分層劃分??和??業務功能劃分??。兩種方式各有優缺點,需要根據項目規模、團隊結構和業務特點來選擇。 ??1. 技術分層劃分(橫向拆分)?? 結構示例&…