PyTorch 實現圖像版多頭注意力(Multi-Head Attention)和自注意力(Self-Attention)

本文提供一個適用于圖像輸入的多頭注意力機制(Multi-Head Attention)PyTorch 實現,適用于 ViT、MAE 等視覺 Transformer 中的注意力計算。


模塊說明

  • 輸入支持圖像格式 (B, C, H, W)
  • 內部轉換為序列 (B, N, C),其中 N = H * W
  • 多頭注意力計算:查詢(Q)、鍵(K)、值(V)使用線性層投影
  • 結果 reshape 回原圖維度 (B, C, H, W)

多頭注意力機制代碼(適用于圖像輸入)

import torch
import torch.nn as nnclass ImageMultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(ImageMultiHeadAttention, self).__init__()assert embed_dim % num_heads == 0, "embed_dim 必須能被 num_heads 整除"self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_heads# Q, K, V 的線性映射self.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)# 輸出映射層self.out_proj = nn.Linear(embed_dim, embed_dim)self.scale = self.head_dim ** 0.5def forward(self, x):# 輸入 x: (B, C, H, W),需要 reshape 為 (B, N, C)B, C, H, W = x.shapex = x.view(B, C, H * W).permute(0, 2, 1)  # (B, N, C)Q = self.q_proj(x)K = self.k_proj(x)V = self.v_proj(x)# 拆成多頭 (B, num_heads, N, head_dim)Q = Q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)K = K.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)V = V.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)# 注意力分數計算attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scaleattn_probs = torch.softmax(attn_scores, dim=-1)attn_out = torch.matmul(attn_probs, V)# 合并多頭attn_out = attn_out.transpose(1, 2).contiguous().view(B, H * W, self.embed_dim)# 輸出映射out = self.out_proj(attn_out)# 恢復回原圖維度 (B, C, H, W)out = out.permute(0, 2, 1).view(B, C, H, W)return out# 測試示例
# 假設輸入是一張 14x14 的特征圖(類似 patch embedding 后)
img = torch.randn(4, 64, 14, 14)  # (B, C, H, W)mha = ImageMultiHeadAttention(embed_dim=64, num_heads=8)
out = mha(img)print(out.shape)  # 輸出應為 (4, 64, 14, 14)

PyTorch 實現自注意力機制(Self-Attention)

本節補充自注意力機制(Self-Attention)的核心代碼實現,適用于 ViT 等模型中 patch token 的注意力操作。

自注意力機制代碼(Self-Attention)

import torch
import torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_dim):super(SelfAttention, self).__init__()self.embed_dim = embed_dimself.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)self.out_proj = nn.Linear(embed_dim, embed_dim)self.scale = embed_dim ** 0.5def forward(self, x):# 輸入 x: (B, N, C)B, N, C = x.shape# 一次性生成 Q, K, Vqkv = self.qkv_proj(x)  # (B, N, 3C)Q, K, V = torch.chunk(qkv, chunks=3, dim=-1)  # 各自為 (B, N, C)# 計算注意力分數attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # (B, N, N)attn_probs = torch.softmax(attn_scores, dim=-1)# 得到注意力加權輸出attn_out = torch.matmul(attn_probs, V)  # (B, N, C)# 映射回原維度out = self.out_proj(attn_out)  # (B, N, C)return out#  測試示例
# 假設輸入為 196 個 patch,每個 patch 的嵌入維度為 64
x = torch.randn(2, 196, 64)  # (B, N, C)attn = SelfAttention(embed_dim=64)
out = attn(x)print(out.shape)  # 輸出應為 (2, 196, 64)

📎 拓展說明
? 本實現為單頭自注意力機制
? 可用于 NLP 中的序列特征或 ViT 圖像 patch 序列
? 若需改為多頭注意力,只需將 embed_dim 拆成 num_heads × head_dim 并分別計算后合并


PyTorch 實現圖像輸入的自注意力機制(Self-Attention)

本節介紹一種適用于圖像輸入 (B, C, H, W) 的自注意力機制實現,適合卷積神經網絡與 Transformer 的融合模塊,如 Self-Attention ConvNet、BAM、CBAM、ViT 前層等。

自注意力機制(圖像維度)代碼

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ImageSelfAttention(nn.Module):def __init__(self, in_channels):super(ImageSelfAttention, self).__init__()self.in_channels = in_channelsself.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.key_conv   = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))  # 可學習縮放因子def forward(self, x):# 輸入 x: (B, C, H, W)B, C, H, W = x.size()# 生成 Q, K, Vproj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1)  # (B, N, C//8)proj_key   = self.key_conv(x).view(B, -1, H * W)                      # (B, C//8, N)proj_value = self.value_conv(x).view(B, -1, H * W)                    # (B, C, N)# 注意力矩陣:Q * K^Tenergy = torch.bmm(proj_query, proj_key)         # (B, N, N)attention = F.softmax(energy, dim=-1)             # (B, N, N)# 加權求和 Vout = torch.bmm(proj_value, attention.permute(0, 2, 1))  # (B, C, N)out = out.view(B, C, H, W)# 殘差連接 + 縮放因子out = self.gamma * out + xreturn out#測試用例
x = torch.randn(2, 64, 32, 32)  # 輸入一張圖像:B=2, C=64, H=W=32
self_attn = ImageSelfAttention(in_channels=64)
out = self_attn(x)print(out.shape)  # 輸出形狀應為 (2, 64, 32, 32)

