各類神經網絡學習:(十)注意力機制(第2/4集),pytorch 中的多維注意力機制、自注意力機制、掩碼自注意力機制、多頭注意力機制

上一篇下一篇
注意力機制(第1/4集)待編寫

一、pytorch 中的多維注意力機制:

N L P NLP NLP 領域內,上述三個參數都是 向量 , 在 p y t o r c h pytorch pytorch 中參數向量會組成 矩陣 ,方便代碼編寫。

①結構圖

注意力機制結構圖如下:

在這里插入圖片描述

②計算公式詳解

計算注意力分數的方式有很多,目前最常用的就是點乘。具體如下:

當向量 q u e r y \large query query k e y \large key key 長度相同時,即 q 、 k i ∈ R ( 1 × d ) q、k_i∈R^{(1×d)} qki?R(1×d) ,則有:注意力分數 s ( q , k i ) = < q , k i > d k \large s(q,k_i)=\frac{<q,k_i>}{\sqrt{d_k}} s(q,ki?)=dk? ?<q,ki?>? ,符號 < q , k i > <q,k_i> <q,ki?> 表示點乘/內積運算(向量點乘,結果為標量)。其中 d k d_k dk? k i k_i ki? 向量的長度(為什么要在原注意力分數底下除以 d k \sqrt{d_k} dk? ? 后面會詳解)。

當向量組成矩陣時,假設 Q ∈ R ( n × d ) Q∈R^{(n×d)} QR(n×d) K ∈ R ( m × d ) K∈R^{(m×d)} KR(m×d) V ∈ R ( m × v ) V∈R^{(m×v)} VR(m×v) 。每個矩陣都是由參數行向量堆疊組成。則有:

F ( Q ) = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q ? K T d k ) ? V \Large F(Q)=Attention(Q,K,V)=softmax(\frac{Q·K^T}{\sqrt{d_k}})·V F(Q)=Attention(Q,K,V)=softmax(dk? ?Q?KT?)?V
其中 Q K T d ∈ R ( n × m ) \large \frac{QK^T}{\sqrt{d}}∈R^{(n×m)} d ?QKT?R(n×m) 是注意力分數,, s o f t m a x ( Q K T d k ) ∈ R ( n × m ) \large softmax(\frac{QK^T}{\sqrt{d_k}})∈R^{(n×m)} softmax(dk? ?QKT?)R(n×m) 是注意力權重, F ( Q ) ∈ R ( n × v ) \large F(Q)∈R^{(n×v)} F(Q)R(n×v) 是輸出。

這是一種并行化矩陣計算形式,將所有的 q q q 組合成一個矩陣 Q Q Q k k k v v v 類似,都被組合成了矩陣 K K K V V V 。其詳細過程如下:

已知 Q ∈ R ( n × d ) Q∈R^{(n×d)} QR(n×d) K ∈ R ( m × d ) K∈R^{(m×d)} KR(m×d) V ∈ R ( m × v ) V∈R^{(m×v)} VR(m×v) ,該尺寸表示有 n n n q q q m m m k k k m m m v v v 。則:

Q × K T = [ [ ? q 1 ? ] [ ? q 2 ? ] ? [ ? q n ? ] ] ● [ [ ? k 1 ? ? ] [ ? k 2 ? ? ] ? [ ? k m ? ? ] ] = [ q 1 ? k 1 q 1 ? k 2 ? q 1 ? k m q 2 ? k 1 q 2 ? k 2 ? q 2 ? k m ? ? ? ? q n ? k 1 q n ? k 2 ? q n ? k m ] Q \times K^T =\\ \begin{bmatrix} \begin{bmatrix} \cdots & q_1 & \cdots \end{bmatrix} \\ \begin{bmatrix} \cdots & q_2 & \cdots \end{bmatrix} \\ \vdots \\ \begin{bmatrix} \cdots & q_n & \cdots \end{bmatrix} \end{bmatrix} ● \begin{bmatrix} \begin{bmatrix} \vdots \\ k_1 \\ \vdots \\ \vdots \end{bmatrix} & \begin{bmatrix} \vdots \\ k_2 \\ \vdots \\ \vdots \end{bmatrix} & \cdots & \begin{bmatrix} \vdots \\ k_m \\ \vdots \\ \vdots \end{bmatrix} \end{bmatrix}= \begin{bmatrix} q_1 \cdot k_1 & q_1 \cdot k_2 & \cdots & q_1 \cdot k_m \\ q_2 \cdot k_1 & q_2 \cdot k_2 & \cdots & q_2 \cdot k_m \\ \vdots & \vdots & \ddots & \vdots \\ q_n \cdot k_1 & q_n \cdot k_2 & \cdots & q_n \cdot k_m \end{bmatrix} Q×KT= ?[??q1????][??q2????]?[??qn????]? ? ? ??k1???? ?? ??k2???? ???? ??km???? ?? ?= ?q1??k1?q2??k1??qn??k1??q1??k2?q2??k2??qn??k2???????q1??km?q2??km??qn??km?? ?

