002 self-attention自注意力

目錄

一、環境

二、self-attention原理

三、完整代碼


一、環境

本文使用環境為:

  • Windows10
  • Python 3.9.17
  • torch?1.13.1+cu117
  • torchvision 0.14.1+cu117

二、self-attention原理

自注意力(Self-Attention)操作是基于 Transformer 的機器翻譯模型的基本操作,在源語言的編
碼和目標語言的生成中頻繁地被使用以建模源語言、目標語言任意兩個單詞之間的依賴關系。給
定由單詞語義嵌入及其位置編碼疊加得到的輸入表示 {xi ∈ Rd},為了實現對上下文語義依賴的建模,進一步引入在自注意力機制中涉及到的三個元素:查詢 qi(Query),鍵 ki(Key),值 vi (Value)。在編碼輸入序列中每一個單詞的表示的過程中,這三個元素用于計算上下文單詞所對應的權重得分。直觀地說,這些權重反映了在編碼當前單詞的表示時,對于上下文不同部分所需要的關注程度。具體來說,如圖所示,通過三個線性變換 WQ,WK ,WV 將輸入序列中的每一個單詞表示 xi 轉換為其對應的 qi,ki ,vi? 向量。

為了得到編碼單詞 xi 時所需要關注的上下文信息,通過位置 i 查詢向量與其他位置的鍵向量做點積得到匹配分數 qi · k1, qi · k2, ..., qi · kt。為了防止過大的匹配分數在后續 Softmax 計算過程中導致的梯度爆炸以及收斂效率差的問題,這些得分會除放縮因子 √d 以穩定優化。放縮后的得分經過 Softmax 歸一化為概率之后,與其他位置的值向量相乘來聚合希望關注的上下文信息,并最小化不相關信息的干擾。上述計算過程可以被形式化地表述如下:

其中 Q? , K? ,V? 分別表示輸入序列中的不同單詞的 q, k, v 向量拼接組成的矩陣,L 表示序列長度,Z 表示自注意力操作的輸出。為了進一步增強自注意力機制聚合上下文信息的能力,提出了多頭自注意力(Multi-head Attention)的機制,以關注上下文的不同側面。具體來說,上下文中每一個單詞的表示 xi 經過多組線性 {WQ*WK*WV } 映射到不同的表示子空間中。公式會在不同的子空間中分別計算并得到不同的上下文相關的單詞序列表示{Zj}。最終,線性變換 WO 用于綜合不同子空間中的上下文表示并形成自注意力層最終的輸出 xi 。

三、完整代碼

import torch.nn as nn
import torch
import math
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, heads, d_model, dropout = 0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // heads # 512 / 8 self.h = headsself.q_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)self.out = nn.Linear(d_model, d_model)def attention(self, q, k, v, d_k, mask=None, dropout=None):scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # self-attention公式# 掩蓋掉那些為了填補長度增加的單元,使其通過 softmax 計算后為 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1) # self-attention公式if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v) # self-attention公式return outputdef forward(self, q, k, v, mask=None):bs = q.size(0) # 進行線性操作劃分為成 h 個頭k = self.k_linear(k).view(bs, -1, self.h, self.d_k)q = self.q_linear(q).view(bs, -1, self.h, self.d_k)v = self.v_linear(v).view(bs, -1, self.h, self.d_k)# 矩陣轉置k = k.transpose(1,2) q = q.transpose(1,2) v = v.transpose(1,2) # 計算 attentionscores = self.attention(q, k, v, self.d_k, mask, self.dropout)# 連接多個頭并輸入到最后的線性層concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output# 準備q、k、v張量
d_model = 512
num_heads = 8
batch_size = 32
seq_len = 64q = torch.randn(batch_size, seq_len, d_model) # 64 x 512
k = torch.randn(batch_size, seq_len, d_model) # 64 x 512
v = torch.randn(batch_size, seq_len, d_model) # 64 x 512sa = MultiHeadAttention(heads = num_heads, d_model=d_model)
print(sa(q, k, v).shape) # torch.Size([32, 64, 512])
print('')

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/208826.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/208826.shtml
英文地址,請注明出處:http://en.pswp.cn/news/208826.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【XILINX】記錄ISE/Vivado使用過程中遇到的一些warning及解決方案

