縮放點積注意力

Scaled Dot-Product Attention

  • 論文地址

    https://arxiv.org/pdf/1706.03762

注意力機制介紹

  • 縮放點積注意力是Transformer模型的核心組件,用于計算序列中不同位置之間的關聯程度。其核心思想是通過查詢向量(query)和鍵向量(key)的點積來獲取注意力分數,再通過縮放和歸一化處理,最后與值向量(value)加權求和得到最終表示。

    ? image-20250423201641471

數學公式

  • 縮放點積注意力的計算過程可分為三個關鍵步驟:

    1. 點積計算與縮放:通過矩陣乘法計算查詢向量與鍵向量的相似度,并使用 d k \sqrt{d_k} dk? ? 縮放
    2. 掩碼處理(可選):對需要忽略的位置施加極大負值掩碼
    3. Softmax歸一化:將注意力分數轉換為概率分布
    4. 加權求和:用注意力權重對值向量進行加權

    公式表達為:
    Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V Attention(Q,K,V)=softmax(dk? ?QKT?)V
    其中:

    • Q ∈ R s e q _ l e n × d _ k Q \in \mathbb{R}^{seq\_len \times d\_k} QRseq_len×d_k:查詢矩陣
    • K ∈ R s e q _ l e n × d _ k K \in \mathbb{R}^{seq\_len \times d\_k} KRseq_len×d_k:鍵矩陣
    • V ∈ R s e q _ l e n × d _ k V \in \mathbb{R}^{seq\_len \times d\_k} VRseq_len×d_k:值矩陣

    s e q _ l e n seq\_len seq_len 為序列長度, d _ k d\_k d_k 為embedding的維度。

代碼實現

  • 計算注意力分數

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    import torchdef calculate_attention(query, key, value, mask=None):"""計算縮放點積注意力分數參數說明:query: [batch_size, n_heads, seq_len, d_k]key:   [batch_size, n_heads, seq_len, d_k] value: [batch_size, n_heads, seq_len, d_k]mask:  [batch_size, seq_len, seq_len](可選)"""d_k = key.shape[-1]key_transpose = key.transpose(-2, -1)  # 轉置最后兩個維度# 計算縮放點積 [batch, h, seq_len, seq_len]att_scaled = torch.matmul(query, key_transpose) / d_k ** 0.5# 掩碼處理(解碼器自注意力使用)if mask is not None:att_scaled = att_scaled.masked_fill_(mask=mask, value=-1e9)# Softmax歸一化att_softmax = torch.softmax(att_scaled, dim=-1)# 加權求和 [batch, h, seq_len, d_k]return torch.matmul(att_softmax, value)
    
  • 相關解釋

    1. 輸入張量 query, key, value的形狀

      如果是直接計算的話,那么shape是 [batch_size, seq_len, d_model]

      當然為了學習更多的表征,一般都是多頭注意力,這時候shape則是[batch_size, n_heads, seq_len, d_k]

      其中

      • batch_size:批量

      • n_heads:注意力頭的數量

      • seq_len: 序列的長度

      • d_model: embedding維度

      • d_k: d_k = d_model / n_heads

    2. 代碼中的shape轉變

      • key_transpose :key的轉置矩陣

        由 key 轉置了最后兩個維度,維度從 [batch_size, n_heads, seq_len, d_k] 轉變為 [batch_size, n_heads, d_k, seq_len]

      • **att_scaled **:縮放點積

        由 query 和 key 通過矩陣相乘得到

        [batch_size, n_heads, seq_len, d_k] @ [batch_size, n_heads, d_k, seq_len] --> [batch_size, n_heads, seq_len, seq_len]

      • att_score: 注意力分數

        由兩個矩陣相乘得到

        [batch_size, n_heads, seq_len, seq_len] @ [batch_size, n_heads, seq_len, d_k] --> [batch_size, n_heads, seq_len, d_k]


