自注意力機制、多頭自注意力機制、填充掩碼 Python實現

原理講解

【Transformer系列(2)】注意力機制、自注意力機制、多頭注意力機制、通道注意力機制、空間注意力機制超詳細講解

自注意力機制

import torch
import torch.nn as nn# 自注意力機制
class SelfAttention(nn.Module):def __init__(self, input_dim):super(SelfAttention, self).__init__()self.query = nn.Linear(input_dim, input_dim)self.key = nn.Linear(input_dim, input_dim)self.value = nn.Linear(input_dim, input_dim)        def forward(self, x, mask=None):batch_size, seq_len, input_dim = x.shapeq = self.query(x)k = self.key(x)v = self.value(x)atten_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(input_dim, dtype=torch.float))if mask is not None:mask = mask.unsqueeze(1)attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))        atten_scores = torch.softmax(atten_weights, dim=-1)attented_values = torch.matmul(atten_scores, v)return attented_values# 自動填充函數
def pad_sequences(sequences, max_len=None):batch_size = len(sequences)input_dim = sequences[0].shape[-1]lengths = torch.tensor([seq.shape[0] for seq in sequences])max_len = max_len or lengths.max().item()padded = torch.zeros(batch_size, max_len, input_dim)for i, seq in enumerate(sequences):seq_len = seq.shape[0]padded[i, :seq_len, :] = seqmask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)return padded, mask.long()if __name__ == '__main__':batch_size = 2seq_len = 3input_dim = 128seq_len_1 = 3seq_len_2 = 5x1 = torch.randn(seq_len_1, input_dim)            x2 = torch.randn(seq_len_2, input_dim)target_seq_len = 10    padded_x, mask = pad_sequences([x1, x2], target_seq_len)selfattention = SelfAttention(input_dim)    attention = selfattention(padded_x)print(attention)

多頭自注意力機制

import torch
import torch.nn as nn# 定義多頭自注意力模塊
class MultiHeadSelfAttention(nn.Module):def __init__(self, input_dim, num_heads):super(MultiHeadSelfAttention, self).__init__()self.num_heads = num_headsself.head_dim = input_dim // num_headsself.query = nn.Linear(input_dim, input_dim)self.key = nn.Linear(input_dim, input_dim)self.value = nn.Linear(input_dim, input_dim)        def forward(self, x, mask=None):batch_size, seq_len, input_dim = x.shape# 將輸入向量拆分為多個頭## transpose(1,2)后變成 (batch_size, self.num_heads, seq_len, self.head_dim)形式q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# 計算注意力權重attn_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))# 應用 padding maskif mask is not None:# mask: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) 用于廣播mask = mask.unsqueeze(1).unsqueeze(2)  # 擴展維度以便于廣播attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))        attn_scores = torch.softmax(attn_weights, dim=-1)# 注意力加權求和attended_values = torch.matmul(attn_scores, v).transpose(1, 2).contiguous().view(batch_size, seq_len, input_dim)return attended_values# 自動填充函數
def pad_sequences(sequences, max_len=None):batch_size = len(sequences)input_dim = sequences[0].shape[-1]lengths = torch.tensor([seq.shape[0] for seq in sequences])max_len = max_len or lengths.max().item()padded = torch.zeros(batch_size, max_len, input_dim)for i, seq in enumerate(sequences):seq_len = seq.shape[0]padded[i, :seq_len, :] = seqmask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)return padded, mask.long()if __name__ == '__main__':heads = 2batch_size = 2seq_len_1 = 3seq_len_2 = 5input_dim = 128x1 = torch.randn(seq_len_1, input_dim)            x2 = torch.randn(seq_len_2, input_dim)target_seq_len = 10    padded_x, mask = pad_sequences([x1, x2], target_seq_len)multiheadattention = MultiHeadSelfAttention(input_dim, heads)attention = multiheadattention(padded_x, mask)    print(attention)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/77796.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/77796.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/77796.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【大模型】Browser-Use AI驅動的瀏覽器自動化工具

Browser-Use AI驅動的瀏覽器自動化工具 1. 項目概述2. 核心架構3. 實戰指南3.1 環境安裝3.2 快速啟動3.3 進階功能 4. 常見問題與解決5. 項目優勢與局限6. 擴展資源7. 總結 1. 項目概述 項目地址&#xff1a;browser-use Browser-Use 是一個開源工具&#xff0c;旨在通過 AI 代…

