機器學習深度學習——自注意力和位置編碼(數學推導+代碼實現)

👨?🎓作者簡介:一位即將上大四,正專攻機器學習的保研er
🌌上期文章:機器學習&&深度學習——注意力分數(詳細數學推導+代碼實現)
📚訂閱專欄:機器學習&&深度學習
希望文章對你們有所幫助

自注意力和位置編碼

  • 引入
  • 自注意力
    • 多頭注意力
    • 基于多頭注意力實現自注意力
  • 比較CNN、RNN和self-attention
    • 結論
    • 剖析——CNN
    • 剖析——RNN
    • 剖析——self-attention
    • 總結
  • 位置編碼
    • 絕對位置信息
    • 相對位置信息
  • 小結

引入

在深度學習中,經常使用CNN和RNN對序列進行編碼。有了自注意力之后,我們將詞元序列輸入注意力池化中,以便同一組詞元同時充當查詢、鍵和值。具體來說,每個查詢都會關注所有的鍵-值對并生成一個注意力輸出。由于查詢、鍵和值來自同一組輸入,因此被稱為自注意力(self-attention)。下面將使用自注意力進行序列編碼。

import math
import torch
from torch import nn
from d2l import torch as d2l

自注意力

給定一個由詞元組成的序列:
x 1 , . . . , x n 其中任意 x i ∈ R d x_1,...,x_n\\ 其中任意x_i∈R^d x1?,...,xn?其中任意xi?Rd
該序列的自注意力輸出為一個長度相同的序列:
y 1 , . . . , y n 其中 y i = f ( x i , ( x 1 , x 1 ) , . . . , ( x n , x n ) ) ∈ R d y_1,...,y_n\\ 其中y_i=f(x_i,(x_1,x_1),...,(x_n,x_n))∈R^d y1?,...,yn?其中yi?=f(xi?,(x1?,x1?),...,(xn?,xn?))Rd
自注意力就是這樣,任意的xi都是既當key,又當value,還當query。
下面的代碼片段是基于多頭注意力對一個張量完成自注意力的計算,張量形狀為(批量大小,時間步數目或詞元序列長度,d)。輸出與輸入的張量形狀相同。
而在此之前,簡單講解下多頭注意力,接著基于多頭注意力實現自注意力。

多頭注意力

當給定相同的查詢、鍵和值的集合時,我們希望模型可以基于相同的注意力機制學習到不同的行為,然后將不同的行為作為知識組合起來,捕獲序列內各種范圍的依賴關系。因此允許注意力機制組合使用查詢、鍵和值的不同子空間表示是有益的。
因此,與其只使用一個注意力池化,我們可以獨立學習得到h組不同的線性投影來變換查詢、鍵和值。然后,這h組變換后的查詢、鍵和值將并行地送到注意力池化中。最后將這h個注意力池化的輸出拼接在一起,并通過另一可以學習的線性投影進行變換,來產生最終輸出。這就是多頭注意力(multihead attention),如下圖所示:
在這里插入圖片描述
而多頭注意力的實現過程通常使用的是縮放點積注意力來作為每一個注意力頭,我們設定:
p q = p k = p v = p o / h p_q=p_k=p_v=p_o/h pq?=pk?=pv?=po?/h
值得注意的是,如果將查詢、鍵和值的線性變化的輸出數量設置為:
p q h = p k h = p v h = p o p_qh=p_kh=p_vh=p_o pq?h=pk?h=pv?h=po?
就可以并行計算h個頭,下面代碼中的po是通過num_hiddens指定的。

代碼如下:

#@save
class MultiHeadAttention(nn.Module):"""多頭注意力"""def __init__(self, key_size, query_size, value_size, num_hiddens,num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形狀:# (batch_size,查詢或者“鍵-值”對的個數,num_hiddens)# valid_lens 的形狀:# (batch_size,)或(batch_size,查詢的個數)# 經過變換后,輸出的queries,keys,values 的形狀:# (batch_size*num_heads,查詢或者“鍵-值”對的個數,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在軸0,將第一項(標量或者矢量)復制num_heads次,# 然后如此復制第二項,然后諸如此類。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形狀:(batch_size*num_heads,查詢的個數,# num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形狀:(batch_size,查詢的個數,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)#@save
def transpose_qkv(X, num_heads):"""為了多注意力頭的并行計算而變換形狀"""# 輸入X的形狀:(batch_size,查詢或者“鍵-值”對的個數,num_hiddens)# 輸出X的形狀:(batch_size,查詢或者“鍵-值”對的個數,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 輸出X的形狀:(batch_size,num_heads,查詢或者“鍵-值”對的個數,# num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最終輸出的形狀:(batch_size*num_heads,查詢或者“鍵-值”對的個數,# num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])#@save
def transpose_output(X, num_heads):"""逆轉transpose_qkv函數的操作"""X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)

基于多頭注意力實現自注意力

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
attention.eval()

可以輸出驗證一下:

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
print(attention(X, X, X, valid_lens).shape)

輸出結果:

torch.Size([2, 4, 100])

比較CNN、RNN和self-attention

首先看這個圖:
在這里插入圖片描述
接下來進行CNN、RNN以及self-attention三個架構的比較,首先這三個架構目標都是要將n個詞元組成的序列映射到另一個長度相同的序列,其中的每個輸入詞元或輸出詞元都由d維向量表示。我們的比較將基于計算的復雜性、順序操作和最大路徑長度,先給出結論再進行剖析解釋。
我們首先要知道,順序操作會妨礙并行計算,而任意的序列位置組合之間的路徑越短,則能更輕松地學習序列中的遠距離依賴關系。

結論

計算復雜度并行度最大路徑長度
CNNO(knd2)O(n)O(n/k)
RNNO(nd2)O(1)O(n)
self-attentionO(n2d)O(n)O(1)

剖析——CNN

考慮一個卷積核大小為k的卷積層,由于序列長度是n,輸入和輸出的通道數量都是d,所以卷積層的計算復雜度為O(knd2)。而如上圖所示,可以看出CNN網絡是分層的,因此會有O(1)個順序操作,那么這代表著通道可以并行執行n個詞元,那么并行度就是O(n)。
上圖中可以看出k=3,因為這樣剛好就使得x1和x5處于這個卷積核大小為3的雙層卷積神經網絡的感受野內。因此最大的路徑長度一定是不會超過n/k的,下標為n的也會因為卷積核被限制到一個感受野內,因此可以知道最大路徑長度為O(n/k)。

剖析——RNN

當更新RNN的隱狀態時,d×d權重矩陣和d維隱狀態的乘法計算復雜度為O(d2),再加上序列長度為n,因此RNN的計算復雜度為O(nd2),由上圖也可以看出n個序列的順序操作是沒辦法并行化的,則并行度為O(1),最大路徑長度是O(n)(可以理解成當我們要組合y1和yn的時候,這時候長度為n)。

剖析——self-attention

查詢、鍵、值都是n×d矩陣。計算過程為:n×d矩陣乘以d×n矩陣,之后得到的n×n矩陣再乘以n×d矩陣,因此自注意力有O(n2d)的計算復雜度。而上圖展示了自注意力的強大,O(n)的并行度顯而易見,同時最大路徑長度是O(1),因為他們可以任意組合。

總結

總而言之,卷積神經網絡和自注意力都擁有并行計算的優勢,而且自注意力的最大路徑長度最短。
但是因為其計算復雜度是關于序列長度的二次方,所以在很長的序列中計算會非常慢。

位置編碼

在處理詞元序列時,循環神經網絡是逐個的重復地處理詞元的,而自注意力則因為并行計算而放棄了順序操作。為了使用序列的順序信息,通過在輸入表示中添加位置編碼來注入絕對的或相對的位置信息。
位置編碼可以通過學習得到也可以直接固定得到,下面講解基于正弦函數和余弦函數的固定位置編碼。
假設輸入表示X∈Rn×d包含一個序列中n個詞元的d維嵌入表示。位置編碼使用相同形狀的位置嵌入矩陣P∈Rn×d輸出X+P,矩陣第[i,2j](偶數列)和[i,2j+1](奇數列)列上的元素為:
p i , 2 j = s i n ( i 1000 0 2 j / d ) , p i , 2 j + 1 = c o s ( i 1000 0 2 j / d ) p_{i,2j}=sin(\frac{i}{10000^{2j/d}}),\\ p_{i,2j+1}=cos(\frac{i}{10000^{2j/d}}) pi,2j?=sin(100002j/di?),pi,2j+1?=cos(100002j/di?)
看起來很奇怪,在后面講解的時候就能看出來了,先定義一個類來實現它:

#@save
class PositionalEncoding(nn.Module):"""位置編碼"""def __init__(self, num_hiddens, dropout, max_len=1000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(dropout)# 創建一個足夠長的Pself.P = torch.zeros((1, max_len, num_hiddens))X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)self.P[:, :, 0::2] = torch.sin(X)self.P[:, :, 1::2] = torch.cos(X)def forward(self, X):X = X + self.P[:, :X.shape[1], :].to(X.device)return self.dropout(X)

我們可以進行打印圖像,可以清晰看到6、7列比8、9列頻率高,而6與7(8與9同理)由于正余弦函數的相位交替,而導致偏移量不同。

encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])
d2l.plt.show()

運行結果:
在這里插入圖片描述

絕對位置信息

其實就是二進制了,想象一下0-7的二進制表示是各不相同的,而且容易知道:較高比特位的交替頻率低于較低比特位(而使用三教函數的話輸出的是浮點數,顯然會更省空間)。

相對位置信息

除了捕獲絕對位置信息之外,上述的位置編碼還允許模型學習得到輸入序列中相對位置信息。這是因為對于任何確定的位置偏移σ,位置i+σ處的位置編碼可以線性投影位置i處的位置編碼來表示。
用數學來表示:
令 w j = 1 / 1000 0 2 j / d ,對于任何確定的位置偏移 σ : [ c o s ( σ w j ) s i n ( σ w j ) ? s i n ( σ w j ) c o s ( σ w j ) ] [ p i , 2 j p i , 2 j + 1 ] = [ c o s ( σ w j ) s i n ( i w j ) + s i n ( σ w j ) c o s ( i w j ) ? s i n ( σ w j ) s i n ( i w j ) + c o s ( σ w j ) c o s ( i w j ) ] = [ s i n ( ( i + σ ) w j ) c o s ( ( i + σ ) w j ) ] ——積化和差 = [ p i + σ , 2 j p i + σ , 2 j + 1 ] 令w_j=1/10000^{2j/d},對于任何確定的位置偏移σ:\\ \begin{bmatrix} cos(σw_j)&sin(σw_j)\\ -sin(σw_j)&cos(σw_j) \end{bmatrix} \begin{bmatrix} p_{i,2j}\\ p_{i,2j+1} \end{bmatrix}\\ =\begin{bmatrix} cos(σw_j)sin(iw_j)+sin(σw_j)cos(iw_j)\\ -sin(σw_j)sin(iw_j)+cos(σw_j)cos(iw_j) \end{bmatrix}\\ =\begin{bmatrix} sin((i+σ)w_j)\\ cos((i+σ)w_j) \end{bmatrix}——積化和差\\ =\begin{bmatrix} p_{i+σ,2j}\\ p_{i+σ,2j+1} \end{bmatrix} wj?=1/100002j/d,對于任何確定的位置偏移σ[cos(σwj?)?sin(σwj?)?sin(σwj?)cos(σwj?)?][pi,2j?pi,2j+1??]=[cos(σwj?)sin(iwj?)+sin(σwj?)cos(iwj?)?sin(σwj?)sin(iwj?)+cos(σwj?)cos(iwj?)?]=[sin((i+σ)wj?)cos((i+σ)wj?)?]——積化和差=[pi+σ,2j?pi+σ,2j+1??]
2×2投影矩陣不依賴于任何位置的索引i。

小結

