【tensorflow】static_rnn與dynamic_rnn的區別

static_rnn和dynamic_rnn的區別主要在于實現不同。

  • static_rnn會把RNN展平,用空間換時間。 gpu會吃不消(個人測試結果)

  • dynamic_rnn則是使用for或者while循環。

調用static_rnn實際上是生成了rnn按時間序列展開之后的圖。打開tensorboard你會看到sequence_length個rnn_cell
stack在一起,只不過這些cell是share
weight的。因此,sequence_length就和圖的拓撲結構綁定在了一起,因此也就限制了每個batch的sequence_length必須是一致。

調用dynamic_rnn不會將rnn展開,而是利用tf.while_loop這個api,通過Enter, Switch, Merge,
LoopCondition, NextIteration等這些control
flow的節點,生成一個可以執行循環的圖(這個圖應該還是靜態圖,因為圖的拓撲結構在執行時是不會變化的)。在tensorboard上,你只會看到一個rnn_cell,
外面被一群control
flow節點包圍著。對于dynamic_rnn來說,sequence_length僅僅代表著循環的次數,而和圖本身的拓撲沒有關系,所以每個batch可以有不同sequence_length。

static_rnn

導包、加載數據、定義變量
import tensorflow as tf
tf.reset_default_graph() #流式計算圖形graph  循環神經網絡 將名字相同重置了圖
import datetime #打印時間
import os   #保存文件
from tensorflow.examples.tutorials.mnist import input_data# minst測試集
mnist = input_data.read_data_sets('../', one_hot=True)# 每次使用100條數據進行訓練
batch_size = 100
# 圖像向量
width = 28
height = 28
# LSTM隱藏神經元數量
rnn_size = 256
# 輸出層one-hot向量長度的
out_size = 10

聲明變量

def weight_variable(shape, w_alpha=0.01):initial = w_alpha * tf.random_normal(shape)return tf.Variable(initial)def bias_variable(shape, b_alpha=0.1):initial = b_alpha * tf.random_normal(shape)return tf.Variable(initial)# 權重及偏置
w = weight_variable([rnn_size, out_size])
b = bias_variable([out_size])

將數據轉化成RNN所要求的數據

# 按照圖片大小申請占位符
X = tf.placeholder(tf.float32, [None, height, width])
# 原排列[0,1,2]transpose為[1,0,2]代表前兩維裝置,如shape=(1,2,3)轉為shape=(2,1,3)
# 這里的實際意義是把所有圖像向量的相同行號向量轉到一起,如x1的第一行與x2的第一行
x = tf.transpose(X, [1, 0, 2])
# reshape -1 代表自適應,這里按照圖像每一列的長度為reshape后的列長度
x = tf.reshape(x, [-1, width])
# split默任在第一維即0 dimension進行分割,分割成height份,這里實際指把所有圖片向量按對應行號進行重組
x = tf.split(x, height)

構建靜態的循環神經網絡

