【深度學習】計算機視覺(18)——從應用到設計

文章目錄

  • 1 不同的注意力機制
    • 1.1 自注意力
    • 1.2 多頭注意力
    • 1.3 交叉注意力
      • 1.3.1 基礎
      • 1.3.2 進階

1 不同的注意力機制

在學習的過程中,發現有很多計算注意力的方法,例如行/列注意力、交叉注意力等,如果對注意力機制本身不是特別實現,很難進行自己的網絡設計。

1.1 自注意力

在這里插入圖片描述

又拿出這張快被我盤包漿的圖。假設輸入序列的維度為(batch_size, seq_len, d_model),通過線性變換矩陣 W Q , W K , W V ∈ R d m o d e l × d m o d e l W^Q, W^K, W^V ∈ \mathbb{R}^{d_{model}×d_{model}} WQ,WK,WVRdmodel?×dmodel?生成 Q Q Q/ K K K/ V V V,形狀為(batch_size, seq_len, d_model)。注意到, Q ? K T Q·K^T Q?KT再通過Softmax操作得到了Attention Map,是注意力權重矩陣(后續用 A A A表示)。通過之前的學習可以知道,注意力權重矩陣A的格式為二維矩陣,形狀為(batch_size, n,n),其中 n n n是輸入序列的長度(即token數量)。假設輸入序列長度為3,每個token的長度為4:
在這里插入圖片描述
那么 A A A中紅色格子表示第二個token與第三個token的關聯,即 A [ i ] [ j ] A[i][j] A[i][j]每個元素表示輸入序列中第 i i i個序列對第 j j j個序列的注意力權重。這里要注意,是否 A A A是一個以對角線為對稱軸的對稱矩陣呢?雖然 Q ? K T Q·K^T Q?KT是對稱的,但是經過Softmax后,每一行都會轉換為概率分布,這樣“位置3對位置2的影響”與“位置2對位置3的影響”就不同了。
接下來要計算 A ? V A·V A?V,表示每個位置綜合其他位置的加權求和。
在這里插入圖片描述

1.2 多頭注意力

若使用多頭注意力,只是列的長度發生改變,被均分成頭的數量。假設輸入序列的維度為(batch_size, seq_len, d_model),通過線性變換矩陣 W Q , W K , W V ∈ R d m o d e l × d k W^Q, W^K, W^V ∈ \mathbb{R}^{d_{model}×d_{k}} WQ,WK,WVRdmodel?×dk?生成 Q Q Q/ K K K/ V V V,形狀為(batch_size, seq_len, d_k),其中 d k = d m o d e l h d_k=\frac{d_model}{h} dk?=hdm?odel?(h為多頭注意力頭數)。

在多頭注意力(Multi-Head Attention)中, A A A的格式會擴展為四維張量:(batch_size, num_heads, n, n),batch_size表示樣本批次大小,num_heads表示注意力頭數,n表示序列長度。

1.3 交叉注意力

1.3.1 基礎

標準的自注意力機制中, Q Q Q/ K K K/ V V V通常由同一個輸入矩陣 x x x通過不同的線性變換生成。自注意力機制關注于單一輸入序列內部元素之間的關系,通過同源輸入捕捉序列內部依賴關系。

交叉注意力(Cross-Attention)則關注于兩個不同輸入序列之間的相互作用。 Q Q Q K K K可以分布來自不同的輸入序列,常見于編碼器-解碼器架構。

在Transformer模型中,CrossAttention通常用于編碼器和解碼器之間的交互。編碼器負責將輸入序列編碼為一系列特征向量,而解碼器則根據這些特征向量逐步生成輸出序列。為了使解碼器能夠更有效地利用編碼器的信息,CrossAttention層被引入其中。解碼器的每個位置會生成一個查詢向量(query),該向量用于在編碼器的所有位置進行注意力權重計算。編碼器的每個位置則生成一組鍵向量(keys)和值向量(values)。通過計算查詢向量與鍵向量的相似度,并經過softmax函數歸一化后,得到注意力權重。最后,注意力權重與值向量相乘并求和,得到編碼器調整后的輸出,供解碼器使用。

