# 手寫數字識別:使用PyTorch構建MNIST分類器

手寫數字識別:使用PyTorch構建MNIST分類器

在這篇文章中,我將引導你通過使用PyTorch框架構建一個簡單的神經網絡模型,用于識別MNIST數據集中的手寫數字。MNIST數據集是一個經典的機器學習數據集,包含了60,000張訓練圖像和10,000張測試圖像,每張圖像都是28x28像素的灰度手寫數字。
在這里插入圖片描述

在這里插入圖片描述

環境準備

首先,確保你的環境中安裝了PyTorch和torchvision。可以通過以下命令安裝:

pip install torch torchvision

數據加載與預處理

我們首先加載MNIST數據集,并將圖像轉換為PyTorch張量格式,以便模型可以處理。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor'''下載訓練數據集(包含訓練圖片+標簽)'''
training_data = datasets.MNIST( #跳轉到函數的內部源代碼,pycharm 按下ctrl+鼠標點擊 training_data:Datasetroot="data",#表示下載的手寫數字 到哪個路徑。60000train=True, #讀取下載后的數據 中的 訓練集download=True,#如果你之前已經下載過了,就不用再下載transform=ToTensor(), #張量,圖片是不能直接傳入神經網絡模型
)   #對于pytorch庫能夠識別的數據一般是tensor張量。'''下載測試數據集(包含訓練圖片+標簽)'''
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)
print(len(training_data))

數據可視化

為了更好地理解數據,我們可以展示一些手寫數字圖像。