# LSTM
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
# 這里RNN會有與輸入層相同數量的輸出層,我們只需要最后一個輸出
outputs, status = tf.nn.static_rnn(lstm_cell, x, dtype=tf.float32)#取最后一個進行矩陣乘法
y_conv = tf.add(tf.matmul(outputs[-1], w), b)
# 最小化損失優化
Y = tf.placeholder(dtype=tf.float32,shape = [None,10])
#損失使用的交叉熵
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_conv, labels=Y))
optimizer = tf.train.AdamOptimizer(0.01).minimize(loss)
# 計算準確率
correct = tf.equal(tf.argmax(y_conv, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

模型的訓練

# 啟動會話.開始訓練
saver = tf.train.Saver()
session = tf.Session()
session.run(tf.global_variables_initializer())
step = 0
acc_rate = 0.90
while 1:batch_x, batch_y = mnist.train.next_batch(batch_size)batch_x = batch_x.reshape((batch_size, height, width))session.run(optimizer, feed_dict={X:batch_x,Y:batch_y})# 每訓練10次測試一次if step % 10 == 0:batch_x_test = mnist.test.imagesbatch_y_test = mnist.test.labelsbatch_x_test = batch_x_test.reshape([-1, height, width])acc = session.run(accuracy, feed_dict={X: batch_x_test, Y: batch_y_test})print(datetime.datetime.now().strftime('%c'), ' step:', step, ' accuracy:', acc)# 偏差滿足要求,保存模型if acc >= acc_rate:
#             os.sep = ‘/’model_path = os.getcwd() + os.sep + str(acc_rate) + "mnist.model"saver.save(session, model_path, global_step=step)breakstep += 1
session.close()

Wed Dec 18 10:08:45 2019 step: 0 accuracy: 0.1006
Wed Dec 18 10:08:46 2019 step: 10 accuracy: 0.1009
Wed Dec 18 10:08:46 2019 step: 20 accuracy: 0.1028

Wed Dec 18 10:08:57 2019 step: 190 accuracy: 0.9164

dynamic_rnn

加載數據,聲明變量
import tensorflow as tf
tf.reset_default_graph()
from tensorflow.examples.tutorials.mnist import input_data# 載入數據
mnist = input_data.read_data_sets("../", one_hot=True)# 輸入圖片是28
n_input = 28
max_time = 28
lstm_size = 100  # 隱藏單元 可調
n_class = 10  # 10個分類
batch_size = 100   # 每次50個樣本 可調
n_batch_size = mnist.train.num_examples // batch_size    # 計算一共有多少批次

Extracting …/train-images-idx3-ubyte.gz
Extracting …/train-labels-idx1-ubyte.gz
Extracting …/t10k-images-idx3-ubyte.gz
Extracting …/t10k-labels-idx1-ubyte.gz

占位符、權重

# 這里None表示第一個維度可以是任意長度
# 創建占位符
x = tf.placeholder(tf.float32,[None, 28*28])
# 正確的標簽
y = tf.placeholder(tf.float32,[None, 10])# 初始化權重 ,stddev為標準差
weight = tf.Variable(tf.truncated_normal([lstm_size, n_class], stddev=0.1))
# 初始化偏置層
biases = tf.Variable(tf.constant(0.1, shape=[n_class]))

構建動態RNN、損失函數、準確率

# 定義RNN網絡
def RNN(X, weights, biases):#  原始數據為[batch_size,28*28]# input = [batch_size, max_time, n_input]input_ = tf.reshape(X,[-1, max_time, n_input])# 定義LSTM的基本單元
#     lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)# final_state[0] 是cell state# final_state[1] 是hidden statoutputs, final_state = tf.nn.dynamic_rnn(lstm_cell, input_, dtype=tf.float32)display(final_state)results = tf.nn.softmax(tf.matmul(final_state[1],weights)+biases)return results
# 計算RNN的返回結果
prediction = RNN(x, weight, biases)
# 損失函數
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y))
# 使用AdamOptimizer進行優化
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
# 將結果存下來
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
# 計算正確率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

LSTMStateTuple(c=<tf.Tensor ‘rnn/while/Exit_3:0’ shape=(?, 100) dtype=float32>, h=<tf.Tensor ‘rnn/while/Exit_4:0’ shape=(?, 100) dtype=float32>)

訓練數據

saver = tf.train.Saver()with tf.Session() as sess:sess.run(tf.global_variables_initializer())for epoch in range(6):for batch in range(n_batch_size):# 取出下一批次數據batch_xs,batch_ys = mnist.train.next_batch(batch_size)sess.run(train_step, feed_dict={x: batch_xs,y: batch_ys})if(batch%100==0):print(str(batch)+"/" + str(n_batch_size))acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})print("Iter" + str(epoch) + " ,Testing Accuracy = " + str(acc))if acc >0.9:saver.save(sess,'./rnn_dynamic')break

0/550
100/550
200/550
300/550
400/550
500/550
Iter0 ,Testing Accuracy = 0.5903

Iter5 ,Testing Accuracy = 0.9103

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/456006.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/456006.shtml
英文地址,請注明出處:http://en.pswp.cn/news/456006.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

pcie1 4 速度_太陽系行星們誰轉得最快?八大行星自轉速度排行榜,地球排第五...

