神經網絡識別數字圖像案例

學習資料:從零設計并訓練一個神經網絡,你就能真正理解它了_嗶哩嗶哩_bilibili

這個視頻講得相當清楚。本文是學習筆記,不是原創,圖都是從視頻上截圖的。

1. 神經網絡

2. 案例說明

具體來說,設計一個三層的神經網絡。以數字圖像作為輸入,經過神經網絡的計算,識別出圖像中的數字是幾,從而實現數字圖像的分類。

3. 視頻講解內容的提綱

4. 神經網絡的設計和實現

我們要處理的數據是28*28像素的灰色通道圖像。

這樣的灰色圖像包括了28*28=784個數據點。需要先將他展平為1*784大小的向量。然后將這個向量輸入到神經網絡中。

用一個三層神經網絡處理圖片對應的向量X。輸入成需要接收784維的圖片向量X。X里面每個維度的數據都有一個神經元來接收。因此輸入層要包含784個神經元。

隱藏成用于特征提取特征向量,將輸入的特征向量處理成更高級的特征向量。

因為手寫數字圖像識別并不復雜,所以將隱藏層的神經元個數設置為256。這樣,輸入層和隱藏層之間就會有個784*256的線性層。它可以將一個784維的輸入向量轉換為256維的輸出向量。

該輸出向量會繼續向前傳播到達輸出層。

由于最終要將數字圖像識別為0~9,十種可能的數字。因此,輸出層需要定義10個神經元,對應這十種數字。

256維的向量在經過隱藏層和輸出層之間的線性層計算后,就得到了10維的輸出結果。這個10維的向量就代表了10個數字的預測得分。

為了繼續得到輸出層的預測概率,還要將輸出層的輸出輸入到softmax層。softmax層會將10維的向量轉換為10個概率值p0~p9。p0~p9相加的總和等于1.

5. 神經網絡的Pytorch實現

import torch
from torch import nn# 定義神經網絡Network
class Network(nn.Module):def __init__(self):super().__init__()# 線性層1,輸入層和隱藏層之間的線性層self.layer1 = nn.Linear(784, 258)# 線性層2,隱藏層和輸出層之間的線性層self.layer2 = nn.Linear(256, 10)# 在前向傳播,forward函數中,輸入為圖像xdef forward(self, x):x = x.view(-1, 28 * 28) # 使用view函數,將x展平x = self.layer1(x) # 將x輸入到layer1x = torch.relu(x) # 使用relu激活return self.layer2(x) # 輸入至layer2計算結果# 這里沒有直接定義softmax層,因為后面會使用CrossEntropyLoss損失函數# 在這個損失函數中,會實現softmax的計算

6. 訓練數據的準備和處理

from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 初學只要知道大致的數據處理流程即可
if __name__ == '__main__'# 實現圖像的預處理pipelinetransform = trnasforms.Compose([# 轉換成單通道灰度圖transforms.Grayscale(num_output_channels=1),# 轉換為張量transforms.ToTensor()])# 使用ImageFolder函數,讀取數據文件夾,構建數據集dataset# 這個函數會將保持數據的文件夾的名字,作為數據的標簽,組織數據train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)# 打印他們的長度print("train_dataset length: ", len(train_dataset))print("test_dataset length: ", len(test_dataset))# 使用train_loader, 實現小批量的數據讀取# 這里設置小批量的大小,batch_size=64. 也就是每個批次,包括64個數據train_loader = DataLoader(train_datase, batch_size=64, shuffle=True)# 打印train_loader的長度print("train_loader length: ", len(train_loader))# 6000個訓練數據,如果每個小批量,讀入64個樣本,那么60000個數據會被分成938組# 938*64=60032,說明最后一組不夠64個數據# 循環遍歷train_loader# 每一次循環,都會取出64個圖像數據,作為一個小批量batchfor batch_idx, (data, label) in enumerate(train_loader)if batch_idx == 3:breakprint("batch_idx: ", batch_idx)print("data.shape: ", data.shape) # 數據的尺寸print("label: ", label.shape) # 圖像中的數字print(label)

7. 模型的訓練和測試

import torch
from torch import nn
from torch import optim
from model import Network
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoaderif __name__ == '__main__'# 圖像的預處理transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# 讀入并構造數據集train_dataset = datasets.ImageFolder(root='./mnist_images/train', transform=transform)print("train_dataset length: ", len(train_dataset))# 小批量的數據讀入train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)print("train_loader length: ", len(train_loader))# 在使用Pytorch訓練模型時,需要創建三個對象:model = Network() # 1.模型本身,就是我們設計的神經網絡optimizer = optim.Adam(model.parameters()) #2.優化器,優化模型中的參數criterion = nn.CrossEntropyLoss() #3.損失函數,分類問題,使用交叉熵損失誤差# 進入模型的循環迭代# 外層循環,代表了整個訓練數據集的遍歷次數for epoch in range(10):# 內層循環使用train_loader, 進行小批量的數據讀取for batch_idx, (data, label) in enumerate(train_loader):# 內層每循環一次,就會進行一次梯度下降算法# 包括了5個步驟# 這5個步驟是使用pytorch框架訓練模型的定式,初學時先記住即可# 1. 計算神經網絡的前向傳播結果output = model(data)# 2. 計算output和標簽label之間的損失lossloss = criterion(output, label)# 3. 使用backward計算梯度loss.backward()# 4. 使用optimizer.step更新參數optimizer.step()# 5.將梯度清零optimizer.zero_grad()if batch_idx % 100 == 0:print(f"Epoch {epoch + 1}/10"f"| Batch {batch_idx}/{len(train_loader)}"f"| Loss: {loss.item():.4f}")torch.save(model.state_dict(), 'mnist.pth')