? 本模塊基于圖像 (B, C, H, W) 進行自注意力計算
? 使用卷積進行 Q/K/V 提取,保持局部感知力
? gamma 是可學習縮放因子,用于殘差連接控制注意力貢獻度


自注意力中**縮放因子(scale factor)的處理,在序列維度(如 ViT)和圖片維度(如 Self-Attention Conv)**中有點不一樣。下面我們來詳細解釋一下原因,并對兩種寫法做一個統一和對比分析

兩種縮放因子的區別
  1. 序列維度的縮放因子
scale = head_dim ** 0.5  # 或者 embed_dim ** 0.5
attn = (Q @ K.T) / scale

? 來源:Transformer 原始論文(Attention is All You Need)
? 原因:在高維向量內積中,為了避免 dot product 的結果數值過大導致梯度不穩定,需要除以 sqrt(d_k)
? 使用場景:多頭注意力機制,輸入是 (B, N, C),應用在 NLP、ViT 等序列結構

  1. 圖片維度(C, H, W)的注意力機制中沒有縮放,或者使用 softmax 平衡
attn = softmax(Q @ K.T)   # 無 scale,或者手動調節

? 來源:Non-local Net、Self-Attention Conv、BAM 等 CNN + Attention 融合方法
? 原因:Q 和 K 都通過 1x1 conv 壓縮成 C//8 或更小的維度,內積的值本身不會太大;同時圖像 attention 主要用 softmax 控制權重范圍
? 縮放因子的控制通常用 γ(gamma)作為殘差通道縮放,不是 QK 內部的數值縮放


💬 如果你覺得這篇整理有幫助,歡迎點贊收藏!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/900364.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/900364.shtml
英文地址,請注明出處:http://en.pswp.cn/news/900364.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

每日一題(小白)字符串娛樂篇16

分析題意可以了解到本題要求在一串字符串中找到所有組合起來排序遞增的字符串。我們可以默認所有字符在字符串中的上升序列是1,從第一個字符開始找,如果后面的字符大于前面的字符就說明這是一個上序列那么后面字符所在的數組加一,如果連接不上…

Ubuntu 22 Linux上部署DeepSeek R1保姆式操作詳解(Xinference方式)

一、安裝步驟 1.基礎環境安裝 安裝顯卡驅動、cuda,根據自己硬件情況查找相應編號,本篇不介紹這部分內容,只給出參考指令,詳情請讀者自行查閱互聯網其它參考資料。 sudo apt install nvidia-utils-565-server sudo apt install…

Immutable.js 完全指南:不可變數據的藝術與實踐

引言 在現代前端開發中,狀態管理是一個核心挑戰。隨著應用復雜度增加,如何高效、安全地管理應用狀態變得至關重要。Immutable.js 是 Facebook 推出的一個 JavaScript 庫,它提供了持久化不可變數據結構,可以幫助開發者更好地管理應…

字符串數據類型的基本運算

任務描述 本關任務:從后臺輸入任意三個字符串,求最大的字符串。 相關知識 字符串本身是存放在一塊連續的內存空間中,并以’\0’作為字符串的結束標記。 字符指針變量本身是一個變量,用于存放字符串的第 1 個字符的地址。 字符數…

Ubuntu 22.04 一鍵部署openManus

openManus 前言 OpenManus-RL,這是一個專注于基于強化學習(RL,例如 GRPO)的方法來優化大語言模型(LLM)智能體的開源項目,由來自UIUC 和 OpenManus 的研究人員合作開發。 前提要求 安裝deepseek docker方式安裝 ,windows 方式安裝,Linux安裝方式

PDF 轉圖片,一行代碼搞定!批量支持已上線!

大家好,我是程序員晚楓。今天我要給大家帶來一個超實用的功能——popdf 現在支持 PDF 轉圖片了,而且還能批量操作!是不是很激動?別急,我來手把手教你玩轉這個功能。 1. 一行代碼搞定單文件轉換 popdf 的核心就是簡單暴…

《比特城的機密郵件:加密、簽名與防篡改的守護之戰》

點擊下面圖片帶您領略全新的嵌入式學習路線 🔥爆款熱榜 88萬閱讀 1.6萬收藏 第一章:風暴前的密令 比特城的議會大廳內,首席長老艾德文握著一卷足有半人高的羊皮紙,眉頭緊鎖。紙上是即將頒布的《新紀元法典》——這份文件不僅內…

8.用戶管理專欄主頁面開發

用戶管理專欄主頁面開發 寫在前面用戶權限控制用戶列表接口設計主頁面開發前端account/Index.vuelangs/zh.jsstore.js 后端Paginator概述基本用法代碼示例屬性與方法 urls.pyviews.py 運行效果 總結 歡迎加入Gerapy二次開發教程專欄! 本專欄專為新手開發者精心策劃了…