1、在自注意力中,查詢、鍵和值都來自同一組輸入。
2、卷積神經網絡和自注意力都擁有并行計算的優勢,而且自注意力的最大路徑長度最短。但是因為其計算復雜度是關于序列長度的二次方,所以在很長的序列中計算會非常慢。
3、為了使用序列的順序信息,可以通過在輸入表示中添加位置編碼,來注入絕對的或相對的位置信息。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/40133.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/40133.shtml
英文地址,請注明出處:http://en.pswp.cn/news/40133.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Cat(2):下載與安裝

1 github源碼下載 要安裝CAT,首先需要從github上下載最新版本的源碼。 官方給出的建議如下: 注意cat的3.0代碼分支更新都發布在master上,包括最新文檔也都是這個分支注意文檔請用最新master里面的代碼文檔作為標準,一些開源網站…

node.js內置模塊fs,path,http使用方法

NodeJs中分為兩部分 一是V8引擎為了解析和執行JS代碼。 二是內置API,讓JS能調用這些API完成一些后端操作。 內置API模塊(fs、path、http等) 第三方API模塊(express、mysql等) fs模塊 fs.readFile()方法,用于讀取指定文件中的內容。 fs.writeFile()方…

MySQL— 基礎語法大全及操作演示!!!(上)

MySQL—— 基礎語法大全及操作演示(上) 一、MySQL概述1.1 、數據庫相關概念1.1.1 MySQL啟動和停止 1.2 、MySQL 客戶端連接1.3 、數據模型 二、SQL2.1、SQL通用語法2.2、SQL分類2.3、DDL2.3.1 DDL — 數據庫操作2.3.1 DDL — 表操作 2.4、DML2.4.1 DML—…

等保案例 5

用戶簡介 四川省人民代表大會常務委員會,作為省人民代表大會地常設機關,隨著政府部門信息化程度地提高,對信息系統地依賴程度越來越高,同時由于網絡安全形勢日益嚴峻、新型攻擊層出不窮,單位信息化所面臨地各種風險也…

途樂證券-寧德時代發力超充賽道,高壓快充概念強勢拉升,泰永長征漲停

高壓快充概念17日盤中強勢拉升,到發稿,泰永長征漲停,萬祥科技漲超9%,英可瑞漲逾8%,迦南智能漲超4%。 消息面上,8月16日,寧德時代舉行線下新品發布會,正式發布全球首款磷酸鐵鋰4C超充…

Spark第二課RDD的詳解

1.前言 RDD JAVA中的IO 1.小知識點穿插 1. 裝飾者設計模式 裝飾者設計模式:本身功能不變,擴展功能. 舉例: 數據流的讀取 一層一層的包裝,進而將功能進行進一步的擴展 2.sleep和wait的區別 本質區別是字體不一樣,sleep斜體,wait正常 斜體是靜態方法…

經過幾天的亂搞,已經搞出來第一次stm32點燈程序

看吧那個燈泡已經亮了 stm32跟51不同的地方是這里引腳一組16個,如PA0,PA1,PA2,,,,,,PA15 51一組8個 例如P00,P01,P02,,,,P07

全新重構,探尋 24 歲 QQ 大重構背后的思考

在瞬息萬變的互聯網行業中,年過二十四的 QQ 堪稱超長壽的產品,見證了中國互聯網崛起的完整歷程。然而,如今這個元老級產品經歷了一次從內到外徹底的重構。 在這次重構中,QQ 選擇了 Electron 作為 UI 跨平臺開發框架。盡管 Electron 被 Slack、Visual Studio Code 和 Disco…

[Go版]算法通關村第十一關青銅——理解位運算的規則

目錄 數字在計算機中的表示:機器數、真值對機器數進一步細化:原碼、反碼、補碼為何會有原碼、反碼和補碼為何計算機中的按位運算使用的是補碼?位運算規則與、或、異或和取反移位運算移位運算與乘除法的關系位運算常用技巧?? 操作某個位的數…

Unity用NPOI創建Exect表,保存數據,和修改刪除數據。以及打包后的坑——無法打開新創建的Exect表

先說坑花了一下午才找到解決方法解決, 在Unity編輯模式下點擊物體創建對應的表,獲取物體名字與在InputText填寫的注釋數據。然后保存。創建Exect表可以打開,打包PC后,點擊物體創建的表,打不開文件破損 解決方法&#…

