持久化的基于L2正則化和平均滑動模型的MNIST手寫數字識別模型

持久化的基于L2正則化和平均滑動模型的MNIST手寫數字識別模型

覺得有用的話,歡迎一起討論相互學習~Follow Me

參考文獻Tensorflow實戰Google深度學習框架
實驗平臺:
Tensorflow1.4.0
python3.5.0
MNIST數據集將四個文件下載后放到當前目錄下的MNIST_data文件夾下

定義模型框架與前向傳播

import tensorflow as tf# 定義神經網絡結構相關參數
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500# 設置權值函數
# 在訓練時會創建這些變量,在測試時會通過保存的模型加載這些變量的取值
# 因為可以在變量加載時將滑動平均變量均值重命名,所以這個函數可以直接通過同樣的名字在訓練時使用變量本身
# 而在測試時使用變量的滑動平均值,在這個函數中也會將變量的正則化損失加入損失集合def get_weight_variable(shape, regularizer):weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))# 如果使用正則化方法會將該張量加入一個名為'losses'的集合if regularizer != None: tf.add_to_collection('losses', regularizer(weights))return weights# 定義神經網絡前向傳播過程
def inference(input_tensor, regularizer):with tf.variable_scope('layer1'):weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)with tf.variable_scope('layer2'):weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))layer2 = tf.matmul(layer1, weights) + biasesreturn layer2

模型訓練與模型框架及參數持久化

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import os# 配置神經網絡參數
BATCH_SIZE = 100  # 批處理數據大小
LEARNING_RATE_BASE = 0.8  # 基礎學習率
LEARNING_RATE_DECAY = 0.99  # 學習率衰減速度
REGULARIZATION_RATE = 0.0001  # 正則化項
TRAINING_STEPS = 30000  # 訓練次數
MOVING_AVERAGE_DECAY = 0.99  # 平均滑動模型衰減參數
# 模型保存的路徑和文件名
MODEL_SAVE_PATH = "MNIST_model/"
MODEL_NAME = "mnist_model"def train(mnist):# 定義輸入輸出placeholderx = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')  # 可以直接引用mnist_inference中的超參數y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')# 定義L2正則化器regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)# 在前向傳播時使用L2正則化y = mnist_inference.inference(x, regularizer)global_step = tf.Variable(0, trainable=False)# 在可訓練參數上定義平均滑動模型variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)# tf.trainable_variables()返回的是圖上集合GraphKeys.TRAINABLE_VARIABLES中的元素。這個集合中的元素是所有沒有指定trainable=False的參數variables_averages_op = variable_averages.apply(tf.trainable_variables())cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))cross_entropy_mean = tf.reduce_mean(cross_entropy)# 在交叉熵函數的基礎上增加權值的L2正則化部分loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))# 設置學習率,其中學習率使用逐漸遞減的原則learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples/BATCH_SIZE, LEARNING_RATE_DECAY,staircase=True)# 使用梯度下降優化器train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)# with tf.control_dependencies([train_step, variables_averages_op]):# train_op = tf.no_op(name='train')# 在反向傳播的過程中,不僅更新神經網絡中的參數還更新每一個參數的滑動平均值train_op = tf.group(train_step, variables_averages_op)# 定義Saver模型保存器saver = tf.train.Saver()with tf.Session() as sess:tf.global_variables_initializer().run()for i in range(TRAINING_STEPS):xs, ys = mnist.train.next_batch(BATCH_SIZE)_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})# 每1000輪保存一次模型if i%1000 == 0:# 輸出當前的訓練情況,這里只輸出了模型在當前訓練batch上的損失函數大小# 通過損失函數的大小可以大概了解訓練的情況,# 在驗證數據集上的正確率信息會有一個單獨的程序來生成print("After %d training step(s), loss on training batch is %g."%(step, loss_value))# 模型保存saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)def main(argv=None):mnist = input_data.read_data_sets("./MNIST_data", one_hot=True)train(mnist)if __name__ == '__main__':tf.app.run()

模型恢復與評價測試集上的效果

