【TensorFlow篇】--Tensorflow框架實現SoftMax模型識別手寫數字集

一、前述

本文講述用Tensorflow框架實現SoftMax模型識別手寫數字集,來實現多分類。

同時對模型的保存和恢復做下示例。

二、具體原理

代碼一:實現代碼

#!/usr/bin/python
# -*- coding: UTF-8 -*-
# 文件名: 12_Softmax_regression.pyfrom tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf# mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
my_mnist = input_data.read_data_sets("MNIST_data_bak/", one_hot=True)#從本地路徑加載進來# The MNIST data is split into three parts:
# 55,000 data points of training data (mnist.train)#訓練集圖片
# 10,000 points of test data (mnist.test), and#測試集圖片
# 5,000 points of validation data (mnist.validation).#驗證集圖片# Each image is 28 pixels by 28 pixels# 輸入的是一堆圖片,None表示不限輸入條數,784表示每張圖片都是一個784個像素值的一維向量
# 所以輸入的矩陣是None乘以784二維矩陣
x = tf.placeholder(dtype=tf.float32, shape=(None, 784)) #x矩陣是m行*784列
# 初始化都是0,二維矩陣784乘以10個W值 #初始值最好不為0
W = tf.Variable(tf.zeros([784, 10]))#W矩陣是784行*10列
b = tf.Variable(tf.zeros([10]))#bias也必須有10個

y = tf.nn.softmax(tf.matmul(x, W) + b)# x*w 即為m行10列的矩陣就是y #預測值# 訓練
# labels是每張圖片都對應一個one-hot的10個值的向量
y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))#真實值 m行10列
# 定義損失函數,交叉熵損失函數
# 對于多分類問題,通常使用交叉熵損失函數
# reduction_indices等價于axis,指明按照每行加,還是按照每列加
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))#指明按照列加和 一列是一個類別
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#將損失函數梯度下降 #0.5是學習率# 初始化變量
sess = tf.InteractiveSession()#初始化Session
tf.global_variables_initializer().run()#初始化所有變量
for _ in range(1000):batch_xs, batch_ys = my_mnist.train.next_batch(100)#每次迭代取100行數據sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
#每次迭代內部就是求梯度,然后更新參數
# 評估# tf.argmax()是一個從tensor中尋找最大值的序號 就是分類號,tf.argmax就是求各個預測的數字中概率最大的那一個
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))# 用tf.cast將之前correct_prediction輸出的bool值轉換為float32,再求平均
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 測試
print(accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))# 總結
# 1,定義算法公式,也就是神經網絡forward時的計算
# 2,定義loss,選定優化器,并指定優化器優化loss
# 3,迭代地對數據進行訓練
# 4,在測試集或驗證集上對準確率進行評測

代碼二:保存模型

# 有時候需要把模型保持起來,有時候需要做一些checkpoint在訓練中
# 以致于如果計算機宕機,我們還可以從之前checkpoint的位置去繼續
# TensorFlow使得我們去保存和加載模型非常方便,僅需要去創建Saver節點在構建階段最后
# 然后在計算階段去調用save()方法from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
my_mnist = input_data.read_data_sets("MNIST_data_bak/", one_hot=True)# The MNIST data is split into three parts:
# 55,000 data points of training data (mnist.train)
# 10,000 points of test data (mnist.test), and
# 5,000 points of validation data (mnist.validation).# Each image is 28 pixels by 28 pixels# 輸入的是一堆圖片,None表示不限輸入條數,784表示每張圖片都是一個784個像素值的一維向量
# 所以輸入的矩陣是None乘以784二維矩陣
x = tf.placeholder(dtype=tf.float32, shape=(None, 784))
# 初始化都是0,二維矩陣784乘以10個W值
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x, W) + b)# 訓練
# labels是每張圖片都對應一個one-hot的10個值的向量
y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))
# 定義損失函數,交叉熵損失函數
# 對于多分類問題,通常使用交叉熵損失函數
# reduction_indices等價于axis,指明按照每行加,還是按照每列加
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)# 初始化變量
init = tf.global_variables_initializer()
# 創建Saver()節點
saver = tf.train.Saver()#在運算之前,初始化之后

n_epoch = 1000with tf.Session() as sess:sess.run(init)for epoch in range(n_epoch):if epoch % 100 == 0:save_path = saver.save(sess, "./my_model.ckpt")#每跑100次save一次模型,可以保證容錯性#直接保存session即可。
batch_xs, batch_ys = my_mnist.train.next_batch(100)#每一批次跑的數據 用m行數據/迭代次數來計算出來。sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})best_theta = W.eval()save_path = saver.save(sess, "./my_model_final.ckpt")#保存最后的模型,session實際上保存的上面所有的數據