前言 XILINX/AMD是大家常用的FPGA,但是在使用其開發工具ISE/Vivado時免不了會遇到很多warning,(大家是不是發現程序越大warning越多?),并且還有很多warning根據消除不了,看著特心煩? 我這里匯總一些我遇到的…

http和https區別

http和https區別 HTTP(Hypertext Transfer Protocol)和HTTPS(Hypertext Transfer Protocol Secure)是用于在網絡上傳輸數據的兩種協議。它們之間的主要區別在于安全性和數據傳輸方式: 安全性:HTTP是明文傳…

華清遠見嵌入式學習——QT——作業2

作業要求&#xff1a; 代碼運行效果圖&#xff1a; 登錄失敗 和 最小化 和 取消登錄 登錄成功 和 X號退出 代碼&#xff1a; ①&#xff1a;頭文件 #ifndef LOGIN_H #define LOGIN_H#include <QMainWindow> #include <QLineEdit> //行編輯器類 #include…

如何在centos8上配置一個ca證書頒發機構并且頒發一個自簽名證書【超詳細!!!】

在CentOS 8上配置CA證書頒發機構并頒發自簽名證書的步驟如下&#xff1a; 1. 安裝OpenSSL sudo dnf install openssl 2. 創建CA證書目錄 sudo mkdir /etc/pki/CA/ sudo chmod 0700 /etc/pki/CA/ 3. 創建CA證書數據庫 sudo touch /etc/pki/CA/index.txt sudo echo 1000 >…

Java Spring + SpringMVC + MyBatis(SSM)期末作業項目

本系統是一個圖書管理系統&#xff0c;比較適合當作期末作業主要技術棧如下&#xff1a; - 數據庫&#xff1a;MySQL - 開發工具&#xff1a;IDEA - 數據連接池&#xff1a;Druid - Web容器&#xff1a;Apache Tomcat - 項目管理工具&#xff1a;Maven - 版本控制工具&#xf…

探索人工智能領域——每日20個名詞詳解【day12】

目錄 前言 正文 總結 &#x1f308;嗨&#xff01;我是Filotimo__&#x1f308;。很高興與大家相識&#xff0c;希望我的博客能對你有所幫助。 &#x1f4a1;本文由Filotimo__??原創&#xff0c;首發于CSDN&#x1f4da;。 &#x1f4e3;如需轉載&#xff0c;請事先與我聯系以…

學習JVM

java虛擬機 流程&#xff1a;helloworld.java----(javac編譯)----helloworld.class-------(java運行)——JVM——機器碼JVM功能 *解釋和運行 *內存管理 *即時編譯&#xff08;跨平臺-慢一點&#xff09;jit &#xff08;反復用到的代碼 解釋保存再內存里面&#xff09;…

進程、線程、線程池狀態

線程幾種狀態和狀態轉換 進程主要寫明三種基本狀態&#xff1a; 線程池的幾種狀態&#xff1a;

STM32的BKP與RTC簡介

芯片的供電引腳 引腳表橙色的是芯片的供電引腳&#xff0c;其中VSS/VDD是芯片內部數字部分的供電&#xff0c;VSSA/VDDA是芯片內部模擬部分的供電&#xff0c;這4組以VDD開頭的供電都是系統的主電源&#xff0c;正常使用時&#xff0c;全部都要接3.3V的電源上&#xff0c;VBAT是…

Leetcode2477. 到達首都的最少油耗

Every day a Leetcode 題目來源&#xff1a;2477. 到達首都的最少油耗 解法1&#xff1a;貪心 深度優先搜索 題目等價于給出了一棵以節點 0 為根結點的樹&#xff0c;并且初始樹上的每一個節點上都有一個人&#xff0c;現在所有人都需要通過「車子」向結點 0 移動。 對于…

從阻抗匹配看擁塞控制

