本文提供一個適用于圖像輸入的多頭注意力機制(Multi-Head Attention)PyTorch 實現,適用于 ViT、MAE 等視覺 Transformer 中的注意力計算。
模塊說明
- 輸入支持圖像格式
(B, C, H, W)
- 內部轉換為序列
(B, N, C)
,其中N = H * W
- 多頭注意力計算:查詢(Q)、鍵(K)、值(V)使用線性層投影
- 結果 reshape 回原圖維度
(B, C, H, W)
多頭注意力機制代碼(適用于圖像輸入)
import torch
import torch.nn as nnclass ImageMultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(ImageMultiHeadAttention, self).__init__()assert embed_dim % num_heads == 0, "embed_dim 必須能被 num_heads 整除"self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_heads# Q, K, V 的線性映射self.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)# 輸出映射層self.out_proj = nn.Linear(embed_dim, embed_dim)self.scale = self.head_dim ** 0.5def forward(self, x):# 輸入 x: (B, C, H, W),需要 reshape 為 (B, N, C)B, C, H, W = x.shapex = x.view(B, C, H * W).permute(0, 2, 1) # (B, N, C)Q = self.q_proj(x)K = self.k_proj(x)V = self.v_proj(x)# 拆成多頭 (B, num_heads, N, head_dim)Q = Q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)K = K.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)V = V.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)# 注意力分數計算attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scaleattn_probs = torch.softmax(attn_scores, dim=-1)attn_out = torch.matmul(attn_probs, V)# 合并多頭attn_out = attn_out.transpose(1, 2).contiguous().view(B, H * W, self.embed_dim)# 輸出映射out = self.out_proj(attn_out)# 恢復回原圖維度 (B, C, H, W)out = out.permute(0, 2, 1).view(B, C, H, W)return out# 測試示例
# 假設輸入是一張 14x14 的特征圖(類似 patch embedding 后)
img = torch.randn(4, 64, 14, 14) # (B, C, H, W)mha = ImageMultiHeadAttention(embed_dim=64, num_heads=8)
out = mha(img)print(out.shape) # 輸出應為 (4, 64, 14, 14)
PyTorch 實現自注意力機制(Self-Attention)
本節補充自注意力機制(Self-Attention)的核心代碼實現,適用于 ViT 等模型中 patch token 的注意力操作。
自注意力機制代碼(Self-Attention)
import torch
import torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_dim):super(SelfAttention, self).__init__()self.embed_dim = embed_dimself.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)self.out_proj = nn.Linear(embed_dim, embed_dim)self.scale = embed_dim ** 0.5def forward(self, x):# 輸入 x: (B, N, C)B, N, C = x.shape# 一次性生成 Q, K, Vqkv = self.qkv_proj(x) # (B, N, 3C)Q, K, V = torch.chunk(qkv, chunks=3, dim=-1) # 各自為 (B, N, C)# 計算注意力分數attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # (B, N, N)attn_probs = torch.softmax(attn_scores, dim=-1)# 得到注意力加權輸出attn_out = torch.matmul(attn_probs, V) # (B, N, C)# 映射回原維度out = self.out_proj(attn_out) # (B, N, C)return out# 測試示例
# 假設輸入為 196 個 patch,每個 patch 的嵌入維度為 64
x = torch.randn(2, 196, 64) # (B, N, C)attn = SelfAttention(embed_dim=64)
out = attn(x)print(out.shape) # 輸出應為 (2, 196, 64)
📎 拓展說明
? 本實現為單頭自注意力機制
? 可用于 NLP 中的序列特征或 ViT 圖像 patch 序列
? 若需改為多頭注意力,只需將 embed_dim 拆成 num_heads × head_dim 并分別計算后合并
PyTorch 實現圖像輸入的自注意力機制(Self-Attention)
本節介紹一種適用于圖像輸入 (B, C, H, W)
的自注意力機制實現,適合卷積神經網絡與 Transformer 的融合模塊,如 Self-Attention ConvNet、BAM、CBAM、ViT 前層等。
自注意力機制(圖像維度)代碼
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ImageSelfAttention(nn.Module):def __init__(self, in_channels):super(ImageSelfAttention, self).__init__()self.in_channels = in_channelsself.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1)) # 可學習縮放因子def forward(self, x):# 輸入 x: (B, C, H, W)B, C, H, W = x.size()# 生成 Q, K, Vproj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1) # (B, N, C//8)proj_key = self.key_conv(x).view(B, -1, H * W) # (B, C//8, N)proj_value = self.value_conv(x).view(B, -1, H * W) # (B, C, N)# 注意力矩陣:Q * K^Tenergy = torch.bmm(proj_query, proj_key) # (B, N, N)attention = F.softmax(energy, dim=-1) # (B, N, N)# 加權求和 Vout = torch.bmm(proj_value, attention.permute(0, 2, 1)) # (B, C, N)out = out.view(B, C, H, W)# 殘差連接 + 縮放因子out = self.gamma * out + xreturn out#測試用例
x = torch.randn(2, 64, 32, 32) # 輸入一張圖像:B=2, C=64, H=W=32
self_attn = ImageSelfAttention(in_channels=64)
out = self_attn(x)print(out.shape) # 輸出形狀應為 (2, 64, 32, 32)
? 本模塊基于圖像 (B, C, H, W) 進行自注意力計算
? 使用卷積進行 Q/K/V 提取,保持局部感知力
? gamma 是可學習縮放因子,用于殘差連接控制注意力貢獻度
自注意力中**縮放因子(scale factor)的處理,在序列維度(如 ViT)和圖片維度(如 Self-Attention Conv)**中有點不一樣。下面我們來詳細解釋一下原因,并對兩種寫法做一個統一和對比分析
兩種縮放因子的區別
- 序列維度的縮放因子
scale = head_dim ** 0.5 # 或者 embed_dim ** 0.5
attn = (Q @ K.T) / scale
? 來源:Transformer 原始論文(Attention is All You Need)
? 原因:在高維向量內積中,為了避免 dot product 的結果數值過大導致梯度不穩定,需要除以 sqrt(d_k)
? 使用場景:多頭注意力機制,輸入是 (B, N, C),應用在 NLP、ViT 等序列結構
- 圖片維度(C, H, W)的注意力機制中沒有縮放,或者使用 softmax 平衡
attn = softmax(Q @ K.T) # 無 scale,或者手動調節
? 來源:Non-local Net、Self-Attention Conv、BAM 等 CNN + Attention 融合方法
? 原因:Q 和 K 都通過 1x1 conv 壓縮成 C//8 或更小的維度,內積的值本身不會太大;同時圖像 attention 主要用 softmax 控制權重范圍
? 縮放因子的控制通常用 γ(gamma)作為殘差通道縮放,不是 QK 內部的數值縮放
💬 如果你覺得這篇整理有幫助,歡迎點贊收藏!