''展示手寫字圖片,把訓練數據集中的前59000張圖片展示一下'''from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i+59000] #提取第59000張圖片figure.add_subplot(3, 3, i+1) #圖像窗口中創建多個小窗口,小窗口用于顯示圖片plt.title(label)plt.axis("off") # plt.show(I)#是示矢量,plt.imshow(img.squeeze(), cmap="gray")a = img.squeeze()
plt.show()

創建DataLoader

為了高效地加載數據,我們使用DataLoader來批量加載數據。

# '"創建數據DataLoader(數據加載器)開'
#  'batch_size:將數據集分成多份,每一份為batch_size個數據'
#  '優點:可以減少內存的使用,提高訓練速度。train_dataloader = DataLoader(training_data, batch_size=64) #64張圖片為一個包,train_dataloader:<torch
test_dataloader = DataLoader(test_data, batch_size=64)

模型定義

接下來,我們定義一個簡單的神經網絡模型,包含兩個隱藏層和一個輸出層。

'''定義神經網絡類的繼承這種方式'''
class NeuralNetwork(nn.Module):  #通過調用類的形式來使用神經網絡,神經網絡的模型,nn.moduledef __init__(self): #python基礎關于類,self類自已本身super().__init__() #繼承的父類初始化self.flatten = nn.Flatten() #展開,創建一個展開對象flattenself.hidden1 = nn.Linear(28*28, 128 ) #第1個參數:有多少個神經元傳入進來,第2個參數:有多少個數據傳出self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = self.flatten(x) #圖像進行展開x = self.hidden1(x)x = torch.relu(x) #激活函數,torch使用的relu函數 relu,tanhx = self.hidden2(x)x = torch.relu(x)x = self.out(x)return xmodel = NeuralNetwork().to(device) #把剛剛創建的模型傳入到Gpu
print(model)

訓練與測試

我們定義訓練和測試函數,使用交叉熵損失函數和隨機梯度下降優化器。

def train(dataloader, model, loss_fn, optimizer):model.train() #告訴模型,我要開始訓練,模型中w進行隨機化操作,已經更新w。在訓練過程中,w會被修改的
# #pytorch提供2種方式來切換訓練和測試的模式,分別是:model.train()和 model.eval()。# 一般用法是:在訓練開始之前寫上model.trian(),在測試時寫上 model.eval()batch_size_num = 1for X, y in dataloader: #其中batch為每一個數據的編號X, y = X.to(device), y.to(device) #把訓練數據集和標簽傳入cpu或GPUpred = model.forward(X) #.forward可以被省略,父類中已經對次功能進行了設置。自動初始化wloss= loss_fn(pred, y) #通過交叉熵損失函數計算損失值loss# Backpropagation 進來一個batch的數據,計算一次梯度,更新一次網絡optimizer.zero_grad() #梯度值清零loss.backward() #反向傳播計算得到每個參數的梯度值woptimizer.step() #根據梯度更新網絡w參數loss_value = loss.item() #從tensor數據中提取數據出來,tensor獲取損失值if batch_size_num % 100 ==0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):size = len(dataloader.dataset) #10000num_batches = len(dataloader) #打包的數量model.eval() #測試,w就不能再更新。test_loss, correct = 0, 0with torch.no_grad(): #一個上下文管理器,關閉梯度計算。當你確認不會調用Tensor.backward()的時候。這for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()  #test_loss是會自動累加每一個批次的損失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)   #dim=1表示每一行中的最大值對應的索引號,dim=0表示每一列中的最大值b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches #能來衡量模型測試的好壞。correct /= size #平均的正確率print(f"Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}")

訓練模型

最后,我們訓練模型并測試其性能。

loss_fn = nn.CrossEntropyLoss() #創建交叉熵損失函數對象,因為手寫字識別中一共有10個數字,輸出會有10個結果optimizer = torch.optim.SGD(model.parameters(), lr=0.01) #創建一個優化器,SGD為隨機梯度下降算法
# #params:要訓練的參數,一般我們傳入的都是model.parameters()# #lr:learning_rate學習率,也就是步長#loss表示模型訓練后的輸出結果與,樣本標簽的差距。如果差距越小,就表示模型訓練越好,越逼近干真實的模型。# train(train_dataloader, model, loss_fn, optimizer)
# test(test_dataloader, model, loss_fn)epochs = 30
for t in range(epochs):print(f"Epoch {t+1}\n")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

運行結果

在這里插入圖片描述

結論

通過這篇文章,我們成功構建了一個簡單的神經網絡模型來識別MNIST數據集中的手寫數字。這個模型展示了如何使用PyTorch進行數據處理、模型定義、訓練和測試。希望這能幫助你開始自己的深度學習項目!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/902313.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/902313.shtml
英文地址,請注明出處:http://en.pswp.cn/news/902313.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

強化學習筆記(三)——表格型方法(蒙特卡洛、時序差分)

強化學習筆記&#xff08;三&#xff09;——表格型方法&#xff08;蒙特卡洛、時序差分&#xff09; 一、馬爾可夫決策過程二、Q表格三、免模型預測1. 蒙特卡洛策略評估1) 動態規劃方法和蒙特卡洛方法的差異 2. 時序差分2.1 時序差分誤差2.2 時序差分方法的推廣 3. 自舉與采樣…

c++_csp-j算法 (4)

迪克斯特拉() 介紹 迪克斯特拉算法(Dijkstra算法)是一種用于解決單源最短路徑問題的經典算法,由荷蘭計算機科學家艾茲赫爾迪克斯特拉(Edsger W. Dijkstra)于1956年提出。迪克斯特拉算法的基本思想是通過逐步擴展已經找到的最短路徑集合,逐步更新節點到源節點的最短路…

(13)VTK C++開發示例 --- 透視變換

文章目錄 1. 概述2. CMake鏈接VTK3. main.cpp文件4. 演示效果 更多精彩內容&#x1f449;內容導航 &#x1f448;&#x1f449;VTK開發 &#x1f448; 1. 概述 在VTK&#xff08;Visualization Toolkit&#xff09;中&#xff0c;vtkPerspectiveTransform 和 vtkTransform 都是…

深入探索Qt異步編程--從信號槽到Future

概述 在現代軟件開發中,應用程序的響應速度和用戶體驗是至關重要的。尤其是在圖形用戶界面(GUI)應用中,長時間運行的任務如果直接在主線程執行會導致界面凍結,嚴重影響用戶體驗。 Qt提供了一系列工具和技術來幫助開發者實現異步編程,從而避免這些問題。本文將深入探討Qt…

基于Python的圖片/簽名轉CAD小工具開發方案

基于Python的圖片/簽名轉CAD工具開發方案 一、項目背景 傳統設計流程中&#xff0c;設計師常常需要將手寫簽名或掃描圖紙轉換為CAD格式。本文介紹如何利用Python快速開發圖像矢量化工具&#xff0c;實現&#xff1a; &#x1f4f7; 圖像自動預處理?? 輪廓精確提取?? 參數…

【倉頡 + 鴻蒙 + AI Agent】CangjieMagic框架(17):PlanReactExecutor

CangjieMagic框架&#xff1a;使用華為倉頡編程語言編寫&#xff0c;專門用于開發AI Agent&#xff0c;支持鴻蒙、Windows、macOS、Linux等系統。 這篇文章剖析一下 CangjieMagic 框架中的 PlanReactExecutor。 1 PlanReactExecutor的工作原理 #mermaid-svg-OqJUCSoxZkzylbDY…

一文了解相位陣列天線中的真時延

本文要點 真時延是寬帶帶相位陣列天線的關鍵元素之一。 真時延透過在整個信號頻譜上應用可變相移來消除波束斜視現象。 在相位陣列中使用時延單元或電路板&#xff0c;以提供波束控制和相移。 市場越來越需要更快、更可靠的通訊網絡&#xff0c;而寬帶通信系統正在努力滿…

Java中 關于編譯(Compilation)、類加載(Class Loading) 和 運行(Execution)的詳細區別解析

以下是Java中 編譯&#xff08;Compilation&#xff09;、類加載&#xff08;Class Loading&#xff09; 和 運行&#xff08;Execution&#xff09; 的詳細區別解析&#xff1a; 1. 編譯&#xff08;Compilation&#xff09; 定義 將Java源代碼&#xff08;.java文件&#x…

【KWDB 創作者計劃】_深度學習篇---松科AI加速棒

文章目錄 前言一、簡介二、安裝與配置硬件連接驅動安裝軟件環境配置三、使用步驟初始化設備調用SDK接口檢測設備狀態:集成到AI項目四、注意事項兼容性散熱固件更新安全移除五、硬件架構與技術規格核心芯片專用AI處理器內存配置接口類型物理接口虛擬接口能效比散熱設計六、軟件…

如何清理Windows系統中已失效或已刪除應用的默認打開方式設置

在使用Windows系統的過程中&#xff0c;我們可能會遇到一些問題&#xff1a;某些已卸載或失效的應用程序仍然出現在默認打開方式的列表中&#xff0c;這不僅顯得雜亂&#xff0c;還可能影響我們快速找到正確的程序來打開文件。 如圖&#xff0c;顯示應用已經被geek強制刪除&am…

NFC碰一碰發視頻推廣工具開發注意事項丨支持OEM搭建

隨著線下門店短視頻推廣需求的爆發&#xff0c;基于NFC技術的“碰一碰發視頻”推廣工具成為商業熱點。集星引擎在開發同類系統時&#xff0c;總結出六大核心開發注意事項&#xff0c;幫助技術團隊與品牌方少走彎路&#xff0c;打造真正貼合商戶需求的實用型工具&#xff1a; 一…

pgsql中使用jsonb的mybatis-plus和Spring Data JPA的配置

在pgsql中使用jsonb類型的數據時&#xff0c;實體對象要對其進行一些相關的配置&#xff0c;而mybatis和jpa中使用各不相同。 在項目中經常會結合 MyBatis-Plus 和 JPA 進行開發&#xff0c;MyBatis_plus對于操作數據更靈活&#xff0c;jpa可以自動建表&#xff0c;兩者各取其…

kotlin + spirngboot3 + spring security6 配置登錄與JWT

1. 導包 implementation("com.auth0:java-jwt:3.14.0") implementation("org.springframework.boot:spring-boot-starter-security")配置用戶實體類 Entity Table(name "users") data class User(IdGeneratedValue(strategy GenerationType.I…

【JavaWeb后端開發03】MySQL入門

文章目錄 1. 前言1.1 引言1.2 相關概念 2. MySQL概述2.1 安裝2.2 連接2.2.1 介紹2.2.2 企業使用方式(了解) 2.3 數據模型2.3.1 **關系型數據庫&#xff08;RDBMS&#xff09;**2.3.2 數據模型 3. SQL語句3.1 DDL語句3.1.1 數據庫操作3.1.1.1 查詢數據庫3.1.1.2 創建數據庫3.1.1…

人工智能在智能家居中的應用與發展

隨著人工智能&#xff08;AI&#xff09;技術的飛速發展&#xff0c;智能家居逐漸成為現代生活的重要組成部分。從智能語音助手到智能家電&#xff0c;AI正在改變我們與家居環境的互動方式&#xff0c;讓生活更加便捷、舒適和高效。本文將探討人工智能在智能家居中的應用現狀、…

【EasyPan】項目常見問題解答(自用持續更新中…)

EasyPan 網盤項目介紹 一、項目概述 EasyPan 是一個基于 Vue3 SpringBoot 的網盤系統&#xff0c;支持文件存儲、在線預覽、分享協作及后臺管理&#xff0c;技術棧涵蓋主流前后端框架及中間件&#xff08;MySQL、Redis、FFmpeg&#xff09;。 二、核心功能模塊 用戶認證 注冊…

4.1騰訊校招簡歷優化與自我介紹攻略:公式化表達+結構化呈現

騰訊校招簡歷優化與自我介紹攻略&#xff1a;公式化表達結構化呈現 在騰訊校招中&#xff0c;簡歷是敲開面試大門的第一塊磚&#xff0c;自我介紹則是展現個人魅力的黃金30秒。本文結合騰訊面試官偏好&#xff0c;拆解簡歷撰寫公式、自我介紹黃金結構及分崗位避坑指南&#xf…

【Easylive】consumes = MediaType.MULTIPART_FORM_DATA_VALUE 與 @RequestPart

【Easylive】項目常見問題解答&#xff08;自用&持續更新中…&#xff09; 匯總版 consumes MediaType.MULTIPART_FORM_DATA_VALUE 的作用 1. 定義請求的數據格式 ? 作用&#xff1a;告訴 Feign 和 HTTP 客戶端&#xff0c;這個接口 接收的是 multipart/form-data 格式的…

OpenSSL1.1.1d windows安裝包資源使用

環境&#xff1a; QT版本&#xff1a;5.14.2 用途: openssl1.1.1d版本 問題描述&#xff1a; 今天嘗試用百度云人臉識別api搭載QT的人臉識別程序&#xff0c;需要用到 QNetworkManager 訪問 https 開頭的網址。 但是遇到了QT缺乏 openssl 的相關問題&#xff0c;找了大半天…

代碼實戰保險花銷預測

文章目錄 摘要項目地址實戰代碼&#xff08;初級版&#xff09;實戰代碼&#xff08;進階版&#xff09; 摘要 本文介紹了一個完整的機器學習流程項目&#xff0c;重點涵蓋了多元線性回歸的建模與評估方法。項目詳細講解了特征工程中的多項實用技巧&#xff0c;包括&#xff1…