自注意力機制的核心概念
1. Query, Key 和 Value
-
Query(查詢向量):可以看作是你當前在關注的輸入項。假設你正在閱讀一段文字,這就像你當前在讀的句子。
-
Key(鍵向量):表示其他所有輸入項的標識或特征。這就像你在書中已經讀過的所有句子的摘要或要點。
-
Value(值向量):是與每個Key相關聯的具體信息或內容。就像這些句子帶來的詳細信息。
現實比喻:
想象你在圖書館尋找一本特定的書(Query),書架上有很多書,每本書都有一個書名(Key)。根據書名(Key)匹配你的查詢(Query),你從合適的書中獲取詳細內容(Value)。
2. 點積注意力(Dot-Product Attention)
這是計算Query和Key之間相關性的方式。我們通過計算Query和Key的點積來確定它們的關系強度。
比喻:
就像在圖書館,你有一本書的部分標題(Query),你對比書架上所有書的書名(Key),看哪個書名最接近你的標題,然后選出最相關的書(Value)。
3. 縮放(Scaling)
為了防止Query和Key之間的點積結果太大導致數值不穩定,我們將結果除以一個常數——通常是Key向量的維度的平方根。這使得計算更加穩定。
比喻:
假設你在測試你的記憶力,如果你直接用高分數衡量,可能會出現極端值。所以你需要調整分數范圍,使得評估更合理和穩定。
4. Softmax 歸一化
Softmax函數將一組數值轉換為概率分布,使得它們的總和為1。這意味著每個單詞的注意力權重表示它對當前處理單詞的重要性。
比喻:
就像你在評分不同的書,Softmax就像把所有的分數轉換成百分比,這樣你可以看到每本書相對于其他書的重要性。
自注意力機制的工作流程
讓我們更詳細地看看自注意力機制是如何一步一步工作的:
-
生成 Query, Key 和 Value 向量
我們首先通過線性變換將輸入序列的每個單詞轉換成三個不同的向量:Query, Key 和 Value。
query = W_q * input key = W_k * input value = W_v * input
比喻:這是把每個單詞變成三個不同的代表,就像給每個單詞生成了三個不同的標簽,用于不同的目的(查詢、匹配和提供信息)。
-
計算注意力權重
通過計算Query和Key的點積,我們得到它們之間的相關性得分。然后,我們將這些得分除以 d k \sqrt{d_k} dk?? 進行縮放,最后應用Softmax函數來得到權重。
# 計算點積 scores = query.dot(key.T) / sqrt(d_k) # 使用Softmax函數歸一化 attention_weights = softmax(scores)
比喻:這就像你比較當前正在讀的句子(Query)和你已經讀過的所有句子(Key),然后根據它們的相似程度打分。接著,你將這些分數標準化,使它們總和為1,表示每個句子的重要性百分比。
-
加權求和 Value 向量
我們將Value向量按照注意力權重進行加權求和,這樣每個Value對最終輸出的貢獻由它的重要性決定。
# 計算加權的Value output = sum(attention_weights * value)
比喻:就像你根據每本書的重要性百分比(注意力權重),從每本書中提取一定量的信息(Value),最終形成你對整個圖書館信息的理解。
示例和實際應用
假設你在處理一句話“我喜歡吃蘋果,因為蘋果很甜”:
-
Query, Key, Value:
- Query:當前處理的詞是“蘋果”。
- Key:句子中的所有單詞的表示,如“我”,“喜歡”,“吃”,“蘋果”,“因為”,“很”,“甜”。
- Value:這些單詞的具體信息,比如它們的詞義或上下文信息。
-
點積注意力:
- 你在評估“蘋果”和句子中其他詞的關系,比如“蘋果”與“甜”的關系就很重要,而與“我”關系可能不大。
-
Softmax 歸一化:
- 將關系得分轉化為一個概率分布,表示每個單詞對當前詞“蘋果”的重要性。
-
加權求和:
- 最后,根據重要性權重,從每個單詞中提取信息,生成“蘋果”的最終表示,這樣“蘋果”就包含了它和“甜”的關系。
自注意力機制代碼示例
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads"self.values = nn.Linear(self.head_dim, embed_size, bias=False)self.keys = nn.Linear(self.head_dim, embed_size, bias=False)self.queries = nn.Linear(self.head_dim, embed_size, bias=False)self.fc_out = nn.Linear(embed_size, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0]value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# 1. 生成 Query, Key 和 Value 向量values = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)# 2. 計算注意力權重energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)# 3. 加權求和 Value 向量out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.embed_size)out = self.fc_out(out)return out
關鍵概念總結
-
自注意力機制:允許模型在處理一個輸入時,同時關注到整個輸入序列中的所有其他輸入。提高了捕捉長距離依賴關系的能力。
-
Query, Key 和 Value:分別代表當前處理的焦點、其他輸入的標識和它們攜帶的信息。
-
點積注意力:通過計算Query和Key的相似性來確定它們之間的關系強度。
-
縮放:對點積結果進行調整,防止數值過大導致計算不穩定。
-
Softmax 歸一化:將相似性得分轉化為概率分布,表示每個輸入的重要性。
通過這些步驟,自注意力機制能夠幫助模型在處理每一個輸入時同時考慮整個序列,從而更好地理解上下文和詞語之間的關系。