Attention Is All You Need論文閱讀筆記

Attention is All You Need是如今機器學習研究者必讀的論文,該文章提出的Transformer架構是如今很多機器學習項目的基礎,說該文章極大推動了機器學習領域的研究也不為過。
但這么重要,也是必讀的文章對初學者來說其實并不友好,很多前置知識和背景可能因為篇幅原因并沒有詳細介紹,故本文參考周奕帆的解讀,Transformer注意力以及illustrated transformer的同時,再補充更多基礎知識,希望讓機器學習的初學者也能很快讀懂這篇文章。

前置知識

循環神經網絡RNN

普通神經網絡的大致結構為輸入層=>隱含層=>輸出層,層與層之間全連接,但層內的節點互相沒有連接,所以該網絡只能處理順序遞進的任務,而無法處理前后相關聯的,比如翻譯等任務。

為解決同一層次節點的關聯問題,RNN誕生了,該網絡會記憶前面節點的信息并用于當前輸出的計算。即隱藏層之間的節點不再無連接而是有連接的,并且隱藏層的輸入不僅包括輸入層的輸出還包括上一時刻隱藏層的輸出。

針對翻譯任務,其解決過程大致如下:

I want to eat a hamburger, it is so delicious
-->我想吃
-->我想吃漢堡
-->我想吃漢堡,它/他/她(代指誰要看前面語境,RNN的主要結構創新,用前面的輸出作為下一層的輸入)
-->我想吃漢堡,漢堡實在太好吃了

該網絡的多種結構如下圖,這些不同結構的應用與詳解可見RNN詳解:
RNN多種結構

編碼器與解碼器Encoder-Decoder

原始RNN結構在處理翻譯任務時只能實現等長的輸入輸出,然而翻譯任務的輸入輸出在大多數情景下都是不等長的。
為此人們設計了一種新的架構,該架構將RNN網絡分成兩部分,前半段只有輸入,后半段只有輸出,中間通過一個狀態來轉遞信息,這就是編碼器解碼器結構Encoder-Decoder
Encoder-Deocder結構
該結構不是具體的網絡,而是一種框架的統稱,編碼器Encoder負責將輸入轉換為固定長度的向量,解碼器Decoder負責將固定長度的向量轉換成輸出,因為該結構不限制輸入和輸出長度,所以應用十分廣泛,應用場景有:

1,機器翻譯:Encoder-Decoder的最經典應用,事實上該結構就是在機器翻譯領域最先提出的。
2,文本摘要:輸入是一段文本序列,輸出是這段文本序列的摘要序列。
3,閱讀理解:將輸入的文章和問題分別編碼,再對其進行解碼得到問題的答案。
4,語音識別:輸入是語音信號序列,輸出是文字序列。

注意力模型

經過上面的介紹我們可能想到,輸入和輸出之間只有一個固定長度的中間向量作為連接,當輸入較長時,該定長向量還能否保證信息完整性,這就是RNN的不足。

為此,有人提出了注意力機制attention mechanism,該機制仍然使用編碼器和解碼器,但在不同時間輸入多個中間向量來解決問題,解碼器的輸入是編碼器結果的加權和,而非簡單的中間向量,每個輸入對輸出的權重叫注意力,注意力的大小取決于輸入輸出的相關性。

注意力框架從結構上的變化如下圖:
注意力框架
以翻譯為例,數學計算上的結構如下圖,其中 a i j a_{ij} aij?表示相關性, h i h_i hi? i i i階段輸入, c i c_i ci? i i i階段中間向量:
注意力翻譯的數學結構
注意力機制優化了解碼器與編碼器的信息交流方式,處理長文章時更有效,但所有基于RNN的結構也都面臨同一個問題:因為本輪輸入取決于上一輪的狀態,所以訓練過程必須是線性執行的,RNN的訓練速度較慢

基于上述研究現狀,放棄RNN,完全使用注意力機制的Transformer問世了。

文章摘要

摘要中提到序列傳導模型中性能最好的是用注意力機制連接編碼器和解碼器的框架,但該框架仍基于復雜的遞歸和卷積神經網絡,故文章提出了一種完全基于注意力機制,無需遞歸和卷積網絡的框架。實驗結果顯示模型的質量上乘,并且訓練時間明顯減少。

由前置知識中的注意力模型我們知道,RNN網絡雖然解決了元素相關依賴的問題,但是其輸入取決于上一輪狀態,訓練過程的效率較低。

本文引言中將該問題描述為:“這種固有的 Sequences 性質排除了訓練樣本中的并行化,這在較長的序列長度下變得至關重要,因為內存約束限制了樣本之間的批處理。最近的工作通過因式分解技巧和條件計算實現了計算效率的顯著提高,同時也提高了后者情況下的模型性能。但是,順序計算的基本約束仍然存在。”
該描述引出了Transfomer要解決的問題——提高并行訓練效率

隨后介紹了注意力機制:"注意力機制已成為各種任務中引人注目的序列建模和轉導模型不可或缺的一部分,允許對依賴關系進行建模,而不考慮它們在輸入或輸出序列中的距離。然而,除了少數情況外,這種注意力機制都與循環網絡結合使用。在這項工作中,我們提出了 Transformer,這是一種避免重復出現的模型架構,而是完全依賴注意力機制來繪制輸入和輸出之間的全局依賴關系。Transformer 允許顯著提高并行化水平,并且在 8 個 P100 GPU 上經過短短 12 小時的訓練后,就可以達到翻譯質量的新水平。"
上述內容指出注意力機制可以無視序列中的距離,作者依據該機制提出了一種全新的名為Transformer的架構,可以繪制輸入輸出的依賴關系,訓練速度更快,表現也更好。

