卷積神經網絡--手寫數字識別

本文我們通過搭建卷積神經網絡模型,實現手寫數字識別。

pytorch中提供了手寫數字的數據集?,我們可以直接從pytorch中下載

MNIST中包含70000張手寫數字圖像:60000張用于訓練,10000張用于測試

圖像是灰度的,28x28像素

首先,下載數據集

import torch
from torchvision import datasets #封裝與圖像相關的模型,數據集
from torchvision.transforms import ToTensor # #數據轉換,張量,將其他類型的數據轉換為tensor張量training_data=datasets.MNIST(root='data',#表示下載的手寫數字到哪個路徑train=True,#讀取下載后數據中的訓練集download=True,#如果之前已經下載過,就不用再下載transform=ToTensor(),#張量,圖片不能直接傳入神經網絡模型
)test_data=datasets.MNIST(root='data',train=False,download=True,transform=ToTensor(),
)

打包數據

from torch.utils.data import DataLoader train_dataloader=DataLoader(training_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)

判斷當前設備是否支持GPU

device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'using {device} device')

構建卷積神經網絡模型

from torch import nn #導入神經網絡模塊class CNN(nn.Module):def __init__(self):#初始化類super(CNN,self).__init__()#初始化父類self.conv1=nn.Sequential(# 將多個層(如卷積、激活函數、池化等)按順序打包,輸入數據會??依次通過這些層??,無需手動編寫每一層的傳遞邏輯。nn.Conv2d(#2D 卷積層,提取空間特征。in_channels=1,#輸入通道數out_channels=16,#輸出通道數kernel_size=3,#卷積核大小stride=1,#步長padding=1,#填充),nn.ReLU(),#激活函數,引入非線性變換,使得神經網絡能夠學習復雜的非線性變換,增強表達能力nn.MaxPool2d(kernel_size=2)# 2x2最大池化(尺寸減半))self.conv2=nn.Sequential(nn.Conv2d(16,32,3,1,1),nn.ReLU(),# nn.Conv2d(32,32,3,1,1),# nn.ReLU(),nn.MaxPool2d(2),)self.conv3=nn.Sequential(nn.Conv2d(32,64,3,1,1))self.out=nn.Linear(64*7*7,10)def forward(self,x):#前向傳播x=self.conv1(x)x=self.conv2(x)x=self.conv3(x)x=x.view(x.size(0),-1)# 展平為向量(保留batch_size,合并其他維度)output=self.out(x)  # 全連接層輸出return output

返回的output結果大致如圖所示

?模型傳入GPU

model=CNN().to(device)
print(model)

??損失函數,衡量的是??模型預測的概率分布??與??真實的類別分布??之間的差異。

loss_fn=nn.CrossEntropyLoss()

??優化器,用于在訓練神經網絡時更新模型參數,目的是??在神經網絡訓練過程中,自動調整模型的參數(權重和偏置),以最小化損失函數??。

optimizer=torch.optim.Adam(model.parameters(),lr=0.01)

?模型訓練

def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num=1for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)loss=loss_fn(pred,y)# Backpropagation 進來一個batch的數據,計算一次梯度,更新一次網絡optimizer.zero_grad()               #梯度值清零loss.backward()                     #反向傳播計算得到每個參數的梯度值optimizer.step()                    #根據梯度更新網絡參數loss_value=loss.item()if batch_size_num%100==0:print(f'loss:{loss_value:>7f}[number:{batch_size_num}]')batch_size_num+=1epochs=10for i in range(epochs):print(f'第{i}次訓練')train(train_dataloader, model, loss_fn, optimizer)

模型測試

def test(dataloader,model,loss_fn):size = len(dataloader.dataset)# 測試集總樣本數num_batches = len(dataloader)# 測試集總批次數model.eval()#進入到模型的測試狀態,所有的卷積核權重被設為只讀模式test_loss, correct = 0, 0# 初始化累計損失和正確預測數#禁用梯度計算with torch.no_grad():#一個上下文管理器,關閉梯度計算。當你確認不會調用Tensor.backward()的時候。這可以減少計算所用內存消耗。for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)test_loss+=loss_fn(pred,y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()a=(pred.argmax(1)==y)b=(pred.argmax(1)==y).type(torch.float)test_loss/=num_batchescorrect/=sizeprint(f'Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}')test(test_dataloader,model,loss_fn)

