pytorch框架認識--手寫數字識別

手寫數字是機器學習中非常經典的案例,本文將通過pytorch框架,利用神經網絡來實現手寫數字識別

pytorch中提供了手寫數字的數據集,我們可以直接從pytorch中下載

MNIST中包含70000張手寫數字圖像:60000張用于訓練,10000張用于測試

圖像是灰度的,28x28像素

下載數據集

import torch
from torchvision import datasets#封裝了許多與圖像相關的模型,數據集
from torchvision.transforms import ToTensor#數據類型轉換,將其他類型數據轉換為tensor張量'''下載訓練數據集(包含訓練圖片+標簽)'''
training_data = datasets.MNIST(  #導入數據集root="data",#下載的數據集到哪個路徑train=True,#讀取訓練集download=True,#如果已經下載,就不用再下載transform=ToTensor(),#數據類型轉換
)   '''下載測試數據集(包含訓練圖片+標簽) '''
test_data = datasets.MNIST(#導入數據集root="data",#下載的數據集到哪個路徑train=False,#讀取測試集download=True,#如果已經下載,就不用再下載transform=ToTensor(),#數據類型轉換
)

數據可視化,展示手寫數字

from matplotlib import pyplot as plt
figure=plt.figure()
for i in range(9):img,label=training_data[i+59000]figure.add_subplot(3,3,i+1)plt.title(label)plt.axis('off')plt.imshow(img.squeeze(),cmap='gray')a=img.squeeze()
plt.show()

得到結果如下

?打包數據

from torch.utils.data import DataLoader #數據包管理工具,打包數據train_dataloader=DataLoader(training_data,batch_size=64)#64張圖片為一個包
test_dataloader=DataLoader(test_data,batch_size=64)

判斷當前設備是否支持GPU

device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'using {device} device')

構建神經網絡模型

from torch import nn #導入神經網絡模塊class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten=nn.Flatten()self.hidden1=nn.Linear(28*28,128)self.hidden2=nn.Linear(128,256)self.out=nn.Linear(256,10)def forward(self,x):#前向傳播x=self.flatten(x)x=self.hidden1(x)x=torch.relu(x)#引入非線性變換,使神經網絡能夠學習復雜的非線性關系,增強表達能力。x=self.hidden2(x)x=torch.relu(x)x=self.out(x)return x

返回的x結果大致如圖所示?

模型傳入GPU

model=NeuralNetwork().to(device)
print(model)

?損失函數

loss_fn = nn.CrossEntropyLoss() #創建交叉熵損失函數對象,因為手寫字識別中一共有10個數字,輸出會有10個結果

?優化器,用于在訓練神經網絡時更新模型參數,目的是??在神經網絡訓練過程中,自動調整模型的參數(權重和偏置),以最小化損失函數??。

optimizer=torch.optim.Adam(model.parameters(),lr=0.01)#獲取模型中所有需要訓練的參數

?模型訓練

def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num=1for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)#自動初始化 w權值loss=loss_fn(pred,y) #通過交叉熵損失函數計算損失值loss#更新模型參數以最小化損失函數。optimizer.zero_grad()#梯度值清零loss.backward()#反向傳播計算得到每個參數的梯度值optimizer.step() #根據梯度更新網絡參數loss_value=loss.item()if batch_size_num%100==0:print(f'loss:{loss_value:7f}[number:{batch_size_num}]')batch_size_num+=1epochs=10for i in range(epochs):print(f'第{i}次訓練')train(train_dataloader, model, loss_fn, optimizer)

模型測試

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()#進入到模型的測試狀態,所有的卷積核權重被設為只讀模式test_loss, correct = 0, 0with torch.no_grad():   #一個上下文管理器,關閉梯度計算。當你確認不會調用Tensor.backward()的時候。這可以減少計算所用內存消耗。for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item() #correct += (pred.argmax(1) == y).type(torch.float).sum().item()#a = (pred.argmax(1) == y)  #b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizeprint(f"Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}")test(test_dataloader,model,loss_fn)

得到結果如圖

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/76974.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/76974.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/76974.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

WPF 使用依賴注入后關閉窗口程序不結束

原因是在ViewModel中在構造函數中注入了Window 對象,即使沒有使用,主窗口關閉程序不會退出,即使 ViewModel 是 AddTransient 注入的。 解決方法:不使用構造函數注入Window,通過GetService獲取Window 通過注入對象調用…

用戶管理(添加和刪除,查詢信息,切換用戶,查看登錄用戶,用戶組,配置文件)