也就是放棄RNN網絡,僅使用注意力機制構造新的結構,以繪制全局依賴關系

注意力機制

感性認識

論文自頂向下介紹設計,但讀者可能并不了解模塊,一下引入過多新概念容易勸退讀者,所以本文跟隨周奕帆的思路,介紹順序修改為自底向上,首先介紹論文的關鍵機制——注意力。

論文首先對注意力作了大致解釋:注意力函數可以描述為將查詢和一組鍵值對映射到輸出,其中查詢、鍵、值和輸出都是向量。其中輸出由值的權重和構成。像是數據庫的查詢操作。

在學習具體算法之前我們通過一個案例對注意力機制先形成一個感性認識:
比如有如下鍵值對姓名key-年齡value的數據集合張三:18, 李四:22, 張五:25,要執行查詢query為:姓張年齡的平均數,查詢條件為key[0]==‘張'

但在神經網絡中很多操作可能并沒有現實意義,只是單純的數學運算,所以我們可以將一切都轉換為數字和向量。比如原有的數據集合可轉變為[1,0,0]:18, [0,2,1]:22, [1,3,2]:25,假設表示為1,那原有對姓張的查詢就可以表示為向量[1,0,0],使用該查詢向量與集合向量作點積可得到權重向量[1,0,1],如果該權重之和大于1,可用激活函數softmax進行歸一化處理,結果為[1/2,0,1/2],獲得的權重向量中的數值就是對不同key的注意力

該過程用數學表示如下:

[1,0,0]:18 # 張三
[0,2,1]:22 # 李四
[1,3,2]:25	# 張五# 查詢向量[1,0,0]
# 查詢過程,最后獲取權重
dot([1,0,0],[1,0,0])=1 
dot([0,2,1],[1,0,0])=0
dot([1,3,2],[1,0,0])=1# 歸一化處理,獲對每個key的注意力
softmax([1,0,1])=[1/2,0,1/2]
dot([1/2,0,1/2],[18,22,25])=21.5
平均年齡為21.5

到這里我們也會有感覺,這注意力明明是一次查詢匹配的歸一化結果,重點應該在查而不在注意上,周也認為該機制稱為“全局信息查詢”才更合理。

縮放點積注意力Scaled Dot-Product Attention

Q Q Q代表查詢向量組,可由多個查詢如 q = [ 1 , 0 , 0 ] q=[1,0,0] q=[1,0,0]的組成;
K K K是鍵key的向量組, V V V則是鍵值value的向量組,用數學符號表示感性認識中的操作為:
m y a t t e n t i o n ( q , K , V ) = s o f t m a x ( q K T ) V myattention(q,K,V)=softmax(qK^T)V myattention(q,K,V)=softmax(qKT)V

有關softmax的知識可見深度學習基礎,該算法用于多分類問題,將結果歸一化成概率可能性,歸一化后的結果數值為正和為1。

論文中的注意力機制與我們的感性認識大致相同,文章將其稱為“縮放點積注意力”,大名鼎鼎的注意力公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k} } )V Attention(Q,K,V)=softmax(dk? ?QKT?)V
多個查詢 q q q的組合就是 Q Q Q,我們的感性認識和該公式只差一步 d k d_k dk?
論文還引入了兩個變量,分別是鍵的維度 d k d_k dk?,鍵值的維度為 d v d_v dv?,在該公式中,因為 Q Q Q K K K要作點積預算,所以 Q Q Q的維度要和 d k d_k dk?一致,而 d v d_v dv?則無要求。

所以 a t t e n t i o n attention attention的本質就是在我們感性認識的基礎上用 k k k的維度對注意力進行了縮放,那為什么要進行這一步操作呢?這涉及深度學習中激活函數的特性,周的解釋為:“softmax在絕對值較大的區域梯度較小,梯度下降的速度比較慢。因此,我們要讓被softmax的點乘數值盡可能小。而一般dk在較大時,也就是向量較長時,點乘的數值會比較大。除以一個和dk相關的量能夠防止點乘的值過大。”

注意力函數除了本文提到點積的方法,還有加性注意力,效率上點積更快一些,故該論文選了這種方法。

縮放點積注意力的執行流程圖如下:
縮放點積注意力流程圖
1,輸入 Q , K , V Q,K,V Q,K,V M a t M u l MatMul MatMul矩陣乘法,計算查詢 Q Q Q和鍵 K K K的相關度,多采用點積計算;
2, S c a l e Scale Scale縮放,除 d k \sqrt{d_k} dk? ?;
3, M a s k Mask Mask掩碼,屏蔽后續信息,將在整體框架中介紹;
4, S o f t M a x SoftMax SoftMax歸一化;
5,最后將查詢歸一化結果與鍵值 V V V作乘法,計算注意力結果。

自注意力Self-Attention

自注意力在文中介紹很少,但在注意力機制中又是很關鍵的內容,所以展開學習一下。

1,感性認識
在翻譯場景下,注意力機制用于實現英文到中文的翻譯轉換,自注意力則是完成序列內詞義的確定,,比如”The animal didn't cross the street because it was too tired”,如何確定后面的it是指animal還是street呢?通過自注意力機制可確定it的翻譯要重點關注animal

