Multi-Head Attention詳解

在這里插入圖片描述
文中大部分內容以及圖片來自:https://medium.com/@hunter-j-phillips/multi-head-attention-7924371d477a

當使用 multi-head attention 時,通常d_key = d_value =(d_model / n_heads),其中n_heads是頭的數量。研究人員稱,通常使用平行注意層代替全尺寸性,因為該模型能夠“關注來自不同位置的不同表示子空間的信息”。

通過線性層傳遞輸入

計算注意力的第一步是獲得Q、K和V張量;它們分別是查詢張量、鍵張量和值張量。它們是通過采用位置編碼的嵌入來計算的,它將被記為X,同時將張量傳遞給三個線性層,它們被記為Wq, Wk和Wv。這可以從上面的詳細圖像中看到。

  • Q = XWq
  • K = XWk
  • V = XWv
    為了理解乘法是如何發生的,最好將每個組件分解成這個形狀:
  • X的大小為(batch_size, seq_length, d_model)。例如,一批32個序列的長度為10,嵌入為512,其形狀為(32,10,512)。
  • Wq,Wk和Wv的大小為(d_model,d_model)。按照上面的示例,它們的形狀為(512,512)。

因此,可以更好地理解乘法的輸出。每個重量矩陣同時在批處理中 broadcast 每個序列,以創建Q,K和V張量。

  • Q = XWq | (batch_size, seq_length, d_model) x (d_model, d_model) = (batch_size, seq_length, d_model)
  • K = XWk | (batch_size, seq_length, d_model) x (d_model, d_model) = (batch_size, seq_length, d_model)
  • V = XWv | (batch_size, seq_length, d_model) x (d_model, d_model) = (batch_size, seq_length, d_model)

下面的圖片顯示了Q, K和V是如何出現的。每個紫色盒子代表一個序列,每個橙色盒子是序列中的一個 token 或單詞。灰色橢圓表示每個token 的嵌入。
在這里插入圖片描述

下面的代碼加載了Positional Encoding和Embeddings類。

# convert the sequences to integers
sequences = ["I wonder what will come next!","This is a basic example paragraph.","Hello, what is a basic split?"]# tokenize the sequences
tokenized_sequences = [tokenize(seq) for seq in sequences]# index the sequences 
indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]# convert the sequences to a tensor
tensor_sequences = torch.tensor(indexed_sequences).long()# vocab size
vocab_size = len(stoi)# embedding dimensions
d_model = 8# create the embeddings
lut = Embeddings(vocab_size, d_model) # look-up table (lut)# create the positional encodings
pe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)# embed the sequence
embeddings = lut(tensor_sequences)# positionally encode the sequences
X = pe(embeddings)
tensor([[[-3.45, -1.34,  4.12, -3.33, -0.81, -1.93, -0.28,  8.25],[ 7.36, -1.09,  2.32,  1.52,  3.50,  1.42,  0.46, -0.95],[-2.26,  0.53, -1.02,  1.49, -3.97, -2.19,  2.86, -0.59],[-3.87, -2.02,  1.46,  6.78,  0.88,  1.08, -2.97,  1.45],[ 1.12, -2.09,  1.19,  3.87, -0.00,  3.73, -0.88,  1.12],[-0.35, -0.02,  3.98, -0.20,  7.05,  1.55,  0.00, -0.83]],[[-4.27,  0.17, -2.08,  0.94, -6.35,  1.99,  5.23,  5.18],[-0.00, -5.05, -7.19,  3.27,  1.49, -7.11, -0.59,  0.52],[ 0.54, -2.33, -1.10, -2.02, -0.88, -3.15,  0.38,  5.26],[ 0.87, -2.98,  2.67,  3.32,  1.16,  0.00,  1.74,  5.28],[-5.58, -2.09,  0.96, -2.05, -4.23,  2.11, -0.00,  0.61],[ 6.39,  2.15, -2.78,  2.45,  0.30,  1.58,  2.12,  3.20]],[[ 4.51, -1.22,  2.04,  3.48,  1.63,  3.42,  1.21,  2.33],[-2.34,  0.00, -1.13,  1.51, -3.99, -2.19,  2.86, -0.59],[-4.65, -6.12, -7.08,  3.26,  1.50, -7.11, -0.59,  0.52],[-0.32, -2.97, -0.99, -2.05, -0.87, -0.00,  0.39,  5.26],[-0.12, -2.61,  2.77,  3.28,  1.17,  0.00,  1.74,  5.28],[-5.64,  0.49,  2.32, -0.00, -0.44,  4.06,  3.33,  3.11]]],grad_fn=<MulBackward0>)

此時,嵌入序列X的形狀為(3,6,8)。有3個序列,包含6個標記,具有8維嵌入。

Wq、Wk和Wv的線性層可以使用nn.Linear(d_model, d_model)來創建。這將創建一個(8,8)矩陣,該矩陣將在跨每個序列的乘法期間廣播。

Wq = nn.Linear(d_model, d_model)          # query weights (8,8)
Wk = nn.Linear(d_model, d_model)          # key weights   (8,8)
Wv = nn.Linear(d_model, d_model)          # value weights (8,8)Wq.state_dict()['weight']
tensor([[ 0.19,  0.34, -0.12, -0.22,  0.26, -0.06,  0.12, -0.28],[ 0.09,  0.22,  0.32,  0.11,  0.21,  0.03, -0.35,  0.31],[-0.34, -0.21, -0.11,  0.34, -0.28,  0.03,  0.26, -0.22],[-0.35,  0.11,  0.17,  0.21, -0.19, -0.29,  0.22,  0.20],[ 0.19,  0.04, -0.07, -0.02,  0.01, -0.20,  0.30, -0.19],[ 0.23,  0.15,  0.22,  0.26,  0.17,  0.16,  0.23,  0.18],[ 0.01,  0.06, -0.31,  0.19,  0.22,  0.08,  0.15, -0.04],[-0.11,  0.24, -0.20,  0.26, -0.01, -0.14,  0.29, -0.32]])

