Transformer 中 Self-Attention 的二次方復雜度(Quadratic Complexity )問題及改進方法:中英雙語

Transformer 中 Self-Attention 的二次方復雜度問題及改進方法

隨著大型語言模型(LLM)輸入序列長度的增加,Transformer 結構中的核心模塊——自注意力機制(Self-Attention) 的計算復雜度和內存消耗都呈現二次方增長。這不僅限制了模型處理長序列的能力,也成為訓練和推理階段的重要瓶頸。

本篇博客將詳細解釋 Transformer 中 Self-Attention 機制的二次方復雜度來源,結合代碼示例展示這一問題,并介紹一些常見的改進方法。


1. Self-Attention 機制簡介

原理與公式

在自注意力(Self-Attention)機制中,輸入序列 ( X ∈ R n × d X \in \mathbb{R}^{n \times d} XRn×d ) 被映射到三個向量:查詢(Query) ( Q Q Q )、鍵(Key) ( K K K ) 和 值(Value) ( V V V ),三者通過權重矩陣 ( W Q W_Q WQ? )、( W K W_K WK? )、( W V W_V WV? ) 得到:

Q = X W Q , K = X W K , V = X W V Q = X W_Q, \quad K = X W_K, \quad V = X W_V Q=XWQ?,K=XWK?,V=XWV?

自注意力輸出的計算公式為:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk? ?QKT?)V

  • ( n n n ) 是輸入序列的長度(token 數量)。
  • ( d d d ) 是輸入特征的維度。
  • ( d k d_k dk? ) 是鍵向量的維度(通常 ( d k = d / h d_k = d / h dk?=d/h ),其中 ( h h h ) 是多頭注意力的頭數)。

時間復雜度分析

從公式可以看出,自注意力機制中的關鍵操作是:

  1. ( Q K T Q K^T QKT ):查詢向量 ( Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} QRn×dk? ) 與鍵向量 ( K ∈ R n × d k K \in \mathbb{R}^{n \times d_k} KRn×dk? ) 相乘,得到 ( n × n n \times n n×n ) 的注意力分數矩陣。

    • 計算復雜度為 ( O ( n 2 d k ) O(n^2 d_k) O(n2dk?) )。
  2. softmax 操作:在 ( n × n n \times n n×n ) 的注意力矩陣上進行歸一化,復雜度為 ( O ( n 2 ) O(n^2) O(n2) )。

  3. 注意力分數與 ( V V V ) 相乘:將 ( n × n n \times n n×n ) 的注意力分數矩陣與 ( V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} VRn×dv? ) 相乘,復雜度為 ( O ( n 2 d v ) O(n^2 d_v) O(n2dv?) )。

綜上,自注意力機制的時間復雜度為:

O ( n 2 d k + n 2 + n 2 d v ) ≈ O ( n 2 d ) O(n^2 d_k + n^2 + n^2 d_v) \approx O(n^2 d) O(n2dk?+n2+n2dv?)O(n2d)

  • 當 ( d d d ) 是常數時,復雜度主要取決于輸入序列的長度 ( n n n ),即呈二次方增長

空間復雜度分析

自注意力的注意力分數矩陣 ( Q K T Q K^T QKT ) 具有 ( n × n n \times n n×n ) 的大小,需要 ( O ( n 2 ) O(n^2) O(n2) ) 的內存進行存儲。


2. 代碼示例:計算復雜度與空間消耗

以下代碼展示了輸入序列長度增加時,自注意力機制的時間和空間消耗情況:

import torch
import time# 定義自注意力機制
def self_attention(Q, K, V):attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(torch.tensor(Q.shape[-1], dtype=torch.float32))attention_weights = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_weights, V)return output# 測試輸入序列長度不同的時間復雜度
def test_attention_complexity():d_k = 64  # 特征維度for n in [128, 256, 512, 1024, 2048]:  # 輸入序列長度Q = torch.randn((1, n, d_k))  # QueryK = torch.randn((1, n, d_k))  # KeyV = torch.randn((1, n, d_k))  # Valuestart_time = time.time()output = self_attention(Q, K, V)end_time = time.time()print(f"Sequence Length: {n}, Time Taken: {end_time - start_time:.6f} seconds, Output Shape: {output.shape}")if __name__ == "__main__":test_attention_complexity()

運行結果示例