Q Q Q是來自解碼器的當前狀態(例如翻譯任務中的目標語言詞, K K K V V V是來自編碼器的輸出(例如源語言的特征)。
Softmax僅要求 Q Q Q K K K的維度匹配,并不限制來源。

假設 Q Q Q的輸入形狀為(batch_size, seq_len_q, d_model),seq_len_q為目標序列長度; K K K/ V V V的輸入形狀為(batch_size, seq_len_kv, d_model),seq_len_kv為源序列長度,輸出的形狀為(batch_size, h, seq_len_q, seq_len_kv)

最終輸出的注意力權重矩陣 A A A作用于 V V V矩陣,生成融合跨序列信息的輸出 O u t p u t = A ? V Output=A·V Output=A?V

Q Q Q由解碼器的自注意力層輸出生成,解碼器在生成目標序列的每一步時,會將已生成的部分序列通過掩碼自注意力層處理,生成當前步的上下文表示,這一表示作為 Q Q Q的輸入。

以機器翻譯為例,將“我喜歡狼”由中文翻譯成英文。每次生成一個詞,假設當前已經生成了"I"、“LIKE”,接下來要進行后面的詞的翻譯,如下圖。 x 1 ′ x1' x1是已經生成的上下文表示,由解碼器的自注意力層輸出。 x 2 ′ x2' x2是源序列“我喜歡狼”經編碼器輸出的特征向量。
在這里插入圖片描述
圖示紅色問號表示待生成的詞語,當生成第三個目標詞時,原矩陣新增一行,該行表示問號詞對源序列所有三個詞的關注權重,而該行的初始值是基于已生成的詞的嵌入向量和位置編碼生成。對于A,表示接下來要生成的詞與源序列的相關度,比如紅色陰影部分表示問號詞與“我”的語義依賴強度。

?強對齊??:目標詞與源詞存在直接翻譯關系(如"WOLVES"→"狼"),對應權重接近1。
??弱對齊??:目標詞依賴源序列的上下文(如生成冠詞 “THE” 時可能關注源序列的主語位置)。
??零權重??:源詞與當前目標詞無關(如生成英文標點時,權重集中于源序列的句尾詞)。

1.3.2 進階

編碼器的 K K K V V V在推理時是固定不變的,但解碼器的 Q Q Q隨著目標序列生成動態擴展。例如,生成“WOLVES”時, Q 3 Q_3 Q3?需與編碼器的 K K K計算相似度,而歷史 Q Q Q和編碼器的 K K K可能已經被緩存,然后只需要計算 ∑ j = 0 4 A 3 ? V j \sum_{j=0}^{4}A_3·V_j j=04?A3??Vj?即可。

訓練階段,不需要考慮生成詞的先后順序,模型并行處理整個目標序列而非逐詞生成,此時所有目標詞的 Q Q Q必須同時計算,以利用GPU的并行計算能力加速訓練。同時,需要通過反向傳播更新所有Q的權重矩陣 W Q W_Q WQ?,這要求通過計算完整的 Q Q Q矩陣計算所有注意力權重 A A A,才能正確更新權重矩陣 W Q W_Q WQ?,如果僅計算 Q 3 ? K T Q_3·K^T Q3??KT,將導致 W Q W_Q WQ?的梯度無法涵蓋歷史位置的語義關聯。

每個注意力頭的 W Q W_Q WQ?矩陣是固定維度的(d_model×d_k),將每行向量從原來指定的特征向量長度轉換為分多頭之后的特征向量長度,無論目標序列長度如何,所有 Q Q Q向量均通過同一組 W Q W_Q WQ?進行投影。這種設計使得模型能夠處理任意長度的目標序列,但要求所有 Q Q Q的投影邏輯一致。

解碼器隱狀態 H i H_i Hi?指的是解碼器在第 i i i步生成的動態時序表示,包含目標序列的生成進度(如已生成的詞數 i i i)、上下文語義、目標序列內部依賴,計算公式為 s i = g ( s i ? 1 , y i ? 1 , c ) s_i=g(s_{i-1}, y_{i-1}, c) si?=g(si?1?,yi?1?,c),其中 g g g是解碼器的更新函數, c c c是編碼器的上下文向量。


參考來源:
AIGC
深入理解CrossAttention:交叉注意力機制的奧秘
【深度學習】Cross-Attention(交叉注意力)機制詳解與應用

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/905265.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/905265.shtml
英文地址,請注明出處:http://en.pswp.cn/news/905265.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

洛谷 P1955 [NOI2015] 程序自動分析

【題目鏈接】 洛谷 P1955 [NOI2015] 程序自動分析 【題目考點】 1. 并查集 2. 離散化 【解題思路】 多組數據問題,對于每組數據,有多個 x i x j x_ix_j xi?xj?或 x i ≠ x j x_i \neq x_j xi?xj?的約束條件。 所有相等的變量構成一個集合&…

[Java] 輸入輸出方法+猜數字游戲

目錄 1. 輸入輸出方法 1.1 輸入方法 1.2 輸出方法 2. 猜數字游戲 1. 輸入輸出方法 Java中輸入和輸出是屬于Scanner類里面的方法,如果要使用這兩種方法需要引用Scanner類。 import java.util.Scanner; java.util 是Java里面的一個包,里面包含一些工…

zst-2001 上午題-歷年真題 UML(13個內容)

UML基礎 UML - 第1題 ad UML - 第2題 依賴是暫時使用對象,關聯是長期連接 依賴:依夜情 關聯:天長地久 組合:組一輩子樂隊 聚合:好聚好散 bd UML - 第3題 adc UML - 第4題 bad UML - 第5題 d UML…

WebFlux vs WebMVC vs Servlet 對比

WebFlux vs WebMVC vs Servlet 技術對比 WebFlux、WebMVC 和 Servlet 是 Java Web 開發中三種不同的技術架構,它們在編程模型、并發模型和適用場景上有顯著區別。以下是它們的核心對比: 核心區別總覽 特性ServletSpring WebMVCSpring WebFlux編程模型…

htmlUnit和Selenium的區別以及使用BrowserMobProxy捕獲網絡請求

1. Selenium:瀏覽器自動化之王 核心定位: 跨平臺、跨語言的瀏覽器操控框架,通過驅動真實瀏覽器實現像素級用戶行為模擬。 技術架構: 核心特性: 支持所有主流瀏覽器(含移動端模擬) 精…

SSRF相關

SSRF(Server Side Request Forgery,服務器端請求偽造),攻擊者以服務器的身份發送一條構造好的請求給服務器所在地內網進行探測或攻擊。 產生原理: 服務器端提供了能從其他服務器應用獲取數據的功能,如從指定url獲取網頁內容、加載指定地址的圖…

SaaS備份的必要性:廠商之外的數據保護策略

在當今數字化時代,企業對SaaS(軟件即服務)應用的依賴程度不斷攀升。SaaS應用為企業提供了便捷的生產力工具,然而,這也使得數據安全面臨諸多挑戰,如意外刪除、勒索軟件攻擊以及供應商故障等。因此&#xff0…

【Python 基礎語法】

Python 基礎語法是編程的基石,以下從核心要素到實用技巧進行系統梳理: 一、代碼結構規范 縮進規則 使用4個空格縮進(PEP 8標準)縮進定義代碼塊(如函數、循環、條件語句) def greet(name):if name: # 正確縮…

利用“Flower”實現聯邦機器學習的實戰指南

一個很尷尬的現狀就是我們用于訓練 AI 模型的數據快要用完了。所以我們在大量的使用合成數據! 據估計,目前公開可用的高質量訓練標記大約有 40 萬億到 90 萬億個,其中流行的 FineWeb 數據集包含 15 萬億個標記,僅限于英語。 作為…

自動化測試與功能測試詳解

🍅 點擊文末小卡片,免費獲取軟件測試全套資料,資料在手,漲薪更快 什么是自動化測試? 自動化測試是指利用軟件測試工具自動實現全部或部分測試,它是軟件測試的一個重要組成 部分,能完成許多手工測試無…

MySQL全量,增量備份與恢復

目錄 一.MySQL數據庫備份概述 1.數據備份的重要性 2.數據庫備份類型 3.常見的備份方法 二:數據庫完全備份操作 1.物理冷備份與恢復 2.mysqldump 備份與恢復 3.MySQL增量備份與恢復 3.1MySQL增量恢復 3.2MySQL備份案例 三:定制企業備份策略思路…

Ubuntu 安裝 Nginx

Nginx 是一個高性能的 Web 服務器和反向代理服務器,同時也可以用作負載均衡器和 HTTP 緩存。 Nginx 的主要用途 用途說明Web服務器提供網頁服務,處理用戶的 HTTP 請求,返回 HTML、CSS、JS、圖片等靜態資源。反向代理服務器將用戶請求轉發到…

人工智能 機器學習期末考試題

自測試卷2 一、選擇題 1.下面哪個屬性不是NumPy中數組的屬性( )。 A.ndim B.size C.shape D.add 2.一個簡單的Series是由( )的數據組成的。 A.兩…

使用阿里云CLI調用OpenAPI

介紹使用阿里云CLI調用OpenAPI的具體操作流程,包括安裝、配置憑證、生成并調用命令等步驟。 方案概覽 使用阿里云CLI調用OpenAPI,大致分為四個步驟: 安裝阿里云CLI:根據您使用設備的操作系統,選擇并安裝相應的版本。…

K8S Svc Port-forward 訪問方式

在 Kubernetes 中,kubectl port-forward 是一種 本地與集群內資源(Pod/Service)建立臨時網絡隧道 的訪問方式,無需暴露服務到公網,適合開發調試、臨時訪問等場景。以下是詳細使用方法及注意事項: 1. 基礎用…

23、DeepSeek-V2論文筆記

DeepSeek-V2 1、背景2、KV緩存優化2.0 KV緩存(Cache)的核心原理2.1 KV緩存優化2.2 性能對比2.3 架構2.4多頭注意力 (MHA)2.5 多頭潛在注意力 (MLA)2.5.1 低秩鍵值聯合壓縮 (Low-Rank Key-Value …

MySQL OCP試題解析(2)

試題如下圖所示: 一、題目背景還原 假設存在以下MySQL用戶權限配置: -- 創建本地會計用戶CREATE USER accountinglocalhost IDENTIFIED BY acc_123;-- 創建匿名代理用戶(用戶名為空,允許任意主機)CREATE USER % IDENTI…

深度學習Y7周:YOLOv8訓練自己數據集

🍨 本文為🔗365天深度學習訓練營中的學習記錄博客🍖 原作者:K同學啊 一、配置環境 1.官網下載源碼 2.安裝需要環境 二、準備好自己的數據 目錄結構: 主目錄 data images(存放圖片) annotati…

英偉達Blackwell架構重構未來:AI算力革命背后的技術邏輯與產業變革

——從芯片暴力美學到分布式智能體網絡,解析英偉達如何定義AI基礎設施新范式 開篇:當算力成為“新石油”,英偉達的“煉油廠”如何升級? 2025年3月,英偉達GTC大會上,黃仁勛身披標志性皮衣,宣布了…

CurrentHashMap的整體系統介紹及Java內存模型(JVM)介紹

當我們提到ConurrentHashMap時,先想到的就是HashMap不是線程安全的: 在多個線程共同操作HashMap時,會出現一個數據不一致的問題。 ConcurrentHashMap是HashMap的線程安全版本。 它通過在相應的方法上加鎖,來保證多線程情況下的…