不知道大家有沒有玩兒過陀螺呢&#xff1f;玩兒陀螺的技術如果很好的話&#xff0c;它可以在地上飛快地旋轉并且能夠旋轉很長的時間。有趣的是&#xff0c;宇宙中的很多星球就像陀螺一樣繞著一個中心軸旋轉著。這就是星球的自轉。在太陽系中有八顆大行星&#xff0c;它們都在自…

python中時間模塊

時間日期相關的模塊 calendar 日歷模塊time   時間模塊datetime 日期時間模塊timeit   時間檢測模塊 日歷模塊 calendar() 功能&#xff1a;獲取指定年份的日歷字符串 格式&#xff1a;calendar.calendar&#xff08;年份,w2,l1&#xff0c;c6,m3&#xff09; 返回值&…

硬盤接口詳細解釋

硬盤是電腦主要的存儲媒介之一&#xff0c;由一個或者多個鋁制或者玻璃制的碟片組成。碟片外覆蓋有鐵磁性材料。硬盤有固態硬盤&#xff08;SSD 盤&#xff0c;新式硬盤&#xff09;、機械硬盤&#xff08;HDD 傳統硬盤&#xff09;、混合硬盤&#xff08;HHD 一塊基于傳統機械…

【Keras】30 秒上手 Keras+實例對mnist手寫數字進行識別準確率達99%以上

本文我們將學習使用Keras一步一步搭建一個卷積神經網絡。具體來說&#xff0c;我們將使用卷積神經網絡對手寫數字(MNIST數據集)進行識別&#xff0c;并達到99%以上的正確率。 為什么選擇Keras呢&#xff1f; 主要是因為簡單方便。更多細節請看&#xff1a;https://keras.io/ …

分布式資本沈波:未來區塊鏈殺手級應用將出現在“+區塊鏈”

雷鋒網5月22日報道&#xff0c;日前“區塊鏈技術和應用峰會”在杭州國際博覽中心舉行。會上&#xff0c;分布式資本創始管理人沈波作了《區塊鏈的投資現狀與發展趨勢》演講。 沈波表示&#xff0c;由于區塊鏈的共識機制和無法篡改兩大特點&#xff0c;它在各行各業皆有應用潛力…

幀間預測小記

幀間預測后&#xff0c;在比特流中會有相應的信息&#xff1a;殘差信息&#xff0c;運動矢量信息&#xff0c;所選的模式。 宏塊的色度分量分辨率是亮度分辨率的一半&#xff08;Cr和Cb&#xff09;&#xff0c;水平和垂直均一半。色度塊采用和亮度塊一致的分割模式&#xff0…

ImageJ Nikon_科研論文作圖之ImageJ

各位讀者朋友們又見面了&#xff0c;今天給大家介紹一款圖片處理軟件——ImageJ&#xff0c;這是一款免費的科學圖像分析工具&#xff0c;廣泛應用于生物學研究領域。ImageJ軟件能夠對圖像進行縮放、旋轉、扭曲、模糊等處理&#xff0c;也可計算選定區域內分析對象的一系列幾何…

python中面向對象

面向對象 Object Oriented 面向對象的學習&#xff1a; 面向對象的語法&#xff08;簡單&#xff0c;記憶就可以搞定&#xff09;面向對象的思想&#xff08;稍難&#xff0c;需要一定的理解&#xff09; 面向過程和面向對象的區別 面向過程開發&#xff0c;以函數作為基本結構…

【urllib】url編碼問題簡述

對url編解碼總結 需要用到urllib庫中的parse模塊 import urllib.parse # Python3 url編碼 print(urllib.parse.quote("天天")) # Python3 url解碼 print(urllib.parse.unquote("%E5%A4%E5%A4%")) urlparse() # urllib.parse.urlparse(urlstring,scheme,…

冷知識 —— 地理

西安1980坐標系&#xff1a; 1978 年 4 月在西安召開全國天文大地網平差會議&#xff0c;確定重新定位&#xff0c;建立我國新的坐標系。為此有了 1980 國家大地坐標系。1980 國家大地坐標系采用地球橢球基本參數為 1975 年國際大地測量與地球物理聯合會第十六屆大會推薦的數據…