Wq的權重如上圖所示。Wk和Wv形狀相同,但權重不同。當X穿過每一個線性層時,它保持它的形狀,但是現在Q, K和V已經被權值轉換成唯一的張量。

Q = Wq(X) # (3,6,8)x(broadcast 8,8) = (3,6,8)
K = Wk(X) # (3,6,8)x(broadcast 8,8) = (3,6,8)
V = Wv(X) # (3,6,8)x(broadcast 8,8) = (3,6,8)Q
tensor([# sequence 0[[-3.13,  2.71, -2.07,  3.54, -2.25, -0.26, -2.80, -4.31],[ 1.70,  1.63, -2.90, -2.90,  1.15,  3.01,  0.49, -1.14],[-0.69, -2.38,  3.00,  3.09,  0.97, -0.98, -0.10,  2.16],[-3.52,  2.08,  2.36,  2.16, -2.48,  0.58,  0.33, -0.26],[-1.99,  1.18,  0.64, -0.45, -1.32,  1.61,  0.28, -1.18],[ 1.66,  2.46, -2.39, -0.97, -0.47,  1.83,  0.36, -1.06]],# sequence 1[[-3.13, -2.43,  3.85,  4.34, -0.60, -0.03,  0.04,  0.62],[-0.82, -2.67,  1.82,  0.89,  1.30, -2.65,  2.01,  1.56],[-1.42,  0.11, -1.40,  1.36, -0.21, -0.87, -0.88, -2.24],[-2.70,  1.88, -0.10,  1.95, -0.75,  2.54, -0.14, -1.91],[-2.67, -1.58,  2.46,  1.93, -1.78, -2.44, -1.76, -1.23],[ 1.23,  0.78, -1.93, -1.12,  1.07,  2.98,  1.82,  0.18]],# sequence 2[[-0.71,  1.90, -1.12, -0.97, -0.23,  3.54,  0.65, -1.39],[-0.87, -2.54,  3.16,  3.04,  0.94, -1.10, -0.10,  2.07],[-2.06, -3.30,  3.63,  2.39,  0.38, -3.87,  1.86,  1.79],[-2.00,  0.02, -0.90,  0.68, -1.03, -0.63, -0.70, -2.77],[-2.76,  1.90,  0.14,  2.34, -0.93,  2.38, -0.17, -1.75],[-1.82,  0.15,  1.79,  2.87, -1.65,  0.97, -0.21, -0.54]]],grad_fn=<ViewBackward0>)

Q K和V都是這個形狀。和前面一樣,每個矩陣是一個序列,每一行都是由嵌入表示的 token。

把Q, K, V分成多頭

通過創建Q,K和V張量,現在可以通過將D_Model的視圖更改為 (n_heads, d_key) 將它們分為各自的頭部。N_heads可以是任意數字,但是使用較大的嵌入時,通常要執行8、10或12。請記住,d_key = (d_model / n_heads)。
在這里插入圖片描述
在前面的圖像中,每個 token 在單個維度中包含d_model嵌入。現在,這個維度被分成行和列來創建一個矩陣;每行是一個包含鍵的頭。這可以從上面的圖片中看到。

每個張量的形狀變成:

  • (batch_size, seq_length, d_model) → (batch_size, seq_length, n_heads, d_key)

假設示例中選擇了四個heads,則(3,6,8)張量將被分成(3,6,4,2)張量,其中有3個序列,每個序列中有6個tokens,每個標記中有4個heads,每個正面中有2個元素。

這可以通過view來實現,它可以用來添加和設置每個維度的大小。由于每個示例的批大小或序列數量是相同的,因此可以設置批大小。同樣,在每個張量中,頭的數量和鍵的數量應該是恒定的。-1可用于表示剩余值,即序列長度。

batch_size = Q.size(0)   
n_heads = 4
d_key = d_model//n_heads # 8/4 = 2# query tensor | -1 = query_length | (3, 6, 8) -> (3, 6, 4, 2)
Q = Q.view(batch_size, -1, n_heads, d_key)# value tensor | -1 = key_length | (3, 6, 8) -> (3, 6, 4, 2) 
K = K.view(batch_size, -1, n_heads, d_key)# value tensor | -1 = value_length | (3, 6, 8) -> (3, 6, 4, 2) 
V = V.view(batch_size, -1, n_heads, d_key)                   Q

下面是Q張量的例子。這3個序列中的每一個都有6個tokens,每個標記是一個tokens,每個標記有4個head(行)和2個keys。

tensor([# sequence 0[[[-3.13,  2.71],[-2.07,  3.54],[-2.25, -0.26],[-2.80, -4.31]],[[ 1.70,  1.63],[-2.90, -2.90],[ 1.15,  3.01],[ 0.49, -1.14]],[[-0.69, -2.38],[ 3.00,  3.09],[ 0.97, -0.98],[-0.10,  2.16]],[[-3.52,  2.08],[ 2.36,  2.16],[-2.48,  0.58],[ 0.33, -0.26]],[[-1.99,  1.18],[ 0.64, -0.45],[-1.32,  1.61],[ 0.28, -1.18]],[[ 1.66,  2.46],[-2.39, -0.97],[-0.47,  1.83],[ 0.36, -1.06]]],# sequence 1[[[-3.13, -2.43],[ 3.85,  4.34],[-0.60, -0.03],[ 0.04,  0.62]],[[-0.82, -2.67],[ 1.82,  0.89],[ 1.30, -2.65],[ 2.01,  1.56]],[[-1.42,  0.11],[-1.40,  1.36],[-0.21, -0.87],[-0.88, -2.24]],[[-2.70,  1.88],[-0.10,  1.95],[-0.75,  2.54],[-0.14, -1.91]],[[-2.67, -1.58],[ 2.46,  1.93],[-1.78, -2.44],[-1.76, -1.23]],[[ 1.23,  0.78],[-1.93, -1.12],[ 1.07,  2.98],[ 1.82,  0.18]]],# sequence 2[[[-0.71,  1.90],[-1.12, -0.97],[-0.23,  3.54],[ 0.65, -1.39]],[[-0.87, -2.54],[ 3.16,  3.04],[ 0.94, -1.10],[-0.10,  2.07]],[[-2.06, -3.30],[ 3.63,  2.39],[ 0.38, -3.87],[ 1.86,  1.79]],[[-2.00,  0.02],[-0.90,  0.68],[-1.03, -0.63],[-0.70, -2.77]],[[-2.76,  1.90],[ 0.14,  2.34],[-0.93,  2.38],[-0.17, -1.75]],[[-1.82,  0.15],[ 1.79,  2.87],[-1.65,  0.97],[-0.21, -0.54]]]], grad_fn=<ViewBackward0>)

為了繼續,最好將序列長度和n個頭(第二次和第三次)調換成以下形狀

  • (batch_size, seq_length, n_heads, d_key) → (batch_size, n_heads, seq_length, d_key)

現在,每個序列被分成n_heads,每個頭接收seq_length長度token中的d_key個元素,而不是d_model個。這達到了研究人員在不同位置關注來自不同表示子空間的信息的目的。

這個張量的可視化如下圖所示。每個序列是紫色的,每個頭是灰色的。在頭部中,每個標記是一行d_key元素。
在這里插入圖片描述

回到前面的例子,Q張量將從(3,6,4,2)轉置到(3,4,6,2)。這個張量現在將表示3個序列,每個序列分為n_heads= 4,每個頭包含 seq_length= 6個tokens,每個tokens有一個 d_key = 2元素鍵。

本質上,每個頭部都包含每個序列 tokens 的副本,但它只有一個 d_key= 2的元素表示,而不是完整的d_model= 8的元素表示。這意味著每個序列同時在n_head= 4個不同的子空間中表示。

下面的代碼使用permute來切換每個張量的第二軸和第三軸。

# query tensor | (3, 6, 4, 2) -> (3, 4, 6, 2) 
Q = Q.permute(0, 2, 1, 3)
# key tensor | (3, 6, 4, 2) -> (3, 4, 6, 2) 
K = K.permute(0, 2, 1, 3)
# value tensor | (3, 6, 4, 2) -> (3, 4, 6, 2) 
V = V.permute(0, 2, 1, 3)Q
tensor([# sequence 0[[[-3.13,  2.71],[ 1.70,  1.63],[-0.69, -2.38],[-3.52,  2.08],[-1.99,  1.18],[ 1.66,  2.46]],[[-2.07,  3.54],[-2.90, -2.90],[ 3.00,  3.09],[ 2.36,  2.16],[ 0.64, -0.45],[-2.39, -0.97]],[[-2.25, -0.26],[ 1.15,  3.01],[ 0.97, -0.98],[-2.48,  0.58],[-1.32,  1.61],[-0.47,  1.83]],[[-2.80, -4.31],[ 0.49, -1.14],[-0.10,  2.16],[ 0.33, -0.26],[ 0.28, -1.18],[ 0.36, -1.06]]],# sequence 1[[[-3.13, -2.43],[-0.82, -2.67],[-1.42,  0.11],[-2.70,  1.88],[-2.67, -1.58],[ 1.23,  0.78]],[[ 3.85,  4.34],[ 1.82,  0.89],[-1.40,  1.36],[-0.10,  1.95],[ 2.46,  1.93],[-1.93, -1.12]],[[-0.60, -0.03],[ 1.30, -2.65],[-0.21, -0.87],[-0.75,  2.54],[-1.78, -2.44],[ 1.07,  2.98]],[[ 0.04,  0.62],[ 2.01,  1.56],[-0.88, -2.24],[-0.14, -1.91],[-1.76, -1.23],[ 1.82,  0.18]]],# sequence 2[[[-0.71,  1.90],[-0.87, -2.54],[-2.06, -3.30],[-2.00,  0.02],[-2.76,  1.90],[-1.82,  0.15]],[[-1.12, -0.97],[ 3.16,  3.04],[ 3.63,  2.39],[-0.90,  0.68],[ 0.14,  2.34],[ 1.79,  2.87]],[[-0.23,  3.54],[ 0.94, -1.10],[ 0.38, -3.87],[-1.03, -0.63],[-0.93,  2.38],[-1.65,  0.97]],[[ 0.65, -1.39],[-0.10,  2.07],[ 1.86,  1.79],[-0.70, -2.77],[-0.17, -1.75],[-0.21, -0.54]]]], grad_fn=<PermuteBackward0>)

雖然擁有完整的視圖很好,但通過檢查單個序列更容易理解。

很容易在這個序列中看到四個heads。每個頭包含六行,這是 tokens,每行有兩個元素,這是鍵。這顯示了如何將序列拆分為四個子空間,以創建同一序列的不同表示。

# select the first sequence from the Query tensor
Q[0]
tensor([# head 0[[-3.13,  2.71],[ 1.70,  1.63],[-0.69, -2.38],[-3.52,  2.08],[-1.99,  1.18],[ 1.66,  2.46]],# head 1[[-2.07,  3.54],[-2.90, -2.90],[ 3.00,  3.09],[ 2.36,  2.16],[ 0.64, -0.45],[-2.39, -0.97]],# head 2[[-2.25, -0.26],[ 1.15,  3.01],[ 0.97, -0.98],[-2.48,  0.58],[-1.32,  1.61],[-0.47,  1.83]],# head 3[[-2.80, -4.31],[ 0.49, -1.14],[-0.10,  2.16],[ 0.33, -0.26],[ 0.28, -1.18],[ 0.36, -1.06]]], grad_fn=<SelectBackward0>)

計算注意力

在這里插入圖片描述
將Q, K和V分成多個頭,現在可以計算Q和K的標量點積。上面的等式表明,第一步是執行張量乘法。然而,K必須先轉置。

接下來,為了清晰起見,每個張量的seq長度形狀將通過其各自的張量,Q_length,K_length 或 V_length 來知道

  • Q has a shape of (batch_size, n_heads, Q_length, d_key)
  • K has a shape of (batch_size, n_heads, K_length, d_key)
  • V has a shape of (batch_size, n_heads, V_length, d_key)

K最右邊的兩個維度必須調換,以改變形狀為(batch_size, n_heads, d_key, K_length)。

現在, Q K T QK^T QKT的輸出是

  • (batch_size, n_heads, Q_length, d_key) x (batch_size, n_heads, d_key, K_length) = (batch_size, n_heads, Q_length, K_length)

每個張量中的相應序列將相互乘法。Q中的第一個序列將乘以K中的第一個序列,Q中的第二個序列與K中的第二個序列相乘。當這些序列相互相乘時,每個頭將在相反的張量中與相應的頭相乘。Q的第一個序列的第一個頭將與K的第一個序列的第一個頭相乘,Q的第一個序列的第二個頭與K的第一個序列的第二個頭相乘。在乘以這些頭時,Q頭中每個形狀為(Q_length,d_key)的token與K頭中的每個token相乘,形狀為(d_key,K_length)。結果是一個(Q-length,K_length)矩陣,顯示每個單詞與包括自身在內的所有其他單詞的強度。這就是“self-attention”這個名字的來源,因為模型通過將單詞乘以另一個自身表示來發現哪些單詞與序列最相關。

Q K T QK^T QKT由 d_key 縮放,以幫助使softmax函數在下一步的輸出不那么集中在0和1附近。在未縮放分布中,接近0和1的值更接近分布的中間。

繼續這個例子,縮放后的點積的輸出形狀為(3, 4, 6, 2) x (3, 4, 2, 6) = (3, 4, 6, 6)。

# calculate scaled dot product
scaled_dot_prod = torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(d_key) # (batch_size, n_heads, Q_length, K_length)

這個張量然后通過softmax函數來創建一個概率分布。請注意softmax是如何應用于每個頭部中每個矩陣的每一行的。softmax維度可以設置為-1或3,因為兩者都表示形狀中最右邊的維度,即鍵。

# apply softmax to get context for each token and others
attn_probs = torch.softmax(scaled_dot_prod, dim=-1) # (batch_size, n_heads, Q_length, K_length)

這些注意概率可以使用來自matplotlib的imshow可視化。可以在附錄中找到同時顯示序列的所有頭部的函數,稱為display_attention。白色更接近于1,黑色更接近于0。

# sequence 0
display_attention(["i", "wonder", "what", "will", "come", "next"], ["i", "wonder", "what", "will", "come", "next"], attn_probs[0], 4, 2, 2)

在這里插入圖片描述

# sequence 1
display_attention(["this", "is", "a", "basic", "example", "paragraph"], ["this", "is", "a", "basic", "example", "paragraph"], attn_probs[1], 4, 2, 2)

在這里插入圖片描述

# sequence 2
display_attention(["hello", "what", "is", "a", "basic", "split"], ["hello", "what", "is", "a", "basic", "split"], attn_probs[2], 4, 2, 2)

在這里插入圖片描述

它們演示了每個query(row)和key(column)之間的關系。序列中單詞之間的每個交集都代表了關系的強度。由于這些值是從隨機權重生成的,因此它們目前沒有顯示任何有效的關系。下圖展示了編碼器經過訓練后這些矩陣的樣子。
在這里插入圖片描述
計算出這些概率后,下一步是將它們與V張量相乘,以創建這些分布的總結。每個單詞的上下文本質上是聚合的。
在這里插入圖片描述

# multiply attention and values to get reweighted values
A = torch.matmul(attn_probs, V) # (batch_size, n_heads, Q_length, d_key)

下面是這個示例的每個步驟的圖表。
在這里插入圖片描述

這里到底發生了什么:好吧,Q和K都是同一序列的表示,分為不同頭部的query和key組件。這計算了序列中每個單詞與序列中所有其他單詞之間的關系。這發生在 n_heads 子空間中。計算每個單詞的query表示和每個單詞的key表示之間的點積。這表示每個單詞和其他單詞之間的“強度”或“重量”。通過訓練,這種力量將幫助模型理解哪些單詞之間應該有更高的“權重”;這將表明哪些單詞對上下文和預測最重要。再次強調,query與key相乘,以在每個token和序列中的所有其他token之間生成權重。

softmax張量中的每一行都表示一個token與同一序列中的其他token之間的關系。在V中,每一列是序列的一個表示。將兩個張量相乘以重新加權值,并計算每個頭或子空間中每個token的最重要上下文的摘要。

下面的圖表顯示了序列中單個頭部的 self-attention。

在這里插入圖片描述

Passing It Through the Output Layer

此時,在通過最后的線性層(即多頭注意機制中的最后一層)之前,這些頭部可以重新連接在一起。

串聯將反轉最初執行的分割。第一步是n_heads和Q_length的轉置。第二步是將 n_heads 和 d_key 連接起來,得到 d_model。

一旦完成,A將具有 (batch_size, Q_length, d_model) 的形狀。

# transpose from (3, 4, 6, 2) -> (3, 6, 4, 2)
A = A.permute(0, 2, 1, 3).contiguous()# reshape from (3, 6, 4, 2) -> (3, 6, 8) = (batch_size, Q_length, d_model)
A = A.view(batch_size, -1, n_heads*d_key)A 
tensor([[[ 0.41, -0.71,  0.63, -0.22,  0.79, -3.58,  0.11,  1.71],[-0.15,  0.93,  0.50, -0.40, -0.43, -1.36,  0.11,  1.64],[-1.05, -1.58, -0.14, -1.42,  0.12,  0.21, -0.54, -0.52],[ 0.31, -0.65, -0.17, -1.33,  0.84, -3.78, -0.02,  0.41],[ 0.58, -0.83, -0.56, -1.17,  0.83, -3.70,  0.11,  1.65],[-0.17,  0.99,  0.58, -0.32,  0.65, -3.14,  0.11,  1.61]],[[-1.08, -1.93, -1.62,  3.69,  0.62, -0.34, -1.88, -2.31],[-1.17, -1.84, -1.76,  1.62,  0.60, -0.40, -2.56, -1.59],[-1.29, -0.52, -0.89, -1.06,  0.31,  0.07,  0.90,  1.69],[-0.90, -0.07, -1.43,  1.97,  1.16, -1.30,  0.73,  1.51],[-1.09, -1.92, -1.61,  2.89, -0.21,  0.92,  0.55,  1.32],[-0.92, -1.14, -0.95, -1.66,  0.28, -0.70, -0.91, -0.78]],[[-0.27,  0.87, -1.54, -3.73,  1.00, -1.33, -0.80,  0.07],[-1.13, -1.86, -1.22,  0.61, -0.47,  0.15, -0.10, -3.30],[-1.04, -1.82, -1.48,  0.91, -0.70,  0.45, -1.37, -0.49],[-0.37,  0.57, -1.24, -1.56, -0.29,  0.44, -0.97,  0.25],[-0.22,  1.10, -0.89, -0.33,  1.02, -1.33, -0.80,  0.19],[-0.37,  0.62, -1.02,  0.15,  0.80, -1.09, -0.37, -0.42]]],grad_fn=<ViewBackward0>)

最后一步是讓A通過Wo,它的形狀為 (d_model, d_model)。同樣,權重張量在批處理中的每個序列中廣播。最后的輸出保持其形狀
在這里插入圖片描述

Wo = nn.Linear(d_model, d_model)# (3, 6, 8) x (broadcast 8, 8) = (3, 6, 8)
output = Wo(A)              
tensor([[[-0.39, -0.45, -0.17,  0.18, -0.24, -1.68, -0.35, -0.56],[ 0.38,  0.02,  0.28, -0.42, -0.70, -0.81,  0.05,  0.03],[ 1.01, -0.72,  0.12,  0.18,  1.20, -0.29,  1.10, -0.59],[-0.50, -0.84, -0.07,  0.22,  0.49, -1.58,  0.13, -0.90],[-0.15, -0.95, -0.35,  0.17,  0.15, -1.65, -0.27, -0.79],[-0.47, -0.04,  0.15,  0.03, -0.83, -1.24, -0.04, -0.15]],[[-1.29, -0.85, -1.02,  1.56,  0.32, -0.08, -0.14,  0.40],[-0.45, -1.19, -0.70,  1.23,  0.75, -0.42,  0.46, -0.38],[ 1.33, -0.58, -0.34,  0.10, -0.13,  0.15,  0.44,  0.38],[-0.42, -0.32, -0.97,  0.89, -1.19,  0.01, -0.66,  1.11],[ 0.66, -0.75, -1.36,  0.73, -0.69,  0.47, -0.79,  1.29],[ 0.60, -1.03,  0.01,  0.29,  1.20, -0.50,  1.07, -0.78]],[[ 0.61, -0.66,  0.54, -0.06,  0.97, -0.68,  1.30, -1.08],[-0.22, -1.02, -0.38,  0.62,  1.46,  0.30,  0.74,  0.10],[ 0.67, -1.23, -0.65,  0.47,  0.58, -0.18,  0.31, -0.09],[ 0.94, -0.43,  0.30, -0.22,  0.40, -0.23,  0.78, -0.36],[-0.46, -0.03,  0.16,  0.37, -0.23, -0.55,  0.34, -0.11],[-0.54, -0.15, -0.03,  0.46, -0.06, -0.29,  0.26,  0.13]]],grad_fn=<ViewBackward0>)

該輸出將傳遞到下一層,其中包括殘差加法和layer normalization。這些將在后面的文章中討論。

Multi-Head Attention in Transformers

在解釋了多頭注意力的每個組件之后,實現就很簡單了,并且使用了前面列出的相同組件。唯一增加的是一個dropout層。

代碼中有一個掩碼的實現,但現在可以忽略它。它不會對實現之后的示例產生影響。當描述編碼器和解碼器時,將對此進行解釋。

請注意,在這個實現中,Q、K和V張量是同時分割和排列的,這與上面的實現不同。

class MultiHeadAttention(nn.Module):def __init__(self, d_model: int = 512, n_heads: int = 8, dropout: float = 0.1):"""Args:d_model:      dimension of embeddingsn_heads:      number of self attention headsdropout:      probability of dropout occurring"""super().__init__()assert d_model % n_heads == 0            # ensure an even num of headsself.d_model = d_model                   # 512 dimself.n_heads = n_heads                   # 8 headsself.d_key = d_model // n_heads          # assume d_value equals d_key | 512/8=64self.Wq = nn.Linear(d_model, d_model)    # query weightsself.Wk = nn.Linear(d_model, d_model)    # key weightsself.Wv = nn.Linear(d_model, d_model)    # value weightsself.Wo = nn.Linear(d_model, d_model)    # output weightsself.dropout = nn.Dropout(p=dropout)     # initialize dropout layer  def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Tensor = None):"""Args:query:         query vector         (batch_size, q_length, d_model)key:           key vector           (batch_size, k_length, d_model)value:         value vector         (batch_size, s_length, d_model)mask:          mask for decoder     Returns:output:        attention values     (batch_size, q_length, d_model)attn_probs:    softmax scores       (batch_size, n_heads, q_length, k_length)"""batch_size = key.size(0)                  # calculate query, key, and value tensorsQ = self.Wq(query)                       # (32, 10, 512) x (512, 512) = (32, 10, 512)K = self.Wk(key)                         # (32, 10, 512) x (512, 512) = (32, 10, 512)V = self.Wv(value)                       # (32, 10, 512) x (512, 512) = (32, 10, 512)# split each tensor into n-heads to compute attention# query tensorQ = Q.view(batch_size,                   # (32, 10, 512) -> (32, 10, 8, 64) -1,                           # -1 = q_lengthself.n_heads,              self.d_key).permute(0, 2, 1, 3)         # (32, 10, 8, 64) -> (32, 8, 10, 64) = (batch_size, n_heads, q_length, d_key)# key tensorK = K.view(batch_size,                   # (32, 10, 512) -> (32, 10, 8, 64) -1,                           # -1 = k_lengthself.n_heads,              self.d_key).permute(0, 2, 1, 3)         # (32, 10, 8, 64) -> (32, 8, 10, 64) = (batch_size, n_heads, k_length, d_key)# value tensorV = V.view(batch_size,                   # (32, 10, 512) -> (32, 10, 8, 64) -1,                           # -1 = v_lengthself.n_heads, self.d_key).permute(0, 2, 1, 3)         # (32, 10, 8, 64) -> (32, 8, 10, 64) = (batch_size, n_heads, v_length, d_key)# computes attention# scaled dot product -> QK^{T}scaled_dot_prod = torch.matmul(Q,        # (32, 8, 10, 64) x (32, 8, 64, 10) -> (32, 8, 10, 10) = (batch_size, n_heads, q_length, k_length)K.permute(0, 1, 3, 2)) / math.sqrt(self.d_key)      # sqrt(64)# fill those positions of product as (-1e10) where mask positions are 0if mask is not None:scaled_dot_prod = scaled_dot_prod.masked_fill(mask == 0, -1e10)# apply softmax attn_probs = torch.softmax(scaled_dot_prod, dim=-1)# multiply by values to get attentionA = torch.matmul(self.dropout(attn_probs), V)       # (32, 8, 10, 10) x (32, 8, 10, 64) -> (32, 8, 10, 64)# (batch_size, n_heads, q_length, k_length) x (batch_size, n_heads, v_length, d_key) -> (batch_size, n_heads, q_length, d_key)# reshape attention back to (32, 10, 512)A = A.permute(0, 2, 1, 3).contiguous()              # (32, 8, 10, 64) -> (32, 10, 8, 64)A = A.view(batch_size, -1, self.n_heads*self.d_key) # (32, 10, 8, 64) -> (32, 10, 8*64) -> (32, 10, 512) = (batch_size, q_length, d_model)# push through the final weight layeroutput = self.Wo(A)                                 # (32, 10, 512) x (512, 512) = (32, 10, 512) return output, attn_probs                           # return attn_probs for visualization of the scores

現在,可以將它與嵌入層和位置編碼層一起使用,以生成與本文類似的輸出。將使用相同的示例,但將使用該類生成不同的輸出。記住,這假設Embeddings和PositionalEncoding模塊與MultiHeadAttention模塊一起加載。

torch.set_printoptions(precision=2, sci_mode=False)# convert the sequences to integers
sequences = ["I wonder what will come next!","This is a basic example paragraph.","Hello, what is a basic split?"]# tokenize the sequences
tokenized_sequences = [tokenize(seq) for seq in sequences]# index the sequences 
indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]# convert the sequences to a tensor
tensor_sequences = torch.tensor(indexed_sequences).long()# vocab size
vocab_size = len(stoi)# embedding dimensions
d_model = 8# create the embeddings
lut = Embeddings(vocab_size, d_model) # look-up table (lut)# create the positional encodings
pe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)# embed the sequence
embeddings = lut(tensor_sequences)# positionally encode the sequences
X = pe(embeddings)# set the n_heads
n_heads = 4# create the attention layer
attention = MultiHeadAttention(d_model, n_heads, dropout=0.1)# pass X through the attention layer three times to create Q, K, and V
output, attn_probs = attention(X, X, X, mask=None)output

正如預期的那樣,輸出的形狀與輸入的形狀相同,即(3,6,8)。

tensor([[[-0.54,  0.58, -0.86,  0.72,  0.73,  0.26,  0.22, -1.31],[-0.88, -0.50,  0.06, -1.04,  0.79,  0.05,  0.78, -1.34],[-2.34,  0.46,  0.84,  0.15,  1.22,  1.25,  1.99, -1.55],[-2.69,  0.17,  0.57,  0.20,  1.44,  1.89,  1.99, -1.95],[-0.00, -1.09,  0.21, -0.90,  1.34, -0.32, -0.30, -0.81],[-1.25, -0.88,  0.85, -0.05,  1.54,  0.11,  0.77, -1.59]],[[-0.36, -0.52, -0.66, -0.71, -0.46,  0.83,  0.68,  0.19],[-0.45, -0.04, -0.76, -0.12,  0.21,  1.05,  0.54, -0.12],[-0.97,  0.15, -0.32, -0.14, -0.07,  0.96,  1.07, -0.42],[ 0.06, -0.69, -0.71, -0.72,  0.04,  0.32,  0.20,  0.13],[-0.40,  0.14, -0.48,  0.36, -0.85,  0.72,  0.77,  0.45],[-0.17, -0.69, -0.45, -0.98, -0.15,  0.14,  0.52, -0.04]],[[ 0.57,  0.26, -0.24,  0.44,  0.08, -0.66, -0.37, -0.23],[-0.33,  0.75,  0.58,  0.06,  0.32, -0.63,  0.55, -0.10],[-0.50,  0.46, -0.64,  0.87,  0.65,  0.85,  0.29, -0.60],[ 1.54,  0.43,  1.51,  0.09, -0.19, -2.58, -0.84,  1.40],[ 1.46, -0.38, -0.51, -0.06,  0.04, -0.83, -1.10,  1.08],[-0.28,  1.85,  0.19,  1.38, -0.69, -0.01,  0.55, -0.11]]],grad_fn=<ViewBackward0>)

來自注意力的概率也可以使用注意力問題來預覽。下面是第一個序列的heads的注意力分布。

display_attention(["i", "wonder", "what", "will", "come", "next"], ["i", "wonder", "what", "will", "come", "next"], attn_probs[0], 4, 2, 2)

在這里插入圖片描述

Supplementary Images of Attention

在這里插入圖片描述
這是多頭注意力計算的另一種view。下圖是如何在同一序列的不同表示之間計算softmax的示例。
在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/710611.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/710611.shtml
英文地址,請注明出處:http://en.pswp.cn/news/710611.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

01-Vue2 介紹與指令的使用

1. Vue核心 1.1. Vue簡介 1.1.1. 官網 中文官網Vue.js - 漸進式 JavaScript 框架 | Vue.js (vuejs.org)https://cn.vuejs.org/ 英文官網Vue.js - The Progressive JavaScript Framework | Vue.js (vuejs.org)https://vuejs.org/ 1.1.2. 介紹與描述 VUE是構建于用戶界面的漸進…

靶機滲透之sar

Name: Sar: 1Date release: 15 Feb 2020Author: LoveSeries: Sar Download: https://drive.google.com/open?id1AFAmM21AwiAEiVFUA0cSr_GeAYaxd3lQ 對于vulnhub中的靶機&#xff0c;我們都需先下載鏡像&#xff0c;然后導入VM&#xff0c;并將網絡連接改為NAT模式。首先我們…

UDP數據報套接字編程入門

目錄 1.TCP和UDP的特點及區別 1.1TCP的特點 1.2UDP的特點 1.3區別 2.UDP Socket的api的介紹 2.1DatagramSocket API 2.2DatagramPacket API 3.回顯客戶端與服務器 3.1回顯服務器 3.1.1UdpEchoServer類的創建 3.1.2服務器的運行方法start() 3.1.3main部分 3.1.4.完整…

C# CAD PaletteSet.Style各種外觀和行為樣式

ps.Style 是 Autodesk.AutoCAD.Windows.PaletteSet 類的一個屬性&#xff0c;用于定義調色板集&#xff08;PaletteSet&#xff09;的各種外觀和行為樣式。它可以是 PaletteSetStyles 枚舉類型的組合值 PaletteSetStyles 枚舉中包含以下一些選項&#xff1a; NameEditable&am…

統計子矩陣

一、題目描述 P8783 [藍橋杯 2022 省 B] 統計子矩陣 二、算法簡析 2.1 二維前綴和 我們知道&#xff0c;只要確定了矩陣的左上頂點和右下頂點&#xff0c;一個矩陣就被固定了。因此&#xff0c;我們可以遍歷這兩個頂點&#xff0c;達到遍歷所有子矩陣的目的&#xff0c;復雜…

AutoSAR(基礎入門篇)12.5-Dem

目錄 一、Dem簡介 二、Dem消抖 1、計數模式 1. 普通增減計數 2. 反向歸零增減模式

在微服務整合dubbo,以為微服務版的若依為例

在微服務整合dubbo&#xff0c;以為微服務版的若依為例 一、環境二、整合過程1、父模塊依賴2、生產者3、消費者 三、修改若依的服務調用方式為dubbo1、改造系統模塊2、改造認證授權中心 四、整合過程遇到的問題1、出現循環引用2、出現依賴沖突3、啟動出現端口號被占用4、出現某…

UVa11726 Crime Scene

題目鏈接 UVa11726 - Crime Scene 題意 給定n&#xff08;n≤100&#xff09;個物體&#xff0c;每個物體都是一個圓或者k&#xff08;k≤10&#xff09;邊形&#xff0c;用長度盡量小的繩子把它們包圍起來。 分析 孟加拉國Manzurur Rahman Khan (Sidky)大神出的難題&#xff…

MySQL 核心模塊揭秘 | 07 期 | 二階段提交 (1) prepare 階段

二階段提交的 prepare 階段&#xff0c;binlog 和 InnoDB 各自會有哪些動作&#xff1f; 本文基于 MySQL 8.0.32 源碼&#xff0c;存儲引擎為 InnoDB。 1. 二階段提交 二階段提交&#xff0c;顧名思義&#xff0c;包含兩個階段&#xff0c;它們是&#xff1a; prepare 階段。…

springboot-基礎-eclipse配置+helloword示例

備份筆記。所有代碼都是2019年測試通過的&#xff0c;如有問題請自行搜索解決&#xff01; 下一篇&#xff1a;springboot-基礎-添加model和controller的簡單例子常用注解含義 目錄 配置helloword示例新建項目創建文件 配置 spring boot官方有定制版eclipse&#xff0c;也就是…

BUUCTF AWD-Test1

打開靶場是這個有些簡陋的界面。 隨便點點&#xff0c;找到這個東西。 看到ThinkPHP&#xff0c;思路瞬間清晰&#xff0c;老熟人了。這個就是ThinkPHP漏洞。根據版本我們去找一下poc。 /index.php/?sIndex/\think\View/display&content%22%3C?%3E%3C?php%20phpinfo();…

SHELL 腳本: 導出NEO4j DUMP并上傳SFTP

前提 開通sftp賬號 安裝expect 示例 NEO4J_HOME/path/to/neo4j # neo4j 安裝目錄 DUMP_PATH/data/dump # DUMP本地保存目錄 DUMP_FILEneo4j_$(date %F).dump #導出文件名稱 UPLOAD_DIR/path/to/stfp/dump/ #上傳目錄 $NEO4J_HOME/bin/neo4j-admin dump --databaseneo4j --t…

Vue-5

Vue 3 的優勢 更容易維護&#xff08;組合式API&#xff09;更快的速度更小的體積更優的數據響應 創建 Vue 3 項目 前提環境條件&#xff1a;已安裝 16.0 或更高版本的 Node.js node -v創建一個 Vue 應用&#xff08;下面的指令將會安裝并執行 create-vue &#xff09; np…

服務端向客戶端推送數據的實現方案

在日常的開發中&#xff0c;我們經常能碰見服務端需要主動推送給客戶端數據的業務場景&#xff0c;比如數據大屏的實時數據&#xff0c;比如消息中心的未讀消息&#xff0c;比如聊天功能等等。 本文主要介紹SSE的使用場景和如何使用SSE。 服務端向客戶端推送數據的實現方案有哪…

MySQL 自增列解析(Auto_increment)

MySQL數據庫為列提供了一種自增屬性&#xff0c;當列被定義為自增時。Insert語句對該列即使不提供值&#xff0c;MySQL也會自動為該列生成遞增的唯一標識&#xff0c;因此這個特性廣泛用于主鍵的自動生成。 一、自增列的用法 自增列具有自動生成序列值&#xff0c;整型&#…

職責鏈模式(Chain of Responsibility Pattern)

定義 職責鏈模式&#xff08;Chain of Responsibility Pattern&#xff09;是一種行為設計模式&#xff0c;它允許對象接收請求并將其沿著處理者鏈傳遞&#xff0c;直到有一個處理者處理它為止。職責鏈模式通過將請求的處理邏輯分布 在職責鏈模式中&#xff0c;通常包含以下幾…

MYSQL04高級_邏輯架構剖析、查詢緩存、解析器、優化器、執行器、存儲引擎

文章目錄 ①. 邏輯架構剖析②. 服務層 - 查詢緩存③. 服務層 - 解析器④. 服務層 - 優化器⑤. 服務層 - 執行器⑥. MySQL8執行原理 ①. 邏輯架構剖析 ①. 服務器處理客戶端請求 ②. 連接層 系統(客戶端)訪問MySQL服務器前,做的第一件事就是建立TCP連接經過三次握手建立連接成…

Linux使用C語言實現通過互斥鎖限制對共享資源的訪問

互斥鎖限制共享資源的訪問 主線程中有兩個線程&#xff0c;分別輸出信息。 #include <stdio.h> #include <pthread.h> #include <unistd.h>int g_data0;void* fun1(void *arg) {printf("t1&#xff1a;%ld thread is create\n", (unsigned long)…

大宋咨詢數據研究在汽車新品上市中的核心作用

隨著汽車行業的快速變革&#xff0c;數據研究已經成為新品上市流程中的不可或缺的一環。從市場定位、產品規劃到營銷策略&#xff0c;數據研究不僅為汽車企業提供了獨特的洞察&#xff0c;還為其提供了決策依據&#xff0c;確保新品在競爭激烈的市場中取得優勢。在這一領域&…

Kubernetes IoTDB系列 | IoTDB搭建 | v1.3.0

目錄 一、IoTDB 介紹二、k8s 部署 IoTDB一、IoTDB 介紹 IoTDB 是一種面向物聯網(IoT)場景的開源時序數據庫。它專門設計用于高效地存儲和查詢大規模物聯網設備產生的時序數據。IoTDB 提供了高吞吐量、低延遲、靈活的數據模型以及多種數據查詢和存儲引擎等特性,使其成為處理…