from model import Network
from torchvision import transforms
from torchvision import datasets
import torchif __name__ == '__main__'transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),transforms.ToTensor()])# 讀取測試數據集test_dataset = datasets.ImageFolder(root='./mnist_images/test', transform=transform)print("test_dataset length: ", len(test_dataset))model = Network() # 定義神經網絡模型model.load_state_dict(torch.load('mnist.pth')) # 加載剛剛訓練好的模型文件rigth = 0 # 保存正確識別的數量for i, (x, y) in enumerate(test_dataset):output = model(x) # 將其中的數據x輸入到模型predict = output.argmax(1).item() # 選擇概率最大標簽的作為預測結果# 對比預測值predict和真實標簽yif predict == y:right += 1else:# 將識別錯誤的樣例打印出來img_path = test_dataset.samples[i][0]print(f"wrong case: predict = {predict} y = {y} img_path = {img_path}")# 計算出測試效果sample_num = len(test_dataset)acc = right * 1.0 / sample_numprint("test accuracy = %d / %d = %.3lf" % (right, sample_num, acc))

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/43766.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/43766.shtml
英文地址,請注明出處:http://en.pswp.cn/web/43766.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

如何找工作 校招 | 社招 | 秋招 | 春招 | 提前批

馬上又秋招了,作者想起以前讀書的時候,秋招踩了很多坑,但是第一份工作其實挺重要的。這里寫一篇文章,分享一些校招社招的心得。 現在大學的情況是,管就業的人,大都是沒有就業的輔導員(筆者見過…

億發512版本更新,看數據駕駛艙、掃碼揀貨、UDI序列號的新功能

如果您正尋求突破傳統業務模式的束縛,希望擁抱數字化轉型帶來的無限可能,我們誠邀您體驗億發軟件。億發專業團隊將為您提供個性化的咨詢和定制服務,幫助您的企業快速適應市場變化,實現業務模式和商業模式的創新。

【騰訊云生成式AI產品解決方案深度分析 2024】

文末有福利! 騰訊云生成式AI產品解決方案 (一) 基于生成式AI的騰訊云產品架構升級 (二) 騰訊云完善的產品矩陣,滿足不同路線客戶需求 1. 路線一 標準軟件 (1) 騰訊樂享AI助手 落地背景及挑戰在企業知識管理、培訓學習、辦公協同場景中,存…

初識C++ | 基本介紹、命名空間、輸入輸出、缺省函數、函數重載、引用、內聯函數、nullptr

基本介紹 C的起源 1979年,當時的 Bjarne Stroustrup 正在?爾實驗室從事計算機科學和軟件?程的研究?作。?對項?中復雜的軟件開 發任務,特別是模擬和操作系統的開發?作,他感受到了現有語?(如C語?)在表達能?、可…

無法定位程序輸入點kernel32.dll ——一鍵修復丟失kernel32.dll方案

無法定位程序輸入點" 錯誤通常發生在 Windows 操作系統中,當一個程序試圖加載一個 DLL(動態鏈接庫)文件中的特定函數,但無法找到該函數的入口點時。kernel32.dll 是 Windows 操作系統中的一個關鍵 DLL 文件,它包含…

Backyard二指夾爪硬件安裝與軟件配置

一、背景 每次要用機械臂做實驗時,都要重新配置好一會,尤其這個Backyard二指夾爪,各種連接線和外接電源。雖然很麻煩,但理清思路后,10分鐘就可以搞定。所以說腦力勞動的效率永遠大于體力勞動,要多想&#…

HiFi音頻pro和普通HiFi音頻

針對那些對音質要求極高、追求專業級音頻表現的用戶,音頻設備公司專門設計了HiFi 音頻Pro系列。它們在設計和性能上更為精細和高級,當然價格通常也會反映其高端定位和專業水準。相比之下,普通HiFi音頻設備雖然也能提供良好的音質,…

設置DepthBufferBits和設置DepthStencilFormat的區別

1)設置DepthBufferBits和設置DepthStencilFormat的區別 2)Unity打包exe后,游戲內拉不起Steam的內購 3)Unity 2022以上Profiler.FlushMemoryCounters耗時要怎么關掉 4)用GoodSky資產包如何實現晝夜播發不同音樂功能 這是…