Sequence Length: 128, Time Taken: 0.001200 seconds, Output Shape: torch.Size([1, 128, 64])
Sequence Length: 256, Time Taken: 0.004500 seconds, Output Shape: torch.Size([1, 256, 64])
Sequence Length: 512, Time Taken: 0.015800 seconds, Output Shape: torch.Size([1, 512, 64])
Sequence Length: 1024, Time Taken: 0.065200 seconds, Output Shape: torch.Size([1, 1024, 64])
Sequence Length: 2048, Time Taken: 0.260000 seconds, Output Shape: torch.Size([1, 2048, 64])

從結果可以看出,隨著序列長度的增加,計算時間呈現明顯的二次方增長。


3. 二次方復雜度的改進方法

為了減少自注意力機制的計算復雜度,許多研究者提出了優化方案,主要包括:

1. 低秩近似方法

利用低秩矩陣分解減少 ( Q K T Q K^T QKT ) 的計算復雜度,例如:

  • Linformer:將 ( n × n n \times n n×n ) 的注意力矩陣通過低秩分解近似為 ( n × k n \times k n×k )(其中 ( k ? n k \ll n k?n )),復雜度降為 ( O ( n k ) O(nk) O(nk) )。

2. 稀疏注意力(Sparse Attention)

  • LongformerBigBird:通過引入局部窗口和全局注意力機制,僅計算部分注意力分數,避免完整的 ( Q K T Q K^T QKT ) 計算,將復雜度降低為 ( O ( n log ? n ) O(n \log n) O(nlogn) ) 或 ( O ( n ) O(n) O(n) )。

3. 線性注意力(Linear Attention)

  • Performer:使用核技巧將自注意力計算轉化為線性操作,復雜度降為 ( O ( n d ) O(n d) O(nd) )。

4. 分塊方法(Blockwise Attention)

將輸入序列分成多個塊,僅在塊內或塊間進行注意力計算,適用于長序列任務。


4. 總結

在 Transformer 的自注意力機制中,由于需要計算 ( Q K T Q K^T QKT ) 和存儲 ( n × n n \times n n×n ) 的注意力矩陣,其時間和空間復雜度均為 ( O ( n 2 ) O(n^2) O(n2) )。這對于處理長序列任務(如長文本、DNA 序列分析等)來說是一個顯著的挑戰。

為了解決這一問題,近年來提出了多種優化方法,包括低秩近似、稀疏注意力、線性注意力等,成功將復雜度從 ( O ( n 2 ) O(n^2) O(n2) ) 降低到 ( O ( n ) O(n) O(n) ) 或 ( O ( n log ? n ) O(n \log n) O(nlogn) ),從而使 Transformer 更加高效地處理長序列任務。

代碼示例和實驗結果清楚地展示了二次方復雜度的實際影響,同時也強調了優化方法的重要性。

英文版

The Quadratic Complexity of Self-Attention in Transformers and Possible Improvements

The core of the Transformer architecture in large language models (LLMs) is the self-attention mechanism. While it has proven revolutionary, its computational complexity and memory requirements grow quadratically as the input sequence length increases. This blog will explain the source of this quadratic complexity, demonstrate it with code, and discuss possible optimization methods.


1. Understanding Self-Attention

Mathematical Formulation

Given an input sequence ( X ∈ R n × d X \in \mathbb{R}^{n \times d} XRn×d ) with ( n n n ) tokens and ( d d d ) features, the self-attention mechanism computes the query (Q), key (K), and value (V) matrices as follows:

Q = X W Q , K = X W K , V = X W V Q = X W_Q, \quad K = X W_K, \quad V = X W_V Q=XWQ?,K=XWK?,V=XWV?

The output of the self-attention mechanism is calculated as:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk? ?QKT?)V

Where:

  • ( n n n ): Sequence length
  • ( d d d ): Feature dimension
  • ( d k d_k dk? ): Dimension of queries/keys (typically ( d k = d / h d_k = d/h dk?=d/h ) for multi-head attention with ( h h h ) heads)

Time Complexity Analysis

