【現代深度學習技術】注意力機制05:多頭注意力

在這里插入圖片描述

【作者主頁】Francek Chen
【專欄介紹】 ? ? ?PyTorch深度學習 ? ? ? 深度學習 (DL, Deep Learning) 特指基于深層神經網絡模型和方法的機器學習。它是在統計機器學習、人工神經網絡等算法模型基礎上,結合當代大數據和大算力的發展而發展出來的。深度學習最重要的技術特征是具有自動提取特征的能力。神經網絡算法、算力和數據是開展深度學習的三要素。深度學習在計算機視覺、自然語言處理、多模態數據分析、科學探索等領域都取得了很多成果。本專欄介紹基于PyTorch的深度學習算法實現。
【GitCode】專欄資源保存在我的GitCode倉庫:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。

文章目錄

    • 一、模型
    • 二、實現
    • 小結


??在實踐中,當給定相同的查詢、鍵和值的集合時,我們希望模型可以基于相同的注意力機制學習到不同的行為,然后將不同的行為作為知識組合起來,捕獲序列內各種范圍的依賴關系(例如,短距離依賴和長距離依賴關系)。因此,允許注意力機制組合使用查詢、鍵和值的不同子空間表示(representation subspaces)可能是有益的。

??為此,與其只使用單獨一個注意力匯聚,我們可以用獨立學習得到的 h h h組不同的線性投影(linear projections)來變換查詢、鍵和值。然后,這 h h h組變換后的查詢、鍵和值將并行地送到注意力匯聚中。最后,將這 h h h個注意力匯聚的輸出拼接在一起,并且通過另一個可以學習的線性投影進行變換,以產生最終輸出。這種設計被稱為多頭注意力(multihead attention)。對于 h h h個注意力匯聚輸出,每一個注意力匯聚都被稱作一個(head)。圖1展示了使用全連接層來實現可學習的線性變換的多頭注意力。

在這里插入圖片描述

圖1 多頭注意力:多個頭連結然后線性變換

一、模型

??在實現多頭注意力之前,讓我們用數學語言將這個模型形式化地描述出來。給定查詢 q ∈ R d q \mathbf{q} \in \mathbb{R}^{d_q} qRdq?、鍵 k ∈ R d k \mathbf{k} \in \mathbb{R}^{d_k} kRdk?和值 v ∈ R d v \mathbf{v} \in \mathbb{R}^{d_v} vRdv?,每個注意力頭 h i \mathbf{h}_i hi? i = 1 , … , h i = 1, \ldots, h i=1,,h)的計算方法為:
h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p v (1) \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v} \tag{1} hi?=f(Wi(q)?q,Wi(k)?k,Wi(v)?v)Rpv?(1) 其中,可學習的參數包括 W i ( q ) ∈ R p q × d q \mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q} Wi(q)?Rpq?×dq? W i ( k ) ∈ R p k × d k \mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k} Wi(k)?Rpk?×dk? W i ( v ) ∈ R p v × d v \mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v} Wi(v)?Rpv?×dv?,以及代表注意力匯聚的函數 f f f f f f可以是注意力評分函數中的加性注意力和縮放點積注意力。多頭注意力的輸出需要經過另一個線性轉換,它對應著 h h h個頭連結后的結果,因此其可學習參數是 W o ∈ R p o × h p v \mathbf W_o\in\mathbb R^{p_o\times h p_v} Wo?Rpo?×hpv?
W o [ h 1 ? h h ] ∈ R p o (2) \mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o} \tag{2} Wo? ?h1??hh?? ?Rpo?(2)

??基于這種設計,每個頭都可能會關注輸入的不同部分,可以表示比簡單加權平均值更復雜的函數。

import math
import torch
from torch import nn
from d2l import torch as d2l

二、實現

??在實現過程中通常選擇縮放點積注意力作為每一個注意力頭。為了避免計算代價和參數代價的大幅增長,我們設定 p q = p k = p v = p o / h p_q = p_k = p_v = p_o / h pq?=pk?=pv?=po?/h。值得注意的是,如果將查詢、鍵和值的線性變換的輸出數量設置為 p q h = p k h = p v h = p o p_q h = p_k h = p_v h = p_o pq?h=pk?h=pv?h=po?,則可以并行計算 h h h個頭。在下面的實現中, p o p_o po?是通過參數num_hiddens指定的。

