【模型細節】MHSA:多頭自注意力 (Multi-head Self Attention) 詳細解釋,使用 PyTorch代碼示例說明

MHSA:使用 PyTorch 實現的多頭自注意力 (Multi-head Self Attention) 代碼示例,包含詳細注釋說明:

  1. 線性投影
    通過三個線性層分別生成查詢(Q)、鍵(K)、值(V)矩陣:
    Q=Wq?x,K=Wk?x,V=Wv?xQ = W_q·x, \quad K = W_k·x, \quad V = W_v·xQ=Wq??x,K=Wk??x,V=Wv??x

  2. 分割多頭
    將每個矩陣分割為 hhh 個頭部:
    Q→[Q1,Q2,...,Qh],每個Qi∈Rdk\text{Q} \rightarrow [Q_1, Q_2, ..., Q_h], \quad \text{每個} Q_i \in \mathbb{R}^{d_k}Q[Q1?,Q2?,...,Qh?],每個Qi?Rdk?

  3. 計算注意力分數
    對每個頭部計算縮放點積注意力:
    Attention(Qi,Ki,Vi)=softmax(QiKiTdk)Vi\text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)V_iAttention(Qi?,Ki?,Vi?)=softmax(dk??Qi?KiT??)Vi?

  4. 合并多頭
    拼接所有頭部的輸出并通過線性層:
    MultiHead=Wo?[head1;head2;...;headh]\text{MultiHead} = W_o·[\text{head}_1; \text{head}_2; ... ; \text{head}_h]MultiHead=Wo??[head1?;head2?;...;headh?]

數學原理:

多頭注意力允許模型同時關注不同表示子空間的信息:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^OMultiHead(Q,K,V)=Concat(head1?,...,headh?)WO
其中每個頭的計算為:
headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)headi?=Attention(QWiQ?,KWiK?,VWiV?)

以下是一個使用 PyTorch 實現的多頭自注意力 (Multi-head Self Attention) 代碼示例,包含詳細注釋說明:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):"""embed_dim: 輸入向量維度num_heads: 注意力頭的數量"""super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_heads  # 每個頭的維度# 檢查維度是否可整除assert self.head_dim * num_heads == embed_dim# 定義線性變換層self.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)self.fc_out = nn.Linear(embed_dim, embed_dim)def forward(self, x):"""x: 輸入張量,形狀為 (batch_size, seq_len, embed_dim)"""batch_size = x.shape[0]  #[4,10,512]# 1. 線性投影Q = self.query(x)  # (batch_size, seq_len, embed_dim) #[4,10,512]K = self.key(x)    # (batch_size, seq_len, embed_dim) #[4,10,512]V = self.value(x)  # (batch_size, seq_len, embed_dim) #[4,10,512]# 2. 分割多頭Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  #[4,8,10,64]K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  #[4,8,10,64]V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  #[4,8,10,64]# 現在形狀: (batch_size, num_heads, seq_len, head_dim)# 3. 計算注意力分數# 計算 Q·K^T / sqrt(d_k)energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5) #[4,8,10,64]* #[4,8,64,10] = [4,8,10,10]# 形狀: (batch_size, num_heads, seq_len, seq_len)# 4. 應用softmax獲取注意力權重attention = F.softmax(energy, dim=-1)# 形狀: (batch_size, num_heads, seq_len, seq_len)# 5. 計算加權和out = torch.matmul(attention, V)#[4,8,10,10]* [4,8,10,64] = [4,8,10,64]# 形狀: (batch_size, num_heads, seq_len, head_dim)# 6. 合并多頭out = out.permute(0, 2, 1, 3).contiguous()out = out.view(batch_size, -1, self.embed_dim)# 形狀: (batch_size, seq_len, embed_dim)# 7. 最終線性變換out = self.fc_out(out)return out# 使用示例
if __name__ == "__main__":# 參數設置embed_dim = 512  # 輸入維度num_heads = 8    # 注意力頭數seq_len = 10     # 序列長度batch_size = 4   # 批大小# 創建多頭注意力模塊mha = MultiHeadAttention(embed_dim, num_heads)# 生成模擬輸入數據input_data = torch.randn(batch_size, seq_len, embed_dim)# 前向傳播output = mha(input_data)print("輸入形狀:", input_data.shape)print("輸出形狀:", output.shape)