有關該部分詳細解釋可參考一文搞定自注意力機制。
自注意力識別語義
公式上變為如下形式:
s e l f ? a t t e n t i o n ( X ) = s o f t m a x ( x x T d k ) X self-attention(X)=softmax(\frac{xx^T}{\sqrt{d_k} } )X self?attention(X)=softmax(dk? ?xxT?)X
其中 X X X是序列中的所有元素, d k d_k dk?是向量的維度。
所以自注意力本質是注意力機制的一種特殊形式,每個輸入元素與序列中的其他元素進行交互,計算得出自己的自身的表示,其實就是輸入token[i]和整個輸入token的注意力計算

2,具體算法
1,自注意力中, Q K V QKV QKV分別由輸入 x x x和三個隨機矩陣作點積得到,這三個隨機矩陣是網絡的參數之一,會在訓練中不斷調整。

2,獲取 Q K V QKV QKV矩陣后,使用注意力公式計算得到注意力 S S S,以animal為例,注意力矩陣 S S S經過softmax歸一化操作后可能是[0.88,0.02,0.05,0.05,0]

3,此時與整個序列作點積,該操作會獲得一個向量并沖淡與該詞關聯性不強的其他單詞。

4,將第三步獲得的序列相加就是最終得到的注意力結果。

舉一個具體例子:
有三個輸入分別為: i n p u t 1 = [ 1 , 0 , 1 ] i n p u t 2 = [ 1 , 1 , 0 ] i n p u t 3 = [ 0 , 0 , 1 ] input1=[1,0,1] input2=[1,1,0] input3=[0,0,1] input1=[1,0,1]input2=[1,1,0]input3=[0,0,1],組合形成的輸入為 i n p u t = [ 1 0 1 1 1 0 0 0 1 ] input=\begin{bmatrix} 1 & 0 & 1\\ 1 & 1 & 0\\ 0 & 0 &1 \end{bmatrix} input= ?110?010?101? ?隨機初始化的參數矩陣假設為
W Q = [ 1 0 3 2 1 0 0 0 1 ] W K = [ 1 1 0 2 1 1 1 0 1 ] W V = [ 1 0 0 0 1 1 1 1 0 ] W^Q=\begin{bmatrix} 1 & 0 & 3\\ 2 & 1 & 0\\ 0 & 0 &1 \end{bmatrix} W^K=\begin{bmatrix} 1 & 1&0 \\ 2& 1& 1\\ 1 & 0&1 \end{bmatrix} W^V=\begin{bmatrix} 1& 0 & 0\\ 0& 1 &1 \\ 1& 1&0 \end{bmatrix} WQ= ?120?010?301? ?WK= ?121?110?011? ?WV= ?101?011?010? ?
計算得到的 Q , K , V Q,K,V Q,K,V分別為:
Q = [ 1 0 4 3 1 3 0 0 1 ] K = [ 2 1 1 3 2 1 1 0 1 ] V = [ 2 1 0 1 1 1 1 1 0 ] Q=\begin{bmatrix} 1 & 0 & 4\\ 3 & 1 & 3\\ 0 & 0 &1 \end{bmatrix} K=\begin{bmatrix} 2 & 1 & 1\\ 3 & 2 & 1\\ 1 & 0 &1 \end{bmatrix} V=\begin{bmatrix} 2 & 1 & 0\\ 1 & 1 & 1\\ 1 & 1 &0 \end{bmatrix} Q= ?130?010?431? ?K= ?231?120?111? ?V= ?211?111?010? ?
使用注意力公式 Q K T V QK^TV QKTV可得所有輸入的自注意力結果,為了簡化計算,只使用 i n p u t 1 input1 input1的查詢 q 1 q_1 q1?作為輸入進行計算,過程如下
q 1 × K T = [ 1 , 0 , 4 ] × [ 2 3 1 1 2 0 1 0 1 ] = [ 6 , 3 , 5 ] q_1\times K^T=[1, 0, 4]\times\begin{bmatrix} 2& 3&1 \\ 1& 2&0 \\ 1& 0 &1 \end{bmatrix}=[6, 3 ,5] q1?×KT=[1,0,4]× ?211?320?101? ?=[6,3,5]
將該向量 [ 6 , 3 , 5 ] [6, 3, 5] [6,3,5]softmax處理后再與V點積運算并相加求和即獲得 i n p u t 1 input1 input1的自注意力結果,也就是其更新后的自身表示。

自注意力計算過程示例圖如下:
自注意力計算圖示

通過對自身所在序列進行注意力計算有什么作用呢?
背景中我們提到,注意力機制的出現是為了彌補EncoderDecoder之間用于聯系的定長向量過于脆弱,在長程輸入下的失效問題。
我的理解是,自注意力利用注意力機制使每個輸入都能獲得自己與輸入整體中其他元素的關聯,使單個內容與上下文的聯系更加緊,單獨元素也攜帶上下文信息。

多頭注意力Multi-Head Attention

