Sliding Window Attention(滑動窗口注意力)解析
Sliding Window Attention(滑動窗口注意力) 是 Longformer (來源:https://arxiv.org/pdf/2004.05150)提出的 稀疏注意力機制,旨在解決 標準 Transformer 計算復雜度隨序列長度增加呈二次增長 的問題。它的核心思想是:
- 每個 token 僅關注局部窗口內的其他 token,而不是整個序列。
- 計算復雜度從 ( O ( n 2 ) O(n^2) O(n2)) 降至 ( O ( n ? w ) O(n \cdot w) O(n?w)),其中 ( w w w) 是窗口大小。
- 支持更長的文本處理,避免傳統 Transformer 處理長序列時的顯存和計算瓶頸。
該方法使 Transformer 能夠高效處理上千到上萬個 token 的長文本,特別適用于 文檔級任務,如長文摘要、法律文本分析、醫療文檔理解等。
1. 為什么需要 Sliding Window Attention?
1.1 傳統 Transformer 的問題
Transformer 的 自注意力(Self-Attention) 機制需要計算所有 token 之間的交互:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax} \left(\frac{QK^T}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dk??QKT?)V
其中:
- ( Q, K, V ) 分別是 查詢(Query)、鍵(Key)、值(Value) 矩陣,形狀為 ( n × d k n \times d_k n×dk? )。
- 計算量為 ( O ( n 2 d k ) O(n^2 d_k) O(n2dk?) ),隨著序列長度 ( n n n ) 增加,計算量急劇上升。
- 這導致 Transformer 無法處理長文本,因為顯存需求和計算復雜度都 隨 ( n 2 n^2 n2 ) 增長。
1.2 Sliding Window Attention 解決了什么問題?
- 局部注意力(Local Attention):每個 token 僅與附近窗口內的 token 交互,而不是整個序列。
- 計算復雜度降低:從 ( O ( n 2 ) O(n^2) O(n2) ) 降為 ( O ( n ? w ) O(n \cdot w) O(n?w) ),其中 ( w w w ) 是窗口大小。
- 顯存占用減少:只需要存儲窗口內的注意力權重,而非完整的 ( n × n n \times n n×n ) 矩陣。
這意味著,Sliding Window Attention 允許 Transformer 處理更長的序列,從傳統的 512 tokens 提高到 8K-16K tokens 甚至更長。
2. Sliding Window Attention 計算原理
2.1 標準 Transformer Attention
在標準 Transformer 結構中:
- 每個 token 計算所有其他 token 的注意力權重。
- 形成一個 ( n × n n \times n n×n ) 的注意力矩陣。
- 計算復雜度:( O ( n 2 d ) O(n^2 d) O(n2d) )。
2.2 Sliding Window Attention
在 Sliding Window Attention 結構中:
- 每個 token 僅與窗口內的其他 token 交互。
- 注意力矩陣變為稀疏矩陣,只有窗口大小 ( w w w ) 內的注意力權重被計算。
- 計算復雜度變為:( O ( n ? w ? d ) O(n \cdot w \cdot d) O(n?w?d) )。
示例:
- 設 ( w = 5 w = 5 w=5 ),則:
- 第 10 個 token 僅關注
[8, 9, 10, 11, 12]
。 - 第 20 個 token 僅關注
[18, 19, 20, 21, 22]
。 - 這樣每個 token 只計算 5 個注意力權重,而不是所有 n 個。
- 第 10 個 token 僅關注
3. Sliding Window Attention 的 PyTorch 實現
以下是 Longformer 的 Sliding Window Attention 計算的 PyTorch 實現:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SlidingWindowAttention(nn.Module):def __init__(self, embed_dim, num_heads, window_size):"""滑動窗口注意力機制Args:embed_dim: 詞嵌入維度 dnum_heads: 注意力頭的數量 hwindow_size: 滑動窗口大小 w"""super(SlidingWindowAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.window_size = window_sizeself.head_dim = embed_dim // num_heads # 每個頭的維度assert self.head_dim * num_heads == embed_dim, "embed_dim 必須是 num_heads 的整數倍"# 線性投影層self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)self.o_proj = nn.Linear(embed_dim, embed_dim, bias=False)def forward(self, x):"""前向傳播Args:x: 輸入張量 [batch, seq_len, embed_dim]Returns:輸出張量 [batch, seq_len, embed_dim]"""batch_size, seq_len, _ = x.shape# 計算 Q, K, VQ = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, seq, d]K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# 初始化注意力矩陣attn_scores = torch.full((batch_size, self.num_heads, seq_len, seq_len), float("-inf"), device=x.device)# 計算滑動窗口注意力for i in range(seq_len):start = max(0, i - self.window_size)end = min(seq_len, i + self.window_size + 1)attn_scores[:, :, i, start:end] = torch.matmul(Q[:, :, i, :], K[:, :, start:end, :].transpose(-2, -1))# 歸一化attn_scores /= self.head_dim ** 0.5attn_weights = F.softmax(attn_scores, dim=-1)# 計算注意力加權的 Valueoutput = torch.matmul(attn_weights, V)# 重新排列形狀output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)output = self.o_proj(output) # 線性變換回原始維度return output
4. Sliding Window Attention 的優缺點
? 優點
- 計算復雜度降低:從 ( O ( n 2 ) O(n^2) O(n2) ) 降至 ( O ( n ? w ) O(n \cdot w) O(n?w) )。
- 可擴展到長序列:支持 8K-16K tokens 甚至更長。
- 適用于文檔級任務:如長文摘要、法律分析、醫療 NLP 任務。
? 缺點
- 不能捕捉遠距離依賴:只能處理窗口范圍內的 token 交互。
- 需要全局注意力補充:必須結合 Global Attention 來補充遠程信息(例如
CLS
位置)。
5. 結論
- Sliding Window Attention 解決了 Transformer 計算復雜度隨序列長度二次增長的問題。
- 通過限制每個 token 只關注局部窗口內的 token,使得計算復雜度降低為 ( O ( n ? w ) O(n \cdot w) O(n?w) )。
- 適用于長文本處理,并可結合 Global Attention 進一步提升模型性能。
🚀 這是 Transformer 在長文本任務上的關鍵優化方案之一!
Sliding Window Attention 如何支持長序列,并處理遠距離依賴?
你提到的 “Sliding Window Attention 只能處理局部信息,那為什么還能支持 8K-16K 甚至更長的序列?” 這個問題很好。我們來詳細拆解這個機制,看看它如何 既能高效處理長序列,又能解決遠距離依賴問題。
1. Sliding Window Attention 適用于長序列的原因
(1) 計算復雜度降低
傳統 全自注意力(Full Self-Attention) 計算復雜度為:
O ( n 2 ? d ) O(n^2 \cdot d) O(n2?d)
即隨著序列長度 ( n n n ) 增加,計算量呈二次增長。例如:
- ( n = 1024 n = 1024 n=1024 ) 時,需要計算 百萬級 注意力分數。
- ( n = 8192 n = 8192 n=8192 ) 時,需要計算 千萬級 注意力分數,顯存消耗極大。
而 Sliding Window Attention 僅在局部窗口 ( w w w ) 內計算注意力:
O ( n ? w ? d ) O(n \cdot w \cdot d) O(n?w?d)
通常 ( w ? n w \ll n w?n )(如 ( w = 512 , n = 8192 w=512, n=8192 w=512,n=8192 )),計算復雜度大幅降低,使得 處理長序列成為可能。
(2) 通過層級疊加,間接傳播長距離信息
雖然單個 Sliding Window 只能看到局部范圍的 token,但 Transformer 具有多層結構,可以通過 層級疊加 逐步擴展信息傳播范圍。
示例:
- 設窗口大小 ( w = 512 ),模型有 12 層 Transformer。
- 第 1 層:每個 token 只看到 相鄰 512 個 token。
- 第 2 層:由于前一層已經融合了 512 個 token 信息,相當于 間接看到 1024 個 token。
- 第 3 層:可看到 1536 個 token。
- ……
- 第 12 層:最終可以捕捉 6144+ token 的信息。
這意味著,即使單層 Sliding Window 只能看到局部信息,但多層疊加后, 整個 Transformer 仍然能捕捉遠程依賴。
? 這類似 CNN 中的感受野(Receptive Field)擴展:
- 低層捕捉局部信息,高層逐步擴大感受野。
- 頂層的
CLS
token 可以聚合全局信息。
2. 如何進一步增強遠程依賴能力?
(1) 結合全局注意力(Global Attention)
Sliding Window 主要用于局部注意力,但為了處理關鍵任務位置(如 CLS
,任務相關實體),通常會額外增加 Global Attention:
Hybrid?Attention = Sliding?Window + Global?Attention \text{Hybrid Attention} = \text{Sliding Window} + \text{Global Attention} Hybrid?Attention=Sliding?Window+Global?Attention
- Global Attention 讓
CLS
token 直接看到所有位置,用于捕捉全局信息。 - 關鍵 token(如問題 token、摘要 token)可被全局注意力連接,使遠距離 token 之間的信息傳遞更高效。
(2) 結合 Dilated Attention(擴張窗口注意力)
為了提高遠程依賴能力,可以使用 Dilated Sliding Window Attention(擴張窗口注意力):
- 例如,窗口間隔
gap = 2
,每個 token 除了看到最近的 512 個 token,還能看到更遠的 token。 - 這種方法 類似 CNN 的 Dilated Convolution,可以擴大感受野,而不會增加太多計算量。
3. Sliding Window Attention 如何影響長文本任務?
-
適用于 8K+ 長文本摘要(Summarization)
- 長文摘要模型(如 Longformer)使用 Sliding Window + Global Attention,使
CLS
位置能整合全局信息。 - 例如:arXiv 論文摘要任務,輸入 16K tokens,模型仍然可以高效運行。
- 長文摘要模型(如 Longformer)使用 Sliding Window + Global Attention,使
-
適用于長文 QA(Long Document QA)
- 傳統 QA 需要截斷上下文(如 BERT 只能用 512 tokens)。
- Longformer 可以處理 8K+ tokens,保證所有信息被覆蓋,提升答案查找準確率。
-
適用于長文分類(Long Document Classification)
CLS
位置的 Global Attention 可以整合 8K+ tokens 的全局信息,提高分類準確度。
4. 結論
? Sliding Window Attention 可以擴展到 8K+ 序列,原因如下:
- 降低計算復雜度,從 ( O ( n 2 ) O(n^2) O(n2) ) 變為 ( O ( n ? w ) O(n \cdot w) O(n?w) ),可擴展到長文本。
- 通過 Transformer 層級堆疊,多層次傳遞信息,間接覆蓋全局依賴。
- 結合 Global Attention,讓關鍵 token 直接連接全局,提高遠程依賴建模能力。
- 結合 Dilated Attention(擴張窗口)可以進一步提升長距離信息傳播。
🔹 最終,Sliding Window Attention + Global Attention + Dilated Attention 讓 Transformer 既能高效處理長文本,又能捕捉全局依賴! 🚀
Hybrid Attention(滑動窗口注意力 + 全局注意力)解析
在 Longformer 等長序列 Transformer 結構中,Hybrid Attention 結合了 Sliding Window Attention(局部注意力)和 Global Attention(全局注意力),以同時實現:
- 高效計算:滑動窗口注意力降低計算復雜度到 ( O ( n ? w ) O(n \cdot w) O(n?w) ),適用于長文本。
- 全局依賴捕捉:全局注意力允許關鍵 token(如
CLS
、問題 token)能訪問所有 tokens,確保長距離信息流通。
1. 什么是 Global Attention?
在標準 全自注意力(Self-Attention) 機制中:
- 每個 token 計算所有 token 的注意力權重,計算復雜度 ( O ( n 2 ) O(n^2) O(n2) )。
- 但 Sliding Window Attention 僅計算局部窗口內的 token,無法直接建模遠程依賴。
Global Attention 的作用:
- 指定部分 token 作為“全局節點”,這些 token 可以訪問所有 tokens,同時 所有 tokens 也可以訪問這些全局節點。
- 一般用于關鍵任務相關 tokens,例如:
[CLS]
(分類任務)問題 tokens
(問答任務)摘要 tokens
(摘要任務)
Hybrid Attention 結合兩者的方式:
Hybrid?Attention = Sliding?Window?Attention + Global?Attention \text{Hybrid Attention} = \text{Sliding Window Attention} + \text{Global Attention} Hybrid?Attention=Sliding?Window?Attention+Global?Attention
- 大部分 tokens 采用 Sliding Window 計算注意力,計算復雜度為 ( O ( n ? w ) O(n \cdot w) O(n?w) )。
- 關鍵 tokens 采用 Global Attention,可以訪問整個序列,補充長距離信息。
2. Hybrid Attention 計算方法
假設:
window_size = 512
global_mask
指定哪些 token 需要全局注意力(如CLS
)。
計算步驟:
- 計算 Sliding Window Attention
- 每個 token 僅計算 窗口范圍內的注意力。
- 計算 Global Attention
- 只有被標記為全局注意力的 token 計算 全局 self-attention。
- 這些 token 可以訪問所有 tokens,所有 tokens 也可以訪問它們。
- 合并兩種注意力機制
- 局部 token 使用 Sliding Window Attention。
- 全局 token 額外加上 Global Attention 權重,確保遠程依賴信息傳遞。
3. PyTorch 實現(可運行)
下面是完整的 Hybrid Attention(滑動窗口注意力 + 全局注意力) 的 PyTorch 實現:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass HybridAttention(nn.Module):def __init__(self, embed_dim, num_heads, window_size):"""Hybrid Attention: Sliding Window Attention + Global AttentionArgs:embed_dim: 詞嵌入維度 dnum_heads: 注意力頭數量 hwindow_size: 滑動窗口大小 w"""super(HybridAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.window_size = window_sizeself.head_dim = embed_dim // num_heads # 每個頭的維度assert self.head_dim * num_heads == embed_dim, "embed_dim 必須是 num_heads 的整數倍"# 線性投影層self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)self.o_proj = nn.Linear(embed_dim, embed_dim, bias=False)def forward(self, x, global_mask):"""Args:x: 輸入張量 [batch, seq_len, embed_dim]global_mask: 是否是全局注意力的 mask [batch, seq_len],1 表示全局注意力,0 表示普通窗口注意力Returns:輸出張量 [batch, seq_len, embed_dim]"""batch_size, seq_len, _ = x.shape# 計算 Q, K, VQ = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, seq, d]K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# 初始化注意力分數attn_scores = torch.full((batch_size, self.num_heads, seq_len, seq_len), float("-inf"), device=x.device)# 計算 Sliding Window Attentionfor i in range(seq_len):start = max(0, i - self.window_size)end = min(seq_len, i + self.window_size + 1)attn_scores[:, :, i, start:end] = torch.matmul(Q[:, :, i, :], K[:, :, start:end, :].transpose(-2, -1))# 計算 Global Attentionglobal_mask_expanded = global_mask.unsqueeze(1).unsqueeze(1) # [batch, 1, 1, seq_len]global_indices = global_mask_expanded.expand(-1, self.num_heads, seq_len, -1) # 擴展到注意力頭維度attn_scores.masked_fill_(global_indices == 0, float("-inf")) # 讓非全局 token 只計算滑動窗口注意力global_attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # 全局 attention 計算attn_scores = torch.where(global_indices == 1, global_attn_scores, attn_scores) # 合并全局和局部注意力# 歸一化 softmaxattn_scores /= self.head_dim ** 0.5attn_weights = F.softmax(attn_scores, dim=-1)# 計算注意力加權的 Valueoutput = torch.matmul(attn_weights, V)# 重新排列形狀output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)output = self.o_proj(output) # 線性變換回原始維度return output
4. 代碼解讀
-
計算 Sliding Window Attention
- 只在窗口范圍內計算注意力分數,保證計算復雜度 ( O(n \cdot w) )。
-
計算 Global Attention
- 通過
global_mask
選擇需要全局注意力的 token(如CLS
)。 - 計算這些 token 與所有 tokens 之間的注意力分數。
- 通過
-
融合全局 & 局部注意力
- 使用
torch.where()
選擇是否應用全局注意力:- 全局 token:使用全局 self-attention 計算的權重。
- 局部 token:僅計算滑動窗口內的注意力。
- 使用
5. 運行示例
# 測試 Hybrid Attention
batch_size, seq_len, embed_dim, num_heads, window_size = 2, 16, 64, 8, 4
x = torch.randn(batch_size, seq_len, embed_dim)
global_mask = torch.zeros(batch_size, seq_len, dtype=torch.long) # 默認無全局注意力
global_mask[:, 0] = 1 # 讓 CLS 位置作為全局注意力hybrid_attn = HybridAttention(embed_dim, num_heads, window_size)
output = hybrid_attn(x, global_mask)
print(output.shape) # 預期: (batch_size, seq_len, embed_dim)
6. 結論
- Hybrid Attention 結合了 Sliding Window 和 Global Attention,使 Transformer 既高效又能捕捉遠程依賴。
- PyTorch 實現支持運行,適用于長文本任務(如長文摘要、QA、分類等)。
- 適用于 8K+ 長文本,提高推理效率,同時保持全局信息流通!🚀
Hybrid Attention 計算 Global Attention 詳細解析
代碼段
# 計算 Global Attention
global_mask_expanded = global_mask.unsqueeze(1).unsqueeze(1) # [batch, 1, 1, seq_len]
global_indices = global_mask_expanded.expand(-1, self.num_heads, seq_len, -1) # 擴展到注意力頭維度
attn_scores.masked_fill_(global_indices == 0, float("-inf")) # 讓非全局 token 只計算滑動窗口注意力
global_attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # 全局 attention 計算
attn_scores = torch.where(global_indices == 1, global_attn_scores, attn_scores) # 合并全局和局部注意力
1. 代碼解析
(1) 處理 global_mask
global_mask_expanded = global_mask.unsqueeze(1).unsqueeze(1) # [batch, 1, 1, seq_len]
作用
global_mask
是一個[batch, seq_len]
形狀的張量,標識哪些 token 是全局注意力。unsqueeze(1).unsqueeze(1)
作用是擴展維度,使其形狀變成[batch, 1, 1, seq_len]
,以便后續expand()
操作匹配attn_scores
形狀。
示例
假設 global_mask
:
global_mask = torch.tensor([[1, 0, 0, 0, 0], # Batch 0:CLS(位置 0)是全局 token[0, 0, 1, 0, 0] # Batch 1:位置 2 是全局 token
])
經過 unsqueeze()
變成:
global_mask_expanded = torch.tensor([[[[1, 0, 0, 0, 0]]], # Batch 0[[[0, 0, 1, 0, 0]]], # Batch 1
]) # 形狀 [batch=2, 1, 1, seq_len=5]
(2) 擴展 global_mask
維度以匹配 attn_scores
global_indices = global_mask_expanded.expand(-1, self.num_heads, seq_len, -1) # [batch, num_heads, seq_len, seq_len]
作用
expand(-1, self.num_heads, seq_len, -1)
擴展維度,使global_indices
形狀與attn_scores
匹配。
示例(假設 num_heads = 2
)
global_indices = torch.tensor([# Batch 0[[[1, 0, 0, 0, 0], # 頭 1[1, 0, 0, 0, 0]]], # 頭 2# Batch 1[[[0, 0, 1, 0, 0], # 頭 1[0, 0, 1, 0, 0]]] # 頭 2
]) # 形狀 [batch=2, num_heads=2, seq_len=5, seq_len=5]
- 現在,每個 batch 的
global_indices
里:1
表示全局注意力 token。0
表示普通 token。
(3) 讓普通 token 只計算滑動窗口內的注意力
attn_scores.masked_fill_(global_indices == 0, float("-inf"))
作用
- 讓 非全局 token 的注意力變成
-inf
,確保它們只能計算滑動窗口范圍的注意力。
示例
假設 attn_scores
初始值:
attn_scores = torch.tensor([[[10, 20, 30, 40, 50], [15, 25, 35, 45, 55]], [[5, 10, 15, 20, 25], [10, 15, 20, 25, 30]]
], dtype=torch.float32)
執行 masked_fill_()
后:
attn_scores = torch.tensor([[[10, -inf, -inf, -inf, -inf], [15, -inf, -inf, -inf, -inf]], [[-inf, -inf, 15, -inf, -inf], [-inf, -inf, 20, -inf, -inf]]
])
現在:
- 普通 token 的注意力分數變成
-inf
,它們只能計算滑動窗口范圍內的 token。 - 全局 token(如 CLS)不受影響,它們仍然可以訪問所有 token。
(4) 計算全局 Attention
global_attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # 計算全局注意力
作用
- 全局 token 應該能訪問所有 token,所以這里 計算完整的注意力分數矩陣。
Q @ K^T
計算所有Q
和K
之間的點積注意力。
(5) 合并全局和滑動窗口注意力
attn_scores = torch.where(global_indices == 1, global_attn_scores, attn_scores)
作用
global_indices == 1
位置使用 完整的全局注意力global_attn_scores
。- 其他位置仍然使用 滑動窗口注意力
attn_scores
。
示例
假設 global_attn_scores
:
global_attn_scores = torch.tensor([[[100, 110, 120, 130, 140], [105, 115, 125, 135, 145]], [[50, 60, 70, 80, 90], [55, 65, 75, 85, 95]]
])
執行 torch.where()
后:
attn_scores = torch.tensor([[[100, -inf, -inf, -inf, -inf], [105, -inf, -inf, -inf, -inf]], [[-inf, -inf, 70, -inf, -inf], [-inf, -inf, 75, -inf, -inf]]
])
現在:
- 全局 token (
global_indices == 1
) 采用全局注意力global_attn_scores
。 - 普通 token (
global_indices == 0
) 繼續使用滑動窗口注意力(-inf
表示屏蔽)。
6. 結論
代碼 | 作用 |
---|---|
global_mask_expanded = global_mask.unsqueeze(1).unsqueeze(1) | 擴展 global_mask 形狀,方便與 attn_scores 匹配 |
global_indices = global_mask_expanded.expand(-1, self.num_heads, seq_len, -1) | 復制 global_mask 到所有注意力頭 |
attn_scores.masked_fill_(global_indices == 0, float("-inf")) | 讓普通 token 只能訪問滑動窗口內的 token |
global_attn_scores = torch.matmul(Q, K.transpose(-2, -1)) | 計算完整的全局注意力分數 |
attn_scores = torch.where(global_indices == 1, global_attn_scores, attn_scores) | 讓全局 token 采用全局注意力,而普通 token 繼續使用滑動窗口 |
🚀 最終,我們的 Hybrid Attention 既能高效計算長文本,又能讓 CLS
等關鍵 token 訪問全局信息! 🎯
后記
2025年2月23日14點36分于上海,在GPT 4o大模型輔助下完成。