MNIST 手寫數字識別模型分析

功能概述

這段代碼實現了一個基于TensorFlow和Keras的MNIST手寫數字識別模型。主要功能包括:

  1. 加載并預處理MNIST數據集
  2. 構建一個簡單的全連接神經網絡模型
  3. 訓練模型并評估其性能
  4. 使用訓練好的模型進行預測
  5. 保存和加載模型

代碼解析

1. 導入必要的庫

import matplotlib
import tensorflow.keras as keras
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from pasta.augment import inline
  • 導入TensorFlow和Keras用于構建和訓練神經網絡
  • 導入NumPy用于數值計算
  • 導入Matplotlib用于數據可視化
  • 從pasta.augment導入inline用于在Jupyter Notebook中直接顯示圖像

2. 打印TensorFlow版本

print(tf.__version__)

輸出當前使用的TensorFlow版本,用于環境檢查。

3. 加載MNIST數據集

path = '../doc/mnist.npz'
with np.load(path) as data:x_train, y_train = data['x_train'], data['y_train']x_test, y_test = data['x_test'], data['y_test']
print(x_train[0])
  • 從本地文件加載MNIST數據集
  • 數據集包含訓練集(x_train, y_train)和測試集(x_test, y_test)
  • 打印第一個訓練樣本的像素值

4. 數據可視化

%matplotlib inline
plt.imshow(x_train[0], cmap=plt.cm.binary)
plt.show()
  • 使用Matplotlib顯示第一個訓練樣本的圖像
  • cmap=plt.cm.binary設置為黑白顯示

5. 打印第一個訓練樣本的標簽

print(y_train[0])

輸出第一個訓練樣本對應的數字標簽。

6. 數據歸一化

x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)
print(x_train[0])
  • 對圖像數據進行歸一化處理,將像素值縮放到0-1范圍
  • 打印歸一化后的第一個訓練樣本

7. 構建神經網絡模型

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))
  • 創建一個Sequential模型
  • 添加Flatten層將28x28的圖像展平為784維向量
  • 添加兩個全連接層(Dense),每層128個神經元,使用ReLU激活函數
  • 添加輸出層,10個神經元對應10個數字類別,使用Softmax激活函數

8. 編譯模型

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
  • 使用Adam優化器
  • 使用稀疏分類交叉熵作為損失函數
  • 使用準確率作為評估指標

9. 訓練模型

model.fit(x_train, y_train, epochs=3)
  • 訓練模型3個epoch
  • 使用訓練數據進行擬合

10. 評估模型

val_loss, val_acc = model.evaluate(x_test, y_test)
print(val_loss)
print(val_acc)
  • 在測試集上評估模型性能
  • 輸出測試損失和準確率

11. 使用模型進行預測

predictions = model.predict(x_test)
print(predictions)
print(np.argmax(predictions[0]))
plt.imshow(x_test[0], cmap=plt.cm.binary)
plt.show()
  • 對測試集進行預測
  • 打印預測結果(概率分布)
  • 使用argmax獲取第一個測試樣本的預測標簽
  • 顯示第一個測試樣本的圖像

12. 保存和加載模型

def softmax_v2(x):return tf.keras.activations.softmax(x)new_model = tf.keras.models.load_model('epic_num_reader.model.keras',custom_objects={'softmax_v2': softmax_v2}
)predictions = new_model.predict(x_test)
print(np.argmax(predictions[0]))
  • 定義一個softmax_v2函數用于兼容性
  • 加載之前保存的模型
  • 使用加載的模型進行預測

總結

這段代碼實現了一個簡單但有效的MNIST手寫數字分類器。主要特點包括:

  1. 使用全連接神經網絡結構
  2. 實現了數據預處理和歸一化
  3. 達到了較高的測試準確率(約97%)
  4. 包含了模型保存和加載功能
  5. 提供了可視化工具檢查數據和預測結果

demo001.ipynb