自注意力模式下,所有輸入都只關注序列中其他位置的向量與自身的關系,也就是注意力都在自身,對整個序列信息的關注就可能相對不足,論文作者提出了多頭注意力機制解決這一問題,即將學習到不同線性投影的查詢、鍵和值,線性投影到 dk、dk 和 dv 維度是有益的。多頭注意力允許模型共同關注來自不同位置的不同表示子空間的信息,緩和單個注意力頭的波動
多頭注意力計算流程圖
其實就是作者發現組合多個自注意力結果的效果更好,通過 C o n c a t Concat Concat組合器將多個縮放點積注意力組合起來生成效果,修改后的數學表示為:
h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i K ) head_i=Attention(QW_i^Q, KW^K_i, VW_i^K ) headi?=Attention(QWiQ?,KWiK?,VWiK?)
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , . . . , h e a d i ) W O MultiHead(Q, K, V )=Concat(head_1, ... , head_i)W^O MultiHead(Q,K,V)=Concat(head1?,...,headi?)WO

其中 d m o d e l d_{model} dmodel?為輸出長度,即每個輸入會映射為 d m o d e l d_{model} dmodel?個輸出,默認設置為512, n n n個輸入的維度為 n × d m o d e l n\times d_{model} n×dmodel?
W Q , W K W^Q, W^K WQ,WK形狀為 d m o d e l × d k d_{model}\times d_k dmodel?×dk? W V W^V WV形狀為 d m o d e l × d v d_{model}\times d_v dmodel?×dv? W O W^O WO形狀為 h d v × d m o d e l hd_{v}\times d_{model} hdv?×dmodel?

通常注意力頭個數 h h h選擇為8,維度 d k = d v = d m o d e l / h = 64 d_k=d_v=d_{model}/h=64 dk?=dv?=dmodel?/h=64 h e a d i head_i headi?的維度為 n × d k × d k × n × n × d v = n × d v n\times d_{k}\times d_k \times n\times n\times d_v=n\times d_v n×dk?×dk?×n×n×dv?=n×dv?,最后得到的 M u l t i H e a d MultiHead MultiHead維度為 n × d m o d e l n\times d_{model} n×dmodel?
而如果不適用多頭縮小 h h h倍,單頭計算維度為 n × d m o d e l × d m o d e l × d m o d e l = n × d m o d e l n\times d_{model} \times d_{model}\times d_{model}=n\times d_{model} n×dmodel?×dmodel?×dmodel?=n×dmodel?,可見多頭將維度縮小為原來的 1 h \frac{1}{h} h1?,但組合了 h h h個結果,所以總計算成本與全維的單頭部注意力相似。

多頭注意力計算組合如下圖:
多頭注意力計算組合
該過程可同時訓練多組 W Q , W K , W V W^Q, W^K, W^V WQ,WK,WV提高關聯性的同時增加了自注意力表示的豐富度

Transformer模型

了解了注意力機制,我們回過頭來學習Transformer的整體架構,Transformer 遵循這種整體架構,編碼器和解碼器都使用堆疊式自注意力層和逐點全連接層,分別如圖的左半部分和右半部分所示
Transformer整體架構

主干結構

輸入輸出的處理我們暫且擱置,先只看主干結構:
Transformer主干結構
圖中橙色部分的Multi-Head Attention我們已經了解了,文中提到編碼器解碼器分別由 N = 6 N=6 N=6個相同層組成,主干結構的大致流程為:
輸入 i n p u t s inputs inputs經過編碼器處理進入解碼器,上一輪的輸出 O u t p u t s Outputs Outputs直接進入解碼器,經過單獨一層處理后與編碼器輸入一同處理生成新的輸出。
但對模型的主干結構還有三個疑問:
1,Add&Norm是做什么的?
2,Feed Forward的功能是什么?
3,右圖下方的多頭注意力為什么加了Masked?

原文中提到:
編碼器部分第一個是多頭自注意力機制,第二個是簡單的、位置完全連接的前饋網絡。我們在兩個子層中的每一個周圍采用殘差連接,然后進行層歸一化
而解碼器部分除了每個編碼器層中的兩個子層之外,解碼器還插入了第三個子層,該子層對編碼器堆棧的輸出執行多頭注意。我們還修改了解碼器堆棧中的 self-attention 子層,以防止 positions 關注后續位置。這種掩碼,再加上輸出嵌入向量偏移一個位置的事實,確保對位置 i 的預測只能依賴于小于 i 的位置的已知輸出

也就是說:
Add&Norm是殘差連接和歸一化;
Feed Forward是完全連接的前饋網絡;
Masked是為了保證輸出不受預測影響。

下面依次對這些知識進行補充介紹。

殘差連接add&Norm

本小節參考ResNet核心思想。

殘差連接的基礎是跳躍連接(skip connection),指的是越過某些中間層,將數據直接添加到中間節點上。
該操作使得信息可以更自由地流動,并且保留了原始輸入數據中的細節和語義信息。 使信息更容易傳播到后面的層次,避免了信息丟失。
比如網絡層層遞進的結構中,尤其卷積操作,輸入細節可能不斷丟失,隨著層數的深入,學習效果反而可能變差。
此時添加一個跳躍連接,將輸入與最后一層拼接起來,可以保留輸入的更多原始語義和信息。

跳躍連接示意圖如下:
跳躍連接
殘差網絡ResNet則是專門用于解決網絡層數過深導致的梯度消失和訓練困難問題,其核心思想是引入殘差塊構建網絡,并使用跳躍連接將輸入直接添加到輸出層上

該部分并非Transformer核心內容,有關原理講解可見ResNet模型詳解。

對該模型作大致介紹,輸入 x x x,經過卷積層、批歸一化(BN)、激活函數(ReLU)等一通操作構成 F ( x ) F(x) F(x),輸出為 H ( x ) = F ( x ) + x H(x)=F(x)+x H(x)=F(x)+x通過跳躍連接將輸入與之相加,最后歸一化操作。

