一、單頭注意力
單頭注意力的大致流程如下:
① 查詢編碼向量、鍵編碼向量和值編碼向量分別經過自己的全連接層(Wq、Wk、Wv)后得到查詢Q、鍵K和值V;
②?查詢Q和鍵K經過注意力評分函數(如:縮放點積運算)得到值權重矩陣;
③ 權重矩陣與值向量相乘,得到輸出結果。
?圖1 單頭注意力模型
?
二、多頭注意力?
2.1 使用多頭注意力的意義? ? ??
? ? ? ? 看了一些對多頭注意力機制解釋的視頻,我自己的淺顯理解是:在實踐中,我們會希望查詢Q能夠從給定內容中盡可能多地匹配到與自己相關的語義信息,從而得到更準確的預測輸出。而多頭注意力將查詢、鍵和值分成不同的子空間表示(representation subspaces)(有點類似于子特征?),使得匹配過程更加細化。
2.2 代碼實現
????????也許直接看代碼能更快地理解這個過程:
import torch
from torch import nn
from attentionScore import DotProductAttention
# 多頭注意力模型
class MultiHeadAttention(nn.Module):def __init__(self, key_size, query_size, value_size, num_hiddens,num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)# queries:(batch_size,查詢的個數,query_size)# keys:(batch_size,“鍵-值”對的個數,key_size)# values:(batch_size,“鍵-值”對的個數,value_size)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形狀:(batch_size,查詢或者“鍵-值”對的個數,num_hiddens)queries = self.W_q(queries)keys = self.W_k(keys)values = self.W_v(values)# 經過變換后,輸出的queries,keys,values的形狀:(batch_size*num_heads,查詢或者“鍵-值”對的個數,num_hiddens/num_heads)queries = transpose_qkv(queries, self.num_heads)keys = transpose_qkv(keys, self.num_heads)values = transpose_qkv(values, self.num_heads)# valid_lens的形狀:(batch_size,)或(batch_size,查詢的個數)if valid_lens is not None:# 在軸0,將第一項(標量或者矢量)復制num_heads次,然后如此復制第二項,然后諸如此類。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形狀:(batch_size*num_heads,查詢的個數,num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形狀:(batch_size,查詢的個數,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)
# 為了多注意力頭的并行計算而變換形狀
def transpose_qkv(X, num_heads):# 輸入X的形狀:(batch_size,查詢或者“鍵-值”對的個數,num_hiddens)# 輸出X的形狀:(batch_size,查詢或者“鍵-值”對的個數,num_heads,num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 輸出X的形狀:(batch_size,num_heads,查詢或者“鍵-值”對的個數, num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最終輸出的形狀:(batch_size*num_heads,查詢或者“鍵-值”對的個數, num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])
# 逆轉transpose_qkv函數的操作
def transpose_output(X, num_heads):X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)
????????可以發現,前面的處理流程和單頭注意力的第①步是一樣的,都是使用全連接層計算查詢Q、鍵K、值V。但在進行點積運算之前,模型使用transpose_qkv函數對QKV進行了切割變換,下圖可以幫助理解這個過程:
圖2 transpose_qkv函數處理Q
圖3?transpose_qkv函數處理K?
? ? ? ? 這個過程就像是把一個整體劃分為了很多小的子空間。一個不知道恰不恰當的比喻,就像是把“父母”這個詞拆分成了“長輩”、“養育者”、“監護人”、“爸媽”多重含義。
? ? ? ? 對切割變換后的QK進行縮放點積運算,過程如下圖所示:
?圖4 對切割變換后的Q和K進行縮放點積運算
? ? ? ? transpose_output后的輸出結果:
圖5 對值加權結果進行transpose_output變換后?
????????對比單頭注意力的值加權輸出,原來的每個查詢Q匹配到了更多的value:
圖6 多頭注意力與單頭注意力的值加權結果對比
????????整個過程就像是把一個父需求分割成不同的子需求,子需求單獨與不同的子特征進行匹配,最后使得每個父需求獲得了更多的語義信息。