#@save
class MultiHeadAttention(nn.Module):"""多頭注意力"""def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形狀:# (batch_size,查詢或者“鍵-值”對的個數,num_hiddens)# valid_lens 的形狀:# (batch_size,)或(batch_size,查詢的個數)# 經過變換后,輸出的queries,keys,values 的形狀:# (batch_size*num_heads,查詢或者“鍵-值”對的個數,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在軸0,將第一項(標量或者矢量)復制num_heads次,# 然后如此復制第二項,然后諸如此類。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形狀:(batch_size*num_heads,查詢的個數,# num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形狀:(batch_size,查詢的個數,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)

??為了能夠使多個頭并行計算,上面的MultiHeadAttention類將使用下面定義的兩個轉置函數。具體來說,transpose_output函數反轉了transpose_qkv函數的操作。

#@save
def transpose_qkv(X, num_heads):"""為了多注意力頭的并行計算而變換形狀"""# 輸入X的形狀:(batch_size,查詢或者“鍵-值”對的個數,num_hiddens)# 輸出X的形狀:(batch_size,查詢或者“鍵-值”對的個數,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 輸出X的形狀:(batch_size,num_heads,查詢或者“鍵-值”對的個數,# num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最終輸出的形狀:(batch_size*num_heads,查詢或者“鍵-值”對的個數,# num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])#@save
def transpose_output(X, num_heads):"""逆轉transpose_qkv函數的操作"""X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)

