深度學習(魚書)day04--手寫數字識別項目實戰

深度學習(魚書)day04–手寫數字識別項目實戰

魚書的相關源代碼下載:
點擊鏈接:http://www.ituring.com.cn/book/1921
點擊“隨書下載”
在這里插入圖片描述
第三項就是源代碼:
在這里插入圖片描述
解壓后,在pycharm(或其它IDE)中打開此文件夾查看或運行即可。(紅框內是本人自建的文件)
在這里插入圖片描述

一、MNIST數據集

  • 和求解機器學習問題的步驟(分成學習和推理兩個階段進行)一樣,使用神經網絡解決問題時,也需要首先使用訓練數據(學習數據)進行權重參數的學習;進行推理時,使用剛才學習到的參數,對輸入數據進行分類

  • 這里我們來進行手寫數字圖像的分類。假設學習已經全部結束,我們使用學習到的參數,先實現神經網絡的“推理處理”。這個推理處理也稱為神經網絡的前向傳播(forward propagation)

  • 這里使用的數據集是MNIST手寫數字圖像集。MNIST是機器學習領域最有名的數據集之一,被應用于從簡單的實驗到發表的論文研究等各種場合。MNIST數據集是由0到9的數字圖像構成的。訓練圖像有6萬張,測試圖像有1萬張,這些圖像可以用于學習和推理。MNIST數據集的一般使用方法是,先用訓練圖像進行學習,再用學習到的模型度量能在多大程度上對測試圖像進行正確的分類。
    在這里插入圖片描述

  • MNIST的圖像數據是28像素 × 28像素的灰度圖像1通道),各個像素的取值在0到255之間。每個圖像數據都相應地標有“7”“2”“1”等標簽

本書提供了便利的Python腳本mnist.py,該腳本支持從下載MNIST數據集到將這些數據轉換成NumPy數組等處理(mnist.py在dataset目錄下)。使用mnist.py時,當前目錄必須是ch01、ch02、ch03、…、ch08目錄中的一個。使用mnist.py中的load_mnist()函數,就可以按下述方式輕松讀入MNIST數據。

import sys, os
sys.path.append(os.pardir)  # 為了導入父目錄中的文件而進行的設定
from dataset.mnist import load_mnist(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)

在這里插入圖片描述

load_mnist()它負責加載并預處理 MNIST 數據集,使其適合機器學習模型的訓練和測試。

def load_mnist(normalize=True, flatten=True, one_hot_label=False):"""讀入MNIST數據集Parameters----------normalize : 將圖像的像素值正規化為0.0~1.0one_hot_label : one_hot_label為True的情況下,標簽作為one-hot數組返回one-hot數組是指[0,0,1,0,0,0,0,0,0,0]這樣的數組flatten : 是否將圖像展開為一維數組Returns-------(訓練圖像, 訓練標簽), (測試圖像, 測試標簽)"""if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) 

下面詳細解釋這個函數的功能、參數和返回值:

參數說明

  1. normalize(默認 True
    • 是否對圖像像素值進行歸一化(將 0-255的像素值縮放到 0.0-1.0的浮點數)。
    • 如果設為 False,則保持原始的 0-255uint8格式。
  2. flatten(默認 True
    • 是否將圖像展平為一維數組(784維向量)。
    • 如果設為 False,則保留原始圖像形狀 (1, 28, 28)(單通道,28×28 像素)。
  3. one_hot_label(默認 False
    • 是否將標簽轉換為 one-hot 編碼(例如,數字 3變為 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0])。
    • 如果設為 False,則標簽保持為原始的數字(0-9)。

Python有 pickle這個便利的功能。這個功能可以將程序運行中的對象保存為文件。如果加載保存過的 pickle文件,可以立刻復原之前程序運行中的對象。用于讀入MNIST數據集的load_mnist()函數內部也使用了 pickle功能(在第 2次及以后讀入時)。利用 pickle功能,可以高效地完成MNIST數據的準備工作

顯示MNIST圖像:

import sys, os
sys.path.append(os.pardir)  # 為了導入父目錄中的文件而進行的設定
from dataset.mnist import load_mnist
import numpy as np
from PIL import Imagedef img_show(img):pil_img = Image.fromarray(np.uint8(img))pil_img.show()(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)img = x_train[1]
label = t_train[1]
print(label)
print(img.shape)img = img.reshape(28, 28)
print(img.shape)img_show(img)

在這里插入圖片描述

注意的是,flatten=True時讀入的圖像是以一列(一維)NumPy數組的形式保存的,因此,顯示圖像時,需要把它變為原來的28像素 × 28像素的形狀

二、神經網絡的推理處理

神經網絡的輸入層有784個神經元,輸出層有10個神經元。

輸入層的784這個數字來源于圖像大小的28 × 28 = 784,輸出層的10這個數字來源于10類別分類(數字0到9,共10類別)。

此外,這個神經網絡有2個隱藏層,第1個隱藏層有50個神經元,第2個隱藏層有100個神經元。這個50和100可以設置為任何值。下面我們先定義**get_data()、init_network()、predict()**這3個函數。

def get_data():(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)return x_test, t_testdef init_network():with open("sample_weight.pkl", 'rb') as f:network = pickle.load(f)return networkdef predict(network, x):W1, W2, W3 = network['W1'], network['W2'], network['W3']b1, b2, b3 = network['b1'], network['b2'], network['b3']a1 = np.dot(x, W1) + b1z1 = sigmoid(a1)a2 = np.dot(z1,W2)z2 = sigmoid(a2)a3 = np.dot(z2, W3) + b3y = softmax(a3)return y

init_network()會讀入保存在pickle文件sample_weight.pkl中的學習到的權重參數。這個文件中以字典變量的形式保存了權重和偏置參數。我們用這3個函數來實現神經網絡的推理處理。然后,評價它的識別精度(accuracy),即能在多大程度上正確分類。

因為之前我們假設學習已經完成,所以學習到的參數被保存下來。假設保存在sample_weight.pkl文件中,在推理階段,我們直接加載這些已經學習到的參數。

x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):y = predict(network, x[i])p = np.argmax(y) # 獲取概率最大的下標if p == t[i]:accuracy_cnt += 1
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

在這里插入圖片描述

predict()函數以NumPy數組的形式輸出各個標簽對應的概率。比如輸出[0.1, 0.3, 0.2, …, 0.04]的數組,該數組表示“0”的概率為0.1,“1”的概率為0*.*3。

我們取出這個概率列表中的最大值的索引(第幾個元素的概率最高),作為預測結果。可以用np.argmax(x)函數取出數組中的最大值的索引,np.argmax(x)將獲取被賦給參數x的數組中的最大值元素的索引。

最后,比較神經網絡所預測的答案和正確解標簽,將回答正確的概率作為識別精度。

正規化:將normalize設置成True后,函數內部會進行轉換,將圖像的各個像素值除以255,使得數據的值在0.0~1.0的范圍內。像這樣把數據限定到某個范圍內的處理稱為正規化(normalization)

預處理:對神經網絡的輸入數據進行某種既定的轉換稱為預處理(pre-processing)。這里,作為對輸入圖像的一種預處理,我們進行了正規化

實際上,很多預處理都會考慮到數據的整體分布。比如,利用數據整體的均值或標準差,移動數據,使數據整體以 0為中心分布,或者進行正規化,把數據的延展控制在一定范圍內。除此之外,還有將數據整體的分布形狀均勻化的方法,即數據白化(whitening)等。

三、批處理

network = init_network()
W1, W2, W3 = network['W1'], network['W2'], network['W3']
print(x.shape)
print(x[0].shape)
print(W1.shape)
print(W2.shape)
print(W3.shape)

在這里插入圖片描述

通過上述結果來確認一下多維數組的對應維度的元素個數是否一致,省略了偏置:

在這里插入圖片描述

現在我們來考慮打包輸入多張圖像的情形。比如,我們想用predict()函數一次性打包處理100張圖像。為此,可以把x的形狀改為100 × 784,將100張圖像打包作為輸入數據:
在這里插入圖片描述

輸出數據的形狀為100 × 10。這表示輸入的100張圖像的結果被一次性輸出了。比如,x[0]和y[0]中保存了第0張圖像及其推理結果,x[1]和y[1]中保存了第1張圖像及其推理結果。

這種打包式的輸入數據稱為批(batch)。批有“捆”的意思,圖像就如同紙幣一樣扎成一捆。

批處理可以縮短處理時間。這是因為大多數處理數計算的庫都進行了能夠高效處理大型數組運算的最優化。并且,在神經網絡的運算中,當數據傳送成為瓶頸時,批處理可以減輕數據總線的負荷(嚴格地講,相對于數據讀入,可以將更多的時間用在計算上)。也就是說,批處理一次性計算大型數組要比分開逐步計算各個小型數組速度更快。

batch_size = 100 # 批數量
accuracy_cnt = 0
for i in range(0,len(x),batch_size):batch_x = x[i:i+batch_size]batch_y = predict(network, x[i:i+batch_size])p = np.argmax(batch_y,axis=1)  # 獲取概率最大的下標accuracy_cnt += np.sum(p == t[i:i+batch_size])
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

在這里插入圖片描述

部分代碼詳解:

range()函數若指定為range(start, end),則會生成一個由startend-1之間的整數構成的列表。range(start, end, step)這樣指定3個整數,則生成的列表中的下一個元素會增加step指定的值。

x[i:i+batch_n]會取出從第i個到第i+batch_n個之間的數據。本例中是像x[0:100]、x[100:200]……這樣,從頭開始以100為單位將數據提取為批數據。

list( range(0, 10) ) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]list( range(0, 10, 3) ) # [0, 3, 6, 9]

