TensorFlow中RNN實現的正確打開方式

上周寫的文章《完全圖解RNN、RNN變體、Seq2Seq、Attention機制》介紹了一下RNN的幾種結構,今天就來聊一聊如何在TensorFlow中實現這些結構,這篇文章的主要內容為:

  • 一個完整的、循序漸進的學習TensorFlow中RNN實現的方法。這個學習路徑的曲線較為平緩,應該可以減少不少學習精力,幫助大家少走彎路。

  • 一些可能會踩的坑

  • TensorFlow源碼分析

  • 一個Char RNN實現示例,可以用來寫詩,生成歌詞,甚至可以用來寫網絡小說!(項目地址:https://github.com/hzy46/Char-RNN-TensorFlow

一、學習單步的RNN:RNNCell

如果要學習TensorFlow中的RNN,第一站應該就是去了解“RNNCell”,它是TensorFlow中實現RNN的基本單元,每個RNNCell都有一個call方法,使用方式是:(output, next_state) = call(input, state)。

借助圖片來說可能更容易理解。假設我們有一個初始狀態h0,還有輸入x1,調用call(x1, h0)后就可以得到(output1, h1):


TensorFlow中RNN實現的正確打開方式

再調用一次call(x2, h1)就可以得到(output2, h2):

TensorFlow中RNN實現的正確打開方式

也就是說,每調用一次RNNCell的call方法,就相當于在時間上“推進了一步”,這就是RNNCell的基本功能。

在代碼實現上,RNNCell只是一個抽象類,我們用的時候都是用的它的兩個子類BasicRNNCell和BasicLSTMCell。顧名思義,前者是RNN的基礎類,后者是LSTM的基礎類。這里推薦大家閱讀其源碼實現(地址:http://t.cn/RNJrfMl),一開始并不需要全部看一遍,只需要看下RNNCell、BasicRNNCell、BasicLSTMCell這三個類的注釋部分,應該就可以理解它們的功能了。

除了call方法外,對于RNNCell,還有兩個類屬性比較重要:

  • state_size

  • output_size

前者是隱層的大小,后者是輸出的大小。比如我們通常是將一個batch送入模型計算,設輸入數據的形狀為(batch_size, input_size),那么計算時得到的隱層狀態就是(batch_size, state_size),輸出就是(batch_size, output_size)。

可以用下面的代碼驗證一下(注意,以下代碼都基于TensorFlow最新的1.2版本):

import tensorflow as tf

import numpy as np


cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)?# state_size = 128

print(cell.state_size)?# 128


inputs = tf.placeholder(np.float32, shape=(32, 100))?# 32 是 batch_size

h0 = cell.zero_state(32, np.float32)?# 通過zero_state得到一個全0的初始狀態,形狀為(batch_size, state_size)

output, h1 = cell.call(inputs, h0)?#調用call函數


print(h1.shape)?# (32, 128)

對于BasicLSTMCell,情況有些許不同,因為LSTM可以看做有兩個隱狀態h和c,對應的隱層就是一個Tuple,每個都是(batch_size, state_size)的形狀:

import tensorflow as tf

import numpy as np

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)

inputs = tf.placeholder(np.float32, shape=(32, 100))?# 32 是 batch_size

h0 = lstm_cell.zero_state(32, np.float32)?# 通過zero_state得到一個全0的初始狀態

output, h1 = lstm_cell.call(inputs, h0)


print(h1.h)??# shape=(32, 128)

print(h1.c) ?# shape=(32, 128)

二、學習如何一次執行多步:tf.nn.dynamic_rnn

基礎的RNNCell有一個很明顯的問題:對于單個的RNNCell,我們使用它的call函數進行運算時,只是在序列時間上前進了一步。比如使用x1、h0得到h1,通過x2、h1得到h2等。這樣的h話,如果我們的序列長度為10,就要調用10次call函數,比較麻煩。對此,TensorFlow提供了一個tf.nn.dynamic_rnn函數,使用該函數就相當于調用了n次call函數。即通過{h0,x1, x2, …., xn}直接得{h1,h2…,hn}。

具體來說,設我們輸入數據的格式為(batch_size, time_steps, input_size),其中time_steps表示序列本身的長度,如在Char RNN中,長度為10的句子對應的time_steps就等于10。最后的input_size就表示輸入數據單個序列單個時間維度上固有的長度。另外我們已經定義好了一個RNNCell,調用該RNNCell的call函數time_steps次,對應的代碼就是:

# inputs: shape = (batch_size, time_steps, input_size)

# cell: RNNCell

# initial_state: shape = (batch_size, cell.state_size)。初始狀態。一般可以取零矩陣

outputs, state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)

此時,得到的outputs就是time_steps步里所有的輸出。它的形狀為(batch_size, time_steps, cell.output_size)。state是最后一步的隱狀態,它的形狀為(batch_size, cell.state_size)。

此處建議大家閱讀tf.nn.dynamic_rnn的文檔(地址:https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)做進一步了解。

三、學習如何堆疊RNNCell:MultiRNNCell

很多時候,單層RNN的能力有限,我們需要多層的RNN。將x輸入第一層RNN的后得到隱層狀態h,這個隱層狀態就相當于第二層RNN的輸入,第二層RNN的隱層狀態又相當于第三層RNN的輸入,以此類推。在TensorFlow中,可以使用tf.nn.rnn_cell.MultiRNNCell函數對RNNCell進行堆疊,相應的示例程序如下:

import tensorflow as tf

import numpy as np


# 每調用一次這個函數就返回一個BasicRNNCell

def get_a_cell():
? ?return tf.nn.rnn_cell.BasicRNNCell(num_units=128)

# 用tf.nn.rnn_cell MultiRNNCell創建3層RNN

cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])?# 3層RNN

# 得到的cell實際也是RNNCell的子類

# 它的state_size是(128, 128, 128)

# (128, 128, 128)并不是128x128x128的意思

# 而是表示共有3個隱層狀態,每個隱層狀態的大小為128

print(cell.state_size)?# (128, 128, 128)

# 使用對應的call函數

inputs = tf.placeholder(np.float32, shape=(32, 100))?# 32 是 batch_size

h0 = cell.zero_state(32, np.float32)?# 通過zero_state得到一個全0的初始狀態

output, h1 = cell.call(inputs, h0)

print(h1)?# tuple中含有3個32x128的向量

通過MultiRNNCell得到的cell并不是什么新鮮事物,它實際也是RNNCell的子類,因此也有call方法、state_size和output_size屬性。同樣可以通過tf.nn.dynamic_rnn來一次運行多步。

此處建議閱讀MutiRNNCell源碼(地址:http://t.cn/RNJrfMl)中的注釋進一步了解其功能。

四、可能遇到的坑1:Output說明

在經典RNN結構中有這樣的圖:

TensorFlow中RNN實現的正確打開方式

在上面的代碼中,我們好像有意忽略了調用call或dynamic_rnn函數后得到的output的介紹。將上圖與TensorFlow的BasicRNNCell對照來看。h就對應了BasicRNNCell的state_size。那么,y是不是就對應了BasicRNNCell的output_size呢?答案是否定的。

找到源碼中BasicRNNCell的call函數實現:

def call(self, inputs, state):
? ?"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
? ?output = self._activation(_linear([inputs, state], self._num_units, True))
? ?return output, output

這句“return output, output”說明在BasicRNNCell中,output其實和隱狀態的值是一樣的。因此,我們還需要額外對輸出定義新的變換,才能得到圖中真正的輸出y。由于output和隱狀態是一回事,所以在BasicRNNCell中,state_size永遠等于output_size。TensorFlow是出于盡量精簡的目的來定義BasicRNNCell的,所以省略了輸出參數,我們這里一定要弄清楚它和圖中原始RNN定義的聯系與區別。

再來看一下BasicLSTMCell的call函數定義(函數的最后幾行):

new_c = (
? ?c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))

new_h = self._activation(new_c) * sigmoid(o)


if self._state_is_tuple:
?new_state = LSTMStateTuple(new_c, new_h)

else:
?new_state = array_ops.concat([new_c, new_h], 1)

return new_h, new_state

我們只需要關注self._state_is_tuple == True的情況,因為self._state_is_tuple == False的情況將在未來被棄用。返回的隱狀態是new_c和new_h的組合,而output就是單獨的new_h。如果我們處理的是分類問題,那么我們還需要對new_h添加單獨的Softmax層才能得到最后的分類概率輸出。

還是建議大家親自看一下源碼實現(地址:http://t.cn/RNJsJoH)來搞明白其中的細節。

五、可能遇到的坑2:因版本原因引起的錯誤

在前面我們講到堆疊RNN時,使用的代碼是:

# 每調用一次這個函數就返回一個BasicRNNCell

def get_a_cell():
? ?return tf.nn.rnn_cell.BasicRNNCell(num_units=128)

# 用tf.nn.rnn_cell MultiRNNCell創建3層RNN

cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])?# 3層RNN