上述運算可以得到每個小 q q q m m m 個小 k k k 的注意力分數,再經過放縮(除以 d k \sqrt{d_k} dk? ? )和 s o f t m a x softmax softmax 函數后得到每個小 q q q m m m 個小 k k k 的注意力權重矩陣,其尺寸為 n × m n×m n×m ,最終和 V V V 相乘,得到 F ( Q ) F(Q) F(Q) ,其尺寸為 n × v n×v n×v ,對應著 n n n q q q v a l u e value value

③公式細節解釋

  1. 第一點:

    使用點乘來計算注意力分數的意義:矩陣點乘 Q ? K T Q·K^T Q?KT 就意味著做點積/內積,(在注意力機制中,點積通常等同于內積,在數學上點積是內積的特例),內積可直接衡量兩個向量的方向對齊程度。若兩個向量方向一致(夾角為 0 ° 0° ),則內積最大;方向相反(夾角為 180 ° 180° 180° ),則內積最小。點乘不僅包含方向信息,還隱含向量長度的乘積。例如,若兩個長向量方向一致,內積值會顯著高于短向量,可能更強調其相關性。

  2. 第二點:

    上述公式中, s o f t m a x softmax softmax 里對注意力分數還除以了 d k \sqrt{d_k} dk? ? ,是因為:由于 s o f t m a x softmax softmax 函數的計算公式用到了 e e e 的次方,當兩個數之間的倍數很大時,比如說 99 和 1 ,經過求 e e e 的次方運算之后,差別會指數倍增加,這樣求出來的概率會很離譜,不是0.99和0.01,而是0.99999999和0.0000000001(很多9和很多0)。讓其中每個元素除以 d k \sqrt{d_k} dk? ? 之后,會降低倍數增加的程度(更數學性的解釋可以看 00 預訓練語言模型的前世今生(全文 24854 個詞) - B站-水論文的程序猿 - 博客園 這篇博客中的有關注意力機制的講解)。其功能類似于防止梯度消失。

  3. 第三點:

    一般來說,在 t r a n s f o r m e r transformer transformer 里, K = V K=V K=V 。當然 K ≠ V K≠V K=V 也可以,不過兩者之間一定是有對應關系,能組成鍵值對的。

二、自注意力機制(Self-Attention)

當上述的三個參數都由一個另外的共同參數 經過不同的線性變換 生成時(即三者同源),就是自注意力機制。其值體現為 Q ≈ K ≈ V Q≈K≈V QKV

這三個矩陣是在同一個矩陣 X X X 上乘以不同的系數矩陣 W Q 、 W K 、 W V W_Q、W_K、W_V WQ?WK?WV? 得到的,因此自注意力機制可以說是在計算 X X X 內部各個 x i x_i xi? 之間的相關性。其后續步驟和注意力機制一樣。(為什么叫自注意力機制,估計是因為這里是計算自己內部之間的相關性吧)

注意】:最終生成的新的 v a l u e value value 其實依然是小 x x x 的向量表示,只不過這個新向量蘊含了其他的小 x x x 的信息。

具體公式如下:
Q = W Q ? X , K = W K ? X , V = W V ? X F ( Q ) = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q ? K T d k ) ? V \large Q=W_Q·X,~~~~K=W_K·X,~~~~V=W_V·X\\ \Large F(Q)=Attention(Q,K,V)=softmax(\frac{Q·K^T}{\sqrt{d_k}})·V Q=WQ??X,????K=WK??X,????V=WV??XF(Q)=Attention(Q,K,V)=softmax(dk? ?Q?KT?)?V
N L P NLP NLP 中,可以舉一個小例子理解一下(矩陣內數值即為注意力權重):