本文使用的方法與之類似,子層的輸出是 L a y e r N o r m ( x + S u b l a y e r ( x ) ) LayerNorm(x + Sublayer(x)) LayerNorm(x+Sublayer(x)),其中 S u b l a y e r ( x ) Sublayer(x) Sublayer(x)是由子層本身實現的函數, L a y e r N o r m LayerNorm LayerNorm是歸一化方法。

殘差連接要求 x x x F ( x ) + x F(x)+x F(x)+x等長,所以在Transformer中模型子層和嵌入層的輸入都是 d m o d e l = 512 d_{model}=512 dmodel?=512

前饋網絡Feed Forward

就是一個網絡,其中包括兩個線性變換,中間有一個 ReLU 激活,執行公式為:
F N N ( X ) = m a x ( 0 , x W 1 + b 1 ) W 2 + b 2 FNN(X)=max(0,xW_1+b_1)W_2+b_2 FNN(X)=max(0,xW1?+b1?)W2?+b2?
其中隱藏層維數 d f f = 2048 d_{ff}=2048 dff?=2048

因為注意力機制本質是線性的矩陣變換,該網絡 F N N FNN FNN通過激活函數引入非線性特征,使模型能擬合復雜函數關系;同時將 d m o d e l = 512 d_{model}=512 dmodel?=512維上升到 d f f = 2048 d_{ff}=2048 dff?=2048維,擴展特征空間,方便更深層次提取特征。

右移操作shifted right

在講掩碼多頭注意力之前不得不轉到主干結構之外,解碼器的另一個輸入來自上一個解碼器的輸出,但該 o u t p u t s outputs outputs使用了一次shifted right操作,這步操作有什么含義?
shifted right操作
首先看一下Transformer對注意力機制的使用:
1,在編碼器解碼器層中,查詢 Q Q Q來自前一個解碼器,鍵 K K K和值 V V V來自編碼器。編碼器與輸入 i n p u t input input一一對應。
2,編碼器使用自注意力。該層注意力的 Q , K , V Q, K, V Q,K,V都來自前一層編碼器。
3,解碼器使用自注意力,該機制使整個序列對當前生成的元素都可見。

這種機制有兩個問題
1,當編碼器接受第0個字符時,上一層的輸出該如何設置使編碼器解碼器處理對齊?
2,Transformer并行訓練時,序列 t 0 , . . . , t k t_0, ..., t_k t0?,...,tk?都進入解碼器,但此時才輸出到 t 2 t_2 t2?,如何避免后續序列對輸出的干擾。

shiftedright右移操作解決第一個對齊問題。

該操作將輸出序列整體右移一位,解碼器訓練數據為 1 , 2 , 3 1,2,3 1,2,3,整體右移變為 < s o s > , 1 , 2 , 3 <sos>,1,2,3 <sos>,1,2,3,此時編碼器處理第0個數據, < s o s > <sos> <sos>就可以作為上一個解碼器的結果輸出給當前解碼器,實現對齊

掩碼多頭注意力Masked Multi-Head Attention

該機制解決前面提到的第二個問題,輸出序列當前步 i i i預測輸出應該只根據該步以前的結果,即 i ? 1 i-1 i?1實現,但用于解碼器訓練的輸出序列全部輸入解碼器中,為防止后續元素對解碼過程的干擾,應該把后續元素給“蓋住”,類似如下過程:
掩碼注意力類比
掩碼注意力的mask就是實現“蓋”的功能,論文中使用將注意力 s o f t m a x softmax softmax的輸入設置為 ? ∞ -\infty ?實現,該設置導致權重為0,被遮住的輸出就全為0。

訓練并行,推理串行

首先聲明一個結論,Transformer是訓練并行,推理串行的結構

前面提到的第二個問題我反反復復看了好幾遍,假如有輸入序列 < s o s > , 1 , 2 , 3 <sos>, 1, 2, 3 <sos>,1,2,3要進入Transformer,編碼器輸出 K i , V i K_i, V_i Ki?,Vi?給解碼器,解碼器接收上一步解碼器輸出 Q i ? 1 Q_{i-1} Qi?1?與本次編碼輸出 K i , V i K_i, V_i Ki?,Vi?得出本次輸出 Q i Q_i Qi?,這樣一個串行執行的結構怎么會涉及避免后續元素的干擾問題?

我忽略了一個關鍵問題,Transformer就是為了改進RNN串行訓練速度慢才誕生的,所以其訓練過程一定是并行的。
對于這個問題很多資料一帶而過,本節單獨對該內容進行詳細說明。

傳統RNN網絡的訓練過程應該是輸入序列 X X X輸入編碼器,生成 K , V K, V K,V給解碼器,解碼器從 < s o s > <sos> <sos>開始不斷推理生成預測 P P P,最后與正確輸出 P P P進行比較,使用損失函數更新編碼器解碼器參數,大致流程如下:
RNN訓練流程圖
該方案的缺陷是序列P要逐步生成并代替 < s o s > <sos> <sos>,這種串行的訓練模式效率十分低下。