ubuntu20.04安裝安裝x11vnc服務基于gdm3或lightdm這兩種主流的顯示管理器。

前言&#xff1a;在服務端安裝vnc服務&#xff0c;可以方便的遠程操作服務器&#xff0c;而不用非要插上顯示器才行。所以在服務器上安裝vnc是很重要的。在ubuntu20中&#xff0c;默認的顯示管理器已經變為gdm3&#xff0c;它可以帶來與 GNOME 無縫銜接的體驗&#xff0c;強調功…

用銀河麒麟 LiveCD 快速查看原系統 IP 和打印機配置

原文鏈接&#xff1a;用銀河麒麟 LiveCD 快速查看原系統 IP 和打印機配置 Hello&#xff0c;大家好啊&#xff01;今天給大家帶來一篇在銀河麒麟操作系統的 LiveCD 或系統試用鏡像環境下&#xff0c;如何查看原系統中電腦的 IP 地址與網絡打印機 IP 地址的實用教程。在系統損壞…

C++——STL——容器deque(簡單介紹),適配器——stack,queue,priority_queue

目錄 1.deque&#xff08;簡單介紹&#xff09; 1.1 deque介紹&#xff1a; 1.2 deque迭代器底層 1.2.1 那么比如說用迭代器實現元素的遍歷&#xff0c;是如何實現的呢&#xff1f; 1.2.2 頭插 1.2.3 尾插 1.2.4 實現 ?編輯 1.2.5 總結 2.stack 2.1 函數介紹 2.2 模…

Java并發編程-線程池