# 導入 keras 模塊
import matplotlib
import tensorflow.keras as keras
# 導入 tensorflow 模塊
import tensorflow as tf
# 導入 pasta 模塊中的 augment 和 inline 子模塊
from pasta.augment import inline# 打印 TensorFlow 的版本
print(tf.__version__)# 指定本地文件路徑
path = '../doc/mnist.npz'
# 導入 numpy 模塊
import numpy as np
# 從本地加載 MNIST 數據集
with np.load(path) as data:x_train, y_train = data['x_train'], data['y_train']x_test, y_test = data['x_test'], data['y_test']
# 打印訓練數據集的第一個樣本
print(x_train[0])# 導入 matplotlib.pyplot 模塊
import matplotlib.pyplot as plt
# 使用 inline 后,圖形將直接顯示在 Jupyter Notebook 中
# %matplotlib inline
# 可視化訓練數據集的第一個樣本
plt.imshow(x_train[0], cmap=plt.cm.binary)
plt.show()# 打印訓練標簽的第一個樣本
print(y_train[0])# 對訓練和測試數據進行歸一化處理
x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)# 打印歸一化后的訓練數據集的第一個樣本
print(x_train[0])# 可視化歸一化后的訓練數據集的第一個樣本
plt.imshow(x_train[0], cmap=plt.cm.binary)
plt.show()# 創建一個 Sequential 模型
model = tf.keras.models.Sequential()
# 添加一個 Flatten 層,用于將輸入數據展平
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
# 添加一個 Dense 層,包含 128 個神經元,使用 ReLU 激活函數
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
# 再添加一個 Dense 層,配置同上
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
# 添加一個 Dense 層,包含 10 個神經元,使用 Softmax 激活函數
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))
# 編譯模型,指定優化器、損失函數和評估指標
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
# 訓練模型
model.fit(x_train, y_train, epochs=3)
# 評估模型
val_loss, val_acc = model.evaluate(x_test, y_test)
print(val_loss)
print(val_acc)# 使用模型進行預測
predictions = model.predict(x_test)
print(predictions)# 導入 numpy 模塊
import numpy as np# 打印第一個測試樣本的預測標簽
print(np.argmax(predictions[0]))# 可視化第一個測試樣本
plt.imshow(x_test[0], cmap=plt.cm.binary)
plt.show()# 保存模型
def softmax_v2(x):# 將 softmax_v2 映射到標準 softmaxreturn tf.keras.activations.softmax(x)# 加載之前保存的模型
new_model = tf.keras.models.load_model('epic_num_reader.model.keras',custom_objects={'softmax_v2': softmax_v2}
)# 使用加載的模型進行預測
predictions = new_model.predict(x_test)
# 打印第一個測試樣本的預測標簽
print(np.argmax(predictions[0]))

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/90275.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/90275.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/90275.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

進階系統策略

該策略主要基于價格動態分析,結合多種技術指標和數學計算來生成交易信號。其核心邏輯包括: 1. 價格極值計算:首先,策略計算給定周期(由`Var3`定義)內的最高價和最低價,分別存儲在`Var12`和`Var13`中。這一步驟旨在捕捉價格的短期波動范圍。 2. 相對位置計算:接著,策…

【Linux內核】Linux驅動開發

推薦書籍: 《Linux內核探秘:深入解析文件系統和設備驅動的架構與設計》 知識點 x86的IO地址空間和內存地址空間是獨立的兩套地址空間,并且使用不同的指令訪問。MOV, IN, OUT。內存映射I/O可以將IO映射到內存。ARM等RISC采用統一編編址&#x…

MySQL用戶管理(15)

文章目錄前言一、用戶用戶信息創建用戶修改密碼刪除用戶二、數據庫的權限MySQL中的權限給用戶授權回收權限總結前言 其實與 Linux 操作系統類似,MySQL 中也有 超級用戶 和 普通用戶 之分 如果一個用戶只需要訪問 MySQL 中的某一個數據庫,甚至數據庫中的某…

react19相關問題和解答

目錄 1. react19將ref放在了props中(不再需要 forwardRef),那么是不是可以通過ref獲取子組件的全部變量了? 我的子組件的useImperativeHandle還需要定義嗎? 1.1. ref 在 props 中的本質變化 1.2. 為什么不能訪問全部變量? 2. In HTML,cannot be a descendant of. Thi…

Code Composer Studio:CCS 設置代碼折疊

Code Composer Studio:設置代碼折疊,可以按函數,if, 等把代碼折疊起來。1.2.開啟折疊選項3.開啟后,如果文件已經打開,要關掉重新打開文件就可以開到折疊功能生效。

JMeter groovy 編譯成.jar 文件

groovy 編譯 一、windows 下手動安裝Groovy 下載 Groovy 二進制包 前往官網:https://groovy.apache.org/download.html 下載 Binary release( https://groovy.jfrog.io/ui/native/dist-release-local/groovy-zips/apache-groovy-sdk-4.0.27.zip &#xf…

使用maven-shade-plugin解決依賴版本沖突

項目里引入多個版本依賴時,最后只會使用其中一個,一般可以通過排除不使用的依賴處理,但是如果需要同時使用多個版本,可以使用maven-shade-plugin解決。以最典型的poi為例,poi版本兼容性很低,如果出現找不到…

[CH582M入門第十一步]DS18B20驅動

學習目標: 1、介紹DS18B20 2、學習單總線 3、學習DS18B20程序驅動一、DS18B20介紹 DS18B20 是一款由 Maxim Integrated(原Dallas Semiconductor) 推出的 數字溫度傳感器,以其單總線(1-Wire)通信協議、高精度和廣泛應用而聞名。以下是其核心特點和應用介紹: 主要特性 數…