使用示例

  • 測試代碼

    if __name__ == "__main__":# 模擬輸入:batch_size=2, 8個注意力頭,序列長度512,d_k=64x = torch.ones((2, 8, 512, 64))# 計算注意力(未使用掩碼)att_score = calculate_attention(x, x, x)print("輸出形狀:", att_score.shape)  # torch.Size([2, 8, 512, 64])print("注意力分數示例:\n", att_score[0,0,:3,:3])
    

    在實際使用中通常會將此實現封裝為nn.Module并與位置編碼、殘差連接等組件配合使用,構建完整的Transformer層。


本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/77912.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/77912.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/77912.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

可吸收聚合物:醫療科技與綠色未來的交匯點

可吸收聚合物(Biodegradable Polymers)作為生物醫學工程的核心材料,正引領一場從“金屬/塑料植入物”到“智能降解材料”的范式轉移。根據QYResearch(恒州博智)預測,2031年全球可吸收聚合物市場銷售額將突破…

房地產項目績效考核管理制度與績效提升

房地產項目績效考核管理制度的核心目的是通過合理的績效考核機制,提升項目的整體運作效率,并鼓勵項目團隊成員的積極性。該制度適用于所有房地產項目部工作人員,涵蓋了項目經理和項目成員的考核。考核的主要內容包括項目經理和項目部成員的工…

【算法筆記】動態規劃基礎(一):dp思想、基礎線性dp

目錄 前言動態規劃的精髓什么叫“狀態”動態規劃的概念動態規劃的三要素動態規劃的框架無后效性dfs -> 記憶化搜索 -> dp暴力寫法記憶化搜索寫法記憶化搜索優化了什么?怎么轉化成dp?dp寫法 dp其實也是圖論首先先說結論:狀態DAG是怎樣的…

pytorch 51 GroundingDINO模型導出tensorrt并使用c++進行部署,53ms一張圖

本專欄博客第49篇文章分享了將 GroundingDINO模型導出onnx并使用c++進行部署,并嘗試將onnx模型轉換為trt模型,fp16進行推理,可以發現推理速度提升了一倍。為此對GroundingDINO的trt推理進行調研,發現 在GroundingDINO-TensorRT-and-ONNX-Inference項目中分享了模型導出onnx…

一個關于相對速度的假想的故事-6

既然已經知道了速度是不能疊加的,同時也知道這個疊加是怎么做到的,那么,我們實際上就知道了光速的來源,也就是這里的虛數單位的來源: 而它的來源則是, 但這是兩個速度的比率,而光速則是一個速度…

深度學習激活函數與損失函數全解析:從Sigmoid到交叉熵的數學原理與實踐應用

目錄 前言一、sigmoid 及導數求導二、tanh 三、ReLU 四、Leaky Relu五、 Prelu六、Softmax七、ELU八、極大似然估計與交叉熵損失函數8.1 極大似然估計與交叉熵損失函數算法理論8.1.1 伯努利分布8.1.2 二項分布8.1.3 極大似然估計總結 前言 書接上文 PaddlePaddle線性回歸詳解…

Python內置函數---breakpoint()

用于在代碼執行過程中動態設置斷點,暫停程序并進入調試模式。 1. 基本語法與功能 breakpoint(*args, kwargs) - 參數:接受任意數量的位置參數和關鍵字參數,但通常無需傳遞(默認調用pdb.set_trace())。 - 功能&#x…

從零手寫 RPC-version1