傳統RNN采用該方案的原因是當前生成的元素無法有效關聯其他元素,必須逐步生成,將已有元素輸入下一步預測節點。
而Transformer使用注意力機制,生成的每個元素都清楚知道自己和其他元素的關系,所以對任意元素的預測都可以獨立進行。注意力允許Transformer同步生成所有序列,即同時生成 t 0 , t 1 , . . . , t n t_0, t_1, ..., t_n t0?,t1?,...,tn?

取消遞推訓練后的解碼器上一步輸入 Q Q Q又由誰決定呢?Transformer使用一種名為“教師強迫”Teacher forcing的訓練模式,即每個節點都將正確的輸出作為已知的序列進行預測,使用這種模式可以實現所有節點無需等待上一步預測,獨立進行訓練。

但正確的輸出過于完整,節點 i i i應該只基于 0 ? i 1 0-i_1 0?i1?個序列進行預測,所以作者采用了掩碼注意力機制,將i及以后的輸出對當前節點隱藏

假設有一訓練集為x=我想吃漢堡,p=I want to eat a hamburger,Transformer訓練過程如下:
transformer并行訓練
該過程比較復雜,再次簡要總結:
1,注意力建立序列關聯性,使訓練無需等待前一次結果獨立進行;
2,Teacher forcing機制讓解碼器將正確輸出作為上一次結果,實現并行訓練;
3,掩碼注意力機制為當前節點屏蔽后續正確輸出,使訓練依據局限在當前序列。

如果還不能理解可以看transformer如何實現并行化,該文章角度中,編碼器并行由注意力機制實現,解碼器并行由Teacher forcing和掩碼注意力共同作用實現。

但推理過程仍然是串行的,解碼器不斷迭代生成序列用于下一次輸出。

嵌入層Embedding

主干網絡之后我們來看模型對輸入輸出的處理:
輸入輸出處理
和其他序列轉換模型一樣,Transformers也使用詞嵌入處理輸入輸出,將其轉換為維度 d m o d e l d_{model} dmodel?的向量。
詞嵌入就是將詞轉換為向量的方法,如貓=>[0,1,1] 狗=>[0,0,1],轉換后的形式更容易被計算機識別并計算,有關具體方法可見從0開始詞嵌入。

隨后使用通常的線性變換和softmax將解碼器輸出轉換為預測的下一個標記概率。

文中提到:兩個嵌入層和線性變換層之間共享相同的權重矩陣,并在嵌入層中將權重乘 d m o d e l \sqrt{d_{model}} dmodel? ?

位置編碼Positional Encoding

Transformer現在只剩下一個位置沒有解讀了,就是輸入后的位置編碼:
位置編碼
因為Transformer中不含遞歸和卷積,注意力機制為了并行計算也只具有全局查詢功能,因此為了讓模型利用序列的順序,必須注入一些關于序列中標記的相對或絕對位置的信息。
為此,作者設計了一種“位置編碼”,將序列順序信息直接與編碼器和解碼器輸入的 E m b e d d i n g Embedding Embedding相加,保證位置信息參與模型的后續運算,所以真正的輸入 i n p u t = e m b e d d i n g + p o s i t i o n a l e n c o d i n g input=embedding+positional encoding input=embedding+positionalencoding

位置編碼與詞嵌入具有相同維度 d m o d e l d_{model} dmodel?,使用 p o s pos pos表示位置, i i i表示維度,
作者使用了不同頻率的正弦和余弦函數作為位置編碼:
P E ( p o s , 2 i ) = s i n ( p o s / 10000 2 i / d m o d e l ) PE(pos,2i)=sin(pos/10000^{2i/d_{model}}) PE(pos,2i)=sin(pos/100002i/dmodel?)
P E ( p o s , 2 i + 1 ) = c o s ( p o s / 10000 2 i / d m o d e l ) PE(pos,2i+1)=cos(pos/10000^{2i/d_{model}}) PE(pos,2i+1)=cos(pos/100002i/dmodel?)

比如輸入dog eat food,每個詞對應的 p o s = 0 , 1 , 2 pos=0,1,2 pos=0,1,2,假設轉換的 E m b e d d i n g Embedding Embedding維度是20,計算過程示例如下:
位置編碼計算

在該公式中,當 i = 0 i=0 i=0即低維情景下,頻率 1 10000 2 i / d m o d e l = 1 \frac{1}{10000^{2i/d_{model}}}=1 100002i/dmodel?1?=1,當 i = d m o d e l i=d_{model} i=dmodel?即高維場景下,頻率等于 1 10000 1 2 ≈ 0.01 \frac{1}{10000^{\frac{1}{2}}}\approx 0.01 1000021?1?0.01

也就是低維頻率大,高維頻率小,這進一步導致低維位置編碼變化大,高維位置編碼變化小,相加得到的結果輸入網絡后會使網絡通過低維位置編碼關注局部差異,又通過高維建立序列整體聯系

這種設計同時可保證編碼唯一,并且不依賴最大序列長度,支持任意長度的輸入。
同時作者還嘗試了可學習的位置編碼函數,發現二者結果幾乎相同,最終選擇了該版本也是因為其支持任意長度的輸入序列,可學習的位置編碼只能處理訓練時的序列長度。

有關位置編碼更詳細的介紹可參考詳解自注意力機制中的位置編碼。

實驗結果

作者對英語-德語,英語-法語兩項任務進行了訓練,使用8張P100顯卡,每個訓練步驟需要0.4s,基本模型用12小時就可以訓練結束,大模型3天半也能訓練完成。

