[譯] RNN 循環神經網絡系列 2:文本分類

  • 原文地址:RECURRENT NEURAL NETWORKS (RNN) – PART 2: TEXT CLASSIFICATION
  • 原文作者:GokuMohandas
  • 譯文出自:掘金翻譯計劃
  • 本文永久鏈接:github.com/xitu/gold-m…
  • 譯者:Changkun Ou
  • 校對者:yanqiangmiffy, TobiasLee

本系列文章匯總

  1. RNN 循環神經網絡系列 1:基本 RNN 與 CHAR-RNN
  2. RNN 循環神經網絡系列 2:文本分類
  3. RNN 循環神經網絡系列 3:編碼、解碼器
  4. RNN 循環神經網絡系列 4:注意力機制
  5. RNN 循環神經網絡系列 5:自定義單元

RNN 循環神經網絡系列 2:文本分類

在第一篇文章中,我們看到了如何使用 TensorFlow 實現一個簡單的 RNN 架構。現在我們將使用這些組件并將其應用到文本分類中去。主要的區別在于,我們不會像 CHAR-RNN 模型那樣輸入固定長度的序列,而是使用長度不同的序列。

文本分類

這個任務的數據集選用了來自 Cornell 大學的語句情緒極性數據集 v1.0,它包含了 5331 個正面和負面情緒的句子。這是一個非常小的數據集,但足夠用來演示如何使用循環神經網絡進行文本分類了。

我們需要進行一些預處理,主要包括標注輸入、附加標記(填充等)。請參考完整代碼了解更多。

預處理步驟

  1. 清洗句子并切分成一個個 token;
  2. 將句子轉換為數值 token;
  3. 保存每個句子的序列長。

Screen Shot 2016-10-05 at 7.32.36 PM.png

如上圖所示,我們希望在計算完成時立即對句子的情緒做出預測。引入額外的填充符會帶來過多噪聲,這樣的話你模型的性能就會不太好。注意:我們填充序列的唯一原因是因為需要以固定大小的批量輸入進 RNN。下面你會看到,使用動態 RNN 還能避免在序列完成后的不必要計算。

模型

代碼:

class model(object):def __init__(self, FLAGS):# 占位符self.inputs_X = tf.placeholder(tf.int32,shape=[None, None], name='inputs_X')self.targets_y = tf.placeholder(tf.float32,shape=[None, None], name='targets_y')self.dropout = tf.placeholder(tf.float32)# RNN 單元stacked_cell = rnn_cell(FLAGS, self.dropout)# RNN 輸入with tf.variable_scope('rnn_inputs'):W_input = tf.get_variable("W_input",[FLAGS.en_vocab_size, FLAGS.num_hidden_units])inputs = rnn_inputs(FLAGS, self.inputs_X)#initial_state = stacked_cell.zero_state(FLAGS.batch_size, tf.float32)# RNN 輸出seq_lens = length(self.inputs_X)all_outputs, state = tf.nn.dynamic_rnn(cell=stacked_cell, inputs=inputs,sequence_length=seq_lens, dtype=tf.float32)# 由于使用了 seq_len[0],state 自動包含了上一次的對應輸出# 因為 state 是一個帶有張量的元組outputs = state[0]# 處理 RNN 輸出with tf.variable_scope('rnn_softmax'):W_softmax = tf.get_variable("W_softmax",[FLAGS.num_hidden_units, FLAGS.num_classes])b_softmax = tf.get_variable("b_softmax", [FLAGS.num_classes])# Logitslogits = rnn_softmax(FLAGS, outputs)probabilities = tf.nn.softmax(logits)self.accuracy = tf.equal(tf.argmax(self.targets_y,1), tf.argmax(logits,1))# 損失函數self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits, self.targets_y))# 優化self.lr = tf.Variable(0.0, trainable=False)trainable_vars = tf.trainable_variables()# 使用梯度截斷來避免梯度消失和梯度爆炸grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, trainable_vars), FLAGS.max_gradient_norm)optimizer = tf.train.AdamOptimizer(self.lr)self.train_optimizer = optimizer.apply_gradients(zip(grads, trainable_vars))# 下面是用于采樣的值# (在每個單詞后生成情緒)# 取所有輸出作為第一個輸入序列# (由于采樣,只需一個輸入序列)sampling_outputs = all_outputs[0]# Logitssampling_logits = rnn_softmax(FLAGS, sampling_outputs)self.sampling_probabilities = tf.nn.softmax(sampling_logits)# 保存模型的組件self.global_step = tf.Variable(0, trainable=False)self.saver = tf.train.Saver(tf.all_variables())def step(self, sess, batch_X, batch_y=None, dropout=0.0,forward_only=True, sampling=False):input_feed = {self.inputs_X: batch_X,self.targets_y: batch_y,self.dropout: dropout}if forward_only:if not sampling:output_feed = [self.loss,self.accuracy]elif sampling:input_feed = {self.inputs_X: batch_X,self.dropout: dropout}output_feed = [self.sampling_probabilities]else: # 訓練output_feed = [self.train_optimizer,self.loss,self.accuracy]outputs = sess.run(output_feed, input_feed)if forward_only:if not sampling:return outputs[0], outputs[1]elif sampling:return outputs[0]else: # 訓練return outputs[0], outputs[1], outputs[2]復制代碼