這個代碼在TensorFlow 1.2中是可以正確使用的。但在之前的版本中(以及網上很多相關教程),實現方式是這樣的:

one_cell = ?tf.nn.rnn_cell.BasicRNNCell(num_units=128)

cell = tf.nn.rnn_cell.MultiRNNCell([one_cell] * 3)?# 3層RNN

如果在TensorFlow 1.2中還按照原來的方式定義,就會引起錯誤!

六、一個練手項目:Char RNN

上面的內容實際上就是TensorFlow中實現RNN的基本知識了。這個時候,建議大家用一個項目來練習鞏固一下。此處特別推薦Char RNN項目,這個項目對應的是經典的RNN結構,實現它使用的TensorFlow函數就是上面說到的幾個,項目本身又比較有趣,可以用來做文本生成,平常大家看到的用深度學習來寫詩寫歌詞的基本用的就是它了。

Char RNN的實現已經有很多了,可以自己去Github上面找,我這里也做了一個實現,供大家參考。項目地址為:hzy46/Char-RNN-TensorFlow(地址:https://github.com/hzy46/Char-RNN-TensorFlow)。代碼的部分實現來自于《安娜卡列尼娜文本生成——利用TensorFlow構建LSTM模型》

這篇專欄,在此感謝?@天雨粟?。

我主要向代碼中添加了embedding層,以支持中文,另外重新整理了代碼結構,將API改成了最新的TensorFlow 1.2版本。

可以用這個項目來寫詩(以下詩句都是自動生成的):

何人無不見,此地自何如。
一夜山邊去,江山一夜歸。
山風春草色,秋水夜聲深。
何事同相見,應知舊子人。
何當不相見,何處見江邊。
一葉生云里,春風出竹堂。
何時有相訪,不得在君心。

還可以生成代碼:

static int page_cpus(struct flags *str)
{
? ? ? ?int rc;
? ? ? ?struct rq *do_init;
};

/*
* Core_trace_periods the time in is is that supsed,
*/
#endif

/*
* Intendifint to state anded.
*/
int print_init(struct priority *rt)
{ ? ? ? /* Comment sighind if see task so and the sections */
? ? ? ?console(string, &can);
}

此外生成英文更不是問題(使用莎士比亞的文本訓練):

LAUNCE:
The formity so mistalied on his, thou hast she was
to her hears, what we shall be that say a soun man
Would the lord and all a fouls and too, the say,
That we destent and here with my peace.

PALINA:
Why, are the must thou art breath or thy saming,
I have sate it him with too to have me of
I the camples.

最后,如果你腦洞夠大,還可以來做一些更有意思的事情,比如我用了著名的網絡小說《斗破蒼穹》訓練了一個RNN模型,可以生成下面的文本:

聞言,蕭炎一怔,旋即目光轉向一旁的那名灰袍青年,然后目光在那位老者身上掃過,那里,一個巨大的石臺上,有著一個巨大的巨坑,一些黑色光柱,正在從中,一道巨大的黑色巨蟒,一股極度恐怖的氣息,從天空上暴射而出 ,然后在其中一些一道道目光中,閃電般的出現在了那些人影,在那種靈魂之中,卻是有著許些強者的感覺,在他們面前,那一道道身影,卻是如同一道黑影一般,在那一道道目光中,在這片天地間,在那巨大的空間中,彌漫而開……

“這是一位斗尊階別,不過不管你,也不可能會出手,那些家伙,可以為了這里,這里也是能夠有著一些異常,而且他,也是不能將其他人給你的靈魂,所以,這些事,我也是不可能將這一個人的強者給吞天蟒,這般一次,我們的實力,便是能夠將之擊殺……”

“這里的人,也是能夠與魂殿強者抗衡。”

蕭炎眼眸中也是掠過一抹驚駭,旋即一笑,旋即一聲冷喝,身后那些魂殿殿主便是對于蕭炎,一道冷喝的身體,在天空之上暴射而出,一股恐怖的勁氣,便是從天空傾灑而下。

“嗤!”

還是挺好玩的吧,另外還嘗試了生成日文等等。

七、學習完整版的LSTMCell

上面只說了基礎版的BasicRNNCell和BasicLSTMCell。TensorFlow中還有一個“完全體”的LSTM:LSTMCell。這個完整版的LSTM可以定義peephole,添加輸出的投影層,以及給LSTM的遺忘單元設置bias等,可以參考其源碼(地址:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L417)了解使用方法。

八、學習最新的Seq2Seq API

Google在TensorFlow的1.2版本(1.3.0的rc版已經出了,貌似正式版也要出了,更新真是快)中更新了Seq2Seq API,使用這個API我們可以不用手動地去定義Seq2Seq模型中的Encoder和Decoder。此外它還和1.2版本中的新數據讀入方式Datasets兼容。可以閱讀此處的文檔(地址:http://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq)學習它的使用方法。

九、總結

最后簡單地總結一下,這篇文章提供了一個學習TensorFlow RNN實現的詳細路徑,其中包括了學習順序、可能會踩的坑、源碼分析以及一個示例項目hzy46/Char-RNN-TensorFlow(地址:https://github.com/hzy46/Char-RNN-TensorFlow),希望能對大家有所幫助。




本文作者:Non
本文轉自雷鋒網禁止二次轉載,原文鏈接

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/287426.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/287426.shtml
英文地址,請注明出處:http://en.pswp.cn/news/287426.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【遙感物候】Hants NDVI時間序列諧波分析法數據重構,植被生長季曲線效果可佳(附Hants軟件下載)

NDVI時間序列諧波分析法(Harmonic Analysis of NDVI Time-Series)(簡稱Hants )對時間序列數據進行平滑。該方法是一種新的物候分析方法,可用于定量化的監測植被動態變化。其核心算法是傅里葉變換和最小二乘法擬合, 即把時間波譜數據分解成許多不同頻率的正弦曲線和余弦曲線,…

Android之在Java socket作為服務器里面返回數據頭部怎么寫入瀏覽器需要下載文件的文件名

1 問題 Android app里面寫了一個Java socket的簡單服務器,在瀏覽器里面輸入相應的IP和端口訪問服務器下載文件,Java socket怎么寫返回數據的頭部信息,瀏覽器才知道需要下載文件的名字呢? 2 關于Content-Disposition 在常規的HTTP應答中,Content-Disposition 響應頭指示回…

java中hasnext的作用_java中Scanner的hasNext()的疑問

第一個問題,兩段代碼的區別在于阻塞的位置不同,加上一行輸出代碼就可以很明顯地看到差別。Test.javaimport java.util.Scanner;public class Test {public static void main(String[] args) {Scanner s new Scanner(System.in);while(s.hasNext()){Syst…

《看聊天記錄都學不會C語言?太菜了吧》(2)我說編程很容易你們不服?

若是大一學子或者是真心想學習剛入門的小伙伴可以私聊我,若你是真心學習可以送你書籍,指導你學習,給予你目標方向的學習路線,無套路,博客為證。 本系列文章將會以通俗易懂的對話方式進行教學,對話中將涵蓋…

ABAP的自學之路 ,初步認識ABAP 一

由于工作的關系,最近需要對SAP系統進行二次開發,于是開始學習ABAP。鑒于網上對于ABAP的資料少之又少,所以自己整理一些資料。 第一章 ABAP 開發環境和總體介紹1.1 ABAP 開發環境ABAP 開發的三種環境:(1)SAP…

LCD1602,4位數據總線液晶屏時鐘,STC12C5A60S2的10位ADC功能程序

/* 程序名:    LCD1602,4位數據總線液晶屏時鐘,STC12C5A60S2的10位ADC功能程序 編寫時間:  2015年10月4日 硬件支持:  LCD1602液晶屏 STC12C5A60S2 外部12MHZ晶振 接線定義: DB7 --> P1^7DB6…

WPF|黑暗模式的錢包支付儀表盤界面設計

收集下大家的意見,是否需要在文中貼上源碼(文末會給出源碼鏈接),請大家踴躍留言。閱讀目錄效果展示準備簡單說明 源碼結尾(視頻及源碼倉庫)1. 效果展示欣賞效果:2. 準備創建一個WPF工程&#x…

量子計算機的現狀和趨勢

量子計算機概述 計算機是一種新型的運算 它具有具有強大的并行處理數據的能力,可解決現有計算機難以運算的數學問題。因此,它成為世界各國戰略競爭的焦點。 量子計算機的優勢 量子計算機與現有的電子計算機以及正在研究的光計算機,生物計算機…

【空間數據庫】Windows操作系統PostgreSQL+PostGIS環境搭建圖文安裝教程

PostgreSQL是一種特性非常齊全的自由軟件的對象-關系型數據庫管理系統(ORDBMS),PostgreSQL支持大部分的SQL標準并且提供了很多其他現代特性,如復雜查詢、外鍵、觸發器、視圖、事務完整性、多版本并發控制等。同樣,PostgreSQL也可以用許多方法擴展,例如通過增加新的數據類…

Android之gravity=“center_vertical“和layout_gravity=“center“的效果

1、兩控件分別加上2個下面的屬性 gravity="center_vertical" android:layout_gravity="center" 代碼如下 <LinearLayoutandroid:id="@+id/ll_no_love"android:layout_width="match_parent"android:layout_height="match…

《看聊天記錄都學不會C語言?太菜了吧》(3)人艱不拆,代碼都在談戀愛?!

若是大一學子或者是真心想學習剛入門的小伙伴可以私聊我&#xff0c;若你是真心學習可以送你書籍&#xff0c;指導你學習&#xff0c;給予你目標方向的學習路線&#xff0c;無套路&#xff0c;博客為證。 本系列文章將會以通俗易懂的對話方式進行教學&#xff0c;對話中將涵蓋…

spark java 計數_spark程序——統計包含字符a或者b的行數

本篇分析一個spark例子程序。程序實現的功能是&#xff1a;分別統計包含字符a、b的行數。java源碼如下&#xff1a;package sparkTest;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import…

golang reflect

reflect包實現了運行時反射&#xff0c;允許程序操作任意類型的對象。典型用法是用靜態類型interface{}保存一個值&#xff0c;通過調用TypeOf獲取其動態類型信息&#xff0c;該函數返回一個Type類型值。調用ValueOf函數返回一個Value類型值&#xff0c;該值代表運行時的數據。…

DB2常用命令

查看DB2License信息 DB2基礎命令 轉載于:https://www.cnblogs.com/arcer/p/5573317.html

.NET7 Preview4之MapGroup

這篇是“聞(看)香(碼)識(學)女(技)人(術)”。這也是一個有意思的功能&#xff0c;路由分組&#xff0c;啥也不說了&#xff0c;看代碼看結果&#xff1a;using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.OpenApi;var builder WebApplication.Create…

【空間數據庫】ArcGIS 10.6 Database_Server_Desktop安裝、連接數據庫服務、創建企業級數據庫(附server10.6.ecp)

由于作者一直使用SQL Server 2008 R2開發版,之前在ArcGIS中創建企業級數據庫都是基于單獨安裝的SQL Server 2008 R2開發版,今天我們演示安裝ArcGIS10.6自帶的數據庫服務(SQL Server 2014 Express版本)、連接數據庫服務和創建數據庫。 首先,我們來看一下完整的ArcGIS10.6安…

(一)easyUI之樹形網絡

樹形網格&#xff08;TreeGrid&#xff09;可以展示有限空間上帶有多列和復雜數據電子表 一、案例一&#xff1a;按tree的數據結構來生成 前臺<% page language"java" contentType"text/html; charsetUTF-8"pageEncoding"UTF-8"%> <!DO…

《看聊天記錄都學不會C語言?太菜了吧》(4)零基礎的我原來早就學會編程了?

若是大一學子或者是真心想學習剛入門的小伙伴可以私聊我&#xff0c;若你是真心學習可以送你書籍&#xff0c;指導你學習&#xff0c;給予你目標方向的學習路線&#xff0c;無套路&#xff0c;博客為證。 本系列文章將會以通俗易懂的對話方式進行教學&#xff0c;對話中將涵蓋…

Android之華為平板打日志提示Permission denied

1 問題 $ adb logcat | grep ssfsafaf int logctl_get(): open /dev/hwlog_switch fail -1, 13. Permission deniedNote: log switch off, only log_main and log_events will have logs!2 解決辦法 1&#xff09;、如果是華為手機&#xff0c;打開手機的撥號界面&#xff0c…

二叉樹結構 codevs 1029 遍歷問題

codevs 1029 遍歷問題 時間限制: 1 s空間限制: 128000 KB題目等級 : 鉆石 Diamond題目描述 Description我們都很熟悉二叉樹的前序、中序、后序遍歷&#xff0c;在數據結構中常提出這樣的問題&#xff1a;已知一棵二叉樹的前序和中序遍歷&#xff0c;求它的后序遍歷&#xff0c;…