Transformer使用更少的訓練成本實現了最好的訓練效果。
Transformer訓練成果
另外進行了消融實驗,即分別修改實驗條件和參數,確定對結果影響更大的因素。
Transformer消融實驗
實驗A表明,頭部數量 h h h和鍵的維度 d k d_k dk?比例要適中,太大太小效果都不好,而且另外表明多頭注意力要優于單頭。

實驗B表明,鍵 d k d_k dk?減小會損失性能,作者認為相關性的建立比較復雜,如果能有比點積更好的方法可能會帶來性能的提升。

實驗CD表明,大模型效果更好,并且使用dropout是有必要的。

實驗E則是證明可學習的嵌入位置編碼與文章使用的結果相近。

總結

針對循環神經網絡RNN串行訓練效率低的問題,Transformer提出了一種完全基于注意力機制的模型結構,更適合長序列任務。

對該文章解決問題的核心方法總結如下:

自注意力關聯輸入序列。在該結構中,以往結合本次輸入和上步預測的信息關聯手段被自注意力計算取代,自注意力計算下每個輸入元素都能獲得自己與整個輸入序列的關聯性。

Teacher Forcing提供并行訓練數據。并行時仍采用迭代訓練方法,訓練數據采用Teacher Forcing方法,將正確輸出作為上一步預測直接輸入到解碼器中,與自注意力一起作為并行訓練的基礎。

掩碼注意力屏蔽后續輸出干擾。當下節點的預測結果應該只根據已有數據,但Teacher Forcing將正確輸出全部輸入到解碼器,故在解碼器第一層設置掩碼注意力,屏蔽當前節點以后的數據,保證預測只來自已有信息。

位置編碼提供位置信息。使用自注意力只能使序列獲得整體關聯信息,而無法使模型利用序列順序,將位置編碼與詞嵌入相加可使模型從順序角度對序列進行關聯。

總的來看,Transformer模型中包含更多信息,又摒棄了RNN串行訓練的方式,提高訓練效率的同時也提高了訓練質量,應用領域也由自然語言處理發展到計算機視覺,有如此大的名氣確實是實至名歸。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/81829.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/81829.shtml
英文地址,請注明出處:http://en.pswp.cn/web/81829.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【MAC】YOLOv8/11/12 轉換為 CoreML 格式并實現實時目標檢測

在本文中,我們將詳細介紹如何將 YOLOv8/11/12 模型轉換為 CoreML 格式,并使用該模型在攝像頭實時檢測中進行目標檢測。主要適用于M1、M2、M3、M4芯片的產品。 以下教程在YOLOv8/11/12均適用,此處就以 YOLOv11 舉例 目錄 前提條件YOLOv8/11/12 轉換為 CoreML實時目標檢測結論…

Redis--緩存擊穿詳解及解決方案

緩存擊穿 緩存擊穿問題也稱熱點key問題&#xff0c;就是一個高并發訪問&#xff08;該key訪問頻率高&#xff0c;訪問次數多&#xff09;并且緩存重建業務比較復雜的key突然失效了&#xff0c;大量的請求訪問會在瞬間給數據庫帶來巨大的沖擊。 緩存重建業務比較復雜&#xff…

UniApp X:鴻蒙原生開發的機會與DCloud的崛起之路·優雅草卓伊凡

UniApp X&#xff1a;鴻蒙原生開發的機會與DCloud的崛起之路優雅草卓伊凡 有句話至少先說&#xff0c;混開框架中目前uniapp x是率先支持了鴻蒙next的開發的&#xff0c;這點來說 先進了很多&#xff0c;也懂得審時度勢。 一、UniApp X如何支持鴻蒙原生應用&#xff1f; UniAp…

域名解析怎么查詢?有哪些域名解析查詢方式?

在互聯網的世界里&#xff0c;域名就像是我們日常生活中的門牌號&#xff0c;幫助我們快速定位到想要訪問的網站。而域名解析則是將這個易記的域名轉換為計算機能夠識別的IP地址的關鍵過程。當我們想要了解一個網站的域名解析情況&#xff0c;或者排查網絡問題時&#xff0c;掌…

算力卡上部署OCR文本識別服務與測試

使用modelscope上的圖像文本行檢測和文本識別模型進行本地部署并轉為API服務。 本地部署時把代碼中的檢測和識別模型路徑改為本地模型的路徑。 關于模型和代碼原理可以參見modelscope上這兩個模型相關的頁面&#xff1a; iic/cv_resnet18_ocr-detection-db-line-level_damo iic…

大語言模型的完整訓練周期從0到1的體系化拆解

以下部分內容參考了AI。 要真正理解大語言模型&#xff08;LLM&#xff09;的創生過程&#xff0c;我們需要將其拆解為一個完整的生命周期&#xff0c;每個階段的關鍵技術相互關聯&#xff0c;共同支撐最終模型的涌現能力。以下是體系化的訓練流程框架&#xff1a; 階段一&am…

吃水果(貪心)

文章目錄 題目描述輸入格式輸出格式樣例輸入樣例輸出提交鏈接提示 解析參考代碼 題目描述 最近米咔買了 n n n 個蘋果和 m m m 個香蕉&#xff0c;他每天可以選擇吃掉一個蘋果和一個香蕉&#xff08;必須都吃一個&#xff0c;即如果其中一種水果的數量為 0 0 0&#xff0c;則…

【FAQ】HarmonyOS SDK 閉源開放能力 —Account Kit(4)