上面的代碼就是我們的模型代碼,它在訓練的過程中使用了輸入的文本。注意:為了清楚起見,我們決定將批量數據的大小保存在我們的輸入和目標占位符中,但是我們應該讓它們獨立于一個特定的批量大小之外。由于這個特定的批量大小依賴于 batch_size,如果我們這么做,那么我們就還得輸入一個 initial_state。我們通過嵌入他們來為每個數據序列來輸入 token。實踐策略表明,我們在輸入文本上使用 skip-gram 模型預訓練嵌入權重能夠取得更好的性能。

在此模型中,我們再次使用 dynamic_rnn,但是這次我們提供了sequence_length 參數的值,它是一個包含每個序列長度的列表。這樣,我們就可以避免在輸入序列的最后一個詞之后進行的不必要的計算。length 函數就用來獲取這個列表的長度,如下所示。當然,我們也可以在外面計算seq_len,再通過占位符進行傳遞。

def length(data):relevant = tf.sign(tf.abs(data))length = tf.reduce_sum(relevant, reduction_indices=1)length = tf.cast(length, tf.int32)return length復制代碼

由于我們填充符 token 為 0,因此可以使用每個 token 的 sign 性質來確定它是否是一個填充符 token。如果輸入大于 0,則 tf.sign 為 1;如果輸入為 0,則為 tf.sign 為 0。這樣,我們可以逐步通過列索引來獲得 sign 值為正的 token 數量。至此,我們可以將這個長度提供給 dynamic_rnn 了。

注意:我們可以很容易地在外部計算 seq_lens,并將其作為占位符進行傳參。這樣我們就不用依賴于 PAD_ID = 0 這個性質了。

一旦我們從 RNN 拿到了所有的輸出和最終狀態,我們就會希望分離對應輸出。對于每個輸入來說,將具有不同的對應輸出,因為每個輸入長度不一定不相同。由于我們將 seq_len 傳給了 dynamic_rnn,而 state 又是最后一個對應輸出,我們可以通過查看 state 來找到對應輸出。注意,我們必須取 state[0],因為返回的 state 是一個張量的元組。

其他需要注意的事情:我并沒有使用 initial_state,而是直接給 dynamic_rnn 設置 dtype。此外,dropout 將根據 forward_only 與否,作為參數傳遞給 step()

推斷

總的來說,除了單個句子的預測外,我還想為具有一堆樣本句子整體情緒進行預測。我希望看到的是,每個單詞都被 RNN 讀取后,將之前的單詞分值保存在內存中,從而查看預測分值是怎樣變化的。舉例如下(值越接近 0 表明越靠近負面情緒):