大數據培訓前景怎么樣?企業需求量大嗎

大數據行業對大家來說并不陌生,大數據行業市場人才需求量大,越早入行越有優勢,發展機會和上升空間等大。不少人通過大數據培訓來提升自己的經驗和自身技術能力,以此來獲得更好的就業機會。 2023大數據培訓就業前景怎么樣呢?企業需…

ubuntu18 下更改 mysql 數據目錄

一、修改步驟 更改 MySQL 的數據目錄需要注意以下幾個步驟: 停止 MySQL 服務 在 Ubuntu 中,你可以使用以下命令停止 MySQL 服務: sudo systemctl stop mysql 復制現有數據 假設你的新的數據目錄是 /new/dir/mysql,你應該使用 rsy…

區間覆蓋 線段覆蓋 二分

4195. 線段覆蓋 - AcWing題庫 P2082 區間覆蓋&#xff08;加強版&#xff09; - 洛谷 | 計算機科學教育新生態 (luogu.com.cn) 做法&#xff1a; void solve() {int n; cin>>n;vector<array<LL,2>> seg(n);for(auto &t: seg) cin>>t[0]>>…

從視覺裝備到智能駕駛,天準科技能否打造第二增長極?

智能網聯汽車已經成為了上市公司跨界布局的熱門賽道。 天準科技是工業視覺智能裝備領域的龍頭企業&#xff0c;主要客戶包括蘋果、三星等企業。招股說明書顯示&#xff0c;2016年至2018年&#xff0c;天準科技來源于蘋果公司及其供應商的收入合計占比達到49.98%、67.99%及76.0…

Spark操作Hive表冪等性探索

前言 旁邊的實習生一邊敲著鍵盤一邊很不開心的說:做數據開發真麻煩,數據bug排查太繁瑣了,我今天數據跑的有問題,等我處理完問題重新跑了代碼,發現報表的數據很多重復,準備全部刪了重新跑。 我:你的數據操作具備冪等性嗎? 實習生:啥是冪等性?數倉中的表還要考慮冪等…

JVS開源基礎框架:平臺基本信息介紹

JVS是面向軟件開發團隊可以快速實現應用的基礎開發腳手架&#xff0c;主要定位于企業信息化通用底座&#xff0c;采用微服務分布式框架&#xff0c;提供豐富的基礎功能&#xff0c;集成眾多業務引擎&#xff0c;它靈活性強&#xff0c;界面化配置對開發者友好&#xff0c;底層容…

互聯網賬號被封禁解決辦法,以qq為例

百度搜索&#xff1a;互聯網信息服務投訴平臺 電腦端瀏覽器&#xff1a;打開 ts.isc.org.cn 推薦使用360極速瀏覽器 谷歌瀏覽器 提交完成后&#xff0c;將投訴碼保存&#xff0c;可以在“查詢評價”處用投訴碼查詢進度

windows安裝go,以及配置工作區,配置vscode開發環境

下載安裝go 我安裝在D:\go路徑下配置環境變量 添加GOROOT value為D:\go修改path 添加%GOROOT%\bin添加GOPATH value為%USERPROFILE%\go 其中GOPATH 是我們自己開發的工作區&#xff0c;其中包含三個folder bin,pkg,以及src&#xff0c;其中src為我們編寫代碼的位置 配置vscod…

Vue路由守衛

目錄 一、全局路由守衛二、獨享路由守衛三、組件內路由守衛 一、全局路由守衛 作用全局 router.beforeEach全局前置路由守衛—初始化的時候被調用、每次路由切換之前被調用router.afterEach全局后置路由守衛—初始化的時候被調用、每次路由切換之后被調用 配置 // 該文件專…

git使用規范

Git規范&#xff08;公司使用gitlab&#xff09; 版本規范 前端項目使用語義化版本進行發布: 版本格式&#xff1a;主版本號.次版本號.修訂號&#xff0c;版本號遞增規則如下&#xff1a; 主版本號&#xff1a;當你做了不兼容的 API 修改&#xff0c;次版本號&#xff1a;當…