1.問題描述&#xff1a; LoginWithHuaweiIDButton不支持深色模式下定制文字和loading樣式&#xff1f; 解決方案&#xff1a; LoginWithHuaweiIDButtonParams 中的有個supportDarkMode屬性&#xff0c;設置為true后&#xff0c;需要自行響應系統的變化&#xff0c;見文檔&am…

【C語言】指針詳解(接)

前言&#xff1a; 文接上章&#xff0c;在上章節講解了部分指針知識點&#xff0c;在本章節為大家繼續提供。 六指針與字符串&#xff1a;C 語言字符串的本質 在 C 語言中&#xff0c;字符串實際上是一個以\0結尾的字符數組。字符串常量本質上是指向字符數組首元素的指針&…

第5講、Odoo 18 CLI 模塊源碼全解讀

Odoo 作為一款強大的企業級開源 ERP 系統&#xff0c;其命令行工具&#xff08;CLI&#xff09;為開發者和運維人員提供了極大的便利。Odoo 18 的 odoo/cli 目錄&#xff0c;正是這些命令行工具的核心實現地。本文將結合源碼&#xff0c;詳細解讀每個 CLI 文件的功能與實現機制…

如何將 PDF 文件中的文本提取為 YAML(教程)

這篇博客文章將向你展示如何將 PDF 轉換為 YAML&#xff0c;通過提取帶有結構標簽的標記內容來實現。 什么是結構化 PDF&#xff1f; 一些 PDF 文件包含結構化內容&#xff0c;也稱為帶標簽&#xff08;tagged&#xff09;或標記內容&#xff08;marked content&#xff09;&…

銀發團扎堆本地游,“微度假”模式如何盤活銀發旅游市場?

? 銀發微度假&#xff0c;席卷江浙滬 作者 | AgeClub呂嬈煒 前言 均價200-300元的兩天一夜微度假產品&#xff0c;正在中老年客群中走紅。 “我們屬于酒店直營&#xff0c;沒有中間商賺差價&#xff0c;老年人乘坐地鐵到目的地站&#xff0c;會有大巴負責接送&#xff0c;半…

蘋果iOS應用ipa文件進行簽名后無法連接網絡,我們該怎么解決

蘋果iOS應用ipa文件在經過簽名處理后&#xff0c;如果發現無法連接網絡&#xff0c;這可能會給用戶帶來極大的不便。為了解決這一問題&#xff0c;可以采取一系列的排查和解決步驟&#xff0c;以確保應用能夠順利地訪問互聯網。 首先&#xff0c;確保你的設備已經連接到一個穩…

MySQL 中 ROW_NUMBER() 函數詳解

MySQL 中 ROW_NUMBER() 函數詳解 ROW_NUMBER() 是 SQL 窗口函數中的一種&#xff0c;用于為查詢結果集中的每一行分配一個??唯一的連續序號??。與 RANK() 和 DENSE_RANK() 不同&#xff0c;ROW_NUMBER() 不會處理重復值&#xff0c;即使排序字段值相同&#xff0c;也會嚴格…

Leetcode百題斬-二叉樹

二叉樹作為經典面試系列&#xff0c;那么當然要來看看。總計14道題&#xff0c;包含大量的簡單題&#xff0c;說明這確實是個比較基礎的專題。快速過快速過。 先構造一個二叉樹數據結構。 public class TreeNode {int val;TreeNode left;TreeNode right;TreeNode() {}TreeNode…

Asp.Net Core 如何配置在Swagger中帶JWT報文頭

文章目錄 前言一、配置方法二、使用1、運行應用程序并導航到 /swagger2、點擊右上角的 Authorize 按鈕。3、輸入 JWT 令牌&#xff0c;格式為 Bearer your_jwt_token。4、后續請求將自動攜帶 Authorization 頭。 三、注意事項總結 前言 配置Swagger支持JWT 一、配置方法 在 …

MySQL 定時邏輯備份

文章目錄 配置密碼編寫備份腳本配置權限定時任務配置檢查效果如果不想保留明文密碼手工配置備份密碼修改備份命令 配置密碼 cat >> /root/.my.cnf <<"EOF" [client] userroot passwordYourPassword EOF編寫備份腳本 cat > /usr/local/bin/mysql_dum…

在qt中使用c++實現與Twincat3 PLC變量通信

這是一個只針對新手的教程&#xff0c;下載安裝就不說了&#xff0c;我下的是TC31-Full-Setup.3.1.4024.66.exe是這個版本&#xff0c;其他版本應該問題不大。 先創建一個項目 選中SYSTEM&#xff0c;在右側點擊Choose Target&#xff08;接下來界面跟我不一樣沒關系&#xf…

云原生微服務devops項目管理英文表述詳解

文章目錄 1.云原生CNCF trail map云原生技術棧路線圖 2. 微服務單體應用與微服務應用架構區別GraphQLKey differences: GraphQL and REST 3.容器化&編排dockerKubernetesContainers and ContainerizationContainer Basics 4. DevOps & CI/CDTerms and Definitions 5.Ag…

pyside 使用pyinstaller導出exe(含ui文件)

第一步&#xff1a;首先確保安裝好pyinstall&#xff0c;終端運行 pyinstaller -w main.py 生成兩個文件夾 打開exe文件報錯&#xff0c;問題是ui文件找不到 第二步&#xff1a;將ui文件復制到exe所在文件夾&#xff0c;打開成功 ![在這里插入圖片描述](https://i-blog.csdni…