輸出示例:

輸入形狀: torch.Size([4, 10, 512])
輸出形狀: torch.Size([4, 10, 512])

此實現保持了輸入輸出維度一致,可直接集成到Transformer等架構中。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/91288.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/91288.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/91288.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

PGSQL運維優化:提升vacuum執行時間觀測能力

本文是 IvorySQL 2025 生態大會暨 PostgreSQL 高峰論壇上的演講內容,作者:NKYoung。 6 月底濟南召開的 HOW2025 IvorySQL 生態大會上,我在內核論壇分享了“提升 vacuum 時間觀測能力”的主題,提出了新增統計信息的方法&#xff0c…

神奇的數據跳變

目的 上周遇上了一個非常奇怪的問題,就是軟件的數據在跳變,本來數據應該是158吧,數據一會變成10,一會又變成158,數據在不斷地跳變,那是怎么回事?? 這個問題非常非常的神奇,讓人感覺太不可思議了。 這是這段時間,我遇上的最神奇的事了,沒有之一,最神奇的事,下面…

【跨國數倉遷移最佳實踐3】資源消耗減少50%!解析跨國數倉遷移至MaxCompute背后的性能優化技術

本系列文章將圍繞東南亞頭部科技集團的真實遷移歷程展開,逐步拆解 BigQuery 遷移至 MaxCompute 過程中的關鍵挑戰與技術創新。本篇為第3篇,解析跨國數倉遷移背后的性能優化技術。注:客戶背景為東南亞頭部科技集團,文中用 GoTerra …

【MySQL集群架構與實踐3】使用Dcoker實現讀寫分離

目錄 一. 在Docker中安裝ShardingSphere 二 實踐:讀寫分離 2.1 應用場景 2.2 架構圖 2.3 服務器規劃 2.4 啟動數據庫服務器 2.5. 配置讀寫分離 2.6 日志配置 2.7 重啟ShardingSphere 2.8 測試 2.9. 負載均衡 2.9.1. 隨機負載均衡算法示例 2.9.2. 輪詢負…

maven的阿里云鏡像地址

在 Maven 中配置阿里云鏡像可以加速依賴包的下載,尤其是國內環境下效果明顯。以下是阿里云 Maven 鏡像的配置方式: 配置步驟:找到 Maven 的配置文件 settings.xml 全局配置:位于 Maven 安裝目錄的 conf/settings.xml用戶級配置&am…

大語言模型信息抽取系統解析

這段代碼實現了一個基于大語言模型的信息抽取系統,能夠從金融和新聞類文本中提取結構化信息。下面我將詳細解析整個代碼的結構和功能。1. 代碼整體結構代碼主要分為以下幾個部分:模式定義:定義不同領域(金融、新聞)需要抽取的實體類型示例數據…

Next實習項目總結串聯講解(一)