Screen Shot 2016-10-05 at 8.34.51 PM.png

注意:這是一個非常簡單的模型,其數據集非常有限。主要目的只是為了闡明它是如何搭建以及如何運行的。為了獲得更好的性能,請嘗試使用數據量更大的數據集,并考慮具體的網絡架構,比如 Attention 模型、Concept-Aware 詞嵌入以及隱喻(symbolization to name)等等。

損失屏蔽(這里不需要)

最后,我們來計算 cost。你可能會注意到我們沒有做任何損失屏蔽(loss masking)處理,因為我們分離了對應輸出,僅用于計算損失函數。然而,對于其他諸如機器翻譯的任務來說,我們的輸出很有可能還來自填充符 token。我們不想考慮這些輸出,因為傳遞了 seq_lens 參數的 dynamic_rnn 將返回 0。下面這個例子比較簡單,只用來說明這個實現大概是怎么回事;我們這里再一次使用了填充符 token 為 0 的性質:

# 向量化 logits 和目標
targets = tf.reshape(targets, [-1]) # 將張量 targets 轉為向量
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, targets)
mask = tf.sign.(tf.to_float(targets)) # targets 為 0 則輸出為 0, target < 0 則輸出為 -1, 否則 為 1
masked_losses = mask*losses # 填充符所在位置的貢獻為 0復制代碼

首先我們要將 logits 和 targets 向量化。為了使 logits 向量化,一個比較好的辦法是將 dynamic_rnn 的輸出向量化為 [-1,num_hidden_units] 的形狀,然后乘以 softmax 權重 [num_hidden_units,num_classes]。通過損失屏蔽操作,就可以消除填充符所在位置貢獻的損失。

代碼

GitHub 倉庫 (正在更新,敬請期待!)

張量形狀變化的參考

原始未處理過的文本 X 形狀為 [N,]y 的形狀為 [N, C],其中 C 是輸出類別的數量(這些是手動完成的,但我們需要使用獨熱編碼來處理多類情況)。

然后 X 被轉化為 token 并進行填充,變成了 [N, <max_len>]。我們還需要傳遞形狀為 [N,]seq_len 參數,包含每個句子的長度。

現在 Xseq_leny 通過這個模型首先嵌入為 [NXD],其中 D 是嵌入維度。X 便從 [N, <max_len>] 轉換為了 [N, <max_len>, D]。回想一下,X 在這里有一個中間表示,它被獨熱編碼為了 [N, <max_len>, <num_words>]。但我們并不需要這么做,因為我們只需要使用對應詞的索引,然后從詞嵌入權重中取值就可以了。

我們需要將這個嵌入后的 X 傳遞給 dynamic_rnn 并返回 all_outputs[N, <max_len>, D])以及 state[1, N, D])。由于我們輸入了 seq_lens,對于我們而言它就是最后一個對應的狀態。從維度的角度來說,你可以看到, all_outputs 就是來自 RNN 的對于每個句子中的每個詞的全部輸出結果。然而,state 僅僅只是每個句子的最后一個對應輸出。

現在我們要輸入 softmax 權重,但在此之前,我們需要通過取第一個索引(state[0])來把狀態從 [1,N,D] 轉換為[N,D]。如此便可以通過與 softmax 權重 [D,C] 的點積,來得到形狀為 [N,C] 的輸出。其中,我們做指數級 softmax 運算,然后進行正則化,最終結合形狀為 [N,C]target_y 來計算損失函數。

注意:如果你使用了基本的 RNN 或者 GRU,從 dynamic_rnn 返回的 all_outputsstate 的形狀是一樣的。但是如果使用 LSTM 的話,all_outputs 的形狀就是 [N, <max_len>, D]state 的形狀為 [1, 2, N, D]


掘金翻譯計劃 是一個翻譯優質互聯網技術文章的社區,文章來源為 掘金 上的英文分享文章。內容覆蓋 Android、iOS、React、前端、后端、產品、設計 等領域,想要查看更多優質譯文請持續關注 掘金翻譯計劃、官方微博、知乎專欄。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/257206.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/257206.shtml
英文地址,請注明出處:http://en.pswp.cn/news/257206.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