獨家| ChinaLedger白碩:區塊鏈中的隱私保護

隱私問題一直是區塊鏈應用落地的障礙問題之一&#xff0c;如何既能滿足監管&#xff0c;又能不侵害數據隱私&#xff0c;是行業都在攻克的問題。那么&#xff0c;到底隱私問題為何難&#xff1f;有什么解決思路&#xff0c;以及實踐創新呢&#xff1f;零知識證明、同態加密等技…

手機處理器排行榜2019_手機處理器AI性能排行榜出爐,高通驍龍第一,華為排在第十名...

↑↑↑擊上方"藍字"關注&#xff0c;每天推送最新科技新聞安兔兔在近日公布了今年四月份Android手機處理器AI性能排行榜&#xff0c;榜單顯示高通驍龍865處理器的AI性能在Android陣營中排在第一名——該處理器的AI性能得分接近46萬分&#xff0c;今年的小米10、三星G…

芯片支持的且會被用到的H.264特性 預測編碼基本原理

視頻壓縮&#xff1a; 1.H.264基本檔次和主要檔次&#xff1b;2.CAVLC熵編碼&#xff0c;即基于上下文的自適應變長編碼&#xff1b;&#xff08;不支持CABAC&#xff0c;即基于上下文的自適應算術編碼&#xff09;分辨率&#xff1a;僅用到1080p60&#xff0c;即分辨率為1920*…

MongoDB 數據庫 【總結筆記】

一、MongoDB 概念解析 什么是MongoDB&#xff1f; ? 1、MongoDB是有C語言編寫的&#xff0c;是一個基于分布式文件存儲的開源數據庫系統&#xff0c;在高負載的情況下&#xff0c;添加更多節點&#xff0c;可以保證服務器的性能 ? 2、MongoDB為web應用提供了高性能的數據存儲…

PHP 函數截圖 哈哈哈

轉載于:https://www.cnblogs.com/bootoo/p/6714676.html

python中的魔術方法

魔術方法 魔術方法就是一個類/對象中的方法&#xff0c;和普通方法唯一的不同時&#xff0c;普通方法需要調用&#xff01;而魔術方法是在特定時刻自動觸發。 1.__init__ 初始化魔術方法 觸發時機&#xff1a;初始化對象時觸發&#xff08;不是實例化觸發&#xff0c;但是和實…

2016年光伏電站交易和融資的十大猜想

1領跑者計劃備受關注&#xff0c;競價上網或從試點開始 領跑者計劃規模大&#xff0c;上網條件好&#xff0c;又有政府背書&#xff0c;雖說價格也不便宜&#xff0c;但省去很多隱性成本&#xff0c;對于致力于規模化發展的大型企業來說仍是首要選擇。同時&#xff0c;從能源管…

loading gif 透明_搞笑GIF:有這樣的女朋友下班哪里都不想去

原標題&#xff1a;搞笑GIF&#xff1a;有這樣的女朋友下班哪里都不想去這樣的廣場舞看著不涼快嗎&#xff1f;大哥慢點&#xff0c;機器經受不住你這樣的速度求孩子的心里陰影面積生孩子就是用來玩的。有這樣的媳婦做飯&#xff0c;下班哪里也不想去1.領導在門外用門夾核桃&am…

Redis數據庫 【總結筆記】

一、NoSql&#xff08;非關系型數據庫&#xff09; NoSQL&#xff1a;NoSQL Not Only SQL 非關系型數據庫 ? NoSQL&#xff0c;泛指非關系型的數據庫。隨著互聯網web2.0網站的興起&#xff0c;傳統的關系數據庫在應付web2.0網站&#xff0c;特別是超大規模和高并發的SNS類型…

基于IP的H.264關鍵技術

一、 引言 H.264是ITU-T最新的視頻編碼標準&#xff0c;被稱作ISO/IEC14496-10或MPEG-4 AVC&#xff0c;是由運動圖像專家組(MPEG)和ITU的視頻編碼專家組共同開發的新產品。H.264分兩層結構&#xff0c;包括視頻編碼層和網絡適配層。視頻編碼層處理的是塊、宏塊和片的數據&…