【北京迅為】《i.MX8MM嵌入式Linux開發指南》-第一篇 嵌入式Linux入門篇-第十八章 Linux編寫第一個自己的命令

i.MX8MM處理器采用了先進的14LPCFinFET工藝,提供更快的速度和更高的電源效率;四核Cortex-A53,單核Cortex-M4,多達五個內核 ,主頻高達1.8GHz,2G DDR4內存、8G EMMC存儲。千兆工業級以太網、MIPI-DSI、USB HOST、WIFI/BT…

Python-找客戶軟件

軟件功能 請求代碼: 填充表格: 可以search全國各個區縣的所有企業信息,過濾手機號、查看是否續存/在業狀態。方便找客戶。 支持定-制-其他引-留-阮*件(XHSS,DYY,KS,Bi-li*Bi-li) V*…

AutoHotKey自動熱鍵(八)腳本快速暫停與重新加載

我們在編輯腳本的時候,可以添加快捷鍵來改變腳本的狀態 ;暫停腳本 F11::Suspend;重置腳本 F12::Reloadreload用來重置腳本 我們可以在腳本開頭加上標簽提示腳本重啟成功 ToolTip, 腳本已經重啟 Sleep, 1000 ToolTip第二個ToolTip是用來關閉提示器用的 這個提示功能一定要寫…

oracle dba常用腳本2

11、表空間實有、現有、使用情況查詢對比 SELECT TABLESPACE_NAME 表空間,TO_CHAR(ROUND(BYTES / 1024, 2), 99990.00) || 實有,TO_CHAR(ROUND(FREE / 1024, 2), 99990.00) || G 現有,TO_CHAR(ROUND((BYTES - FREE) / 1024, 2), 99990.00) || G 使用,TO_CHAR(ROUND(10000 * US…

【開源合規】開源許可證風險場景詳細解讀

文章目錄 前言關于BlackDuck許可證風險對比圖弱互惠型許可證舉個例子具體示例LGPL系列LGPL-2.0-onlyLGPL-2.0-or-laterLGPL-2.1-onlyLGPL-2.1-or-laterLGPL-3.0-onlyLGPL-3.0-or-laterMPL系列MPL-1.0MPL-1.1MPL-2.0EPL系列EPL-1.0EPL-2.0互惠型許可證GPL系列GPL-1.0GPL-2.0GPL-…

常用錄屏軟件,分享這四款寶藏軟件!

在數字化時代,錄屏軟件已經成為我們日常工作、學習和娛樂中不可或缺的工具。無論你是需要錄制教學視頻、游戲過程,還是進行產品演示,一款高效、易用的錄屏軟件都能讓你的工作事半功倍。今天,就為大家揭秘四款寶藏級錄屏軟件&#…

重磅|九科信息完成諾輝領投的B1輪融資,累計融資已達億級

近日,九科信息宣布B1輪融資順利完成。本輪由深圳諾輝嶺南投資管理有限公司領投,深創投索斯福(深圳)私募創業投資基金跟投。 截至本輪,九科信息累計融資達億級。但真正讓九科人驕傲的,并非融資本身&#xff…

無法找到模塊“@wangeditor/editor-for-vue”的聲明文件

vue3項目中使用wangeditor/editor遇到的問題 開發環境不管紅線報錯正常使用 打包的時候就會報錯了 1.安裝依賴 pnpm install --save wangeditor/editor wangeditor/editor-for-vuenext 2.遇到的問題 3.解決方法 在src目錄下面創建 wangeditor-types.d.ts 文件 代碼如下 de…

IEC62056標準體系簡介-6.IEC62056標準體系的特點

相對于其它常用的計量儀表通信協議,如IEC1107、IEC 62056-31、IEC 60870-5-102以及北美使用的通信協議ANSI C12.18(光口)、C12.19(公用表)和C12.21(電話通信)和國內使用的DL/T645等,…

The First項目報告:創新型金融生態Lista DAO

一、Lista DAO是什么? LISTA是Lista DAO的原生加密協議代幣,設計為一種可互操作的實用代幣,旨在促進去中心化金融(DeFi)領域內的支付、治理與激勵。LISTA的誕生源于Lista DAO項目,該項目是一個基于BNB鏈的…

springboot3 集成GraalVM

目錄 安裝GraalVM 配置環境變量 Pom.xml 配置 build包 測試 安裝GraalVM Download GraalVM 版本和JDK需要自己選擇 配置環境變量 Jave_home 和 path 設置setting.xml <profile><id>graalvm-ce-dev</id><repositories><repository><id&…

2024最新版pycharm安裝激火教程,附安裝包+激huo馬,Python教程,pycharm安裝包!!

PyCharm的安裝 PyCharm 是一個專門為 Python 開發者設計的 IDE&#xff0c;它同樣具有代碼導航、重構、調試和分析等功能。PyCharm 支持多種項目類型&#xff0c;如普通項目、Python 測試項目、Django 項目等&#xff0c;并提供了大量的內置模板和插件&#xff0c;以幫助您更快…