[置頂] Android開發者官方網站文檔 - 國內踏得網鏡像

Mark 一下&#xff1a; 鏡像地址&#xff1a;http://wear.techbrood.com/index.html Android DevelopTools: http://www.androiddevtools.cn/ 轉載于:https://www.cnblogs.com/superle/p/4561856.html

Java實現選擇排序

選擇排序思想就是選出最小或最大的數與第一個數交換&#xff0c;然后在剩下的數列中重復完成該動作。 package Sort;import java.util.Arrays;public class SelectionSort {public static int selectMinKey(int[] list, int beginIdx) {int idx beginIdx;int temp list[begin…

ASP.NET MVC中ViewData、ViewBag和TempData

1.ViewData 1.1 ViewData繼承了IDictionary<string, object>,因此在設置ViewData屬性時,傳入key必須要字符串型別,value可以是任意類型。 1.2 ViewData它只會存在這次的HTTP要求而已,而不像Session可以將數據帶到下HTTP要求。 public class TestController : Controller{…

java 正則表達式驗證郵箱格式是否合規 以及 正則表達式元字符

package com.ykmimi.testtest; /*** 測試郵箱地址是否合規* author ukyor**/ public class EmailTest {public static void main(String[] args) {//定義要匹配的Email地址的正則表達式//其中\w代表可用作標識符的字符,不包括$. \w表示多個// \\.\\w表示點.后面有\w 括號{2,3}…

鏡頭選型

景深&#xff1a; 光圈越大&#xff0c;光圈值越小&#xff0c;景深越小 光圈越小&#xff0c;光圈值越大&#xff0c;景深越深 焦距越長&#xff0c;視角越小&#xff0c;主體像越大&#xff0c;景深越小 主體越近&#xff0c;景深越小

迅雷賬號

賬號 jiangchnangli:1 密碼 892812 網址 http://www.s8song.net/read-htm-tid-4906661.html漫晴xydcq7681轉載于:https://www.cnblogs.com/wlzhang/p/4563118.html

【Swift學習】Swift編程之旅---ARC(二十)

Swift使用自動引用計數(ARC)來跟蹤并管理應用使用的內存。大部分情況下&#xff0c;這意味著在Swift語言中&#xff0c;內存管理"仍然工作"&#xff0c;不需要自己去考慮內存管理的事情。當實例不再被使用時&#xff0c;ARC會自動釋放這些類的實例所占用的內存。然而…

像元大小及精度

說完了光學系統的分辨率之后我們來看看相機的圖像分辨率。圖像分辨率比較好理解&#xff0c;就是單位距離內的像用多少個像素來顯示。以我們的ORCA-Flash4.0為例&#xff0c;芯片的像元大小為 6.5 μm&#xff0c;在 40X物鏡的放大倍率下&#xff0c;1 μm的物經光學系統放大為…

轉:傳入的表格格式數據流(TDS)遠程過程調用(RPC)協議流不正確 .

近期在做淘寶客的項目&#xff0c;大家都知道&#xff0c;淘寶的商品詳細描述字符長度很大&#xff0c;所以就導致了今天出現了一個問題 VS的報錯是這樣子的 ” 傳入的表格格式數據流(TDS)遠程過程調用(RPC)協議流不正確“ 還說某個desricption 過長之類的話 直覺告訴我&#…

合并bin文件-----帶boot發布版本比較好用的bat(便捷版)

直接上圖上代碼&#xff08;代碼在結尾&#xff09;&#xff0c;有不會用的可以留言&#xff1a; 第一步&#xff1a;工程介紹&#xff0c;關鍵點--- 1.bat文件放所在app和boot工程的同級目錄下 2.release為運行bat自動生成文件夾 第二步&#xff1a;合版.bat 針對具體項目需…

第五天 斷點續傳和下載

1 斷點續傳&#xff0c; 2.多線程下載原理 3.httpUtils 多線程斷點下載的使用。 ------------- 1.拿到需要下載的文件的大小&#xff0c;和需要初始的線程數 2.得到每個線程需要下載的大小&#xff0c;最后一個線程負責將剩下的數據全部下載。 3.同時需要設置一個與下載文件同大…

關于cmake從GitHub上下載的源碼啟動時報錯的問題

關于cmake從GitHub上下載的源碼啟動時報錯的問題&#xff1a; 由于cmake會產生all_build和zero_check兩個project&#xff0c;此時需要右擊鼠標將需要運行的項目設為啟動項&#xff0c;在進行編譯&#xff0c;現只針對“找不到all_build文件“的出錯信息&#xff0c;若有相關編…

一個人的Scrum之準備工作

在2012年里&#xff0c;我想自己一人去實踐一下Scrum&#xff0c;所以才有了這么一個開篇。 最近看了《輕松的Scrum之旅》這本書&#xff0c;感覺對我非常有益。書中像講述故事一樣描述了在執行Scrum過程中的點點滴滴&#xff0c; 仿佛我也跟著進行了一次成功的Scrum。同樣的&a…

Elementary OS安裝Chrome

elementary os 官方網站&#xff1a;https://elementary.io/ 這os是真好看&#xff01;首先這是基于ubuntu的&#xff0c;所以可以安裝ubuntu的軟件&#xff01; 電腦必備瀏覽器必須是chrome呀&#xff01;下載地址&#xff1a; https://www.chrome64bit.com/index.php/google…

vs+opencv編譯出現內存問題

將圖片路徑改為項目下的相對路徑&#xff0c;如 …\data\01.jpg; 其中…表示項目所在目錄的上級目錄&#xff0c;不要用絕對路徑&#xff0c;具體原因未知&#xff0c;同時&#xff0c;出現opencv_worldxxx.lib找不到情況&#xff0c;1.鏈接中依賴項是否寫錯&#xff08;英文輸…

runtime--實現篇02(Category增加屬性)

在iOS設計Category中&#xff0c;默認不能直接添加屬性&#xff0c;如果分類中通過property修飾的屬性&#xff0c;只會生成setter和getter的聲明&#xff0c; 不會生成其實現&#xff1b;因此&#xff0c;如果一定要添加屬性的話&#xff0c;需要借助runtime特性&#xff0c;通…

spark、oozie、yarn、hdfs、zookeeper、

為什么80%的碼農都做不了架構師&#xff1f;>>> spark、 oozie:任務調度 yarn:資源調度 hdfs:分布式文件系統 zookeeper、 轉載于:https://my.oschina.net/u/3709135/blog/1556661

關于halcon多區域挑選有關算法的自我理解(tuple_sort_index)

多區域根據面積挑選想要的obj area_center&#xff08;regions&#xff0c;areas&#xff09; tuple_sort_index(areas&#xff0c;indexs) tuple_sort_index算子將一組數組進行升序排列&#xff0c;然后將其在原數組的index按升序放入indexs中&#xff0c; 例如原數組areas[20…

JLOI2016 方

bzoj4558 真是一道非常excited的題目啊…JLOI有毒 題目大意&#xff1a;給一個(N1)*(M1)的網格圖&#xff0c;格點坐標為(0~N,0~M)&#xff0c;現在挖去了K個點&#xff0c;求剩下多少個正方形&#xff08;需要注意的是正方形可以是斜著的&#xff0c;多斜都可以&#xff09; N…

opencv 直方圖反向投影

轉載至&#xff1a;http://www.cnblogs.com/zsb517/archive/2012/06/20/2556508.html 直方圖反向投影式通過給定的直方圖信息&#xff0c;在圖像找到相應的像素分布區域&#xff0c;opencv提供兩種算法&#xff0c;一個是基于像素的&#xff0c;一個是基于塊的。 使用方法不寫了…