import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_train# 每10秒加載一次最新的模型
# 加載的時間間隔。
EVAL_INTERVAL_SECS = 10def evaluate(mnist):with tf.Graph().as_default() as g:x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}# 直接通過調用封裝好的函數來計算前向傳播的結果,因為測試時不關注正則化損失的值所以這里用于計算正則化損失的函數被設置為Noney = mnist_inference.inference(x, None)correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))# 如果需要離線預測未知數據的類別,只需要將計算正確率的部分改為答案的輸出即可。accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 通過獲取變量重命名的方式來加載模型,這樣在前向傳播的過程中就不需要調用滑動平均的函數來獲取平均值# 這樣可以完全共用mnist_inference.py重定義的前向傳播過程variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)while True:with tf.Session() as sess:# tf.train.get_checkpoint_state函數會通過checkpoint文件自動找到目錄中最新模型的文件名ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)if ckpt and ckpt.model_checkpoint_path:# 加載模型saver.restore(sess, ckpt.model_checkpoint_path)# 通過文件名得到模型保存是迭代的輪數global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]accuracy_score = sess.run(accuracy, feed_dict=validate_feed)print("After %s training step(s), validation accuracy = %g"%(global_step, accuracy_score))else:print('No checkpoint file found')returntime.sleep(EVAL_INTERVAL_SECS)# 每次運行都是讀取最新保存的模型,并在MNIST驗證數據集上計算模型的正確率# 每隔EVAL_INTERVAL_SECS秒來調用一側計算正確率的過程以檢驗訓練過程中的正確率變化# ###  主程序
def main(argv=None):mnist = input_data.read_data_sets("./MNIST_data", one_hot=True)evaluate(mnist)if __name__ == '__main__':main()# After 29001 training step(s), validation accuracy = 0.9854

轉載于:https://www.cnblogs.com/cloud-ken/p/9318037.html

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/277643.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/277643.shtml
英文地址,請注明出處:http://en.pswp.cn/news/277643.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

怎樣制作滴滴截圖_滴滴老了嗎?

作者 / 薛靜 來源 / 盒飯財經(ID:daxiongfan)滴滴最近有點忙。6月11日,滴滴地圖與公交事業部負責人柴華還在忙于解答消費者對于滴滴司機繞路的質疑,網上就流傳出了滴滴司機直播性侵的消息。當晚,滴滴急忙在官方微博中做出回應稱已…

mysql Backup recovery