代碼三:恢復模型

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
my_mnist = input_data.read_data_sets("MNIST_data_bak/", one_hot=True)# The MNIST data is split into three parts:
# 55,000 data points of training data (mnist.train)
# 10,000 points of test data (mnist.test), and
# 5,000 points of validation data (mnist.validation).# Each image is 28 pixels by 28 pixels# 輸入的是一堆圖片,None表示不限輸入條數,784表示每張圖片都是一個784個像素值的一維向量
# 所以輸入的矩陣是None乘以784二維矩陣
x = tf.placeholder(dtype=tf.float32, shape=(None, 784))
# 初始化都是0,二維矩陣784乘以10個W值
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x, W) + b)
# labels是每張圖片都對應一個one-hot的10個值的向量
y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))saver = tf.train.Saver()with tf.Session() as sess: saver.restore(sess, "./my_model_final.ckpt")#把路徑下面所有的session的數據加載進來 y y_head還有模型都保存下來了。# 評估# tf.argmax()是一個從tensor中尋找最大值的序號,tf.argmax就是求各個預測的數字中概率最大的那一個correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))# 用tf.cast將之前correct_prediction輸出的bool值轉換為float32,再求平均accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 測試print(accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))

轉載于:https://www.cnblogs.com/LHWorldBlog/p/8661434.html

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/278228.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/278228.shtml
英文地址,請注明出處:http://en.pswp.cn/news/278228.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

web頁面鎖屏初級嘗試

因為工作需要&#xff0c;所以在網上找了一些素材來弄這個功能。在我找到的素材中&#xff0c;大多都是不完善的。雖然我的也不是很完善&#xff0c;但是怎么說呢。要求不是很高的話。可以直接拿來用的【需要引用jQuery】。廢話不多說直接上代碼 這部分是js代碼 1 <script&g…

Java 并發工具箱之concurrent包

概述 java.util.concurrent 包是專為 Java并發編程而設計的包。包下的所有類可以分為如下幾大類&#xff1a; locks部分&#xff1a;顯式鎖(互斥鎖和速寫鎖)相關&#xff1b;atomic部分&#xff1a;原子變量類相關&#xff0c;是構建非阻塞算法的基礎&#xff1b;executor部分&…

如何提高gps精度_如何在鍛煉應用程序中提高GPS跟蹤精度

如何提高gps精度l i g h t p o e t/Shutterstocklightpoet /快門Tracking your runs, bike rides, and other workouts is fun because you can see how much you’re improving (or, in my case, dismally failing to improve). For it to be effective, though, you have to …

centos proftp_在CentOS上禁用ProFTP

centos proftpI realize this is probably only relevant to about 3 of the readers, but I’m posting this so I don’t forget how to do it myself! In my efforts to ban the completely insecure FTP protocol from my life entirely, I’ve decided to disable the FTP…

Java通過Executors提供四種線程池

http://cuisuqiang.iteye.com/blog/2019372 Java通過Executors提供四種線程池&#xff0c;分別為&#xff1a;newCachedThreadPool創建一個可緩存線程池&#xff0c;如果線程池長度超過處理需要&#xff0c;可靈活回收空閑線程&#xff0c;若無可回收&#xff0c;則新建線程。n…

一個在線編寫前端代碼的好玩的工具

https://codesandbox.io/ 可以編寫 Angular&#xff0c;React&#xff0c;Vue 等前端代碼。 可以實時編輯和 preview。 live 功能&#xff0c;可以多人協作編輯&#xff0c;不過是收費的功能。 可以增加依賴的包&#xff0c;比如編寫 React 時&#xff0c;可以安裝任意的第三…

MySQL數據庫基礎(五)——SQL查詢

MySQL數據庫基礎&#xff08;五&#xff09;——SQL查詢 一、單表查詢 1、查詢所有字段 在SELECT語句中使用星號“”通配符查詢所有字段在SELECT語句中指定所有字段select from TStudent; 2、查詢指定字段 查詢多個字段select Sname,sex,email from TStudent; 3、查詢指定記錄…

使用生成器創建新的迭代模式

一個函數中需要有一個 yield 語句即可將其轉換為一個生成器。 def frange(start, stop, increment):x startwhile x < stop:yield xx incrementfor i in frange(0, 4, 2):print(i) # 0 2 一個生成器函數主要特征是它只會回應在迭代中使用到的 next 操作 def cutdata(n):p…

