MHD、MQA、GQA注意力機制詳解

MHD、MQA、GQA注意力機制詳解

  • 注意力機制詳解及代碼
    • 前言:
    • MHA
    • MQA
    • GQA

注意力機制詳解及代碼

前言:

自回歸解碼器推理是 Transformer 模型的 一個嚴重瓶頸,因為在每個解碼步驟中加 載解碼器權重以及所有注意鍵和值會產生 內存帶寬開銷

下圖為三種注意力機制的結構圖和實驗結果

在這里插入圖片描述

在這里插入圖片描述

MHA

多頭注意力機制是Transformer模型中的核心組件。在其設計中,"多頭"意味著該機制并不只計算一種注意力權重,而是并行計算多種權重,每種權重都從不同的“視角”捕獲輸入的不同信息。

  • hidden_state經過線性層得到q、k、v
  • q、k、v經過split后增加一個維度:num_heads
  • q、k計算注意力分數score
  • softmax對注意力分數進行歸一化得到注意力權重attention_probs
  • 使用注意力權重和值計算輸出:output
  • 對注意力輸出進行拼接concat
import torch
from torch import nn
class MutiHeadAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiHeadAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩陣self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, hidden_size)self.v_linear = nn.Linear(hidden_size, hidden_size)## 輸出線性層self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key)value = self.split_head(value)## 計算注意力分數attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 對注意力分數進行歸一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)## 對注意力輸出進行拼接output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x):batch_size = x.size()[0]return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)

MQA

多查詢注意力(MQA)可能導致質量下降和訓練不穩定,并且訓練針對質量和推理優化的單獨模型可能不可行。此外,雖然一些語言模型已經使用了多查詢注意力,如PaLM但許多語言模型沒有,包括公開可用的語言模型,如T5和LLaM.

  • hidden_state經過線性層得到q、k、v
  • q、k、v經過split后增加一個維度:num_heads(q = num_heads,k=1,v=1)。相當于多個query,即多查詢。
  • q、k計算注意力分數score
  • softmax對注意力分數進行歸一化得到注意力權重attention_probs
  • 使用注意力權重和值計算輸出:output
  • 對注意力輸出進行拼接concat