The computational bottlenecks of self-attention are:

  1. Computing ( Q K T Q K^T QKT ):
    The query matrix ( Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} QRn×dk? ) is multiplied with the transposed key matrix ( K T ∈ R d k × n K^T \in \mathbb{R}^{d_k \times n} KTRdk?×n ), producing an ( n × n n \times n n×n ) attention score matrix.
    Complexity: ( O ( n 2 d k ) O(n^2 d_k) O(n2dk?) ).

  2. Softmax Operation:
    Softmax normalization is applied along each row of the ( n × n n \times n n×n ) attention matrix.
    Complexity: ( O ( n 2 ) O(n^2) O(n2) ).

  3. Computing Weighted Values:
    The ( n × n n \times n n×n ) attention scores are multiplied by the value matrix ( V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} VRn×dv? ).
    Complexity: ( O ( n 2 d v ) O(n^2 d_v) O(n2dv?) ).

Combining all these steps, the overall time complexity of self-attention is:

O ( n 2 d ) O(n^2 d) O(n2d)

When ( d d d ) is fixed (a constant), the complexity primarily depends on ( n n n ), making it quadratic.


Space Complexity

The attention score matrix ( Q K T Q K^T QKT ) has a size of ( n × n n \times n n×n ), requiring ( O ( n 2 ) O(n^2) O(n2) ) memory to store. This quadratic memory cost limits the model’s ability to handle long sequences.


2. Code Demonstration: Quadratic Complexity in Practice

The following code measures the computation time of self-attention as the input sequence length increases:

import torch
import time# Self-attention function
def self_attention(Q, K, V):attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(torch.tensor(Q.shape[-1], dtype=torch.float32))attention_weights = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_weights, V)return output# Test different sequence lengths
def test_attention_complexity():d_k = 64  # Feature dimensionfor n in [128, 256, 512, 1024, 2048]:  # Sequence lengthsQ = torch.randn((1, n, d_k))  # QueryK = torch.randn((1, n, d_k))  # KeyV = torch.randn((1, n, d_k))  # Valuestart_time = time.time()output = self_attention(Q, K, V)end_time = time.time()print(f"Sequence Length: {n}, Time Taken: {end_time - start_time:.6f} seconds, Output Shape: {output.shape}")if __name__ == "__main__":test_attention_complexity()

Example Output

Sequence Length: 128, Time Taken: 0.001200 seconds, Output Shape: torch.Size([1, 128, 64])
Sequence Length: 256, Time Taken: 0.004500 seconds, Output Shape: torch.Size([1, 256, 64])
Sequence Length: 512, Time Taken: 0.015800 seconds, Output Shape: torch.Size([1, 512, 64])
Sequence Length: 1024, Time Taken: 0.065200 seconds, Output Shape: torch.Size([1, 1024, 64])
Sequence Length: 2048, Time Taken: 0.260000 seconds, Output Shape: torch.Size([1, 2048, 64])

From the output, it is clear that the computation time increases quadratically with the sequence length ( n ).


3. Solutions to Address the Quadratic Complexity

To address the inefficiency of quadratic complexity, several optimization methods have been proposed:

1. Low-Rank Approximation

Techniques like Linformer approximate the ( n × n n \times n n×n ) attention matrix using low-rank decomposition:

  • Complexity is reduced to ( O ( n k ) O(n k) O(nk) ), where ( k ? n k \ll n k?n ).

2. Sparse Attention

Sparse attention mechanisms, such as Longformer and BigBird, compute attention only for selected tokens (e.g., local windows or global tokens):

  • Complexity is reduced to ( O ( n log ? n ) O(n \log n) O(nlogn) ) or ( O ( n ) O(n) O(n) ).

3. Linear Attention

Linear attention, such as in Performer, uses kernel functions to approximate the attention mechanism, avoiding the ( Q K T Q K^T QKT ) operation:

  • Complexity becomes ( O ( n d ) O(n d) O(nd) ).

4. Blockwise and Sliding-Window Attention

Divide the input sequence into smaller chunks or sliding windows and compute attention locally within each block:

  • This approach significantly reduces the computational cost for long sequences.

4. Summary

The self-attention mechanism in Transformer models has a time and space complexity of ( O ( n 2 d ) O(n^2 d) O(n2d)), which grows quadratically with sequence length. This becomes a bottleneck for long input sequences, such as lengthy documents or DNA sequences.

Through our code example, we demonstrated the quadratic increase in computational time as the sequence length grows. To address this limitation, several optimizations—such as low-rank approximations, sparse attention, and linear attention—have been introduced to scale Transformers to longer sequences efficiently.