前端異常捕獲與上報

在一般情況下我們代碼報錯啥的都會覺得 下圖 然后現在來說下經常用的異常 1.try catch 這個是比較常見的異常捕獲方式通常都是 使用try catch能夠很好的捕獲異常并對應進行相應處理&#xff0c;不至于讓頁面掛掉&#xff0c;但是其存在一些弊端&#xff0c;比如需要在捕獲異常的…

Codeforces 924D Contact ATC (看題解)

Contact ATC 我跑去列方程&#xff0c; 然后就gg了。。。 我們計每個飛機最早到達時間為L[ i ], 最晚到達時間為R[ i ]&#xff0c; 對于面對面飛行的一對飛機&#xff0c; 只要他們的時間有交集則必定滿足條件。 對于相同方向飛行的飛機&#xff0c; 只有其中一個的時間包含另…

基于ZXing Android實現生成二維碼圖片和相機掃描二維碼圖片即時解碼的功能

NextQRCode ZXing開源庫的精簡版 **基于ZXing Android實現生成二維碼圖片和相機掃描二維碼圖片即時解碼的功能原文博客 附源碼下載地址** 與原ZXingMini項目對比 NextQRCode做了重大架構修改&#xff0c;原ZXingMini項目與當前NextQRCode不兼容 dependencies {compile com.gith…

flask sqlalchemy 單表查詢

主要內容: 1 sqlalchemy: 一個python的ORM框架 2 使用sqlalchemy 的流程: 創建一個類 創建數據庫引擎 將所有的類序列化成數據表 進行增刪改查操作 # 1.創建一個 Class from sqlalchemy.ext.declarative import declarative_base Base declarative_base() # Base 是 ORM模型 基…

如何在Windows 7或Vista上安裝IIS

If you are a developer using ASP.NET, one of the first things you’ll want to install on Windows 7 or Vista is IIS (internet information server). Keep in mind that your version of Windows may not come with IIS. I’m using Windows 7 Ultimate edition. 如果您…

Dubbo的使用及原理淺析

https://www.cnblogs.com/wang-meng/p/5791598.html轉載于:https://www.cnblogs.com/h-wt/p/10490345.html

ThinkPHP3.2 實現阿里云OSS上傳文件

為什么80%的碼農都做不了架構師&#xff1f;>>> 0、配置文件Config&#xff0c;加入OSS配置選項&#xff0c;設置php.ini最大上傳大小&#xff08;自行解決&#xff0c;這里不做演示&#xff09; OSS > array(ACCESS_KEY_ID > **************, //從OSS獲得的…

ipad和iphone切圖_如何在iPhone,iPad和Mac上簽名PDF

ipad和iphone切圖Khamosh PathakKhamosh PathakDo you have documents to sign? You don’t need to worry about printing, scanning, or even downloading a third-party app. You can sign PDFs right on your iPhone, iPad, and Mac. 你有文件要簽名嗎&#xff1f; 您無需…

一個頁面上有大量的圖片(大型電商網站),加載很慢,你有哪些方法優化這些圖片的加載,給用戶更好的體驗。...

a. 圖片懶加載&#xff0c;滾動到相應位置才加載圖片。 b. 圖片預加載&#xff0c;如果為幻燈片、相冊等&#xff0c;將當前展示圖片的前一張和后一張優先下載。 c. 使用CSSsprite&#xff0c;SVGsprite&#xff0c;Iconfont、Base64等技術&#xff0c;如果圖片為css圖片的話。…

[function.require]: Failed opening required 杰奇cms

在配置杰奇cms移動端的時候&#xff0c;出現了[function.require]: Failed opening required 不要慌&#xff0c;百度一下即可解決。這個就是權限問題。由于移動端要請求pc端的文件&#xff0c;沒權限。加上一個iis_iusrs讀寫權限即可搞定&#xff01;轉載于:https://www.cnblo…

在Ubuntu服務器上打開第二個控制臺會話

Ubuntu Server has the native ability to run multiple console sessions from the server console prompt. If you are working on the actual console and are waiting for a long running command to finish, there’s no reason why you have to sit and wait… you can j…

Cloudstack系統配置(三)

系統配置 CloudStack提供一個基于web的UI&#xff0c;管理員和終端用戶能夠使用這個界面。用戶界面版本依賴于登陸時使用的憑證不同而不同。用戶界面是適用于大多數流行的瀏覽器包括IE7,IE8,IE9,Firefox Chrome等。URL是:(用你自己的管理控制服務器IP地址代替) 1http://<ma…