目錄 添加和刪除用戶 查詢用戶信息 切換用戶 查看當前的操作用戶是誰 查看首次登錄的用戶是誰 用戶組(對屬于同個角色的用戶統一管理) 新增組 刪除組 添加用戶的同時,指定組 修改用戶的組 組的配置文件(/etc/group&…

PyTorch學習-小土堆教程

網絡搭建torch.nn.Module 卷積操作 torch.nn.functional.conv2d(input, weight, biasNone, stride1, padding0, dilation1, groups1) 神經網絡-卷積層

MVCC詳細介紹及面試題

目錄 1.什么是mvcc? 2.問題引入 3. MVCC實現原理? 3.1 隱藏字段 3.2 undo log 日志 3.2.1 undo log版本鏈 3.3 readview 3.3.1 當前讀 ?編輯 3.3.2 快照讀 3.3.3 ReadView中4個核心字段 3.3.4 版本數據鏈訪問的規則(了解&#x…

企業級Active Directory架構設計與運維管理白皮書

企業級Active Directory架構設計與運維管理白皮書 第一章 多域架構設計與信任管理 1.1 企業域架構拓撲設計 1.1.1 林架構設計規范 林根域規劃原則: 采用三段式域名結構(如corp.enterprise.com),避免使用不相關的頂級域名架構主…

android11 DevicePolicyManager淺析

目錄 📘 簡單定義 📘應用啟用設備管理者 📂 文件位置 🧠 DevicePolicyManager 功能分類舉例 🛡? 1. 安全策略控制 📷 2. 控制硬件功能 🧰 3. 應用管理 🔒 4. 用戶管理 &am…

Java學習手冊:Java線程安全與同步機制

在Java并發編程中,線程安全和同步機制是確保程序正確性和數據一致性的關鍵。當多個線程同時訪問共享資源時,如果不加以控制,可能會導致數據不一致、競態條件等問題。本文將深入探討Java中的線程安全問題以及解決這些問題的同步機制。 線程安…

PyTorch核心函數詳解:gather與where的實戰指南

PyTorch中的torch.gather和torch.where是處理張量數據的關鍵工具,前者實現基于索引的靈活數據提取,后者完成條件篩選與動態生成。本文通過典型應用場景和代碼演示,深入解析兩者的工作原理及使用技巧,幫助開發者提升數據處理的靈活…

聲學測溫度原理解釋

已知聲速,就可以得到溫度。 不同溫度下的勝訴不同。 25度的聲速大約346m/s 絕對溫度-273度 不同溫度下的聲速。 FPGA 通過測距雷達測溫度,固定測量距離,或者可以測出當前距離。已知距離,然后雷達發出聲波到接收到回波的時間&a…

【網絡篇】UDP協議的封裝分用全過程

大家好呀 我是浪前 今天講解的是網絡篇的第二章:UDP協議的封裝分用 我們的協議最開始是OSI七層網絡協議 這個OSI 七層網絡協議 是計算機的大佬寫的,但是這個協議一共有七層,太多了太麻煩了,于是我們就把這個七層網絡協議就簡化為…

spring-ai-alibaba使用Agent實現智能機票助手

示例目標是使用 Spring AI Alibaba 框架開發一個智能機票助手,它可以幫助消費者完成機票預定、問題解答、機票改簽、取消等動作,具體要求為: 基于 AI 大模型與用戶對話,理解用戶自然語言表達的需求支持多輪連續對話,能…

嵌入式C語言高級編程:OOP封裝、TDD測試與防御性編程實踐

一、面向對象編程(OOP) 盡管 C 語言并非面向對象編程語言,但借助一些編程技巧,也能實現面向對象編程(OOP)的核心特性,如封裝、繼承和多態。 1.1 封裝 封裝是把數據和操作數據的函數捆綁在一起,對外部隱藏…

藍橋杯 web 常考到的一些知識點

filter:filter方法創建一個新數組,其包含通過所提供函數實現的測試的所有元素。這個 方法不會改變原數組,而是返回一個新的數組。 map:map方法創建一個新數組,其結果是該數組中的每個元素都調用一個提供的函數后的 返回…

音視頻小白系統入門筆記-0

本系列筆記為博主學習李超老師課程的課堂筆記&#xff0c;僅供參閱 音視頻小白系統入門課 音視頻基礎ffmpeg原理 緒論 ffmpeg推流 ffplay/vlc拉流 使用rtmp協議 ffmpeg -i <source_path> -f flv rtmp://<rtmp_server_path> 為什么會推流失敗&#xff1f; 默認…

mysql按條件三表并聯查詢

下面為你呈現一個 MySQL 按條件三表并聯查詢的示例。假定有三個表&#xff1a;students、courses 和 enrollments&#xff0c;它們的結構和關聯如下&#xff1a; students 表&#xff1a;包含學生的基本信息&#xff0c;有 student_id 和 student_name 等字段。courses 表&…

UML之序列圖的消息

序列圖表現各參與者之間為完成某個行為而發生的交互及其時間順序&#xff0c;序列圖中的交互通過消息實現。消息是從一條生命線到另一條生命線的通信&#xff0c;它們通常是水平或傾斜向下的箭頭&#xff0c;從發送方生命線離開&#xff0c;到達接收方生命線。如果需要&#xf…

UniAD:自動駕駛的統一架構 - 創新與挑戰并存

引言 自動駕駛技術正經歷一場架構革命。傳統上&#xff0c;自動駕駛系統采用模塊化設計&#xff0c;將感知、預測和規劃分離為獨立組件。而上海人工智能實驗室的OpenDriveLab團隊提出的UniAD&#xff08;Unified Autonomous Driving&#xff09;則嘗試將這些任務整合到一個統一…

如何寫好合同管理系統需求分析

引言 在當今企業數字化轉型的浪潮中&#xff0c;合同管理系統作為企業法律合規和商業運營的重要支撐工具&#xff0c;其需求分析的準確性和完整性直接關系到系統建設的成敗。本文基于Volere需求過程方法論&#xff0c;結合江鈴汽車集團合同管理系統需求規格說明書實踐案例&…

libevent服務器附帶qt界面開發(附帶源碼)

本章是入門章節&#xff0c;講解如何實現一個附帶界面的服務器&#xff0c;后續會完善與優化 使用qt編譯libevent源碼演示視頻qt的一些知識 1.主要功能有登錄界面 2.基于libevent實現的服務器的業務功能 使用qt編譯libevent 下載這個&#xff0c;其他版本也可以 主要是github上…

八、自動化函數

1.元素的定位 web自動化測試的操作核心是能夠找到頁面對應的元素&#xff0c;然后才能對元素進行具體的操作。 常見的元素定位方式非常多&#xff0c;如id,classname,tagname,xpath,cssSelector 常用的主要由cssSelector和xpath 1.1 cssSelector選擇器 選擇器的功能&#x…