??下面使用鍵和值相同的小例子來測試我們編寫的MultiHeadAttention類。多頭注意力輸出的形狀是(batch_sizenum_queriesnum_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
attention.eval()

在這里插入圖片描述

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

在這里插入圖片描述

小結

  • 多頭注意力融合了來自于多個注意力匯聚的不同知識,這些知識的不同來源于相同的查詢、鍵和值的不同的子空間表示。
  • 基于適當的張量操作,可以實現多頭注意力的并行計算。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/80476.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/80476.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/80476.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SpringBoot 集成滑塊驗證碼AJ-Captcha行為驗證碼 Redis分布式 接口限流 防爬蟲

介紹 滑塊驗證碼比傳統的字符驗證碼更加直觀和用戶友好,能夠很好防止爬蟲獲取數據。 AJ-Captcha行為驗證碼,包含滑動拼圖、文字點選兩種方式,UI支持彈出和嵌入兩種方式。后端提供Java實現,前端提供了php、angular、html、vue、u…

邊緣網關(邊緣計算)

邊緣網關是邊緣計算架構中的關鍵組件,充當連接終端設備(如傳感器、IoT設備)與云端或核心網絡的橋梁。它在數據源頭附近進行實時處理、分析和過濾,顯著提升效率并降低延遲。 核心功能 協議轉換 ○ 支持多種通信協議(如…

OpenCV定位地板上的書

任務目標是將下面的圖片中的書本找出來: 使用到的技術包括:轉灰度圖、提取顏色分量、二值化、形態學、輪廓提取等。 我們嘗試先把圖片轉為灰度圖,然后二值化,看看效果: 可以看到,二值化后,書的…

機器學習第一講:機器學習本質:讓機器通過數據自動尋找規律

機器學習第一講:機器學習本質:讓機器通過數據自動尋找規律 資料取自《零基礎學機器學習》。 查看總目錄:學習大綱 關于DeepSeek本地部署指南可以看下我之前寫的文章:DeepSeek R1本地與線上滿血版部署:超詳細手把手指…

修改圖像分辨率

在這個教程中,您將學習如何使用Python和深度學習技術來調整圖像的分辨率。我們將從基礎的圖像處理技術開始,逐步深入到使用預訓練的深度學習模型進行圖像超分辨率處理。 一、常規修改方法 1. 安裝Pillow庫 首先,你需要確保你的Python環境中…

jsAPI

環境準備 1 安裝nvm nvm 即 (node version manager),好處是方便切換 node.js 版本 安裝注意事項 要卸載掉現有的 nodejs提示選擇 nvm 和 nodejs 目錄時,一定要避免目錄中出現空格選用【以管理員身份運行】cmd 程序來執行 nvm 命令首次運行前設置好國…

SCDN是什么?

SCDN是安全內容分發網絡的簡稱,它在傳統內容分發網絡(CDN)的基礎上,集成了安全防護能力,旨在同時提升內容傳輸速度和網絡安全性。 SCDN的核心功能有: DDoS防御:識別并抵御大規模分布式拒絕服務…

Qt/C++開發監控GB28181系統/實時視頻預覽/視頻點播/rtp解包解碼顯示

一、前言 通過gb28181做實時視頻預覽,也就是視頻點播功能,是最重要的功能了,絕對是整個系統排第一重要的,這就是核心功能,什么設備注冊、獲取通道等都是為了實時預覽做準備的,當然這個功能也是最難的&…

找銀子 題解(c++)

題目 思路 首先,這道題乍一看,應該可以用搜索來做。 但是,搜索會不會超時間限制呢? 為了防止時間超限,我們可以換一種做法。 先創立兩個二維數組,一個是輸入的數組a,一個是數組b。 假設 i 行 j 列的數…

子集樹算法文檔

1.算法概述 子集樹是一種 回溯算法,用于生成一個集合的所有子集。給定一個數組 arr,該算法遞歸地遍歷所有可能的子集,并通過一個輔助數組 x 標記當前元素是否被選中。 2.算法特點 時間復雜度:O(2n)(因為一個包含 n 個…

HTTP/1.1 host虛擬主機詳解

一、核心需求:為什么需要虛擬主機? 在互聯網上,我們常常希望在一臺物理服務器(它通常只有一個公網 IP 地址)上運行多個獨立的網站,每個網站都有自己獨特的域名(例如 www.a-site.com?, www.b-s…

amass:深入攻擊面映射和資產發現工具!全參數詳細教程!Kali Linux教程!

簡介 OWASP Amass 項目使用開源信息收集和主動偵察技術執行攻擊面網絡映射和外部資產發現。 此軟件包包含一個工具,可幫助信息安全專業人員使用開源信息收集和主動偵察技術執行攻擊面網絡映射并執行外部資產發現。 使用的信息收集技術 技術數據來源APIs&#xf…

Spring Web MVC響應

返回靜態頁面 第一步 創建html時,要注意創建的路徑,要在static下面 第二步 把需要寫的內容寫到body內 第三步 直接訪問路徑就可以 返回數據ResponseBody RestController Controller ResponseBody Controller:返回視圖 ResponseBody&…

?鴻蒙PC正式發布:國產操作系統實現全場景生態突破

鴻蒙PC正式發布:國產操作系統實現全場景生態突破? 2025年5月8日,華為在深圳舉辦發布會,正式推出搭載鴻蒙操作系統的個人電腦(PC),標志著國產操作系統在核心技術與生態布局上實現歷史性跨越。此次發布的鴻蒙…

【計算機視覺】OpenCV實戰項目:Text-Extraction-Table-Image:基于OpenCV與OCR的表格圖像文本提取系統深度解析

Text-Extraction-Table-Image:基于OpenCV與OCR的表格圖像文本提取系統深度解析 1. 項目概述2. 技術原理與算法設計2.1 圖像預處理流水線2.2 表格結構檢測算法2.3 OCR優化策略 3. 實戰部署指南3.1 環境配置3.2 核心代碼解析3.3 執行流程示例 4. 常見問題與解決方案4.…

Redis BigKey 問題是什么

BigKey 問題是什么 BigKey 的具體表現是 redis 中的 key 對應的 value 很大,占用的 redis 空間比較大,本質上是大 value 問題。 BigKey怎么找 redis-cli --bigkeysscanBig Key 產生的原因 1.redis數據結構使用不恰當 2.未及時清理垃圾數據 3.對業務預…

go-gin

前置 gin是go的一個web框架,我們簡單介紹一下gin的使用 導入gin :"github.com/gin-gonic/gin" 我們使用import導入gin的包 簡單示例: package mainimport ("github.com/gin-gonic/gin" )func main() {r : gin.Default(…

C# NX二次開發:判斷兩個體是否干涉和獲取系統日志的UFUN函數

大家好,今天要講關于如何判斷兩個體是否干涉和獲取系統日志的UFUN函數。 (1)UF_MODL_check_interference:這個函數的定義為根據單個目標體檢查每個指定的工具體是否有干擾。 Defined in: uf_modl.h Overview Checks each sp…

如何解決 Linux 系統文件描述符耗盡的問題

在Linux系統中,文件描述符(File Descriptor, FD)是操作系統管理打開文件、套接字、管道等資源的抽象標識。當進程或系統耗盡文件描述符時,會導致服務崩潰、連接失敗等嚴重問題。以下是詳細的排查和解決方案: --- ###…

LVGL簡易計算器實戰

文章目錄 📁 文件結構建議🔹 eval.h 表達式求值頭文件🔹 eval.c 表達式求值實現文件(帶詳細注釋)🔹 ui.h 界面頭文件🔹 ui.c 界面實現文件🔹 main.c 主函數入口? 總結 項目效果&…