SGLang + 分布式推理部署DeepSeek671B滿血版

部署設備:28A100 80G,兩臺機器,每臺機器8張A100。 模型:deepseek-671B-int8 模型下載地址:https://huggingface.co/meituan/DeepSeek-R1-Block-INT8 模型參考: 1、SGLang Docker部署 github地址&#…

PCL 間接平差擬合球

目錄 一、算法原理 1、計算流程 2、參考文獻 二、代碼實現 三、結果展示 本文由CSDN點云俠原創,首發于2025年7月24日。博客長期更新,本文最新更新時間為:2025年7月24日。 一、算法原理 1、計算流程 空間球方程: ( x ? a ) 2 + ( y ? b ) 2 + ( z ? c ) 2 = R 2 (1) (…

基于 HAProxy 搭建 EMQ X 集群

負載均衡器(LB)負責分發設備的 MQTT 連接與消息到 EMQ X 集群,采用 LB 可以提高 EMQ X 集群可用性、實現負載平衡以及動態擴容。 HAProxy簡介 HAProxy 是一款高性能的 開源負載均衡器 和 反向代理服務器,主要用于在多個服務器之…

RISC-V基金會Datacenter SIG月會圓滿舉辦,探討RAS、PMU性能分析實踐和經驗

一直以來,龍蜥社區在 RISC-V 生態建設中持續投入,并積極貢獻上游社區。多位龍蜥社區成員在 RISC-V 國際基金會擔任主席/副主席角色,與來自阿里云、阿里達摩院、中興通訊、浪潮信息、中科院軟件所、字節跳動、Google、 MIT、Akeana 等企業的專…

CloudComPy使用PyInstaller打包后報錯解決方案

情況描述 筆者在spec文件中,datas變量設置如下。如果你的報錯類似于“找不到cloudComPy”,先嘗試如下的設置。 datas[(CloudCompare,cloudComPy)], 筆者在打包完成后,打開軟件發現報錯: from cloudComPy import* ModuleNotFoun…

node.js中的path模塊

在 Node.js 中,path 模塊提供了處理和操作文件路徑的功能,其中 path.join 和 path.resolve 是兩個常用的方法。它們在處理路徑時有不同的行為和用途: 功能概述 path.join(): 該方法主要用于將多個路徑片段拼接成一個完整的路徑字符串。它會正…

將Scrapy項目容器化:Docker鏡像構建的工程實踐

引言:爬蟲容器化的戰略意義在云原生與微服務架構主導的時代,??容器化技術??已成為爬蟲項目交付的黃金標準。據2023年分布式系統調查報告顯示:92%的生產爬蟲系統采用容器化部署容器化使爬蟲環境配置時間??減少87%??Docker化爬蟲的故障…

Unity × RTMP × 頭顯設備:打造沉浸式工業遠控視頻系統的完整方案

結合工業現場需求,探索如何通過大牛直播SDK打造可在 Pico、Quest 等頭顯設備中運行的 RTMP 低延遲播放器,助力構建沉浸式遠程操控系統。 一、背景:沉浸式遠程操控的新趨勢 隨著工業自動化、5G 專網、XR 技術的發展,遠程操控正在從…

HTTPS如何保障安全?詳解證書體系與加密通信流程

HTTP協議本身是明文傳輸的,安全性較低,因此現代互聯網普遍采用 HTTPS(HTTP over TLS/SSL) 來實現加密通信。HTTPS的核心是 TLS/SSL證書體系 和 加密通信流程。一、HTTPS 證書體系HTTPS依賴 公鑰基礎設施(PKI, Public K…

數據的評估與清洗篇---清洗數據

處理前的準備 檢查索引與列名 在處理內容之前,需要先看看索引或列名是否有意義,若索引和列名都是亂七八糟的,應該對他們進行重命名或者重新排序,以便我們理解數據。 清洗數據 清洗數據原則 針對數據內容,一般先解決結構性問題,再處理內容性問題。整潔數據的特點是: …

Ubuntu apt和apt-get的區別

好的,這是一個非常經典且重要的問題。apt install 和 apt-get install 的區別是很多 Ubuntu/Debian 新手都會遇到的困惑。 簡單來說,它們的功能非常相似,但設計目標和用戶體驗不同。 一句話總結 apt 是 apt-get 的一個更新、更友好、更現代化…

多端適配災難現場:可視化界面在PC/平板/大屏端的響應式布局實戰

摘要精心設計的可視化大屏,在平板上顯示時圖表擠成一團,在PC端操作按鈕小到難以點擊,某企業的可視化項目曾因多端適配失敗淪為“災難現場”,不僅用戶差評如潮,還被競爭對手嘲諷技術落后。多端適配真的只能靠“反復試錯…