Java并發編程-線程池 線程池運行原理線程池生命周期線程池的核心參數線程池的阻塞隊列線程池的拒絕策略線程池的種類newFixedThreadPoolnewSingleThreadExecutornewCachedThreadPoolnewScheduledThreadPool 創建線程池jdk的Executors(不建議&#xff0c;會導致OOM)jdk的ThreadP…

【前沿】成像“跨界”測量——掃焦光場成像

01 背景 眼睛是人類認識世界的重要“窗口”&#xff0c;而相機作為眼睛的“延伸”&#xff0c;已經成為生產生活中最常見的工具之一&#xff0c;廣泛應用于工業檢測、醫療診斷與影音娛樂等領域。傳統相機通常以“所見即所得”的方式記錄場景&#xff0c;傳感器捕捉到的二維圖像…

TM1640學習手冊及示例代碼

數據手冊 TM1640數據手冊 數據手冊解讀 這里我們看管腳定義DIN和SCLK&#xff0c;一個數據線一個時鐘線 SEG1~SEG8為段碼&#xff0c;GRID1~GRID16為位碼&#xff08;共陰極情況下&#xff09; 這里VDD給5V 數據指令 數據命令設置 地址命令設置 顯示控制命令 共陰極硬件連接圖…

uni-app 開發企業級小程序課程

課程大小&#xff1a;7.7G 課程下載&#xff1a;https://download.csdn.net/download/m0_66047725/90616393 更多資源下載&#xff1a;關注我 備注&#xff1a;缺少兩個視頻5-14 tabs組件進行基本的數據展示和搜索歷史 處理searchData的刪除操作 1-1導學.mp4 2-10小程序內…

判斷點是否在多邊形內

代碼段解析: const intersect = ((yi > y) !== (yj > y)) && (x < (xj - xi) * (y - yi) / (yj - yi) + xi); 第一部分:(yi > y) !== (yj > y) 作用:檢查點 (x,y) 的垂直位置是否跨越多邊形的當前邊。 yi > y 和 yj > y 分別檢查邊的兩個端…

【redis】集群 如何搭建集群詳解

文章目錄 集群搭建1. 創建目錄和配置2. 編寫 docker-compose.yml完整配置文件 3. 啟動容器4. 構建集群超時 集群搭建 基于 docker 在我們云服務器上搭建出一個 redis 集群出來 當前節點&#xff0c;主要是因為我們只有一個云服務器&#xff0c;搞分布式系統&#xff0c;就比較…

[langchain教程]langchain03——用langchain構建RAG應用

RAG RAG過程 離線過程&#xff1a; 加載文檔將文檔按一定條件切割成片段將切割的文本片段轉為向量&#xff0c;存入檢索引擎&#xff08;向量庫&#xff09; 在線過程&#xff1a; 用戶輸入Query&#xff0c;將Query轉為向量從向量庫檢索&#xff0c;獲得相似度TopN信息將…

C語言復習筆記--字符函數和字符串函數(下)

在上篇我們了解了部分字符函數及字符串函數,下面我們來看剩下的字符串函數. strstr 的使用和模擬實現 老規矩,我們先了解一下strstr這個函數,下面看下這個函數的函數原型. char * strstr ( const char * str1, const char * str2); 如果沒找到就返回NULL指針. 下面我們看下它的…

FreeRTOS中的優先級翻轉問題及其解決方案:互斥信號量詳解

FreeRTOS中的優先級翻轉問題及其解決方案&#xff1a;互斥信號量詳解 在實時操作系統中&#xff0c;任務調度是基于優先級的&#xff0c;高優先級任務應該優先于低優先級任務執行。但在實際應用中&#xff0c;有時會出現"優先級翻轉"的現象&#xff0c;嚴重影響系統…

深度學習-全連接神經網絡

四、參數初始化 神經網絡的參數初始化是訓練深度學習模型的關鍵步驟之一。初始化參數&#xff08;通常是權重和偏置&#xff09;會對模型的訓練速度、收斂性以及最終的性能產生重要影響。下面是關于神經網絡參數初始化的一些常見方法及其相關知識點。 官方文檔參考&#xff1…

GIS開發筆記(9)結合osg及osgEarth實現三維球經緯網格繪制及顯隱

一、實現效果 二、實現原理 按照5的間隔分別創建經緯線的節點,掛在到組合節點,組合節點掛接到根節點。可以根據需要設置間隔度數和線寬、線的顏色。 三、參考代碼 //創建經緯線的節點 osg::Node *GlobeWidget::createGraticuleGeometry(float interval, const osg::Vec4 …

《Relay IR的基石:expr.h 中的表達式類型系統剖析》

TVM Relay源碼深度解讀 文章目錄 TVM Relay源碼深度解讀一 、從Constant看Relay表達式的設計哲學1. 類定義概述2. ConstantNode 詳解1. 核心成員2. 關鍵方法3. 類型系統注冊 3. Constant 詳解1. 核心功能 二. 核心內容概述(1) Relay表達式基類1. RelayExprNode 和 RelayExpr 的…

自動駕駛地圖數據傳輸協議ADASIS v2

ADASIS&#xff08;Advanced Driver Assistance Systems Interface Specification&#xff09;直譯過來就是 ADAS 接口規格&#xff0c;它要負責的東西其實很簡單&#xff0c;就是為自動駕駛車輛提供前方道路交通相關的數據&#xff0c;這些數據被抽象成一個標準化的概念&#…

Flutter 狀態管理 Riverpod

Android Studio版本 Flutter SDK 版本 將依賴項添加到您的應用 flutter pub add flutter_riverpod flutter pub add riverpod_annotation flutter pub add dev:riverpod_generator flutter pub add dev:build_runner flutter pub add dev:custom_lint flutter pub add dev:riv…

【EasyPan】MySQL主鍵與索引核心作用解析

【EasyPan】項目常見問題解答&#xff08;自用&持續更新中…&#xff09;匯總版 MySQL主鍵與索引核心作用解析 一、主鍵&#xff08;PRIMARY KEY&#xff09;核心作用 1. 數據唯一標識 -- 創建表時定義主鍵 CREATE TABLE users (id INT AUTO_INCREMENT PRIMARY KEY,use…

IcePlayer音樂播放器項目分析及學習指南

IcePlayer音樂播放器項目分析及學習指南 項目概述 IcePlayer是一個基于Qt5框架開發的音樂播放器應用程序&#xff0c;使用Visual Studio 2013作為開發環境。該項目實現了音樂播放、歌詞顯示、專輯圖片獲取等功能&#xff0c;展現了桌面應用程序開發的核心技術和設計思想。 技…