先來理解阻抗匹配&#xff0c;但我不按傳統方式解釋&#xff0c;因為傳統方案你要先理解如何定義阻抗&#xff0c;然后再學習什么是輸入阻抗和輸出阻抗&#xff0c;最后再看如何讓它們匹配&#xff0c;而讓它們匹配的目標僅僅是信號不反射&#xff0c;以最大能效被負載接收。 …

面試寶典之自我介紹

聽人勸、吃飽飯,奉勸各位小伙伴,不要訂閱該文所屬專欄。 如需要項目實戰或者是體系化資源,文末名片加V! 作者:哈哥撩編程,工作十余年, 從事過全棧研發、產品經理等工作,目前在公司擔任研發部門CTO。榮譽:2022年度博客之星Top4、2023年度超級個體得主、谷歌與亞馬遜開發…

Amazon CodeWhisperer 開箱初體驗

文章作者&#xff1a;Coder9527 科技的進步日新月異&#xff0c;正當人工智能發展如火如荼的時候&#xff0c;各大廠商在“解放”碼農的道路上不斷創造出各種 Coding 利器&#xff0c;今天在下就帶大家開箱體驗一個 Coding 利器&#xff1a; Amazon CodeWhisperer。 亞馬遜云科…

99基于matlab的小波分解和小波能量熵函數

基于matlab的小波分解和小波能量熵函數&#xff0c;通過GUI界面導入西儲大學軸承故障數據&#xff0c;以可視化的圖對結果進行展現。數據可更換自己的&#xff0c;程序已調通&#xff0c;可直接運行。 99小波分解和小波能量熵函數 (xiaohongshu.com)https://www.xiaohongshu.co…

【LeetCode每日一題合集】2023.11.27-2023.12.3 (?)

文章目錄 907. 子數組的最小值之和&#xff08;單調棧貢獻法&#xff09;1670. 設計前中后隊列?&#xff08;設計數據結構&#xff09;解法1——雙向鏈表解法2——兩個雙端隊列 2336. 無限集中的最小數字解法1——維護最小變量mn 和 哈希表維護已經去掉的數字解法2——維護原本…

二分查找|前綴和|滑動窗口|2302:統計得分小于 K 的子數組數目

作者推薦 貪心算法LeetCode2071:你可以安排的最多任務數目 本文涉及的基礎知識點 二分查找算法合集 題目 一個數組的 分數 定義為數組之和 乘以 數組的長度。 比方說&#xff0c;[1, 2, 3, 4, 5] 的分數為 (1 2 3 4 5) * 5 75 。 給你一個正整數數組 nums 和一個整數…

response應用及重定向和request轉發

請求和轉發&#xff1a; response說明一、response文件下載二、response驗證碼實現1.前置知識&#xff1a;2.具體實現&#xff1a;3.知識總結 三、response重定向四、request轉發五、重定向和轉發的區別 response說明 response是指HttpServletResponse,該響應有很多的應用&…

JavaScript 一些少見多怪的玩意

$$() [].forEach.call($$("*"), function (a) {a.style.outline "1px solid #" (~~(Math.random() * (1 << 24))).toString(16);}); 直接復制到控制臺&#xff0c;頁面效果就是頁面中不同的HTML結構被不同顏色的框圈著。 原理&#xff1a; $$函數…

力扣面試150題 | 輪轉數組

力扣面試150題 &#xff5c; 輪轉數組 題目描述解題思路代碼實現 題目描述 189.輪轉數組 給定一個整數數組 nums&#xff0c;將數組中的元素向右輪轉 k 個位置&#xff0c;其中 k 是非負數。 示例 1: 輸入: nums [1,2,3,4,5,6,7], k 3 輸出: [5,6,7,1,2,3,4] 解釋: 向右輪…

Kafka在微服務架構中的應用:實現高效通信與數據流動

微服務架構的興起帶來了分布式系統的復雜性&#xff0c;而Kafka作為一款強大的分布式消息系統&#xff0c;為微服務之間的通信和數據流動提供了理想的解決方案。本文將深入探討Kafka在微服務架構中的應用&#xff0c;并通過豐富的示例代碼&#xff0c;幫助大家更全面地理解和應…