## 多查詢注意力
import torch
from torch import nn
class MutiQueryAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(MutiQueryAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_heads## 初始化Q、K、V投影矩陣self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.head_dim) ###self.v_linear = nn.Linear(hidden_size, self.head_dim) ##### 輸出線性層self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, 1)value = self.split_head(value, 1)## 計算注意力分數attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 對注意力分數進行歸一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, head_num=None):batch_size = x.size()[0]if head_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:return x.view(batch_size, -1, head_num, self.head_dim).transpose(1,2)

GQA

  • 使用 5% 的原始預訓練 計算將現有的多頭語言模型檢查點訓 練到具有 MQA 的模型中
  • 引入分組查詢注意力 (GQA),這是多 頭語言模型的泛化。查詢注意力,它使用中間,多于一個,少于查詢頭數量的鍵值頭。
  • 經過訓練的GQA 實現了接近多頭注意力 的質量,并且速度與 MQA 相當。
  • hidden_state經過線性層得到q、k、v
  • q、k、v經過split后增加一個維度:num_heads(q = num_heads,k=group_num,v=group_num)。相當于把多頭分組了,比如原先有10個頭,那就是10個query,分成5組,每組2個query,1個value,1個key。
  • q、k計算注意力分數score
  • softmax對注意力分數進行歸一化得到注意力權重attention_probs
  • 使用注意力權重和值計算輸出:output
  • 對注意力輸出進行拼接concat
## 分組注意力查詢
import torch
from torch import nn
class MutiGroupAttention(torch.nn.Module):def __init__(self, hidden_size, num_heads, group_num):super(MutiGroupAttention, self).__init__()self.num_heads = num_headsself.head_dim = hidden_size // num_headsself.group_num = group_num## 初始化Q、K、V投影矩陣self.q_linear = nn.Linear(hidden_size, hidden_size)self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)## 輸出線性層self.o_linear = nn.Linear(hidden_size, hidden_size)def forward(self, hidden_state, attention_mask=None):batch_size = hidden_state.size()[0]query = self.q_linear(hidden_state)key = self.k_linear(hidden_state)value = self.v_linear(hidden_state)query = self.split_head(query)key = self.split_head(key, self.group_num)value = self.split_head(value, self.group_num)## 計算注意力分數attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))if attention_mask != None:attention_scores += attention_mask * -1e-9## 對注意力分數進行歸一化attention_probs = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_probs, value)output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)output = self.o_linear(output)return outputdef split_head(self, x, group_num=None):batch_size,seq_len = x.size()[:2]if group_num == None:return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)else:x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1,2)x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads // group_num * group_num, seq_len, self.head_dim)return x

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/12730.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/12730.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/12730.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

鞏固學習8

在 Pandas 中,sep參數用于指定數據中字段之間的分隔符。常見的參數包括: 逗號:,,常用于CSV文件。 制表符:\t,常用于TSV文件。 空格:’ ,用于空格分隔的數據。 分號:;&…

【合成孔徑雷達】合成孔徑雷達的多視角理解和時/頻成像算法的統一解釋

文章目錄 一、什么是雷達成像(1)主要的遙感探測手段:光學、紅外和雷達(2)從數學的角度:雷達成像主要研究什么?數據采集: y T x n yTxn yTxn信息提取: y ? > x ? y…

編譯錯誤:stray ‘\357’ in program的解決方法

目錄 把報錯文件更換編碼格式,我試的utf-8 bom編碼就可以了,可以多換幾種試試。 網友的另一種案例: 編譯錯誤:stray ‘\357’ in program的解決方法 把報錯文件更換編碼格式,我試的utf-8 bom編碼就可以了&#xff0c…

LabVIEW做儀器測試不知道是否適用

LabVIEW(Laboratory Virtual Instrument Engineering Workbench)是一個用于系統工程和測量系統的圖形編程平臺,由National Instruments開發。它非常適用于儀器控制、數據采集、信號處理以及自動化測試與測量系統的開發。如果您的工作涉及到這…

如何同步管理1000個設備的VLAN數據?

什么是VLAN? VLAN,也就是虛擬局域網,是通過為子網提供數據鏈路連接來抽象出局域網的概念。在企業網中,一個企業級交換機一般是24口或者是48口,連接這些接口的終端在物理上形成一個廣播域。廣播域過大,就會導…

【AI智能體】零代碼構建AI應用,全網都在喊話歌手誰能應戰,一鍵AI制作歌手信息查詢應用

歡迎來到《小5講堂》 這是《文心智能體平臺》系列文章,每篇文章將以博主理解的角度展開講解。 溫馨提示:博主能力有限,理解水平有限,若有不對之處望指正! 目錄 文心智能體大賽背景創建應用平臺地址快速構建【基礎配置】…

前端無樣式id或者class等來定位標簽

目錄: 1、使用背景2、代碼處理 1、使用背景 客戶使用我們產品組件,發現替換文件,每次替換都會新增如下的樣式,造就樣式錯亂,是組件的文件,目前臨時處理的話就是替換文件時刪除新增的樣式,但是發…

8評分卡建模整體流程梳理

評分卡建模整體流程梳理 學習目標 掌握評分卡建模流程使用Toad庫構建評分卡1 加載數據 import pandas as pd from sklearn.metrics import roc_auc_score,roc_curve,auc from sklearn.model_selection import train_test_split from sklearn.linear_model import Logis…

云服務器上Redis數據庫被攻擊實錄+總結

情景重現 Redis日志記錄(異常部分): 36346:M 14 May 2024 15:46:12.505 # Possible SECURITY ATTACK detected. It looks like somebody is sending POST or Host: commands to Redis. This is likely due to an attacker attempting to us…

【JVM】閱讀Class字節碼:常量池

目錄 基本結構解析 常量池 常量池簡介 如何閱讀Class文件中的常量池信息 基本結構解析 Magic(魔數) Magic的唯一作用是確定這個文件是否為一個能被虛擬機所接受的class 文件。魔數值固定為0xCAFEBABE,不會改變。 常量池 常量池簡介 下圖是反編譯過后的字節碼文…

Python可視化總結與案例解析

目錄 第一章:Python可視化基礎 1.1 環境搭建 1.2 數據可視化 1.3 統計圖表 1.4 交互式可視化 1.5 實戰案例:網站流量分析 1.6 總結 第二章:Python可視化高級應用 2.1 高級圖表類型 2.2 動態可視化 2.3 數據可視化最佳實踐 2.4 實戰…

TensorFlow的學習

0.基礎概念 術語表: https://developers.google.cn/machine-learning/glossary?hlzh-cn#logits 1.快速入門 https://tensorflow.google.cn/tutorials/quickstart/beginner?hlzh-cn 2.基于Keras進行圖像分類 https://tensorflow.google.cn/tutorials/keras/cl…

gradle 共享存儲掛載緩存目錄的問題

2個任務同時構建的時候,報錯如上。 原因:掛載目錄的問題導致的,掛在最小粒度的目錄下。 /home/app/.gradle/caches/modules-2/files-2.1 掛載到這個級別的目錄下。

一文詳解什么是手機在網時長API

手機在網時長API最近被討論得越來越多,因為隨著移動互聯網的不斷發展,越來越多的場景需要使用到用戶的手機號,比如商品交易、客戶服務、信息收發、網絡即時通訊等。手機號碼狀態查詢功能使用得越來越廣泛,常見的有手機在網時長查詢…

演員怎么上百度百科

百度百科是一個公正、開放、客觀的平臺,它為演員提供了一個展示自己過往經歷和演藝生涯的平臺。以下是百科優化網yajje總結的演員創建百度百科的一些步驟和注意事項: 創建演員百度百科的基本條件 人物影響力:演員創建百度百科需要滿足官方的規…

振弦采集儀在巖土工程監測中的重要性及應用案例分享

振弦采集儀在巖土工程監測中的重要性及應用案例分享 巖土工程監測是為了確保土地和建筑物的穩定性以及確保施工安全而進行的一項重要工作。河北穩控科技振弦采集儀是巖土工程監測中一種常用的儀器設備,通過測量土體振動頻率來評估土體的穩定性和強度變化&#xff0…

寫個進度條

using UnityEngine; using UnityEngine.UI; [ExecuteAlways] public class 進度條控制器 : MonoBehaviour {public Image 母進度條; // 母進度條背景public Image 子進度條; // 子進度條[Range(0, 1)] public float 進度值; // 進度值public float 起點偏移量; // 從左側開始的…

理解打包好的vue項目結構dist包

目錄 linux查詢dist目錄整體解釋子目錄文件解釋CSSFONTSJS linux查詢dist目錄 roothcss-ecs-7881:/www/java_project/dist# ls -l total 3004 drwxr-xr-x 2 root root 4096 Dec 31 10:15 css -rw-r--r-- 1 root root 4286 Dec 31 10:15 favicon.ico drwxr-xr-x 2 root r…

MySQL從主庫恢復從庫

主庫備份數據,拷貝至從節點1.1 備份數據 sudo python /data/apps/xtrabackup/script/xtrabackup.py -m full 備份目錄為: /data/mysql_bakcup/<port>/<date>/full_<date> 例:/data/mysql_backup/13306/20231124/full_164044/ 1.2 拷貝備份數據至從節點 sc…

霸道龍尊短視頻:成都鼎茂宏升文化傳媒公司

霸道龍尊短視頻&#xff1a;龍族的傳奇與現代的交融 在數字化時代的浪潮中&#xff0c;短視頻以其短小精悍、內容豐富的特點&#xff0c;迅速占領了人們的碎片時間。成都鼎茂宏升文化傳媒公司而在這些短視頻中&#xff0c;一股獨特的“霸道龍尊”風潮正在悄然興起&#xff0c;…