01、Tensorflow實現二元手寫數字識別

01、Tensorflow實現二元手寫數字識別(二分類問題)

開始學習機器學習啦,已經把吳恩達的課全部刷完了,現在開始熟悉一下復現代碼。對這個手寫數字實部比較感興趣,作為入門的素材非常合適。

基于Tensorflow 2.10.0

1、識別目標

識別手寫僅僅是為了區分手寫的0和1,所以實際上是一個二分類問題。

2、Tensorflow算法實現

STEP1:導入相關包

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_score

import numpy as np:這是引入numpy庫,并為其設置一個縮寫np。Numpy是Python中用于大規模數值計算的庫,它提供了多維數組對象及一系列操作這些數組的函數。

import tensorflow as tf:這是引入tensorflow庫,并為其設置一個縮寫tf。TensorFlow是一個開源的深度學習框架,它被廣泛用于各種深度學習應用。

from keras.models import Sequential:這是從Keras庫中引入Sequential模型。Keras是一個高級神經網絡API,它可以運行在TensorFlow之上。Sequential模型是Keras中的線性堆棧模型,允許你簡單地堆疊多個網絡層。

from keras.layers import Dense:這是從Keras庫中引入Dense層。Dense層是神經網絡中的全連接層,每個輸入節點與輸出節點都是連接的。

from sklearn.model_selection import train_test_split:這是從scikit-learn庫中引入train_test_split函數。這個函數用于將數據分割為訓練集和測試集。

import matplotlib.pyplot as plt:這是引入matplotlib的pyplot模塊,并為其設置一個縮寫plt。Matplotlib是Python中的繪圖庫,而pyplot是其中的一個模塊,用于繪制各種圖形和圖像。

import warnings:這是引入Python的標準警告庫,它可以用來發出警告,或者過濾掉不需要的警告。

import logging:這是引入Python的標準日志庫,用于記錄日志信息,方便追蹤和調試代碼。

from sklearn.metrics import accuracy_score:這是從scikit-learn庫中引入accuracy_score函數。這個函數用于計算分類準確率,常用于評估分類模型的性能。


STEP2:屏蔽無用警告并允許中文

logging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文顯示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False

logging.getLogger(“tensorflow”).setLevel(logging.ERROR):這行代碼用于設置 TensorFlow 的日志級別為 ERROR。這意味著只有當 TensorFlow 中發生錯誤時,才會在日志中輸出相關信息。較低級別的日志信息(如 WARNING、INFO、DEBUG)將被忽略。

tf.autograph.set_verbosity(0):這行代碼用于設置 TensorFlow 的自動圖形(Autograph)日志的冗長級別為 0。這意味著在將 Python 代碼轉換為 TensorFlow 圖形代碼時,將不會輸出任何日志信息。這有助于減少日志噪音,使日志更加干凈。

warnings.simplefilter(action=‘ignore’,category=FutureWarning):這行代碼用于忽略所有 FutureWarning 類型的警告。在 Python中,當使用某些即將過時或未來版本中可能發生變化的特性時,通常會發出 FutureWarning。通過設置action=‘ignore’,代碼將不會輸出這類警告,使控制臺輸出更加干凈。

plt.rcParams[‘font.sans-serif’]=[‘SimHei’]:這行代碼用于設置 matplotlib 中的默認無襯線字體為 SimHei。SimHei 是一種常用于顯示中文的字體,這樣設置后,matplotlib 將在繪圖時使用 SimHei 字體來顯示中文,從而避免中文亂碼問題。

plt.rcParams[‘axes.unicode_minus’]=False:這行代碼用于解決 matplotlib
中負號顯示異常的問題。默認情況下,matplotlib 可能無法正確顯示負號,將其設置為 False 可以使用 ASCII字符作為負號,從而正常顯示。


STEP3:導入并劃分數據集

劃分10%作為測試:

X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

STEP4:模型構建與訓練

# 構建模型,三層模型進行分類,第一層輸入100個神經元...
model = Sequential([tf.keras.Input(shape=(400,)),    #specify input size### START CODE HERE ###Dense(100, activation='sigmoid'),Dense(10, activation='sigmoid'),Dense(1, activation='sigmoid')### END CODE HERE ###], name = "my_model"
)
# 打印三層模型的參數
model.summary()
# 模型設定,學習率0.001,因為是分類,使用BinaryCrossentropy損失函數
model.compile(loss=tf.keras.losses.BinaryCrossentropy(),optimizer=tf.keras.optimizers.Adam(0.001),
)
# 開始訓練,訓練循環20
model.fit(X_train,y_train,epochs=20
)

STEP5:結果可視化與打印準確度信息
原始的輸入的數據集是400 * 1000的數組,共包含1000個手寫數字的數據,其中400為20*20像素的圖片,因此對每個400的數組進行reshape((20, 20))可以得到原始的圖片進而繪圖。

# 繪制測試集的預測結果,繪制64個
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))if prediction >= 0.5:yhat = 1else:yhat = 0# Display the label above the imageax.set_title(f"{y_test[random_index, 0]},{yhat}")ax.set_axis_off()
fig.suptitle("真實標簽, 預測的標簽", fontsize=16)
plt.show()# 給出預測的測試集誤差
y_pred=model.predict(X_test)
print("測試數據集準確率為:", accuracy_score(y_test, np.round(y_pred)))