如果您要在MySQL數據庫中存儲任何您不想丟失的內容,那么定期備份數據以保護數據免受損失非常重要。本教程將向您展示兩種簡單的方法來備份和恢復MySQL數據庫中的數據。您還可以使用此過程將數據移動到新的Web服務器。 從命令行備份(使用mysqldump&#x…

Kinect開發筆記之三Kinect開發環境配置詳解

0、前言:首先說一下我的開發環境,Visual Studio是2013的,系統是win8的64位版本,SDK是Kinect for windows SDK 1.8版本。雖然前一篇博文費了半天勁,翻譯了2.0SDK的新特性,但我還是決定要回退一個版本。其實我…

opencv python 圖像縮放/圖像平移/圖像旋轉/仿射變換/透視變換

Geometric Transformations of Images 1圖像轉換 OpenCV提供了兩個轉換函數cv2.warpAffine和cv2.warpPerspective,可以使用它們進行各種轉換。 cv2.warpAffine采用2x3變換矩陣,而cv2.warpPerspective采用3x3變換矩陣作為輸入。 2圖像縮放 縮放只是調整圖…

.net調用c++方法時如何釋放c++中分配的內存_C/C++編程筆記:C語言編程知識要點總結!大一C語言知識點(全)...

一、C語言程序的構成與C、Java相比,C語言其實很簡單,但卻非常重要。因為它是C、Java的基礎。不把C語言基礎打扎實,很難成為程序員高手。1、C語言的結構先通過一個簡單的例子,把C語言的基礎打牢。C語言的結構要掌握以下幾點&#x…

Django 使用 mysql 數據庫連接

啟用 mysql 數據庫連接 修改 app01 下的 __init__.py import pymysqlpymysql.install_as_MySQLdb() 修改 settings.py DATABASES {default: {ENGINE: django.db.backends.mysql,NAME: django,USER: django,PASSWORD: django,HOST: 192.168.0.200,PORT: 3306,} } 測試 #生成同步…

Kinect開發筆記之四檢測并調試Kinect設備

之前我們已經裝好了Developer Toolkit 1.8,下面我們來做進一步的測試。首先到開始菜單中找到Kinect for Windows SDK v1.8,點擊其中的Developer Toolkit Browser v1.8.0。打開后,有許多東西,我們選擇最右邊的Tools來篩選一下&…

c語言雙引號和單引號的區別_Python中的單引號和雙引號有什么區別?

在Python中使用單引號或雙引號是沒有區別的,都可以用來表示一個字符串。但是這兩種通用的表達方式可以避免出錯之外,還可以減少轉義字符的使用,使程序看起來更清晰。舉兩個例子:1、包含單引號的字符串定義一個字符串m…

mysql 開發基礎系列22 SQL Model(帶遷移事項)

一.概述 與其它數據庫不同,mysql 可以運行不同的sql model 下, sql model 定義了mysql應用支持的sql語法,數據校驗等,這樣更容易在不同的環境中使用mysql。 sql model 常用來解決下面幾類問題: (1) 通過設置sql mode, …

五月28學習筆記

<!DOCTYPE html><html> <head> <meta charset"UTF-8"> <title></title> </head> <body> <!--鏈接標簽--> <!--核心屬性就是href 屬性值可以是一個跳轉的地址--&…

Kinect開發筆記之五使用PowerShell控制Kinect

這是第一次用MarkDown編輯器來寫博客&#xff0c;挺喜歡這種沒有任何格式舒服的編輯器&#xff0c;自由灑脫更加易讀&#xff0c;留一個不自然的自然段紀念下找到舒服的編輯器。 這次要記錄使用win7/win8內建的PowerShell來控制Kinect&#xff0c;改變Kinect的俯仰角度。 在我…

可轉債數據一覽表集思錄_可轉債股票數據一覽表

128107交科轉債720612061浙江交科-11.90%25113578全筑轉債754030603030全筑股份-1.26%3.84113573縱橫轉債754602603602縱橫通信5.79%2.7113577春秋轉債754890603890春秋電子-9.46%2.4123050聚飛轉債370303300303聚飛光電2.52%7.05110070凌鋼轉債733231600231凌鋼股份24.44%4.41…

國標流媒體H5實現無插件視頻監控按需直播

介紹 按需直播肯定是為了減少帶寬流量和服務器性能占用。安防行業GB28181協議天生就是按需播放的&#xff0c;有人請求播放時服務端才從設備端獲取設備的直播流或錄像視頻&#xff0c;停止播放時就會停止獲取視頻流。同時GB28181協議又是目前安防設備廠商都支持的統一的協議&am…

ipa 安裝包不用市場如果掃碼下載安裝 免費IOS安裝API

在做開發過程中可能會用于生成測試包的情況,不過測試包不能直接安裝,非常不方便,所以我提供給大家一下可通過鏈接下載安裝的方法也可以把鏈接生成二維碼掃碼下載 api地址: https://tool.bitefu.net/ipa/ 文件地址:http://tool.bitefu.net/showdoc/web/#/3 源碼下載:http://tado…

Kinect開發筆記之六Kinect Studio的應用

這一次我們來操作一下Kinect Studio&#xff0c;體驗一下它給我們帶來的功能。 首先我們需要打開Developer Toolkit Browser 1.8&#xff0c;打開后在默認情況下&#xff0c;光標是選擇在All選項卡上的&#xff0c;即我們現在所有Developer Toolkit Browser中的部件都可以看得…

antd picker 使用 如何_如何打造 Serverless JavaScript 全棧商業級應用?

2019 年底我們發布過一篇《O’Reilly 1500 份問卷調研&#xff1a;2019 年 Serverless 落地到底香不香&#xff1f;》&#xff0c;揭示了海外 Serverless 的落地情況&#xff0c;但中國 Serverless 的落地實踐分享相對較少&#xff0c;似乎誰都在喊 Serverless&#xff0c;誰都…

【Android Studio安裝部署系列】十三、Android studio添加和刪除Module 2

版權聲明&#xff1a;本文為HaiyuKing原創文章&#xff0c;轉載請注明出處&#xff01; 概述 新建、導入、刪除Module是常見的操作&#xff0c;這里簡單介紹下。 新建Module File——New——New Module... 選中Android Library 修改Library名稱 在項目工程中修改依賴 和添加下面…

Kinect開發筆記之七Visual Studio結合C#調控Kinect俯仰角度

總感覺自己前面啰啰嗦嗦寫了好多&#xff0c;卻一直都沒有使用用開發kinect的重型武器——Visual Studio。 那么本次我們就借助于Visual Studio&#xff0c;寫一個C#程序&#xff0c;連接Kinect并調用Kinect SDK標準函數庫來改變Kinect的俯仰角。 首先我們打開VS創建一個項目…

hadoop HDFS常用文件操作命令

命令基本格式: hadoop fs -cmd < args >1.ls hadoop fs -ls /列出hdfs文件系統根目錄下的目錄和文件 hadoop fs -ls -R /列出hdfs文件系統所有的目錄和文件 2.put hadoop fs -put < local file > < hdfs file >hdfs file的父目錄一定要存在&#xff0c;否則…

定量庫存控制模型_探索全面流動管理TFM 庫存控制與低減的理性策略

庫存乃萬惡之源庫存不僅占用了資金&#xff0c;還占用了各種管理性資源&#xff0c;形成了“財務性顯性成本“而且過多的庫存導致“緩沖區”的存在&#xff0c;還使得各類問題變得不那么緊迫&#xff0c;從而掩蓋了各類隱藏的問題&#xff0c;這被稱為“隱形成本”零庫存不僅做…