通過argmax()獲取值最大的元素的索引。不過這里需要注意的是,我們給定了參數axis=1。這指定了在100 × 10的數組中,沿著**第1維方向(以第1維為軸)**找到值最大的元素的索引(第0維對應第1個維度

  • 矩陣的第0維是列方向,第1維是行方向。
x = np.array([[0.1, 0.8, 0.1], [0.3, 0.1, 0.6],[0.2, 0.5, 0.3], [0.8, 0.1, 0.1]])
y = np.argmax(x, axis=1)
print(y) # [1 2 1 0]

使用比較運算符(==)生成由True/False構成的布爾型數組,并計算True的個數:

y = np.array([1, 2, 1, 0])
t = np.array([1, 2, 0, 0])
print(y==t) # [True True False True]
np.sum(y==t) # 3

本文參考了該博主的文章

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/90928.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/90928.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/90928.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【自用】NLP算法面經(6)

一、FlashAttention 1、Tile-Based計算 將q,k,v分塊為小塊,每次僅處理一小塊: 利用gpu的片上SRAM完成QK^T和softmax避免中間結果寫入HBM 標準attention的計算算法如下:標準attention實現大量中間結果需要頻繁訪問HBM,而HBM的訪問速…

Vue頁面卡頓優化:從理論到實戰的全面解釋

目錄 1. 理解Vue頁面卡頓的幕后黑手 1.1 響應式系統的“雙刃劍” 1.2 虛擬DOM的“隱藏成本” 1.3 瀏覽器渲染的“性能陷阱” 實戰案例:一個“罪魁禍首”的排查 2. 優化響應式系統:讓數據“輕裝上陣” 2.1 使用v-if和v-show控制渲染 2.2 凍結靜態數據 2.3 精細化響應式…

從0開始學linux韋東山教程Linux驅動入門實驗班(6)

本人從0開始學習linux,使用的是韋東山的教程,在跟著課程學習的情況下的所遇到的問題的總結,理論雖枯燥但是是基礎。本人將前幾章的內容大致學完之后,考慮到后續驅動方面得更多的開始實操,后續的內容將以韋東山教程Linux驅動入門實…

高性能反向代理與負載均衡 HAProxy 與 Nginx

在現代高并發 Web 架構中,HAProxy 和 Nginx 是兩個非常重要的工具。它們在反向代理、負載均衡、SSL 終止、緩存、限流等方面發揮著關鍵作用。 一、HAProxy 與 Nginx 簡介 1. HAProxy 簡介 HAProxy(High Availability Proxy) 是一個使用 C …

AI安全“面壁計劃”:我們如何對抗算法時代的“智子”封鎖?

> 在算法窺視一切的今天,人類需要一場數字世界的“面壁計劃” 2025年,某醫院部署的AI分診系統被發現存在嚴重偏見:當輸入相同癥狀時,系統為白人患者分配急診通道的概率是黑人患者的**1.7倍**。調查發現,訓練數據中少數族裔樣本不足**15%**,導致AI在“認知”上形成了結…

數據庫數據恢復—報錯“system01.dbf需要更多的恢復來保持一致性”的Oracle數據恢復案例

Oracle數據庫故障: 某公司一臺服務器上部署Oracle數據庫。服務器意外斷電導致數據庫報錯,報錯內容為“system01.dbf需要更多的恢復來保持一致性”。該Oracle數據庫沒有備份,僅有一些斷斷續續的歸檔日志。Oracle數據庫恢復流程: 1、…

Spring Cloud Gateway 服務網關

Spring Cloud Gateway是 Spring Cloud 生態系統中的一個 API 網關服務,用于替換由Zuul開發的網關服務,基于Spring 5.0Spring Boot 2.0WebFlux等技術開發,提供了網關的基本功能,例如安全、監控、埋點和限流等,旨在為微服…

[數據結構]#6 樹

樹是一種非線性的數據結構,它由節點組成,并且這些節點之間通過邊連接。樹的每個節點可以有一個或多個子節點,并且有一個特殊的節點叫做根節點(沒有父節點)。樹在計算機科學中應用廣泛,尤其是在數據庫索引、…

車輛網絡安全規定之R155與ISO/SAE 21434

隨著科技的不斷進步,車輛已經從傳統的機械裝置演變為高度智能化的移動終端。現代汽車不僅配備了先進的駕駛輔助系統(ADAS)、車載信息娛樂系統(IVI),還具備聯網功能,能夠實現遠程診斷、自動駕駛、…

Go語言實戰案例-合并多個文本文件為一個

以下是《Go語言100個實戰案例》中的 文件與IO操作篇 - 案例21:合并多個文本文件為一個 的完整內容,適用于初學者學習文件讀取與寫入的綜合運用。🎯 案例目標使用 Go 語言將指定目錄下的多個 .txt 文件,合并成一個新的總文件。&…

基坑滲壓數據不準?選對滲壓計能實現自動化精準監測嗎?

一、滲壓監測的背景 滲壓計是一種專門用于測量構筑物內部孔隙水壓力或滲透壓力的傳感器,適用于長期埋設在水工結構物或其它混凝土結構物及土體內,以測量結構物或土體內部的滲透(孔隙)水壓力。 在水利工程中,大壩、水庫…

Linux網絡:阿里云輕量級應用服務器配置防火墻模板開放端口

1.問題介紹在使用Udp協議或其他協議進行兩臺主機或同一臺主機通信時,常常會出現bind成功,但是在客戶端向服務端發送數據后,服務端無響應的情況,如果使用輕量級應用服務器,大概率是服務器的端口因為防火墻未對公網IP開放…

《 Spring Boot整合多數據源:分庫業務的標準做法》

🚀 Spring Boot整合多數據源:分庫業務的標準做法 文章目錄🚀 Spring Boot整合多數據源:分庫業務的標準做法🔍 一、為什么需要多數據源支持?💡 典型業務場景?? 二、多數據源集成方案對比&#…

前端ApplePay支付-H5全流程實戰指南

提示:文章寫完后,目錄可以自動生成,如何生成可參考右邊的幫助文檔前言近期公司開展關于蘋果支付的相關業務,與之前不同的是,以前后臺直接獲取第三方Wallet封裝好的接口獲取支付地址,H5頁面直接跳轉使用Appl…

Flink窗口:解鎖流計算的秘密武器

Flink 窗口初識在大數據的世界里,數據源源不斷地產生,形成了所謂的 “無限數據流”。想象一下,網絡流量監控中,每一秒都有海量的數據包在網絡中穿梭,這些數據構成了一個無始無終的流。對于這樣的無限數據流&#xff0c…

Java排序算法之<希爾排序>

目錄 1、希爾排序介紹 1.1、定義 1.2、核心思想 2、希爾排序的流程 第 1 輪:gap 4 第 2 輪:gap 2 第 3 輪:gap 1 3、希爾排序的實現 4、時間復雜度分析 5、希爾排序的優缺點 6、適用場景 前言 希爾排序(Shell Sort&…

c++加載qml文件

這里展示了c加載qml文件的三種方式以及qml文件中根節點的訪問準備在創建工程的初期,遇到了一個問題,cmake文件以前都是系統自動生成的,不需要我做過多的操作修改,但是,加載qml的程序主函數是需要用到QGuiApplication&a…

007TG洞察:GPT-5前瞻與AI時代競爭力構建:技術挑戰與落地路徑

最近,GPT-5 即將發布的消息刷爆了科技圈,更讓人期待的是,GPT-6 已經悄悄啟動訓練了,OpenAI 的奧特曼表示對未來1-2年的模型充滿信心,預測AI將進化為能夠發現新知識的“AI科學家”。面對日益強大的通用AI,企…

Windows下編譯OpenVDB

本文記錄在Windows下編譯OpenVDB的流程。 零、環境 操作系統Windows 11VS Code1.92.1Git2.34.1MSYS2msys2-x86_64-20240507Visual StudioVisual Studio Community 2022CMake3.22.1 一、編譯 1.1 下載 git clone https://github.com/AcademySoftwareFoundation/openvdb.git …

react 內置hooks 詳細使用場景,使用案例

useState場景&#xff1a;組件中管理局部狀態&#xff0c;如表單值、開關、計數器等。const [count, setCount] useState(0); return <button onClick{() > setCount(count 1)}>Click {count}</button>;useEffect 場景&#xff1a;組件掛載時執行副作用&#…