By understanding and leveraging these methods, we can improve the efficiency of self-attention and unlock the potential of Transformers for applications involving extremely long sequences.

后記

2024年12月17日22點26分于上海,在GPT4o大模型輔助下完成。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/63230.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/63230.shtml
英文地址,請注明出處:http://en.pswp.cn/web/63230.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

模型 A/B測試(科學驗證)

系列文章 分享 模型,了解更多👉 模型_思維模型目錄。控制變量法。 1 A/B測試的應用 1.1 Electronic Arts(EA)《模擬城市》5游戲網站A/B測試 定義目標: Electronic Arts(EA)在發布新版《模擬城…

Java修飾符詳解:從基礎到高級用法

在Java編程語言中,有許多修飾符可以使用,它們大致可以分為兩大類:訪問控制修飾符、其他類型的修飾符。 這些修飾符主要用于指定類、方法或變量的特性,并且通常位于聲明語句的開頭部分。下面通過一些示例來進一步說明這一點&#…

onnx文件轉pytorch pt模型文件

onnx文件轉pytorch pt模型文件 1.onnx2torch轉換及測試2.存在問題參考文獻 從pytorch格式轉onnx格式,官方有成熟的API;那么假如只有onnx格式的模型文件,該怎樣轉回pytorch格式? https://github.com/ENOT-AutoDL/onnx2torch提供了…

Git merge 和 rebase的區別(附圖)

在 Git 中,merge 和 rebase 是兩種用于整合分支變化的方法。雖然它們都可以將一個分支的更改引入到另一個分支中,但它們的工作方式和結果是不同的。以下是對這兩者的詳細解釋: Git Merge 功能:合并分支,將兩個分支的…

【Web】0基礎學Web—js運算符、選擇結構、循環結構

0基礎學Web—js運算符、選擇結構、循環結構 js運算符選擇結構循環結構 js運算符 算術運算符: - * / %取余 賦值運算符: - * / % 單目運算符: i i --i i– 單獨使用是自增1 或 自減1 如果被使用&#xff0c;先看到啥先操作啥 比較運算符&#xff1a; > 、 >、 < 、…

系列3:基于Centos-8.6 Kubernetes使用nfs掛載pod的應用日志文件

每日禪語 古代&#xff0c;一位官員被革職遣返&#xff0c;心中苦悶無處排解&#xff0c;便來到一位禪師的法堂。禪師靜靜地聽完了此人的傾訴&#xff0c;將他帶入自己的禪房之中。禪師指著桌上的一瓶水&#xff0c;微笑著對官員說&#xff1a;?“你看這瓶水&#xff0c;它已經…

tkdiff安裝:Linux下文本對比工具

tkdiff在Linux下源碼安裝 1.下載解壓2.編譯安裝3.配置環境變量4.驗證及運行 本文&#xff0c;在Linux下使用源碼安裝tkdiff工具&#xff0c;以tkdiff-4.2版本為例&#xff0c;其他版本根據需要替換即可。 1.下載解壓 去 http://sourceforge.net/projects/tkdiff/files/tkdiff…

耐蝕鎳基合金的焊接技術與質量控制

耐蝕鎳基合金是一類在腐蝕環境中具有優異性能的合金材料&#xff0c;廣泛應用于化工、海洋工程、石油天然氣等領域。其焊接技術與質量控制對于確保合金的使用性能和安全性至關重要。以下是對耐蝕鎳基合金焊接技術與質量控制的詳細探討。 一、焊接技術 焊條選擇 耐蝕鎳基合金的焊…

Django REST framework(DRF)在處理不同請求方法時的完整流程

文章目錄 一、POST 請求創建對象的流程二、GET 請求獲取對象列表的流程三、GET 請求獲取單個對象的流程四、PUT/PATCH 請求更新對象的流程五、自定義方法的流程自定義 GET 方法自定義 POST 方法 一、POST 請求創建對象的流程 請求到達視圖層 方法調用&#xff1a; dispatch說明…

機器視覺與OpenCV--01篇

計算機眼中的圖像 像素 像素是圖像的基本單位&#xff0c;每個像素存儲著圖像的顏色、亮度或者其他特征&#xff0c;一張圖片就是由若干個像素組成的。 RGB 在計算機中&#xff0c;RGB三種顏色被稱為RGB三通道&#xff0c;且每個通道的取值都是0到255之間。 計算機中圖像的…