得到結果如圖所示

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/77667.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/77667.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/77667.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

大文件分片上傳進階版(新增md5校驗、上傳進度展示、并行控制,智能分片、加密上傳、斷點續傳、自動重試),實現四位一體的網絡感知型大文件傳輸系統?

上篇文章我們總結了大文件分片上傳的主要核心,但是我對md5校驗和上傳進度展示這塊也比較感興趣,所以在deepseek的幫助下,擴展了一下我們的代碼,如果有任何問題和想法,非常歡迎大家在評論區與我交流,我需要學…

C# 點擊導入,將需要的參數傳遞到彈窗的頁面

點擊導入按鈕,獲取本頁面的datagridview標題的結構,并傳遞到導入界面。 新增一個datatable用于存儲datagridview的caption和name,這里用的是devexpress組件中的gridview。 DataTable dt new DataTable(); DataColumn CAPTION …

android的 framework 是什么

Android的Framework(框架)是Android系統的核心組成部分,它為開發者提供了一系列的API(應用程序編程接口),使得開發者能夠方便地創建各種Android應用。以下是關于它的詳細介紹: 位置與架構 在A…

【MySQL】表的約束(主鍵、唯一鍵、外鍵等約束類型詳解)、表的設計

目錄 1.數據庫約束 1.1 約束類型 1.2 null約束 — not null 1.3 unique — 唯一約束 1.4 default — 設置默認值 1.5 primary key — 主鍵約束 自增主鍵 自增主鍵的局限性:經典面試問題(進階問題) 1.6 foreign key — 外鍵約束 1.7…

數據結構-C語言版本(三)棧

數據結構中的棧:概念、操作與實戰 第一部分 棧分類及常見形式 棧是一種遵循后進先出(LIFO, Last In First Out)原則的線性數據結構。棧主要有以下幾種實現形式: 1. 數組實現的棧(順序棧) #define MAX_SIZE 100typedef struct …

如何以特殊工藝攻克超薄電路板制造難題?

一、超薄PCB的行業定義與核心挑戰 超薄PCB通常指厚度低于1.0毫米的電路板,而高端產品可進一步壓縮至0.4毫米甚至0.2毫米以下。這類電路板因體積小、重量輕、熱傳導性能優異,被廣泛應用于折疊屏手機、智能穿戴設備、醫療植入器械及新能源汽車等領域。然而…

AI 賦能 3D 創作!Tripo3D 全功能深度解析與實操教程

大家好,歡迎來到本期科技工具分享! 今天要給大家帶來一款革命性的 AI 3D 模型生成平臺 ——Tripo3D。 無論你是游戲開發者、設計師,還是 3D 建模愛好者,只要想降低創作門檻、提升效率,這款工具都值得深入了解。 接下…

如何理解抽象且不易理解的華為云 API?

API的概念在華為云的使用中非常抽象,且不容易理解,用通俗的語言 形象的比喻來講清楚——什么是華為云 API,怎么用,背后原理,以及主要元素有哪些,盡量讓新手也能明白。 🧠 一句話先理解&#xf…

第 7 篇:總結與展望 - 時間序列學習的下一步

第 7 篇:總結與展望 - 時間序列學習的下一步 (圖片來源: Guillaume Hankenne on Pexels) 恭喜你!如果你一路跟隨這個系列走到了這里,那么你已經成功地完成了時間序列分析的入門之旅。我們從零開始,一起探索了時間數據的基本概念、…

PPT無法編輯怎么辦?原因及解決方法全解析

在日常辦公中,我們經常會遇到需要編輯PPT的情況。然而,有時我們會發現PPT文件無法編輯,這可能由多種原因引起。今天我們來看看PPT無法編輯的幾種常見原因,并提供實用的解決方法,幫助你輕松應對。 原因1:文…