3、運行結果

按照最初的劃分,數據集包含1000個數據,劃分10%為測試集,也就是100個數據。結果可視化隨機選擇其中的64個數據繪圖,每個圖像的上方標明了其真實標簽和預測的結果,這個是一個非常簡單的示例,準確度還是非常高的。
在這里插入圖片描述

在這里插入圖片描述

4、工程下載與全部代碼

工程鏈接:Tensorflow實現二元手寫數字識別(二分類問題)

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_scorelogging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文顯示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False# load dataset
def load_data():X = np.load("Handwritten_Digit_Recognition_data/X.npy")y = np.load("Handwritten_Digit_Recognition_data/y.npy")X = X[0:1000]y = y[0:1000]return X, y# 加載數據集,查看數據集大小,可以看到有1000個數據集,每個輸入是20*20=400大小的圖片
X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)# # 下面畫圖,隨便從原數據取出幾個畫圖,可以注釋
# m, n = X.shape
# fig, axes = plt.subplots(8, 8, figsize=(8, 8))
# fig.tight_layout(pad=0.1)
# for i, ax in enumerate(axes.flat):
#     # Select random indices
#     random_index = np.random.randint(m)
#     # Select rows corresponding to the random indices and
#     # 將1*400的數據轉換為20*20的圖像格式
#     X_random_reshaped = X[random_index].reshape((20, 20)).T
#     # Display the image
#     ax.imshow(X_random_reshaped, cmap='gray')
#     # Display the label above the image
#     ax.set_title(y[random_index, 0])
#     ax.set_axis_off()
# plt.show()# 構建模型,三層模型進行分類,第一層輸入25個神經元...
model = Sequential([tf.keras.Input(shape=(400,)),    #specify input size### START CODE HERE ###Dense(100, activation='sigmoid'),Dense(10, activation='sigmoid'),Dense(1, activation='sigmoid')### END CODE HERE ###], name = "my_model"
)
# 打印三層模型的參數
model.summary()
# 模型設定,學習率0.001,因為是分類,使用BinaryCrossentropy損失函數
model.compile(loss=tf.keras.losses.BinaryCrossentropy(),optimizer=tf.keras.optimizers.Adam(0.001),
)
# 開始訓練,訓練循環20
model.fit(X_train,y_train,epochs=20
)# 繪制測試集的預測結果,繪制64個
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))if prediction >= 0.5:yhat = 1else:yhat = 0# Display the label above the imageax.set_title(f"{y_test[random_index, 0]},{yhat}")ax.set_axis_off()
fig.suptitle("真實標簽, 預測的標簽", fontsize=16)
plt.show()# 給出預測的測試集誤差
y_pred=model.predict(X_test)
print("測試數據集準確率為:", accuracy_score(y_test, np.round(y_pred)))

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/167814.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/167814.shtml
英文地址,請注明出處:http://en.pswp.cn/news/167814.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

pandas獲取年月第一天、最后一天,加一秒、加一天、午夜時間

Timestamp對象 # ts = pandas.Timestamp(year=2023, month=10, day=15, # hour=15, minute=5, second=50, tz="Asia/Shanghai") ts = pandas.Timestamp("2023-10-15 15:05:50", tz="Asia/Shanghai") # 2023-10-15 15:05…

數據丟失預防措施包括什么

數據丟失預防措施是保護企業或個人重要數據的重要手段。以下是一些有效的預防措施: 可以通過域之盾軟件來實現數據防丟失,具體的功能包括: https://www.yuzhidun.cn/https://www.yuzhidun.cn/ 1、備份數據 定期備份所有重要數據&#xff0…

unittest指南——不拼花哨,只拼實用

📢專注于分享軟件測試干貨內容,歡迎點贊 👍 收藏 ?留言 📝 如有錯誤敬請指正!📢交流討論:歡迎加入我們一起學習!📢資源分享:耗時200小時精選的「軟件測試」資…

centos7 docker開啟認證的遠程端口2376配置

docker開啟2375會存在安全漏洞 暴露了2375端口的Docker主機。因為沒有任何加密和認證過程,知道了主機IP以后,,任何人都可以管理這臺主機上的容器和鏡像,以前貪圖方便,只開啟了沒有認證的docker2375端口,后…

代碼隨想錄算法訓練營第五十三天|1143.最長公共子序列 1035.不相交的線 53. 最大子序和

