簡單實現Transformer的自注意力
關注{曉理紫|小李子},獲取技術推送信息,如感興趣,請轉發給有需要的同學,謝謝支持!!
如果你感覺對你有所幫助,請關注我。
源碼獲取:VX關注并回復chatgpt-0獲得
- 實現的功能
假如有八個令牌,現在想讓每一個令牌至于其前面的通信,如第5個令牌不與6,7,8位置的令牌通信(這是未來的令牌),只與4,3,2,1位置的令牌通信。因此只能通過以前的上下文信息猜測后面的;一種弱的通信方式是取前面的平局值。如5位置==5,4,3,2,1位置上的平局值。
- 實現
- 循環的版本
import torch
from torch.nn import functional as F
import torch.nn as nn
torch.manual_seed(1337)B,T,C = 4,8,2 #batch,time,channels
x = torch.randn(B,T,C)
xbow = torch.zeros((B,T,C))
print(f'x: {x[0]}')
for b in range(B):for t in range(T):xprev = x[b,:t+1] #()t,Cxbow[b,t] = torch.mean(xprev,0)
print(f'xbow: {xbow[0]}')#結果
x: tensor([[ 0.1808, -0.0700],[-0.3596, -0.9152],[ 0.6258, 0.0255],[ 0.9545, 0.0643],[ 0.3612, 1.1679],[-1.3499, -0.5102],[ 0.2360, -0.2398],[-0.9211, 1.5433]])
xbow: tensor([[ 0.1808, -0.0700],[-0.0894, -0.4926],[ 0.1490, -0.3199],[ 0.3504, -0.2238],[ 0.3525, 0.0545],[ 0.0688, -0.0396],[ 0.0927, -0.0682],[-0.0341, 0.1332]])
# 每一行至于自己以及自己以前的數據進行通信
- 通過數據矩陣高效實現
a = torch.tril(torch.ones(3,3)) #下三角函數
a = a/torch.sum(a,1,keepdim=True) #對a求平均數
b = torch.randint(0,10,(3,2)).float()
c = a @ bprint(f'a:{a}')
print(f'b:{b}')
print(f'c:{c}')#結果a:tensor([[1.0000, 0.0000, 0.0000],[0.5000, 0.5000, 0.0000],[0.3333, 0.3333, 0.3333]])
b:tensor([[0., 4.],[1., 2.],[5., 5.]])
c:tensor([[0.0000, 4.0000],[0.5000, 3.0000],[2.0000, 3.6667]])
- 使用Softmax
tril = torch.tril(torch.ones(T,T)) #下三角函數
print(f'tril:{tril}')wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0,float('-inf'))# mask填充,對于tril為0的填充負無窮大
print(f'wei: {wei}')
wei = F.softmax(wei,dim=-1)# softmax對沒一行的每個元素進行求冪,在求平均數
print(f'wei: {wei}')
xbow3 = wei @ xprint(f'xbow3: {xbow3}')
print(torch.allclose(xbow,xbow3))
-
單頭自注意力
- 上面的自注意力是通過相同的方式獲取以往的信息。但是實際上并不希望是統一的方式,因為不同的token標記會發現其他不同的標記。
- 例如:我是元音,那么也許我正在尋找過去的輔音,或與我想知道這些輔音是什么。希望這些信息流向我,所以我現在想以依賴數據的方式收集過去的信息。這就是自注意力解決的問題。
- 方式如下:每個節點或每個位置的每個令牌都會發出兩個向量,一個發出查詢query,一個發出鍵key。查詢向量粗略的說就是我要找的東西,鍵向量粗略的講就是我包含什么。
- 現在在序列中獲取這些標記之間的親和力的方式基本上只是在鍵和查詢之間做一個點乘積。所以我的查詢與所有的其他tokens令牌的所有鍵進行點乘積。并且點積方式變了。如果鍵和查詢有點對齊,它們將交互到非常高的數量,然后我將了解有關特定標記的更多信息,而不是其他不再序列中的任何其他標記。
head_size = 16
key = nn.Linear(C,head_size,bias=False)
query = nn.Linear(C,head_size,bias=False)k = key(x) #(B,T,16)
q = key(x) #(B,T,16)
wei = q @ k.transpose(-2,-1) #轉置時最后兩個維度為負 (B,T,16) @ (B,16,T) ---> (B,T,T)tril = torch.tril(torch.ones(T,T)) #下三角函數
wei = wei.masked_fill(tril==0,float('-inf'))# mask填充,對于tril為0的填充負無窮大 主要是為了避免關注后面信息。如果想讓所有節點進行交流刪除詞句。解碼器中保留,編碼器刪除允許所有節點通信
wei = F.softmax(wei,dim=-1)# softmax對沒一行的每個元素進行求冪,在求平均數 主要為了避免關注過小的信息主要是負數
print(f'wei: {wei[0]}')
out = wei @ x
print(f'out:{out.shape}')
- 但是在真是中并不聚合到x而是計算一個v.x看作為該令牌的私人信息,與不同頭交流的信息存儲在v中
head_size = 16
key = nn.Linear(C,head_size,bias=False)
query = nn.Linear(C,head_size,bias=False)k = key(x) #(B,T,16)
q = key(x) #(B,T,16)
wei = q @ k.transpose(-2,-1) #轉置時最后兩個維度為負 (B,T,16) @ (B,16,T) ---> (B,T,T)tril = torch.tril(torch.ones(T,T)) #下三角函數
wei = wei.masked_fill(tril==0,float('-inf'))# mask填充,對于tril為0的填充負無窮大 主要是為了避免關注后面信息。如果想讓所有節點進行交流刪除詞句。解碼器中保留,編碼器刪除允許所有節點通信
wei = F.softmax(wei,dim=-1)# softmax對沒一行的每個元素進行求冪,在求平均數 主要為了避免關注過小的信息主要是負數
print(f'wei: {wei[0]}')
value = nn.Linear(C,head_size,bias=False)
v = value(x)
out = wei @ v
print(f'out:{out.shape}')
簡單實現自注意力
關注{曉理紫|小李子},獲取技術推送信息,如感興趣,請轉發給有需要的同學,謝謝支持!!
如果你感覺對你有所幫助,請關注我。