前端面試題---GET跟POST的區別(Ajax)

GET 和 POST 是兩種 HTTP 請求方式,它們在傳輸數據的方式和所需空間上有一些重要區別: ? 一句話概括: GET 數據放在 URL 中,受限較多;POST 數據放在請求體中,空間更大更安全。 📦 1. 所需空間…

第 5 篇:初試牛刀 - 簡單的預測方法

第 5 篇:初試牛刀 - 簡單的預測方法 經過前面四篇的學習,我們已經具備了處理時間序列數據的基本功:加載、可視化、分解以及處理平穩性。現在,激動人心的時刻到來了——我們要開始嘗試預測 (Forecasting) 未來! 預測是…

從代碼學習深度學習 - 學習率調度器 PyTorch 版

文章目錄 前言一、理論背景二、代碼解析2.1. 基本問題和環境設置2.2. 訓練函數2.3. 無學習率調度器實驗2.4. SquareRootScheduler 實驗2.5. FactorScheduler 實驗2.6. MultiFactorScheduler 實驗2.7. CosineScheduler 實驗2.8. 帶預熱的 CosineScheduler 實驗三、結果對比與分析…

k8s 基礎入門篇之開啟 firewalld

前面在部署k8s時,都是直接關閉的防火墻。由于生產環境需要開啟防火墻,只能放行一些特定的端口, 簡單記錄一下過程。 1. firewall 與 iptables 的關系 1.1 防火墻(Firewall) 定義: 防火墻是網絡安全系統&…

RSS 2025|蘇黎世提出「LLM-MPC混合架構」增強自動駕駛,推理速度提升10.5倍!

論文題目:Enhancing Autonomous Driving Systems with On-Board Deployed Large Language Models 論文作者:Nicolas Baumann,Cheng Hu,Paviththiren Sivasothilingam,Haotong Qin,Lei Xie,Miche…

list的學習

list的介紹 list文檔的介紹 list是可以在常數范圍內在任意位置進行插入和刪除的序列式容器,并且該容器可以前后雙向迭代。list的底層是雙向鏈表結構,雙向鏈表中每個元素存儲在互不相關的獨立節點中,在節點中通過指針指向其前一個元素和后一…

生物信息學技能樹(Bioinformatics)與學習路徑

李升偉 整理 生物信息學是一門跨學科領域,涉及生物學、計算機科學以及統計學等多個方面。以下是關于生物信息學的學習路徑及相關技能的詳細介紹。 一、基礎理論知識 1. 生物學基礎知識 需要掌握分子生物學、遺傳學、細胞生物學等相關概念。 對基因組結構、蛋白質…

AOSP Android14 Launcher3——遠程窗口動畫關鍵類SurfaceControl詳解

在 Launcher3 執行涉及其他應用窗口(即“遠程窗口”)的動畫時,例如“點擊桌面圖標啟動應用”或“從應用上滑回到桌面”的過渡動畫,SurfaceControl 扮演著至關重要的角色。它是實現這些跨進程、高性能、精確定制動畫的核心技術。 …

超詳細實現單鏈表的基礎增刪改查——基于C語言實現

文章目錄 1、鏈表的概念與分類1.1 鏈表的概念1.2 鏈表的分類 2、單鏈表的結構和定義2.1 單鏈表的結構2.2 單鏈表的定義 3、單鏈表的實現3.1 創建新節點3.2 頭插和尾插的實現3.3 頭刪和尾刪的實現3.4 鏈表的查找3.5 指定位置之前和之后插入數據3.6 刪除指定位置的數據和刪除指定…

17.整體代碼講解

從入門AI到手寫Transformer-17.整體代碼講解 17.整體代碼講解代碼 整理自視頻 老袁不說話 。 17.整體代碼講解 代碼 import collectionsimport math import torch from torch import nn import os import time import numpy as np from matplotlib import pyplot as plt fro…