下面是一些 Next.js 前端面試中常見且具深度的問題,按照邏輯模塊整理,同時提供示范回答建議,便于你條理清晰地展示理解與實踐經驗。 ? 面試講述結構建議 先講 Next.js 是什么,它為什么比 React 更高級。(支持 SSR/SSG/ISR,提升S…

React開發依賴分析

1. React小案例: 在界面顯示一個文本:Hello World點擊按鈕后,文本改為為:Hello React 2. React開發依賴 2.1. 開發React必須依賴三個庫: 2.1.1. react: 包含react所必須的核心代碼2.1.2. react-dom: react渲染在不同平…

工具(一)Cursor

目錄 一、介紹 二、如何打開文件 1、從idea跳轉文件 2、單獨打開項目 三、常見使用 1、Chat 窗口 Ask 對話模式 1.1、使用技巧 1.2 發送和使用 codebase 發送區別 1.3、問題快速修復 2、Chat 窗口 Agent 對話模式 2.1、agent模式功能 2.2、Chat 窗口回滾&撤銷 2.3…

Prompt編寫規范指引

1、📖 引言 隨著人工智能生成內容(AIGC)技術的快速發展,越來越多的開發者開始利用AIGC工具來輔助代碼編寫。然而,如何編寫有效的提示詞(Prompt)以引導AIGC生成高質量的代碼,成為了許…

自我學習----繪制Mark點

在PCB的Layout過程中我們需在光板上放置Mark點以方便生產時的光學定位(三點定位);我個人Mark點繪制步驟如下: layer層:1.放置直徑1mm的焊盤(無網絡連接) 2.放置一個圓直徑2mm,圓心與…

2025年財稅行業拓客破局:小藍本財稅版AI拓客系統助力高效拓客

2025年,在"金稅四期"全面實施的背景下,中國財稅服務市場迎來爆發式增長,根據最新的市場研究報告,2025年中國財稅服務行業產值將達2725.7億元。然而,行業高速發展的背后,80%的財稅公司卻陷入獲客成…

雙向鏈表,對其實現頭插入,尾插入以及遍歷倒序輸出

1.創建一個節點,并將鏈表的首節點返回創建一個獨立節點,沒有和原鏈表產生任何關系#include "head.h"typedef struct Node { int num; struct Node*pNext; struct Node*pPer; }NODE;后續代碼:NODE*createNode(int value) {NODE*new …

2025年自動化工程與計算機網絡國際會議(ICAECN 2025)

2025年自動化工程與計算機網絡國際會議(ICAECN 2025) 2025 International Conference on Automation Engineering and Computer Networks一、大會信息會議簡稱:ICAECN 2025 大會地點:中國柳州 審稿通知:投稿后2-3日內通…

12.Origin2021如何繪制誤差帶圖?

12.Origin2021如何繪制誤差帶圖?選中Y3列→點擊統計→選擇描述統計→選擇行統計→選擇打開對話框輸入范圍選擇B列到D列點擊輸出量→勾選均值和標準差Control選擇下面三列點擊繪圖→選擇基礎2D圖→選擇誤差帶圖雙擊圖像→選擇符號和顏色點擊第二個Sheet1→點擊誤差棒→連接選擇…

如何使用API接口獲取淘寶店鋪訂單信息

要獲取淘寶店鋪的訂單信息,您需要通過淘寶開放平臺(Taobao Open Platform, TOP)提供的API接口來實現。以下是詳細步驟:1. 注冊淘寶開放平臺賬號訪問淘寶開放平臺注冊開發者賬號并完成實名認證創建應用獲取App Key和App Secret2. 申請API權限在"我的…

【Kiro Code 從入門到精通】重要的功能

一、Kiro 是什么? Kiro 是一款智能型集成開發環境(IDE),借助規格說明(specs)、向導(steer)、鉤子(hooks)幫助你高效完成工作。 二、Specs 規格說明 規范&…

直播間里的酒旅新故事:內容正在重構消費鏈路

文/李樂編輯/子夜今年暑期,旅游的熱浪席卷全國。機場、火車站人潮涌動,電子屏上滾動的航班信息與檢票口前的長隊交織成繁忙的出行圖景,酒店預訂量也在這股熱潮中節節攀升。連線 Insight關注到,今年的暑期游有了一些新變化&#xf…

50天50個小項目 (Vue3 + Tailwindcss V4) ? | VerifyAccountUi(驗證碼組件)

&#x1f4c5; 我們繼續 50 個小項目挑戰&#xff01;—— VerifyAccountUi組件 倉庫地址&#xff1a;https://github.com/SunACong/50-vue-projects 項目預覽地址&#xff1a;https://50-vue-projects.vercel.app/ 使用 Vue 3 的 <script setup> 語法結合 Tailwind CS…

AbstractAuthenticationToken 認證流程中??認證令牌的核心抽象類詳解

AbstractAuthenticationToken 認證流程中??認證令牌的核心抽象類詳解在 Spring Security 中&#xff0c;AbstractAuthenticationToken 是 Authentication 接口的??抽象實現類??&#xff0c;其核心作用是為具體的認證令牌&#xff08;如用戶名密碼令牌、JWT 令牌等&#x…