在這里插入圖片描述

上圖中,每一個單詞就是一個小 q q q ,單詞用向量表示。(有個誤區:不是說自注意力機制中,小 q q q 和自己的注意力分數就是最大的,這個要看具體語義需求)

其他變種:交叉注意力機制( Q Q Q V V V 不同源, K K K V V V 同源)。

三、掩碼自注意力機制(Masked Self-Attention)

N L P NLP NLP 里,在訓練過程中,比如說我想訓練模型生成:“The cat is cute” 這樣一個句子,并且計算其自注意力權重,這個時候 “The cat is cute” 就是已知的 label 。但是句子是一個一個單詞生成的( The → The cat → The cat is → The cat is cute),第一個生成 The ,第二個生成 cat … 在沒有完全生成之前,都是不能提前告訴模型后面的答案。已知句子總長度為 4 4 4 ,那么注意力權重的個數依次是 1 → 2 → 3 → 4 。如下圖所示:

在這里插入圖片描述

注意了,這里的生成是指訓練時的生成,掩碼機制只在訓練時使用,因為訓練時機器知道有位置信息的句子(句子的長度也已知曉),為了防止窺探到下一個字就要掩碼。但在實際使用模型時(測試時),是沒有參考答案的,所以不需要掩碼!

其實還有其他作用,諸如:避免填充干擾等,后面在 transformer 里會詳解。

四、多頭注意力機制(Multi-Head Self-Attention)

本質上就是: X X X 做完三次線性變換得到 Q 、 K 、 V Q、K、V QKV之后,將 Q 、 K 、 V Q、K、V QKV分割成 8 8 8 塊進行注意力計算,最后將這 8 8 8 個結果拼接,然后線性變換,使其維度和 X X X 一致。(并不是直接對 X 進行切分,也不是對 X 進行重復線性變換)

意義:原論文其實也說不清楚這樣做的意義,反正給人一種能學到更細致的語義信息的感覺(深度學習就是這樣~~)。

流程圖如下:

在這里插入圖片描述

第一步:

輸入序列 X X X 首先經過三次獨立的線性變換,生成查詢( Q u e r y Query Query)、鍵( K e y Key Key)、值( V a l u e Value Value)矩陣:

Q = W Q ? X Q=W_Q·X Q=WQ??X K = W K ? X K=W_K·X K=WK??X V = W V ? X V=W_V·X V=WV??X 。其中, W Q 、 W K 、 W V W_Q、W_K、W_V WQ?WK?WV? 是可學習的權重矩陣。

第二步:

Q 、 K 、 V Q、K、V QKV 矩陣沿特征維度平均分割為多個頭。一般頭數均為 8 8 8(即 h = 8 h=8 h=8),假設 Q 、 K 、 V Q、K、V QKV 的特征維度為 M M M ,則分割之后每個頭的特征維度為 M / 8 M/8 M/8

第三步:

每個頭各自并行計算注意力并得到各自的輸出(先點積,再縮放,再做 s o f t m a x softmax softmax ,再乘以 v a l u e value value )【每個頭學習不同子空間的語義關系】

第四步:

合并多頭輸出,將所有頭的輸出拼接為完整維度,再通過一次線性變換整合信息:

O u t p u t = C o n c a t ( h e a d 1 , … , h e a d h ) ? W O Output=Concat(head_1,…,head_h)?W_O Output=Concat(head1?,,headh?)?WO? 。其中 W O W_O WO? 是最后的線性層的投影矩陣。