http://noi.openjudge.cn/_2.5基本算法之搜索_1804:小游戲

文章目錄 題目深搜代碼寬搜代碼深搜數據演示圖總結 題目 1804:小游戲 總時間限制: 1000ms 內存限制: 65536kB 描述 一天早上,你起床的時候想:“我編程序這么牛,為什么不能靠這個賺點小錢呢?”因此你決定編寫一個小游戲。 游戲在一…

發生梯度消失, 梯度爆炸問題的原因,怎么解決?

目錄 一、梯度消失的原因 二、梯度爆炸的原因 三、共同的結構性原因 四、解決辦法 五、補充知識 一、梯度消失的原因 梯度消失指的是在反向傳播過程中,梯度隨著層數的增加指數級減小(趨近于0),導致淺層網絡的權重幾乎無法更新…

【USRP】srsRAN 開源 4G 軟件無線電套件

srsRAN 是SRS開發的開源 4G 軟件無線電套件。 srsRAN套件包括: srsUE - 具有原型 5G 功能的全棧 SDR 4G UE 應用程序srsENB - 全棧 SDR 4G eNodeB 應用程序srsEPC——具有 MME、HSS 和 S/P-GW 的輕量級 4G 核心網絡實現 安裝系統 Ubuntu 20.04 USRP B210 sudo …

ChatGPT 4:解鎖AI文案、繪畫與視頻創作新紀元

文章目錄 一、ChatGPT 4的技術革新二、AI文案創作:精準生成與個性化定制三、AI繪畫藝術:從文字到圖像的神奇轉化四、AI視頻制作:自動化剪輯與創意實現五、知識庫與ChatGPT 4的深度融合六、全新的變革和機遇《ChatGPT 4 應用詳解:A…

在js中數組相關用法講解

數組 uniqueArray 簡單數組去重 /*** 簡單數組去重* param arr* returns*/ export const uniqueArray <T>(arr: T[]) > [...new Set(arr)];const arr1 [1,1,1,1 2, 3];uniqueArray(arr); // [1,2,3]uniqueArrayByKey 根據 key 數組去重 /*** 根據key數組去重* …

RT-Thread ulog 日志組件深度分析

一、ulog 組件核心功能解析 輕量化與實時性 ? 資源占用&#xff1a;ulog 核心代碼僅需 ROM<1KB&#xff0c;RAM<0.2KB&#xff0c;支持在資源受限的MCU&#xff08;如STM32F103&#xff09;中運行。 ? 異步/同步模式&#xff1a;默認采用異步環形緩沖區&#xff08;rt_…

T113s3遠程部署Qt應用(dropbear)

T113-S3 是一款先進的應用處理器&#xff0c;專為汽車和工業控制市場而設計。 它集成了雙核CortexTM-A7 CPU和單核HiFi4 DSP&#xff0c;提供高效的計算能力。 T113-S3 支持 H.265、H.264、MPEG-1/2/4、JPEG、VC1 等全格式解碼。 獨立的硬件編碼器可以編碼為 JPEG 或 MJPEG。 集…

12.青龍面板自動化我的生活

安裝 docker方式 docker run -dit \ -v /root/ql:/ql/data \ -p 5700:5700 \ -e ENABLE_HANGUPtrue \ -e ENABLE_WEB_PANELtrue \ --name qinglong \ --hostname qinglong \ --restart always \ whyour/qinglongk8s方式 https://truecharts.org/charts/stable/qinglong/ he…

Maven 遠程倉庫推送方法

步驟 1&#xff1a;配置 pom.xml 中的遠程倉庫地址 在項目的 pom.xml 文件中添加 distributionManagement 配置&#xff0c;指定遠程倉庫的 URL。 xml 復制 <project>...<distributionManagement><!-- 快照版本倉庫 --><snapshotRepository><id…

Spring Boot 日志 配置 SLF4J 和 Logback

文章目錄 一、前言二、案例一&#xff1a;初識日志三、案例二&#xff1a;使用Lombok輸出日志四、案例三&#xff1a;配置Logback 一、前言 在開發 Java 應用時&#xff0c;日志記錄是不可或缺的一部分。日志可以記錄應用的運行狀態、錯誤信息和調試信息&#xff0c;幫助開發者…

JS API 事件監聽

焦點事件案例&#xff1a;搜索框激活下拉菜單 事件對象 事件對象存儲事件觸發時的相關信息 可以判斷用戶按鍵&#xff0c;點擊元素等內容 如何獲取 事件綁定的回調函數中的第一個形參就是事件對象 一般命名為e,event 事件對象常用屬性 type類型 click mouseenter client…

DDD與MVC擴展能力對比

一、架構設計理念的差異二、擴展性差異的具體表現三、DDD擴展性優勢的深層原因四、MVC擴展性不足的典型場景五、總結&#xff1a;架構的本質與選擇六、例子1&#xff09;場景描述2&#xff09;MVC實現示例&#xff08;三層架構&#xff09;3&#xff09;DDD實現示例&#xff08…