文章目錄
- MultiHeadAttentionFormal的實現
- 操作詳解
- 1. 🔍 attention_mask
- 2. 🔍 matmul
- ? 其他實現方式
- 1. 使用 `@` 運算符(推薦簡潔寫法)
- 2. 使用 `torch.einsum()`(愛因斯坦求和約定)
- 3. 使用 `torch.bmm()`(批量矩陣乘法)
- 4. 使用 `unsqueeze` + `squeeze` 控制維度(兼容高維)
- 5. 使用 `F.linear()` 實現投影(不常用)
- 📌 對比總結表
- 💡 示例對比(均等效)
- 3. 🔍 transpose
- 📌 定義
- 🧠 在多頭注意力中的典型應用場景
- ? 其他實現方式
- 1. 使用 `permute(*dims)` —— 更靈活的維度重排
- 2. 使用 `swapaxes(dim0, dim1)` —— 與 transpose 等效
- 📌 總結對比表
- 💡 示例說明
- 🛠 實際應用建議
- 4. 🔍 view()
- 🔄 其他等效實現方式
- 1. `torch.reshape(tensor, shape)`
- 2. 使用 `flatten(start_dim, end_dim)` 合并維度
- 3. 使用 `einops.rearrange`(推薦用于可讀性)
- ? 總結對比
- 💡 實際應用建議
- 5. 🔍 masked_fill()
- 🧠 函數定義
- 示例解析
- ? 實際案例演示
- ?? 注意事項
- 💡 應用場景
- ? 總結
- 📌 最佳實踐建議
- 參考材料
MultiHeadAttentionFormal的實現
import torch
import torch.nn as nn
import mathclass MultiHeadAttentionFormal(nn.Module):def __init__(self, hidden_dim, head_num, attention_dropout=0.1):super().__init__()self.hidden_dim = hidden_dimself.head_num = head_numself.head_dim = hidden_dim // head_num # head_num * head_dim = hidden_dimself.q_proj = nn.Linear(hidden_dim, hidden_dim) # (hidden_dim, head_dim * head_num)self.k_proj = nn.Linear(hidden_dim, hidden_dim)self.v_proj = nn.Linear(hidden_dim, hidden_dim)self.output = nn.Linear(hidden_dim, hidden_dim)self.attention_dropout = nn.Dropout(attention_dropout)def forward(self, x, attention_mask=None):# X (batch_size, seq_len, hidden_dim)batch_size, seq_len, _ = x.shape# Q/K/V的shape: (batch_size, seq_len, hidden_dim)Q = self.q_proj(x)K = self.k_proj(x)V = self.v_proj(x)# (batch_size, seq_len, hidden_dim),其中 hidden_dim = head_num * head_dim# -> (batch_size, seq_len, head_num, head_dim)# -> (batch_size, head_num, seq_len, head_dim)q_state = Q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)k_state = K.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)v_state = V.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)# k_state的轉置# (batch_size, head_num, seq_len, head_dim)# -> (batch_size, head_num, head_dim, seq_len)# 相乘的結果,shape為(batch_size, head_num, seq_len, seq_len)atten_weight = torch.matmul(q_state, k_state.transpose(-2, -1)) / math.sqrt(self.head_dim)print("stage1, atten_weight.shape: ", atten_weight.shape)if attention_mask is not None:atten_weight = atten_weight.masked_fill(attention_mask==0, float("-inf"))print("stage2, atten_weight.shape: ", atten_weight.shape)atten_weight = torch.softmax(atten_weight, dim=-1)print("stage3, atten_weight.shape: ", atten_weight.shape)atten_weight = self.attention_dropout(atten_weight)print("stage4, atten_weight.shape: ", atten_weight.shape)# atten_weight: (batch_size, head_num, seq_len, seq_len)# v_state: (batch_size, head_num, seq_len, head_dim)# => (batch_size, head_num, seq_len, head_dim)output_mid = torch.matmul(atten_weight, v_state)print("stage1, output_mid.shape: ", output_mid.shape, "v_state.shape: ", v_state.shape)# transpose后,張量的內存可能變得不連續,所以需要用contiguous把內存連續化;view()、reshape()、flatten()、torch.nn.Linear、torch.matmul 等操作對輸入張量有連續性的要求。output_mid = output_mid.transpose(1, 2).contiguous()print("stage2, output_mid.shape: ", output_mid.shape)output_mid = output_mid.view(batch_size, seq_len, self.hidden_dim)print("stage3, output_mid.shape: ", output_mid.shape)output = self.output(output_mid)return outputattention_mask = torch.tensor([[1,1],[1,0],[1,0]]
).unsqueeze(1).unsqueeze(2).expand(3, 8, 2, 2)# batch_size, seq_len, hidden_dim
X = torch.rand(3, 2, 128)net = MultiHeadAttentionFormal(128, 8) # hidden_dim = 128, head_num = 8 -> head_dim = 16
net(X, attention_mask)
操作詳解
1. 🔍 attention_mask
首先是創建一個隨機張量,shape為(batch_size, seq_len)
attention_mask = torch.tensor([[1, 1],[1, 0],[1, 0]
])這是一個形狀為 (3, 2) 的張量。
每一行表示一個樣本(batch)的 attention mask:
1 表示該位置是有效的;
0 表示該位置是 padding,需要被屏蔽
然后增加維度
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)---------------------
tensor([[[[1, 1]]],[[[1, 0]]],[[[1, 0]]]])第一次 unsqueeze(1):增加第 1 維度(head_num),形狀變為 (3, 1, 2)
第二次 unsqueeze(2):增加第 2 維度(seq_len),形狀變為 (3, 1, 1, 2)
此時維度含義為:(batch_size, 1, 1, seq_len)
注意:此時還沒有考慮 head_num,只是準備好了 mask 的基本結構
現在擴展到head_num
attention_mask = attention_mask.expand(3, 8, 2, 2)
-------------
tensor([[[[1, 1], [1, 1]], # 頭1[[1, 1], [1, 1]], # 頭2[[1, 1], [1, 1]], # 頭3[[1, 1], [1, 1]], # 頭4[[1, 1], [1, 1]], # 頭5[[1, 1], [1, 1]], # 頭6[[1, 1], [1, 1]], # 頭7[[1, 1], [1, 1]], # 頭8][[[1, 0], [1, 0]],[[1, 0], [1, 0]],[[1, 0], [1, 0]],[[1, 0], [1, 0]],[[1, 0], [1, 0]],[[1, 0], [1, 0]],[[1, 0], [1, 0]],[[1, 0], [1, 0]]],[[[1, 0], [1, 0]],[[1, 0], [1, 0]],[[1, 0], [1, 0]],[[1, 0], [1, 0]],[[1, 0], [1, 0]],[[1, 0], [1, 0]],[[1, 0], [1, 0]],[[1, 0], [1, 0]]]])expand() 是 PyTorch 中用于廣播張量的方法,不會復制數據,而是共享內存。
將 (3, 1, 1, 2) 擴展為 (3, 8, 2, 2):
3:batch size
8:head_num,每個 head 都使用相同的 mask
2:query 的序列長度(seq_len)
2:key/value 的序列長度(seq_len)attention_mask.shape為(batch_size, head_num, seq_len, seq_len)
2. 🔍 matmul
以這行代碼為例:
output_mid = torch.matmul(atten_weight, v_state)
其中:
atten_weight.shape = (batch_size, head_num, seq_len, seq_len)
,即注意力權重矩陣(通常是 softmax 后的結果)v_state.shape = (batch_size, head_num, seq_len, head_dim)
,即 value 的狀態
這個操作本質上是將 attention weight 與 value 進行矩陣乘法,得到加權后的輸出。
? 其他實現方式
1. 使用 @
運算符(推薦簡潔寫法)
output_mid = atten_weight @ v_state
- 等價于
torch.matmul
- 更加 Pythonic,代碼更簡潔
- 支持廣播機制
2. 使用 torch.einsum()
(愛因斯坦求和約定)
output_mid = torch.einsum('bhij,bhjd->bhid', atten_weight, v_state)
- 非常靈活,適用于多頭注意力、交叉注意力等復雜結構
- 顯式控制每個維度的運算規則,可讀性略差但表達能力更強
- 在調試或構建復雜模型時非常有用
3. 使用 torch.bmm()
(批量矩陣乘法)
# 將 batch 和 head 合并成一個大 batch 維度
batch_size, head_num, seq_len, _ = atten_weight.shape
atten_weight_flat = atten_weight.view(-1, seq_len, seq_len) # shape: (B*H, T, T)
v_state_flat = v_state.view(-1, seq_len, head_dim) # shape: (B*H, T, D)output_flat = torch.bmm(atten_weight_flat, v_state_flat) # shape: (B*H, T, D)
output_mid = output_flat.view(batch_size, head_num, seq_len, head_dim)
- 只支持 3D 張量,不支持自動廣播
- 性能接近
matmul
,但需要手動處理維度變形
4. 使用 unsqueeze
+ squeeze
控制維度(兼容高維)
output_mid = torch.matmul(atten_weight.unsqueeze(-2), v_state.unsqueeze(-1)
).squeeze(-1)
- 通過添加/刪除維度來精確控制 matmul 操作維度
- 適合在圖像、視頻等 attention 中使用
5. 使用 F.linear()
實現投影(不常用)
雖然不是標準做法,但如果 atten_weight
是某種投影權重矩陣,也可以用線性層模擬。但在 attention 中通常不適用。
📌 對比總結表
方法 | 輸入要求 | 是否支持 batch | 是否支持 broadcasting | 推薦用于 Attention |
---|---|---|---|---|
torch.matmul | 任意維度 | ? | ? | ??? |
@ | 任意維度 | ? | ? | ???(簡潔) |
torch.einsum | 需要指定索引 | ? | ? | ???(多頭) |
torch.bmm | 必須為 3D | ? | ? | ?(簡單 attention) |
unsqueeze + matmul | 手動控制維度 | ? | ? | ?(特殊場景) |
💡 示例對比(均等效)
# 原始寫法
output_mid = torch.matmul(atten_weight, v_state)# 使用 @ 符號
output_mid = atten_weight @ v_state# 使用 einsum
output_mid = torch.einsum('bhij,bhjd->bhid', atten_weight, v_state)# 使用 bmm(需 flatten + reshape)
batch_size, head_num, seq_len, _ = atten_weight.shape
atten_weight_flat = atten_weight.view(-1, seq_len, seq_len)
v_state_flat = v_state.view(-1, seq_len, head_dim)
output_flat = torch.bmm(atten_weight_flat, v_state_flat)
output_mid = output_flat.view(batch_size, head_num, seq_len, -1)
3. 🔍 transpose
output_mid = output_mid.transpose(1, 2)
這行代碼的作用是交換張量的第 1
維和第 2
維。用于處理多頭注意力(Multi-Head Attention)中張量形狀的調整。
📌 定義
torch.Tensor.transpose(dim0, dim1) -> Tensor
- 功能:返回一個新的張量,其中指定的兩個維度被交換。
- 參數:
dim0
: 第一個維度dim1
: 第二個維度
?? 注意:這個操作不會復制數據,而是返回原始張量的一個視圖(view)。如果后續需要使用 view()
或 reshape()
,可能需要調用 .contiguous()
來確保內存連續。
🧠 在多頭注意力中的典型應用場景
# 假設 input shape: (batch_size, head_num, seq_len, head_dim)
output_mid = output_mid.transpose(1, 2)
原始形狀:
output_mid.shape = (batch_size, head_num, seq_len, head_dim)
轉置后形狀:
output_mid.shape = (batch_size, seq_len, head_num, head_dim)
然后一般會進行 view()
操作來合并 head_num
和 head_dim
,得到最終輸出:
output_mid = output_mid.contiguous().view(batch_size, seq_len, -1)
# 最終 shape: (batch_size, seq_len, hidden_dim)
這是將多頭注意力結果重新拼接回原始隱藏層大小的關鍵步驟。
? 其他實現方式
除了使用 transpose()
,還有以下幾種方法可以實現類似效果:
1. 使用 permute(*dims)
—— 更靈活的維度重排
output_mid = output_mid.permute(0, 2, 1, 3)
permute()
可以一次重排多個維度- 示例前后的 shape 對應關系:
# 原 shape: (batch_size, head_num, seq_len, head_dim) # 新 shape: (batch_size, seq_len, head_num, head_dim)
? 推薦用于更復雜的維度變換場景
2. 使用 swapaxes(dim0, dim1)
—— 與 transpose 等效
output_mid = output_mid.swapaxes(1, 2)
- 與
transpose()
功能相同 - 更語義化,適合閱讀時強調“交換”而非“轉置”
📌 總結對比表
方法 | 支持任意維 | 是否返回 view | 是否支持鏈式操作 | 推薦用途 |
---|---|---|---|---|
transpose() | ? 僅限兩個維度 | ? | ? | 簡單交換兩個維度 |
permute() | ? 多維支持 | ? | ? | 高階張量維度重排(推薦) |
swapaxes() | ? | ? | ? | 強調“交換”,語義更強 |
💡 示例說明
假設輸入為:
output_mid.shape = (3, 8, 2, 16) # batch_size=3, head_num=8, seq_len=2, head_dim=16
使用 transpose(1, 2)
:
output_mid = output_mid.transpose(1, 2)
# output_mid.shape
(batch_size, head_num, seq_len, head_dim)
=> (batch_size, seq_len, head_num, head_dim)
(3, 8, 2, 16)
=> (3, 2, 8, 16)
使用 permute(0, 2, 1, 3)
:
output_mid = output_mid.permute(0, 2, 1, 3)
# output_mid.shape => (3, 2, 8, 16)
兩者等價,但 permute()
更具通用性。
🛠 實際應用建議
- 如果只是交換兩個維度 →
transpose()
- 如果涉及多維重排 →
permute()
- 如果要合并/拆分某些維度 →
permute()
+contiguous()
+view()
4. 🔍 view()
在 PyTorch 中,view()
是一個用于 改變張量形狀(reshape) 的函數。它不會修改張量的數據,只是重新解釋其形狀。
語法:
tensor.view(shape)
示例代碼:
output_mid = output_mid.view(batch_size, seq_len, self.hidden_dim)
前提條件:
output_mid
當前的 shape 是(batch_size, seq_len, head_num, head_dim)
head_num * head_dim == hidden_dim
- 所以 view 后變為
(batch_size, seq_len, hidden_dim)
作用:
將多頭注意力中每個 head 的輸出拼接起來,恢復成原始的 hidden_dim
維度。
比如:
# 假設 batch_size=3, seq_len=2, head_num=8, head_dim=16
output_mid.shape = (3, 8, 2, 16) # transpose + contiguous 后
output_mid = output_mid.view(3, 2, 128) # 8*16 = 128
?? 注意:使用
view()
前必須保證張量是連續的(contiguous),否則會報錯。所以前面通常有.contiguous()
調用。
🔄 其他等效實現方式
除了 view()
,還有以下幾種方式可以實現類似功能:
1. torch.reshape(tensor, shape)
與 view()
類似,但更靈活,可以在非連續內存上運行。
output_mid = output_mid.reshape(batch_size, seq_len, self.hidden_dim)
? 推薦使用這個替代 view()
,因為不需要關心是否是連續內存。
2. 使用 flatten(start_dim, end_dim)
合并維度
output_mid = output_mid.transpose(1, 2).flatten(start_dim=2, end_dim=3)
這相當于把第 2 和第 3 維合并,效果等同于 reshape 或 view。
3. 使用 einops.rearrange
(推薦用于可讀性)
來自 einops
庫(einop庫安裝及介紹),提供更直觀的維度操作方式:
from einops import rearrangeoutput_mid = rearrange(output_mid, 'b h s d -> b s (h d)')
優點:
- 更易讀
- 不需要關心是否連續
- 可擴展性強(支持更多復雜變換)
? 總結對比
方法 | 是否要求連續 | 易讀性 | 靈活性 | 推薦場景 |
---|---|---|---|---|
view() | ? 必須連續 | ?? 差 | ?? 一般 | 小規模調試 |
reshape() | ? 不要求 | ?? 好 | ?? 強 | 通用替換 view |
flatten() | ? 不要求 | ?? 好 | ?? 強 | 多維合并 |
einops.rearrange() | ? 不要求 | ?? 很好 | ?? 非常強 | 工程項目 |
💡 實際應用建議
如果你在寫正式項目或模型工程化,推薦使用:
from einops import rearrangeoutput_mid = rearrange(output_mid, 'b h s d -> b s (h d)')
或者安全版本(不依賴連續內存):
output_mid = output_mid.transpose(1, 2)
output_mid = output_mid.flatten(2) # (b, s, h*d)
這樣不僅代碼清晰,也避免了對 .contiguous()
的依賴問題。
5. 🔍 masked_fill()
在 PyTorch 中,masked_fill()
是一個非常常用的函數,用于 根據布爾掩碼(mask)對張量的某些位置進行填充。它常用于 NLP 任務中,比如 Transformer 模型中的 attention mask 處理。
🧠 函數定義
torch.Tensor.masked_fill(mask, value)
參數說明:
mask
: 一個布爾類型的張量(True/False),形狀必須與原張量相同。value
: 要填充的值,可以是標量或廣播兼容的張量。
行為:
- 對于
mask
中為True
的位置,將原張量對應位置的值替換為value
。 False
的位置保持不變。
示例解析
atten_weight = atten_weight.masked_fill(attention_mask == 0, float("-inf"))
解釋:
attention_mask == 0
:- 這是一個布爾操作,生成一個和
attention_mask
形狀相同的布爾張量。 - 所有等于 0 的位置變成
True
,表示這些位置是 pad 或無效 token,不應該參與 attention 計算。
- 這是一個布爾操作,生成一個和
float("-inf")
:- 將這些被 mask 的位置填入負無窮大。
- 在后續 softmax 中,
exp(-inf)
會變成 0,從而實現“忽略這些位置”的效果。
? 實際案例演示
輸入示例:
import torch# 原始 attention 權重 (模擬)
atten_weight = torch.tensor([[0.1, 0.2, 0.3, 0.4],[0.5, 0.6, 0.7, 0.8]
])# attention mask (pad 位置為 0)
attention_mask = torch.tensor([[1, 1, 0, 0],[1, 0, 0, 0]
])# 應用 masked_fill
atten_weight = atten_weight.masked_fill(attention_mask == 0, float("-inf"))
print(atten_weight)
輸出結果:
tensor([[ 0.1000, 0.2000, -inf, -inf],[ 0.5000, -inf, -inf, -inf]])
后續 softmax 結果:
import torch.nn.functional as F
F.softmax(atten_weight, dim=-1)
輸出:
tensor([[0.4621, 0.5379, 0.0000, 0.0000],[1.0000, 0.0000, 0.0000, 0.0000]])
可以看到,mask 為 0 的位置在 softmax 后變成了 0,不會影響最終注意力分布。
?? 注意事項
-
mask 張量的 shape 必須與目標張量一致:
- 如果你有一個
(batch_size, seq_len)
的 mask,而atten_weight
是(batch_size, head_num, seq_len, seq_len)
,你需要通過unsqueeze
和expand
調整 mask 的維度。 - 示例:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # -> (batch_size, 1, 1, seq_len) attention_mask = attention_mask.expand(batch_size, num_heads, seq_len, seq_len)
- 如果你有一個
-
不能直接使用 int 類型的 mask:
masked_fill
只接受布爾類型作為 mask,所以要確保使用了比較操作如==
,!=
等。
💡 應用場景
場景 | 描述 |
---|---|
padding mask | 防止模型關注到 padding 的 token |
look-ahead mask | 防止 decoder 在預測時看到未來 token |
自定義屏蔽機制 | 如屏蔽某些特定詞、句子結構等 |
? 總結
方法 | 作用 | 推薦指數 |
---|---|---|
masked_fill(mask == 0, -inf) | 屏蔽不需要關注的位置 | ????? |
F.softmax(..., dim=-1) | 使屏蔽位置變為 0 | ???? |
mask 維度適配 | 使用 unsqueeze + expand 調整 mask 到與 attn weight 相同 | ????? |
📌 最佳實踐建議
# 假設 attention_mask: (batch_size, seq_len)
# attn_weights: (batch_size, num_heads, seq_len_q, seq_len_k)# Step 1: 添加兩個維度,使其匹配 attn_weights 的 shape
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # -> (B, 1, 1, S)# Step 2: 擴展 mask 使得其與 attn_weights 形狀完全一致
attention_mask = attention_mask.expand_as(attn_weights) # -> same shape as attn_weights# Step 3: 應用 mask,填入 -inf
attn_weights = attn_weights.masked_fill(attention_mask == 0, float('-inf'))
這樣就能保證每個 head 和 query 的位置都能正確屏蔽掉 pad 或無效 token。
參考材料
https://bruceyuan.com/hands-on-code/from-self-attention-to-multi-head-self-attention.html#%E7%AC%AC%E5%9B%9B%E9%87%8D-multi-head-self-attention
einop庫安裝及介紹