值得一提的是:針對 “將 Q 、 K 、 V Q、K、V QKV分割成 8 8 8 塊” 這個步驟,《Attention Is All You Need》論文原文說的是: linearly project h times ,意思就是將 Q 、 K 、 V Q、K、V QKV通過線性層將其變換為 8 8 8 個新的特征維度為 M / 8 M/8 M/8 Q ′ 、 K ′ 、 V ′ Q^{'}、K^{'}、V^{'} QKV 。不過這在數學上等效于直接分割成 8 8 8 塊,并且后者在算法實現上能提高效率。代碼如下:

Q = torch.randn(batch_size, seq_len, h*d_k)
Q = Q.view(batch_size, seq_len, h, d_k)  # 分割為 h 個頭

.view() 函數的作用是變換尺寸,將原來的三維張量,變成四維張量( h 個三維張量),元素的值不變,元素的總數也不變,其效果等于切割。


本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/76397.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/76397.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/76397.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

uni-app初學

文章目錄 1. pages.json 頁面路由2. 圖標3. 全局 CSS4. 首頁4.1 整體框架4.2 完整代碼4.3 輪播圖 swiper4.3.1 image 4.4 公告4.4.1 uni-icons 4.5 分類 uni-row、uni-col4.6 商品列表 小程序開發網址&#xff1a; 注冊小程序賬號 微信開發者工具下載 uniapp 官網 HbuilderX 下…

VBA將Word文檔內容逐行寫入Excel

如果你需要將Word文檔的內容導入Excel工作表來進行數據加工&#xff0c;使用下面的代碼可以實現&#xff1a; Sub ImportWordToExcel()Dim wordApp As Word.ApplicationDim wordDoc As Word.DocumentDim excelSheet As WorksheetDim filePath As VariantDim i As LongDim para…

MySQL運行一段時間后磁盤出現100%讀寫

MySQL運行一段時間后磁盤出現100%讀寫的情況&#xff0c;可能是由多種原因導致的&#xff0c;以下是一些常見原因及解決方法&#xff1a; 可能的原因 1. 磁盤I/O壓力過大[^0^]&#xff1a;數據量過大&#xff0c;數據庫查詢和寫入操作消耗大量I/O資源。索引效率低&#xff0c…

【RabbitMQ】延遲隊列

1.概述 延遲隊列其實就是隊列里的消息是希望在指定時間到了以后或之前取出和處理&#xff0c;簡單來說&#xff0c;延時隊列就是用來存放需要在指定時間被處理的元素的隊列。 延時隊列的使用場景&#xff1a; 1.訂單在十分鐘之內未支付則自動取消 2.新創建的店鋪&#xff0c;…

Linux筆記之Ubuntu系統設置自動登錄tty1界面

Ubuntu22.04系統 編輯getty配置文件 vim /etc/systemd/system/gettytty1.service.d/override.conf如果該目錄或者文件不存在&#xff0c;進行創建。 在override.conf文件中進行編輯&#xff1a; [Service] ExecStart ExecStart-/sbin/agetty --autologin yourusername --no…

C++程序詩篇的靈動賦形:多態

文章目錄 1.什么是多態&#xff1f;2.多態的語法實現2.1 虛函數2.2 多態的構成2.3 虛函數的重寫2.3.1 協變2.3.2 析構函數的重寫 2.4 override 和 final 3.抽象類4.多態原理4.1 虛函數表4.2 多態原理實現4.3 動態綁定與靜態綁定 5.繼承和多態常見的面試問題希望讀者們多多三連支…

算法訓練之動態規劃(三)

???~~~~~~歡迎光臨知星小度博客空間~~~~~~??? ???零星地變得優秀~也能拼湊出星河~??? ???我們一起努力成為更好的自己~??? ???如果這一篇博客對你有幫助~別忘了點贊分享哦~??? ???如果有什么問題可以評論區留言或者私信我哦~??? ?????? 個…

$_GET變量

$_GET 是一個超級全局變量&#xff0c;在 PHP 中用于收集通過 URL 查詢字符串傳遞的參數。它是一個關聯數組&#xff0c;包含了所有通過 HTTP GET 方法發送到當前腳本的變量。 預定義的 $_GET 變量用于收集來自 method"get" 的表單中的值。 從帶有 GET 方法的表單發…

jQuery多庫共存

在現代Web開發中&#xff0c;項目往往需要集成多種JavaScript庫或框架來滿足不同的功能需求。然而&#xff0c;當多個庫同時使用時&#xff0c;可能會出現命名沖突、功能覆蓋等問題。幸運的是&#xff0c;jQuery提供了一些機制來確保其可以與其他庫和諧共存。本文將探討如何實現…

MySQL 中的聚簇索引和非聚簇索引有什么區別?

MySQL 中的聚簇索引和非聚簇索引有什么區別&#xff1f; 1. 從不同存儲引擎去考慮 在MySIAM存儲引擎中&#xff0c;索引和數據是分開存儲的&#xff0c;包括主鍵索引在內的所有索引都是“非聚簇”的&#xff0c;每個索引的葉子節點存儲的是數據記錄的物理地址&#xff08;指針…

Java從入門到“放棄”(精通)之旅——啟航①

&#x1f31f;Java從入門到“放棄 ”精通之旅&#x1f680; 今天我將要帶大家一起探索神奇的Java世界&#xff01;希望能幫助到同樣初學Java的你~ (??????)?? &#x1f525; Java是什么&#xff1f;為什么這么火&#xff1f; Java不僅僅是一門編程語言&#xff0c;更…

三相電為什么沒零線也能通電

要理解三相電為什么沒零線也能通電&#xff0c;就要從發電的原理說起 1、弧形磁鐵中加入電樞&#xff0c;旋轉切割磁感線會產生電流 隨著電樞旋轉的角度變化&#xff0c;電樞垂直切割磁感線 電樞垂直切割磁感線&#xff0c;此時會產生最大電壓 當轉到與磁感線平行時&#xf…

文件上傳做題記錄

1&#xff0c;[SWPUCTF 2021 新生賽]easyupload2.0 直接上傳php 再試一下phtml 用蟻劍連發現連不上 那就只要命令執行了 2&#xff0c;[SWPUCTF 2021 新生賽]easyupload1.0 當然&#xff0c;直接上傳一個php是不行的 phtml也不行&#xff0c;看下是不是前端驗證&#xff0c;…

【Pandas】pandas DataFrame head

Pandas2.2 DataFrame Indexing, iteration 方法描述DataFrame.head([n])用于返回 DataFrame 的前幾行 pandas.DataFrame.head pandas.DataFrame.head 是一個方法&#xff0c;用于返回 DataFrame 的前幾行。這個方法非常有用&#xff0c;特別是在需要快速查看 DataFrame 的前…

日語學習-日語知識點小記-構建基礎-JLPT-N4階段(1):承上啟下,繼續上路

日語學習-日語知識點小記-構建基礎-JLPT-N4階段(1):承上啟下,繼續上路 1、前言(1)情況說明(2)工程師的信仰2、知識點(1)普通形(ふつうけい)と思います(2)辭書形ことができます(3)Vたことがあります。(4)Vた とき & Vる とき3、單詞(1)日語單詞(2…

碼率自適應(ABR)相關論文閱讀簡報

標題&#xff1a;Quality Enhanced Multimedia Content Delivery for Mobile Cloud with Deep Reinforcement Learning 作者&#xff1a;Muhammad Saleem , Yasir Saleem, H. M. Shahzad Asif, and M. Saleem Mian 單位: 巴基斯坦拉合爾54890工程技術大學計算機科學與工程系 …

匯編語言:指令詳解

零、前置知識 1、數據類型修飾符 名稱解釋byte一個字節&#xff0c;8bitword單字&#xff0c;占2個字節&#xff0c;16bitdword雙字&#xff0c;占4個字節&#xff0c;32bitqword四字&#xff0c;占8個字節&#xff0c;64bit 2、關鍵詞解釋 ptr&#xff1a;它代表 pointer&a…

藍橋杯c ++筆記(含算法 貪心+動態規劃+dp+進制轉化+便利等)

藍橋杯 #include <iostream> #include <vector> #include <algorithm> #include <string> using namespace std; //常使用的頭文件動態規劃 小藍在黑板上連續寫下從 11 到 20232023 之間所有的整數&#xff0c;得到了一個數字序列&#xff1a; S12345…

【C++算法】54.鏈表_合并 K 個升序鏈表

文章目錄 題目鏈接&#xff1a;題目描述&#xff1a;解法C 算法代碼&#xff1a; 題目鏈接&#xff1a; 23. 合并 K 個升序鏈表 題目描述&#xff1a; 解法 解法一&#xff1a;暴力解法 每個鏈表的平均長度為n&#xff0c;有k個鏈表&#xff0c;時間復雜度O(nk^2) 合并兩個有序…

Java中的注解技術講解

Java中的注解&#xff08;Annotation&#xff09;是一種在代碼中嵌入元數據的機制&#xff0c;不直接參與業務邏輯&#xff0c;而是為編譯器、開發工具以及運行時提供額外的信息和指導。下面我們將由淺入深地講解Java注解的概念、實現原理、各種應用場景&#xff0c;并通過代碼…