一、 前置知識 1. 反射 獲取字節碼的三種方式 Class.forName("全類名") (全類名,即包名類名)類名.class對象.getClass() (任意對象都可調用,因為該方法來自Object類) 獲取成員方法 Method getMethod(St…

ARINC818協議(六)

上圖中,紅色虛線上面為我們常用的simple mode簡單模式,下面和上面的結合在一起,就形成了extended mode擴展模式。 ARINC818協議 container header容器頭 ancillary data輔助數據 視頻流 ADVB幀映射 FHCP傳輸協議 R_CTRL:路由控制routing ctr…

PyCharm 鏈接 Podman Desktop 的 podman-machine-default Linux 虛擬環境

#工作記錄 PyCharm Community 連接到Podman Desktop 的 podman-machine-default Linux 虛擬環境詳細步驟 1. 準備工作 確保我們已在 Windows 系統中正確安裝并啟動了 Podman Desktop。 我們將通過 Podman Desktop 提供的名為 podman-machine-default 的 Fedora Linux 41 WSL…

小白自學python第一天

學習python的第一天 一、常用的值類型(先來粗略認識一下~) 類型說明數字(number)包含整型(int)、浮點型(float)、復數(complex)、布爾(boolean&…

初階數據結構--排序算法(全解析!!!)

排序 1. 排序的概念 排序:所謂排序,就是使一串記錄,按照其中的某個或某些些關鍵字的大小,遞增或遞減的排列起來的操作。 2. 常見的排序算法 3. 實現常見的排序算法 以下排序算法均是以排升序為示例。 3.1 插入排序 基本思想:…

Android studio開發——room功能實現用戶之間消息的發送

文章目錄 1. Flask-SocketIO 后端代碼后端代碼 2. Android Studio Java 客戶端代碼客戶端代碼 3. 代碼說明 SocketIO基礎 1. Flask-SocketIO 后端代碼 后端代碼 from flask import Flask, request from flask_socketio import SocketIO, emit import uuidapp Flask(__name_…

4.LinkedList的模擬實現:

LinkedList的底層是一個不帶頭的雙向鏈表。 不帶頭雙向鏈表中的每一個節點有三個域:值域,上一個節點的域,下一個節點的域。 不帶頭雙向鏈表的實現: public class Mylinkdelist{//定義一個內部類(節點)stat…

Sentinel數據S2_SR_HARMONIZED連續云掩膜+中位數合成

在GEE中實現時,發現簡單的QA60是無法去云的,最近S2地表反射率數據集又進行了更新,原有的屬性集也進行了變化,現在的SR數據集名稱是“S2_SR_HARMONIZED”。那么: 要想得到研究區無云的圖像,可以參考執行以下…

理解計算機系統_網絡編程(1)

前言 以<深入理解計算機系統>(以下稱“本書”)內容為基礎&#xff0c;對程序的整個過程進行梳理。本書內容對整個計算機系統做了系統性導引,每部分內容都是單獨的一門課.學習深度根據自己需要來定 引入 網絡是計算機科學中非常重要的部分,筆者過去看過相關的內…

【2025】Datawhale AI春訓營-RNA結構預測(AI+創新藥)-Task2筆記

【2025】Datawhale AI春訓營-RNA結構預測&#xff08;AI創新藥&#xff09;-Task2筆記 本文對Task2提供的進階代碼進行理解。 任務描述 Task2的任務仍然是基于給定的RNA三維骨架結構&#xff0c;生成一個或多個RNA序列&#xff0c;使得這些序列能夠折疊并盡可能接近給定的目…

vim 命令復習

命令模式下的命令及快捷鍵 # dd刪除光所在行的內容 # ndd從光標所在行開始向下刪除n行 # yy復制光標所在行的內容 # nyy復制光標所在行向下n行的內容 # p將復制的內容粘貼到光標所在行以下&#xff08;小寫&#xff09; # P將復制的內容粘貼到光標所在行以上&#xff08;大寫&…

哪些心電圖表現無緣事業編體檢呢?

根據《公務員錄用體檢通用標準》心血管系統條款及事業單位體檢實施細則&#xff0c;心電圖不合格主要涉及以下類型及處置方案&#xff1a; 一、心律失常類 早搏&#xff1a;包括房性早搏、室性早搏和交界性早搏。如果每分鐘早搏次數較多&#xff08;如超過5次&#xff09;&…

Linux學習——UDP

編程的整體框架 bind&#xff1a;綁定服務器&#xff1a;TCP地址和端口號 receivefrom()&#xff1a;阻塞等待客戶端數據 sendto():指定服務器的IP地址和端口號&#xff0c;要發送的數據 無連接盡力傳輸&#xff0c;UDP:是不可靠傳輸 實時的音視頻傳輸&#x…