文檔講解:代碼隨想錄 視頻講解:代碼隨想錄B站賬號 狀態:看了視頻題解和文章解析后做出來了 1143.最長公共子序列 class Solution:def longestCommonSubsequence(self, text1: str, text2: str) -> int:dp [[0] * (len(text2) 1) for _ i…

機器學習入門

簡介 https://huggingface.co/是一個AI社區,類似于github的地位。它開源了許多機器學習需要的基礎組件如:Transformers, Tokenizers等。 許多公司也在不斷地往上面提交新的模型和數據集,利用它你可以獲取以下內容: Datasets : 數…

hikariCP 數據庫連接池配置

springBoot 項目默認自動使用 HikariCP ,HikariCP 的性能比 alibaba/druid快。 一、背景 系統中多少個線程在進行與數據庫有關的工作?其中,而多少個線程正在執行 SQL 語句?這可以讓我們評估數據庫是不是系統瓶頸。 多少個線程在…

基于法醫調查算法優化概率神經網絡PNN的分類預測 - 附代碼

基于法醫調查算法優化概率神經網絡PNN的分類預測 - 附代碼 文章目錄 基于法醫調查算法優化概率神經網絡PNN的分類預測 - 附代碼1.PNN網絡概述2.變壓器故障診街系統相關背景2.1 模型建立 3.基于法醫調查優化的PNN網絡5.測試結果6.參考文獻7.Matlab代碼 摘要:針對PNN神…

【學生成績管理】數據庫示例數據(MySQL代碼)

【學生成績管理】數據庫示例數據(MySQL代碼) 目錄 【學生成績管理】數據庫示例數據(MySQL代碼)一、創建數據庫二、創建dept(學院)表1、創建表結構2、添加示例數據3、查看表中數據 三、創建stu(學…

35.邏輯運算符

目錄 一.什么是邏輯運算符 二.C語言中的邏輯運算符 三.邏輯表達式 三.視頻教程 一.什么是邏輯運算符 同時對倆個或者倆個以上的表達式進行判斷的運算符叫做邏輯運算符。 舉例:比如去網吧上網,只有年滿十八周歲并且帶身份證才可以上網。在C語言中如果…

為什么 Flink 拋棄了 Scala

曾經紅遍一時的Scala 想當初Spark橫空出世之后,Scala簡直就是語言界的一顆璀璨新星,惹得大家紛紛側目,連Kafka這類技術框架也選擇用Scala語言進行開發重構。 可如今,Flink竟然公開宣布棄用Scala 在Flink1.18的官方文檔里&#x…

國家開放大學的學子們 練習題 走起!

試卷代號:1356 高級英語聽說(2) 參考 試題 Section One (20 points, 2 points each) Directions: Listen to the conversation and fill in the blanks with the words you hear. Write the words on the Answer Sheet The conversation will be read TWICE. M…

windows11上安裝WSL

Windows電腦上要配置linux(這里指ubuntu)開發環境,主要有三種方式: 1)在windows上裝個虛擬機(比如vmware)。缺點是vmware加載ubuntu后系統會變慢很多,而且需要通過samba來實現window…

使用Java連接Hbase

我在網上試 了很多代碼,但是大部分都不能實現,Java連接Hbase,一直報一個錯 java.util.concurrent.ExecutionException: org.apache.zookeeper.KeeperException$NoNodeException: KeeperErrorCode NoNode for /hbase/hbaseid一直也不清楚為什…

計算機組成原理。3-408

1.動態存儲和靜態存儲 2.雙端口RAM 注意:cpu通過地址線和數據線讀寫數據時,不能同時寫,但可以同時讀,也不能一邊讀一邊寫。 3.多體并行存儲器 分為高位存儲和低位存儲 小結 4.磁盤存儲器的組成 5.磁盤的性能指標 磁盤讀寫尋道…

如何對網站進行滲透測試

信息搜集 信息搜集拿到域名后獲取真實IP,如果存在CDN想辦法繞過端口掃描,針對開放的端口在獲取客戶同意的前提下進行爆破查找網站子域名,后臺目錄判斷網站的CMS 可以使用 Wappalyzer插件 whatcms 是一個可以用來確定特定網站正在使用的什么…

Vue中Slot的使用指南

目錄 前言 什么是slot? 單個slot的使用 具名slot的使用 作用域插槽 總結 前言 在Vue中,slot是一種非常強大和靈活的功能,它允許你在組件模板中預留出一個或多個"插槽",然后在使用這個組件的時候動態地填充內容。這…

TSINGSEE青犀智能分析網關道路積水識別AI算法方案

在各處的街道、路口等區域,及時發現道路積水問題,可以大大減少城市管理部門壓力,及時處理,減少交通事故與人員摔倒事故。通過道路積水AI算法,能有效提高城市管理部門效率,優化城市管理方式。 那么&#xff…

【Web】PhpBypassTrick相關例題wp

目錄 ①[NSSCTF 2022 Spring Recruit]babyphp ②[鶴城杯 2021]Middle magic ③[WUSTCTF 2020]樸實無華 ④[SWPUCTF 2022 新生賽]funny_php 明天中期考,先整理些小知識點冷靜一下 ①[NSSCTF 2022 Spring Recruit]babyphp payload: a[]1&b1[]1&b2[]2&…

NLP的使用

參考: Apache openNLP 簡介 - 鏈滴 (ld246.com) opennlp 模型下載地址:Index of /apache/opennlp/models/ud-models-1.0/ (tencent.com) OpenNLP是一個流行的開源自然語言處理工具包,它提供了一系列的NLP模型和算法。然而,Open…