神經網絡優化(二) - 滑動平均

1 滑動平均概述

滑動平均(也稱為 影子值 ):記錄了每一個參數一段時間內過往值的平均,增加了模型的泛化性。

滑動平均通常針對所有參數進行優化:W 和 b,

簡單地理解,滑動平均像是給參數加了一個影子,參數變化,影子緩慢追隨。

滑動平均的表示公式為

影子 = 衰減率 * 影子 + ( 1 - 衰減率 ) * 參數

滑動平均值 = 衰減率 * 滑動平均值 + ( 1 - 衰減率 )* 參數

備注

影子初值 = 參數初值

衰減率 = min{ MOVING_AVERAGE_DECAY, (1+輪數) / (10 + 輪數 ) }

示例:

MOVING_AVERAGE_DECAY 為 0.99, 參數 w1 為 0,輪數 global_step 為 0,w1的滑動平均值為 0 。

參數w1更新為 1 時,則

 w1的滑動平均值 = min( 0.99, 1/10 ) * 0 + ( 1 - min( 0.99, 1/10 ) * 1 = 0.9

?假設輪數 global_step 為 100 時,參數 w1 更新為 10 時,則

w1滑動平均值 = min(0.99, 101/110) * 0.9 + ( 1 - min( 0.99, 101/110) * 10 = 1.644

再次運行

w1滑動平均值 = min(0.99, 101/110) * 1.644 + ( 1 - min( 0.99, 101/110) * 10 = 2.328

再次運行

w1滑動平均值 = 2.956

?

2 滑動平均在Tensorflow中的表示方式

第一步 實例化滑動平均類ema

ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY(滑動平均衰減率),global_step(輪數計數器,表示當前輪數)
)

備注:

MOVING_AVERAGE_DECAY 滑動平均衰減率是超參數,一般設定的值比較大;

global_step - 輪數計數器,表示當前輪數,這個參數與其他計數器公用。

第二步 求算滑動平均節點ema_op

ema_op = ema.apply([])

ema.apply([ ]) 函數表示對 [ ] 中的所有數值求滑動平均。

示例:

ema_op = ema.apply(tf.trainable_variables())

每當運行此代碼時,會對所以待優化參數進行求滑動平均運算。

第三步 具體實現方式

在工程應用中,我們通常會將計算滑動平均 ema_op 和訓練過程 train_step 綁定在一起運行,使其合成一個訓練節點,實現的代碼如下

with tf.control_dependencies([ train_step, ema_op ]):train_op = tf.no_op(name = 'train')

?

另外:

查看某參數的滑動平均值

函數ema.average(參數名) --->? 返回 ’ 參數名 ’ 的滑動平均值,

3 示例代碼

# 待優化參數w1,不斷更新w1參數,求w1的滑動平均(影子)import tensorflow as tf# 1. 定義變量及滑動平均類# 定義一個32位浮點變量并賦初值為0.0,
w1 = tf.Variable(0, dtype=tf.float32)# 輪數計數器,表示NN的迭代輪數,賦初始值為0,同時不可被優化(不參數訓練)
global_step = tf.Variable(0, trainable=False)# 設定衰減率為0.99
MOVING_AVERAGE_DECAY = 0.99# 實例化滑動平均類
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)# ema.apply()函數中的參數為待優化更新列表
# 每運行sess.run(ema_op)時,會對函數中的參數求算滑動平均值
# tf.trainable_variables()函數會自動將所有待訓練的參數匯總為待列表
# 因該段代碼中僅有w1一個參數,ema_op = ema.apply([w1])與下段代碼等價
ema_op = ema.apply(tf.trainable_variables())# 2. 查看不同迭代中變量取值的變化。
with tf.Session() as sess:# 初始化init_op = tf.global_variables_initializer()sess.run(init_op)# 用ema.average(w1)獲取w1滑動平均值 (要運行多個節點,作為列表中的元素列出,寫在sess.run中)# 打印出當前參數w1和w1滑動平均值print("current global_step:", sess.run(global_step))print("current w1", sess.run([w1, ema.average(w1)]))# 參數w1的值賦為1sess.run(tf.assign(w1, 1))sess.run(ema_op)print("current global_step:", sess.run(global_step))print("current w1", sess.run([w1, ema.average(w1)]))# 更新global_step和w1的值,模擬出輪數為100時,參數w1變為10, 以下代碼global_step保持為100,每次執行滑動平均操作,影子值會更新 sess.run(tf.assign(global_step, 100))sess.run(tf.assign(w1, 10))sess.run(ema_op)print("current global_step:", sess.run(global_step))print("current w1:", sess.run([w1, ema.average(w1)]))# 每次sess.run會更新一次w1的滑動平均值
    sess.run(ema_op)print("current global_step:", sess.run(global_step))print("current w1:", sess.run([w1, ema.average(w1)]))sess.run(ema_op)print("current global_step:", sess.run(global_step))print("current w1:", sess.run([w1, ema.average(w1)]))sess.run(ema_op)print("current global_step:" , sess.run(global_step))print("current w1:", sess.run([w1, ema.average(w1)]))sess.run(ema_op)print("current global_step:" , sess.run(global_step))print("current w1:", sess.run([w1, ema.average(w1)]))

運行

current global_step: 0
current w1 [0.0, 0.0]
current global_step: 0
current w1 [1.0, 0.9]
current global_step: 100
current w1: [10.0, 1.6445453]
current global_step: 100
current w1: [10.0, 2.3281732]
current global_step: 100
current w1: [10.0, 2.955868]
current global_step: 100
current w1: [10.0, 3.532206]
current global_step: 100
current w1: [10.0, 4.061389]

?

w1 的滑動平均值都向參數 w1 靠近。可見,滑動平均追隨參數的變化而變化。

轉載于:https://www.cnblogs.com/gengyi/p/9901502.html

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/450826.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/450826.shtml
英文地址,請注明出處:http://en.pswp.cn/news/450826.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Docker完全自學手冊

阿里云大學免費課程:Docker完全自學手冊課程介紹:Docker 是 PaaS 提供商 dotCloud 開源的一個基于 LXC 的高級容器引擎,源代碼托管在 Github 上, 基于go語言并遵從Apache2.0協議開源。Docker 是一個開源的應用容器引擎,讓開發者可…

Spring 之注解事務 @Transactional

前些天發現了一個巨牛的人工智能學習網站,通俗易懂,風趣幽默,忍不住分享一下給大家。點擊跳轉到教程。 先讓我們看代碼吧! 以下代碼為在“Spring3事務管理——基于tx/aop命名空間的配置”基礎上修改。首先修改applicationContext…

超級程序員神話

摘要:大部分的程序員在思想里都會某種程度的承認,承認自己只是一個普通的程序員,但這世界上確實有一些超級程序員,在一個為企業開發應用的程序員和一個為谷歌寫搜索算法的程序員之間,或和一個開發用來控制讀寫頭從磁盤…

HashMap30連問,徹底搞懂HashMap

文章目錄一、背景知識1、什么是Map?2、什么是Hash?3、什么是哈希表?4、什么是HashMap?5、如何使用HashMap?6、HashMap有哪些核心參數?7、HashMap與HashTable的對比?8、HashMap和HashSet的區別?…

博弈論的算法總結

開頭先啰嗦一句:想學好博弈,必然要花費很多的時間,深入學習,不要存在一知半解,應該是一看到題目,就想到博弈的類型。 以及,想不斷重復不斷重復,做大量各大oj網站的題目,最…

Slog55_lua面向對象之lua類

Slog55_lua面向對象之lua類 ArthurSlog SLog-55 Year1 GuangzhouChina Aug 30th 2018 微信掃描二維碼,關注我的公眾號GitHub 掘金主頁 簡書主頁 segmentfault 現實中的事情不是根據人的喜好而定的 比如長在你嘴里的智齒 大部分情況下 你會因為自己&#xff0…

Spring中的@scope注解

前些天發現了一個巨牛的人工智能學習網站,通俗易懂,風趣幽默,忍不住分享一下給大家。點擊跳轉到教程。 Scope 簡單點說就是用來指定bean的作用域作用域 (官方解釋:scope用來聲明IOC容器中的對象應該處的限定場景或者…

編程語言大比拼——誰的效率高

摘要:C、C、Java這幾個屹立不倒的開發語言,如果以功能點作為單位的話,誰的效率最高呢?如果在項目初期就能確定功能點數量,那么就可以很好的預測項目完成時間。這一點是不是對你很有幫助呢? 一份6000個項目的…

Hadoop之Flume詳解

1、日志采集框架Flume   1.1 Flume介紹     Flume是一個分布式、可靠、和高可用的海量日志采集、聚合和傳輸的系統。     Flume可以采集文件,socket數據包等各種形式源數據,又可以將采集到的數據輸出到HDFS、hbase、hive、     kafka等眾多…

搞懂Java的反射機制

搞懂Java的反射機制 1.什么是反射? java的反射機制是指可以在運行狀態下獲取類和對象的所有屬性和方法。 2.反射的作用? 1、在運行時獲取一個類/對象的成員變量和方法 2、在運行時創建一個類的對象 3、在運行時判斷一個對象是否屬于一個類 3.反射有哪些…

表單oninput和onchange事件區別

oninput事件是元素value發生變化是立刻觸發,而onchange是元素發生變化并且失去焦點時才會觸發。 轉載于:https://www.cnblogs.com/ykli/p/9565601.html

Struts2中<s:iterator>基本用法及示例

前些天發現了一個巨牛的人工智能學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家。點擊跳轉到教程。 Struts2中<s:iterator>基本用法及示例 Iterator用于遍歷集合&#xff08;java.util.Collection&#xff09;或枚舉值&#xff08;j…

如何使用postman做接口測試

1、get請求傳參 只要是get請求都可以在瀏覽器中直接發&#xff1a; 在訪問地址后面拼 ?keyvalue&keyvalue 例如&#xff1a;在瀏覽器中直接輸入訪問地址&#xff0c;后面直接拼需要傳給服務器的參數http://api.nnzhp.cn/api/user/stu_info?stu_name小黑2、post請求&…

【狂神說】分析前后端分離開源項目?

文章目錄1.如何分析開源項目項目簡介項目源碼2.觀察開源項目3.開源項目下載4.跑起來是第一步5.前后端分離項目固定套路6.如何找到一個開源項目1.如何分析開源項目 學習的方式&#xff1a; 不知道這個代碼怎么來的這個代碼跑不起來這個項目對我們有什么幫助&#xff0c;不會模…

設計公共API的六個注意事項

摘要&#xff1a;俗話說&#xff1a;“好東西就要貢獻出來和大家一起分享”&#xff0c;尤其是在互聯網業務高度發達的今天&#xff0c;如果你的創業公司提供了一項很酷的技術或者服務&#xff0c;并且其他用戶也非常喜歡該產品&#xff0c;在這種情況下&#xff0c;最好的解決…

go 交叉編譯

golang中windows交叉編譯 env GOOSlinux GOARCHamd64 go build .打包鏡像 FROM alpineMAINTAINER "congge"ADD ./casino_niuniu /usr/local/casino_niuniu/bin/casino_niuniu ADD ./templates /usr/loca/lcasino_niuniu/bin/templates ADD ./public /usr/local/casin…

IntelliJ Idea 2017 免費激活方法

見&#xff1a;https://www.cnblogs.com/suiyueqiannian/p/6754091.html 1. 到網站 http://idea.lanyus.com/ 獲取注冊碼。 2.填入下面的license server: http://intellij.mandroid.cn/   http://idea.imsxm.com/   http://idea.iteblog.com/key.php 以上方法驗證均可以

P3193 [HNOI2008]GT考試

傳送門 容易看出是道DP 考慮一位一位填數字 設 f [ i ] [ j ] 表示填到第 i 位&#xff0c;在不吉利串上匹配到第 j 位時不出現不吉利數字的方案數 設 g [ i ] [ j ] 表示不吉利串匹配到第 i 位&#xff0c;再添加一個數字&#xff0c;使串匹配到第 j 位的方案數 那么方程顯然為…

LeetCode刷題攻略

目錄 一、LeetCode簡介 二、刷leetcode的主要目的 三、常用的數據結構 四、常用的算法思想 五、選擇算法題 1、刷題選擇 2、刷題方法 方法一&#xff1a;順序法 方法二&#xff1a;標簽法 方法三&#xff1a;隨機法 方法四&#xff1a;必殺法 六、刷題攻略 TIP 1&…

SQLserver數據庫反編譯生成Hibernate實體類和映射文件

一、建立項目和sqlserver數據庫 eclipse&#xff0c;我使用的版本是neon3 二、Data Source Explorer 選擇OK 在data source Explorer的Database Connections 選擇New 填寫好General的連接信息 新建New Driver Definition 填寫完選擇OK 選擇剛才的Drivers Test Connetion測試 N…