文章目錄
- 1. 位置編碼概述
- 1.1 為什么需要位置編碼?
- 2. 絕對位置編碼 (Absolute Position Encoding)
- 2.1 原理
- 2.2 數學公式
- 2.3 代碼實現
- 2.4 代碼與公式的對應關系
- 2.5 特性與優勢
- 2.6 可學習的絕對位置編碼
- 3. 相對位置編碼 (Relative Position Encoding)
- 3.1 原理
- 3.2 數學公式
- 3.3 Shaw et al. (2018) 相對位置編碼
- 3.4 代碼與公式的對應關系
- 3.5 特性與優勢
- 3.6 帶相對位置的注意力計算
- 4. RoPE (Rotary Position Embedding)
- 4.1 原理
- 4.2 數學公式
- 4.3 公式的詳細推導
- 1. 旋轉向量點積的展開
- 2. 合并第二項和第三項
- 3. 合并第一項和第四項
- 4. 驗證相對位置依賴
- 4.5 代碼實現
- 4.6 代碼與公式的對應關系
- 4.7 特性與優勢
- 4.8 帶RoPE的注意力機制
- 5. 各種位置編碼對比
- 5.1 特點對比
- 5.2 性能測試代碼
- 6. 總結
1. 位置編碼概述
位置編碼是Transformer架構中的關鍵組件,用于為序列中的每個位置提供位置信息。由于自注意力機制本身是位置無關的,需要額外的位置信息來理解序列中元素的順序。
1.1 為什么需要位置編碼?
輸入序列: "我 愛 中 國"[1] [2] [3] [4]沒有位置編碼:自注意力機制無法區分詞語的順序
有位置編碼:每個位置都有唯一的位置標識
2. 絕對位置編碼 (Absolute Position Encoding)
2.1 原理
絕對位置編碼為序列中的每個位置分配一個唯一的編碼向量。最經典的是Transformer論文中的正弦余弦位置編碼。
2.2 數學公式
正弦位置編碼的數學公式如下:
對于位置 pospospos 和維度 iii:
-
當 iii 為偶數時:
PE(pos,2i)=sin?(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)PE(pos,2i)?=sin(100002i/dmodel?pos?) -
當 iii 為奇數時:
PE(pos,2i+1)=cos?(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)PE(pos,2i+1)?=cos(100002i/dmodel?pos?)
其中:
- pospospos 是序列中的位置(范圍:000 到 max_len?1max\_len-1max_len?1)
- iii 是編碼向量中的維度索引(范圍:000 到 dmodel?1d_{\text{model}}-1dmodel??1)
- dmodeld_{\text{model}}dmodel? 是模型的嵌入維度
2.3 代碼實現
import torch
import torch.nn as nn
import math
import numpy as np
import matplotlib.pyplot as pltclass SinusoidalPositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()self.d_model = d_model# 創建位置編碼矩陣pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1).float()# 計算除法項div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))# 應用正弦和余弦函數pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)# 注冊為緩沖區(不參與梯度更新)self.register_buffer('pe', pe.unsqueeze(0))def forward(self, x):# x shape: (batch_size, seq_len, d_model)seq_len = x.size(1)return x + self.pe[:, :seq_len]# 使用示例
d_model = 512
max_len = 100
pos_encoding = SinusoidalPositionalEncoding(d_model, max_len)# 模擬輸入
batch_size, seq_len = 2, 20
x = torch.randn(batch_size, seq_len, d_model)
output = pos_encoding(x)
print(f"輸入形狀: {x.shape}")
print(f"輸出形狀: {output.shape}")
2.4 代碼與公式的對應關系
-
創建位置編碼矩陣:
pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1).float()
這里生成了一個形狀為
(max_len, d_model)
的零矩陣,并準備好位置索引向量。 -
計算分母項:
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
這對應公式中的分母部分 100002i/dmodel10000^{2i/d_{\text{model}}}100002i/dmodel?,通過指數和對數運算轉換為:
exp?(?log?(10000)?2idmodel)\exp\left(-\frac{\log(10000) \cdot 2i}{d_{\text{model}}}\right)exp(?dmodel?log(10000)?2i?) -
應用正弦和余弦函數:
pe[:, 0::2] = torch.sin(position * div_term) # 偶數維度使用正弦 pe[:, 1::2] = torch.cos(position * div_term) # 奇數維度使用余弦
這直接對應公式中的正弦和余弦部分,分別應用于偶數和奇數維度。
-
注冊為緩沖區:
self.register_buffer('pe', pe.unsqueeze(0))
將位置編碼注冊為模型的緩沖區(不參與訓練),并添加批次維度。
-
前向傳播:
def forward(self, x):seq_len = x.size(1)return x + self.pe[:, :seq_len]
將位置編碼加到輸入張量上,只取與輸入序列長度匹配的部分。
2.5 特性與優勢
-
相對位置表示:正弦位置編碼能夠表達相對位置關系,因為對于任意固定偏移量 kkk,PEpos+kPE_{pos+k}PEpos+k? 可以表示為 PEposPE_{pos}PEpos? 的線性函數。
-
泛化能力:可以推廣到比訓練期間見過的更長的序列長度。
-
計算高效:無需學習參數,在推理時直接生成位置編碼。
-
梯度穩定性:由于使用固定函數生成,不會影響模型訓練的梯度流動。
2.6 可學習的絕對位置編碼
class LearnablePositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()self.pos_embedding = nn.Embedding(max_len, d_model)self.max_len = max_lendef forward(self, x):batch_size, seq_len, _ = x.shapepositions = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1)pos_encodings = self.pos_embedding(positions)return x + pos_encodings# 使用示例
learnable_pos = LearnablePositionalEncoding(d_model, max_len)
output_learnable = learnable_pos(x)
print(f"可學習位置編碼輸出形狀: {output_learnable.shape}")
3. 相對位置編碼 (Relative Position Encoding)
3.1 原理
相對位置編碼關注的是位置之間的相對關系,而不是絕對位置。這種方法在處理長序列時表現更好。
3.2 數學公式
相對位置編碼的核心思想是在注意力計算中引入相對位置信息。對于兩個位置 iii 和 jjj,其相對位置為 k=i?jk = i - jk=i?j,編碼公式主要體現在注意力得分的計算中:
-
標準自注意力公式(無相對位置):
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk??QKT?)V -
加入相對位置編碼的注意力公式:
Attention(Q,K,V)=softmax((Q+Rq)(K+Rk)Tdk)(V+Rv)\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{(Q + R_q)(K + R_k)^T}{\sqrt{d_k}}\right)(V + R_v)Attention(Q,K,V)=softmax(dk??(Q+Rq?)(K+Rk?)T?)(V+Rv?)其中:
- RqR_qRq?、RkR_kRk?、RvR_vRv? 分別是查詢(Query)、鍵(Key)、值(Value)的相對位置編碼矩陣;
- RkR_kRk? 和 RvR_vRv? 通常由相對位置索引 kkk 映射得到,即 Rk=Ek(k)R_k = E_k(k)Rk?=Ek?(k) 和 Rv=Ev(k)R_v = E_v(k)Rv?=Ev?(k),其中 EkE_kEk? 和 EvE_vEv? 是可學習的嵌入矩陣。
3.3 Shaw et al. (2018) 相對位置編碼
class RelativePositionEncoding(nn.Module):def __init__(self, d_model, max_relative_position=50):super().__init__()self.d_model = d_modelself.max_relative_position = max_relative_position# 相對位置嵌入vocab_size = 2 * max_relative_position + 1self.relative_position_k = nn.Embedding(vocab_size, d_model)self.relative_position_v = nn.Embedding(vocab_size, d_model)def get_relative_positions(self, seq_len):"""生成相對位置矩陣"""range_vec = torch.arange(seq_len)range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)distance_mat = range_mat - range_mat.transpose(0, 1)# 裁剪到最大相對位置distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)# 轉換為正數索引final_mat = distance_mat_clipped + self.max_relative_positionreturn final_matdef forward(self, query, key, value):seq_len = query.size(1)relative_positions = self.get_relative_positions(seq_len)# 獲取相對位置編碼relative_position_k_emb = self.relative_position_k(relative_positions)relative_position_v_emb = self.relative_position_v(relative_positions)return relative_position_k_emb, relative_position_v_emb# 使用示例
rel_pos_encoding = RelativePositionEncoding(d_model)
q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, seq_len, d_model)
v = torch.randn(batch_size, seq_len, d_model)rel_k, rel_v = rel_pos_encoding(q, k, v)
print(f"相對位置編碼K形狀: {rel_k.shape}")
print(f"相對位置編碼V形狀: {rel_v.shape}")
3.4 代碼與公式的對應關系
-
初始化嵌入層:
vocab_size = 2 * max_relative_position + 1 self.relative_position_k = nn.Embedding(vocab_size, d_model) self.relative_position_v = nn.Embedding(vocab_size, d_model)
vocab_size
對應所有可能的相對位置范圍(從-max_relative_position
到+max_relative_position
);relative_position_k
和relative_position_v
分別對應公式中的 EkE_kEk? 和 EvE_vEv?,用于將相對位置索引映射為嵌入向量。
-
生成相對位置矩陣:
def get_relative_positions(self, seq_len):range_vec = torch.arange(seq_len)range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)distance_mat = range_mat - range_mat.transpose(0, 1)distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)final_mat = distance_mat_clipped + self.max_relative_positionreturn final_mat
- 生成的
distance_mat
是所有位置對 (i,j)(i,j)(i,j) 的相對距離矩陣(即 k=i?jk = i - jk=i?j); clamp
操作將相對距離限制在預設范圍內,避免過遠的位置影響;final_mat
將相對距離轉換為非負索引(通過加上max_relative_position
),便于嵌入層查找。
- 生成的
-
獲取相對位置編碼:
def forward(self, query, key, value):seq_len = query.size(1)relative_positions = self.get_relative_positions(seq_len)relative_position_k_emb = self.relative_position_k(relative_positions)relative_position_v_emb = self.relative_position_v(relative_positions)return relative_position_k_emb, relative_position_v_emb
relative_position_k_emb
和relative_position_v_emb
分別對應公式中的 RkR_kRk? 和 RvR_vRv?;- 它們的形狀均為
(seq_len, seq_len, d_model)
,表示任意兩個位置之間的相對位置編碼。
3.5 特性與優勢
-
捕捉相對位置關系
相比絕對位置編碼(如Sinusoidal PE),相對位置編碼直接建模token對之間的距離,更適合捕捉序列中的結構信息(如語法依賴關系)。 -
參數高效
只需存儲有限范圍內的相對位置嵌入(通常為2*max_relative_position+1
個向量),而不是為每個絕對位置存儲一個向量。 -
泛化能力
對于長度超過訓練時所見的序列,仍能通過相對位置編碼處理,而絕對位置編碼可能超出預定義范圍。 -
靈活應用
可選擇性地應用于注意力機制的不同組件(如僅應用于Key,或同時應用于Key和Value),根據任務需求調整。 -
提升長序列性能
在長文本任務(如文檔摘要、長對話生成)中,相對位置編碼能更好地捕捉遠距離依賴關系。
3.6 帶相對位置的注意力計算
class RelativeMultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, max_relative_position=50):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsself.w_q = nn.Linear(d_model, d_model)self.w_k = nn.Linear(d_model, d_model)self.w_v = nn.Linear(d_model, d_model)self.w_o = nn.Linear(d_model, d_model)self.rel_pos_encoding = RelativePositionEncoding(self.d_k, max_relative_position)def forward(self, query, key, value, mask=None):batch_size, seq_len, _ = query.shape# 線性變換Q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)# 計算注意力分數attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)# 添加相對位置編碼rel_k, rel_v = self.rel_pos_encoding(query, key, value)rel_k = rel_k.unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1, 1)# 相對位置注意力rel_attention = torch.matmul(Q.unsqueeze(-2), rel_k.transpose(-2, -1)).squeeze(-2)attention_scores = attention_scores + rel_attention# 應用掩碼if mask is not None:attention_scores.masked_fill_(mask == 0, -1e9)# Softmaxattention_weights = torch.softmax(attention_scores, dim=-1)# 應用注意力權重context = torch.matmul(attention_weights, V)# 重新整形并通過輸出層context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)output = self.w_o(context)return output, attention_weights
4. RoPE (Rotary Position Embedding)
4.1 原理
RoPE(Rotary Positional Encoding)是一種基于旋轉機制的位置編碼方法,通過旋轉向量空間來隱式表達token間的相對位置關系。它在Transformer模型中取得了顯著效果,尤其是在長序列建模和語言理解任務中。
4.2 數學公式
RoPE的核心思想是通過旋轉操作將位置信息直接融入到向量表示中。對于位置 mmm 處的向量 qmq_mqm? 和位置 nnn 處的向量 knk_nkn?,RoPE的計算公式如下:
-
旋轉操作:
對于位置 mmm 和維度 ddd,將向量 qmq_mqm? 旋轉 θm\theta_mθm? 角度:
RoPE(qm,m)d=qm?cos?(mθd)+RotateHalf(qm)?sin?(mθd)\text{RoPE}(q_m, m)_d = q_m \cdot \cos(m\theta_d) + \text{RotateHalf}(q_m) \cdot \sin(m\theta_d) RoPE(qm?,m)d?=qm??cos(mθd?)+RotateHalf(qm?)?sin(mθd?)
其中:- θd=110000ddmodel\theta_d = \frac{1}{10000^{\frac{d}{d_{\text{model}}}}}θd?=10000dmodel?d?1? 是頻率參數;
- RotateHalf(x)\text{RotateHalf}(x)RotateHalf(x) 表示將向量 xxx 的前半部分與后半部分交換符號后拼接,即 [xd/2+1,xd/2+2,...,xd,?x1,?x2,...,?xd/2][x_{d/2+1}, x_{d/2+2}, ..., x_d, -x_1, -x_2, ..., -x_{d/2}][xd/2+1?,xd/2+2?,...,xd?,?x1?,?x2?,...,?xd/2?]。
-
在注意力機制中的應用:
RoPE通過以下方式改變注意力得分計算:Attention(qm,kn)=RoPE(qm,m)?RoPE(kn,n)\begin{aligned} \text{Attention}(q_m, k_n) &= \text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n) \end{aligned} Attention(qm?,kn?)?=RoPE(qm?,m)?RoPE(kn?,n)?
4.3 公式的詳細推導
1. 旋轉向量點積的展開
RoPE(qm,m)?RoPE(kn,n)=[qmcos?(mθ)+RotateHalf(qm)sin?(mθ)]?[kncos?(nθ)+RotateHalf(kn)sin?(nθ)]=(qmcos?(mθ))?(kncos?(nθ))+(qmcos?(mθ))?(RotateHalf(kn)sin?(nθ))+(RotateHalf(qm)sin?(mθ))?(kncos?(nθ))+(RotateHalf(qm)sin?(mθ))?(RotateHalf(kn)sin?(nθ))\begin{aligned} &\text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n) \\ =& \left[ q_m \cos(m\theta) + \text{RotateHalf}(q_m) \sin(m\theta) \right] \cdot \left[ k_n \cos(n\theta) + \text{RotateHalf}(k_n) \sin(n\theta) \right] \\ =& (q_m \cos(m\theta)) \cdot (k_n \cos(n\theta)) \\ &+ (q_m \cos(m\theta)) \cdot (\text{RotateHalf}(k_n) \sin(n\theta)) \\ &+ (\text{RotateHalf}(q_m) \sin(m\theta)) \cdot (k_n \cos(n\theta)) \\ &+ (\text{RotateHalf}(q_m) \sin(m\theta)) \cdot (\text{RotateHalf}(k_n) \sin(n\theta)) \end{aligned} ==?RoPE(qm?,m)?RoPE(kn?,n)[qm?cos(mθ)+RotateHalf(qm?)sin(mθ)]?[kn?cos(nθ)+RotateHalf(kn?)sin(nθ)](qm?cos(mθ))?(kn?cos(nθ))+(qm?cos(mθ))?(RotateHalf(kn?)sin(nθ))+(RotateHalf(qm?)sin(mθ))?(kn?cos(nθ))+(RotateHalf(qm?)sin(mθ))?(RotateHalf(kn?)sin(nθ))?
2. 合并第二項和第三項
根據 RotateHalf\text{RotateHalf}RotateHalf 的正交性:
q?RotateHalf(k)=?RotateHalf(q)?kq \cdot \text{RotateHalf}(k) = -\text{RotateHalf}(q) \cdot k q?RotateHalf(k)=?RotateHalf(q)?k
以及,三角函數的角度差公式
sin?(a?b)=sin?acos?b?cos?asin?b\sin(a-b) = \sin a \cos b - \cos a \sin b sin(a?b)=sinacosb?cosasinb
將上述展開式中的第二和第三項重寫:
第二項=qm?RotateHalf(kn)?cos?(mθ)sin?(nθ)=?RotateHalf(qm)?kn?cos?(mθ)sin?(nθ)\begin{aligned} \text{第二項} &= q_m \cdot \text{RotateHalf}(k_n) \cdot \cos(m\theta)\sin(n\theta) \\ &= -\text{RotateHalf}(q_m) \cdot k_n \cdot \cos(m\theta)\sin(n\theta) \end{aligned} 第二項?=qm??RotateHalf(kn?)?cos(mθ)sin(nθ)=?RotateHalf(qm?)?kn??cos(mθ)sin(nθ)?
第三項=RotateHalf(qm)?kn?sin?(mθ)cos?(nθ)\begin{aligned} \text{第三項} &= \text{RotateHalf}(q_m) \cdot k_n \cdot \sin(m\theta)\cos(n\theta) \end{aligned} 第三項?=RotateHalf(qm?)?kn??sin(mθ)cos(nθ)?
將第二和第三項合并:
第二項+第三項=RotateHalf(qm)?kn?[sin?(mθ)cos?(nθ)?cos?(mθ)sin?(nθ)]=RotateHalf(qm)?kn?sin?((m?n)θ)\begin{aligned} \text{第二項} + \text{第三項} &= \text{RotateHalf}(q_m) \cdot k_n \cdot \left[ \sin(m\theta)\cos(n\theta) - \cos(m\theta)\sin(n\theta) \right] \\ &= \text{RotateHalf}(q_m) \cdot k_n \cdot \sin\left((m-n)\theta\right) \end{aligned} 第二項+第三項?=RotateHalf(qm?)?kn??[sin(mθ)cos(nθ)?cos(mθ)sin(nθ)]=RotateHalf(qm?)?kn??sin((m?n)θ)?
3. 合并第一項和第四項
第一項=qm?kn?cos?(mθ)cos?(nθ)第四項=RotateHalf(qm)?RotateHalf(kn)?sin?(mθ)sin?(nθ)\begin{aligned} \text{第一項} &= q_m \cdot k_n \cdot \cos(m\theta)\cos(n\theta) \\ \text{第四項} &= \text{RotateHalf}(q_m) \cdot \text{RotateHalf}(k_n) \cdot \sin(m\theta)\sin(n\theta) \end{aligned} 第一項第四項?=qm??kn??cos(mθ)cos(nθ)=RotateHalf(qm?)?RotateHalf(kn?)?sin(mθ)sin(nθ)?
根據 RotateHalf\text{RotateHalf}RotateHalf 的旋轉后點積不變性:
RotateHalf(q)?RotateHalf(k)=q?k\text{RotateHalf}(q) \cdot \text{RotateHalf}(k) = q \cdot k RotateHalf(q)?RotateHalf(k)=q?k
以及,三角函數的角度差公式
cos(a?b)=cos?acos?b+sin?asin?bcos(a-b) = \cos a \cos b + \sin a \sin bcos(a?b)=cosacosb+sinasinb
第四項=RotateHalf(qm)?RotateHalf(kn)?sin?(mθ)sin?(nθ)=qm?kn?sin?(mθ)sin?(nθ)\begin{aligned} \text{第四項} &= \text{RotateHalf}(q_m) \cdot \text{RotateHalf}(k_n) \cdot \sin(m\theta)\sin(n\theta)\\ &= q_m \cdot k_n \cdot \sin(m\theta)\sin(n\theta) \end{aligned} 第四項?=RotateHalf(qm?)?RotateHalf(kn?)?sin(mθ)sin(nθ)=qm??kn??sin(mθ)sin(nθ)?
所以,可以將這兩項合并為:
第一項+第四項=qm?kn?cos?((m?n)θ)\begin{aligned} \text{第一項} + \text{第四項} &= q_m \cdot k_n \cdot \cos\left((m-n)\theta\right) \end{aligned} 第一項+第四項?=qm??kn??cos((m?n)θ)?
最終,ROPE上述表達式可簡化為:
Attention(qm,kn)=RoPE(qm,m)?RoPE(kn,n)=qm?kn?cos?((m?n)θ)+RotateHalf(qm)?kn?sin?((m?n)θ)\begin{aligned} \text{Attention}(q_m, k_n) &= \text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n)\\ &= q_m \cdot k_n \cdot \cos\left((m-n)\theta\right) + \text{RotateHalf}(q_m) \cdot k_n \cdot \sin\left((m-n)\theta\right) \end{aligned} Attention(qm?,kn?)?=RoPE(qm?,m)?RoPE(kn?,n)=qm??kn??cos((m?n)θ)+RotateHalf(qm?)?kn??sin((m?n)θ)?
4. 驗證相對位置依賴
最終表達式中的所有三角函數項均包含 (m?n)θ(m-n)\theta(m?n)θ,即只依賴于位置差 (m-n),而非單獨的 (m) 或 (n)。這表明:
- 相對位置信息被隱式編碼在注意力得分中
- 當 (m-n) 固定時,無論 (m) 和 (n) 的絕對位置如何變化,注意力得分保持不變
- 模型能夠通過這種機制學習到序列中的相對距離關系
4.5 代碼實現
class RoPEPositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000, base=10000):super().__init__()self.d_model = d_modelself.max_len = max_lenself.base = base# 預計算頻率inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))self.register_buffer('inv_freq', inv_freq)# 預計算位置編碼self._build_cache(max_len)def _build_cache(self, max_len):positions = torch.arange(max_len).float()angles = torch.outer(positions, self.inv_freq)# 計算sin和cossin_angles = torch.sin(angles)cos_angles = torch.cos(angles)# 存儲緩存self.register_buffer('sin_cached', sin_angles)self.register_buffer('cos_cached', cos_angles)def rotate_half(self, x):"""旋轉向量的一半維度"""x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]return torch.cat([-x2, x1], dim=-1)def forward(self, x, seq_len=None):if seq_len is None:seq_len = x.shape[-2]# 獲取sin和cos值sin = self.sin_cached[:seq_len, :].unsqueeze(0)cos = self.cos_cached[:seq_len, :].unsqueeze(0)# 擴展維度以匹配輸入if x.dim() == 4: # (batch, heads, seq_len, dim)sin = sin.unsqueeze(1)cos = cos.unsqueeze(1)# 重復sin和cos以匹配完整維度sin = sin.repeat_interleave(2, dim=-1)cos = cos.repeat_interleave(2, dim=-1)# 應用旋轉return x * cos + self.rotate_half(x) * sin# 使用示例
rope = RoPEPositionalEncoding(d_model)
x_rope = torch.randn(batch_size, seq_len, d_model)
output_rope = rope(x_rope)
print(f"RoPE輸出形狀: {output_rope.shape}")
4.6 代碼與公式的對應關系
-
預計算頻率參數:
inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model)) self.register_buffer('inv_freq', inv_freq)
這對應公式中的 θd=110000d/dmodel\theta_d = \frac{1}{10000^{d/d_{\text{model}}}}θd?=10000d/dmodel?1?,用于生成不同維度的旋轉頻率。
-
預計算位置角度的sin和cos值:
def _build_cache(self, max_len):positions = torch.arange(max_len).float()angles = torch.outer(positions, self.inv_freq)sin_angles = torch.sin(angles)cos_angles = torch.cos(angles)self.register_buffer('sin_cached', sin_angles)self.register_buffer('cos_cached', cos_angles)
angles
矩陣對應 mθdm\theta_dmθd?,即位置 mmm 在維度 ddd 上的旋轉角度;sin_cached
和cos_cached
分別存儲 sin?(mθd)\sin(m\theta_d)sin(mθd?) 和 cos?(mθd)\cos(m\theta_d)cos(mθd?),避免重復計算。
-
向量旋轉操作:
def rotate_half(self, x):x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]return torch.cat([-x2, x1], dim=-1)
實現了 RotateHalf(x)\text{RotateHalf}(x)RotateHalf(x) 操作,將向量后半部分取負后與前半部分拼接。
-
應用旋轉位置編碼:
def forward(self, x, seq_len=None):# 獲取對應位置的sin和cos值sin = self.sin_cached[:seq_len, :].unsqueeze(0)cos = self.cos_cached[:seq_len, :].unsqueeze(0)# 擴展維度以匹配輸入if x.dim() == 4: # (batch, heads, seq_len, dim)sin = sin.unsqueeze(1)cos = cos.unsqueeze(1)# 重復以匹配完整維度sin = sin.repeat_interleave(2, dim=-1)cos = cos.repeat_interleave(2, dim=-1)# 應用旋轉:x * cos + rotate_half(x) * sinreturn x * cos + self.rotate_half(x) * sin
這直接對應RoPE的核心公式:
RoPE(x,m)=x?cos?(mθ)+RotateHalf(x)?sin?(mθ)\text{RoPE}(x, m) = x \cdot \cos(m\theta) + \text{RotateHalf}(x) \cdot \sin(m\theta)RoPE(x,m)=x?cos(mθ)+RotateHalf(x)?sin(mθ)
4.7 特性與優勢
-
隱式相對位置編碼
RoPE通過旋轉操作隱式地將相對位置信息融入注意力計算,使得模型能夠更好地捕捉序列中的相對距離關系,優于傳統的絕對位置編碼。 -
旋轉不變性
RoPE保證了位置編碼的旋轉不變性,即對于任意向量 xxx 和位置偏移 kkk,有:
RoPE(x,m)?RoPE(y,m+k)=RoPE(x,0)?RoPE(y,k)\text{RoPE}(x, m) \cdot \text{RoPE}(y, m+k) = \text{RoPE}(x, 0) \cdot \text{RoPE}(y, k)RoPE(x,m)?RoPE(y,m+k)=RoPE(x,0)?RoPE(y,k)
這使得模型在不同位置上具有一致的表示能力。 -
無需額外參數
RoPE不需要像可學習位置編碼那樣引入大量額外參數,只需預計算 sin?\sinsin 和 cos?\coscos 值,計算效率高。 -
長序列建模能力
實驗表明,RoPE在長序列任務(如長文本生成、文檔級NLP)中表現優于Sinusoidal PE和絕對位置編碼,能夠更有效地捕捉遠距離依賴關系。 -
兼容性強
可以直接應用于現有的Transformer架構,無需修改模型的整體結構,易于集成到各種NLP系統中。
4.8 帶RoPE的注意力機制
class RoPEMultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, max_len=5000):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsself.w_q = nn.Linear(d_model, d_model)self.w_k = nn.Linear(d_model, d_model)self.w_v = nn.Linear(d_model, d_model)self.w_o = nn.Linear(d_model, d_model)self.rope = RoPEPositionalEncoding(self.d_k, max_len)def forward(self, query, key, value, mask=None):batch_size, seq_len, _ = query.shape# 線性變換Q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)# 應用RoPEQ = self.rope(Q)K = self.rope(K)# 計算注意力attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:attention_scores.masked_fill_(mask == 0, -1e9)attention_weights = torch.softmax(attention_scores, dim=-1)context = torch.matmul(attention_weights, V)# 重新整形context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)output = self.w_o(context)return output, attention_weights
5. 各種位置編碼對比
5.1 特點對比
編碼類型 | 優點 | 缺點 | 適用場景 |
---|---|---|---|
絕對位置編碼 | 簡單直觀,計算效率高 | 對長序列泛化能力差 | 固定長度序列 |
相對位置編碼 | 更好的泛化能力 | 計算復雜度高 | 需要處理可變長度序列 |
RoPE | 完美的長度外推能力 | 實現相對復雜 | 長序列,語言模型 |
5.2 性能測試代碼
import timedef benchmark_position_encodings():batch_size, seq_len, d_model = 32, 512, 512num_heads = 8# 創建模型models = {'Sinusoidal': SinusoidalPositionalEncoding(d_model),'Learnable': LearnablePositionalEncoding(d_model),'RoPE': RoPEPositionalEncoding(d_model)}x = torch.randn(batch_size, seq_len, d_model)# 基準測試for name, model in models.items():start_time = time.time()for _ in range(100):with torch.no_grad():output = model(x)end_time = time.time()print(f"{name}: {(end_time - start_time) * 1000:.2f}ms")# 運行基準測試
benchmark_position_encodings()
6. 總結
位置編碼是Transformer架構中的關鍵組件,不同類型的位置編碼各有特點:
- 絕對位置編碼:簡單高效,適用于固定長度序列
- 相對位置編碼:關注位置關系,泛化能力更強
- RoPE:通過旋轉矩陣優雅地處理位置信息,支持長度外推
選擇合適的位置編碼方式需要根據具體應用場景和性能需求來決定。現代大語言模型(如GPT、LLaMA等)普遍采用RoPE,因為它在處理長序列時表現出色。