qemu源碼解析【03】qom實例

目錄 qemu源碼解析【03】qom實例arm_sbcon_i2c實例 qemu源碼解析【03】qom實例 arm_sbcon_i2c實例 以hw/i2c/arm_sbcon_i2c.c代碼為例&#xff0c;這個實例很簡單&#xff0c;只用100行左右的代碼&#xff0c;調用qemu系統接口實現了一個i2c硬件模擬先看include/hw/i2c/arm_s…

小程序自定義tab-bar,踩坑記錄

從官方下載代碼 https://developers.weixin.qq.com/miniprogram/dev/framework/ability/custom-tabbar.html 1、把custom-tab-bar 文件放置 pages同級 修改下 custom-tab-bar 下的 JS文件 Component({data: {selected: 0,color: "#7A7E83",selectedColor: "#3…

操作系統(14)請求分頁

前言 操作系統中的請求分頁&#xff0c;也稱為頁式虛擬存儲管理&#xff0c;是建立在基本分頁基礎上&#xff0c;為了支持虛擬存儲器功能而增加了請求調頁功能和頁面置換功能的一種內存管理技術。 一、基本概念 分頁&#xff1a;將進程的邏輯地址空間分成若干個大小相等的頁&am…

git企業開發的相關理論(一)

目錄 一.初識git 二.git的安裝 三.初始化/創建本地倉庫 四.配置用戶設置/配置本地倉庫 五.認識工作區、暫存區、版本庫 六.添加文件__場景一 七.查看 .git 文件/添加到本地倉庫后.git中發生的變化 1.執行git add后的變化 index文件&#xff08;暫存區&#xff09; log…

wxpython圖形用戶界面編程

wxpython圖形用戶界面編程 一、wxpython的基礎 1.1 wxpython的基礎 作為圖形用戶界面開發工具包 wxPython&#xff0c;主要提供了如下 GUI 內容&#xff1a; 窗口。控件。事件處理。布局管理。 1.2 wxpython的類層次機構 1.3 wxpython的安裝 Windows 和 macOS 平臺安裝&a…

水仙花數(流程圖,NS流程圖)

題目&#xff1a;打印出所有的100-999之間的"水仙花數"&#xff0c;并畫出流程圖和NS流程圖。所謂"水仙花數"是指一個三位數&#xff0c;其各位數字立方和等于該數本身。例如&#xff1a;153是一個"水仙花數"&#xff0c;因為1531的三次方&#…

不配置python環境,直接用PyCharm就可以?

有的伙伴可能遇到不安裝python環境只安裝pycharm也可以進行運行代碼。 所以自認為是不需要解釋器就可以運行&#xff1f; 這個是不現實的&#xff0c;有很多伙伴可能是安裝了Pycharm&#xff0c;但Pycharm看你電腦上沒有解釋器&#xff0c;所以在安裝的時候給你默認安裝在C盤…

網絡安全滲透測試概論

滲透測試&#xff0c;也稱為滲透攻擊測試是一種通過模擬惡意攻擊者的手段來評估計算機系統、網絡或應用程序安全性的方法。 目的 旨在主動發現系統中可能存在的安全漏洞、脆弱點以及潛在風險&#xff0c;以便在被真正的惡意攻擊者利用之前&#xff0c;及時進行修復和加固&…

爬蟲數據能用于商業嗎?

在當今數字化時代&#xff0c;數據已成為企業獲取競爭優勢的關鍵資源。網絡爬蟲作為一種數據收集工具&#xff0c;能夠從互聯網上抓取大量數據&#xff0c;這些數據在商業分析中扮演著重要角色。然而&#xff0c;使用爬蟲技術獲取的數據是否合法、能否用于商業分析&#xff0c;…

前端面試匯總(不定時更新)

目錄 HTML & CSS1. XML、HTML、XHTML 有什么區別&#xff1f;?2. XML和JSON的區別&#xff1f;3. 是否了解W3C的規范&#xff1f;?4. 什么是語義化標簽&#xff1f;??5. 行內元素和塊級元素的區別&#xff1f;?6. 行內元素和塊級元素的轉換&#xff1f;?7. 常用的塊級…