位置編碼/絕對位置編碼/相對位置編碼/Rope原理+公式詳細推導及代碼實現

文章目錄

    • 1. 位置編碼概述
      • 1.1 為什么需要位置編碼?
    • 2. 絕對位置編碼 (Absolute Position Encoding)
      • 2.1 原理
      • 2.2 數學公式
      • 2.3 代碼實現
      • 2.4 代碼與公式的對應關系
      • 2.5 特性與優勢
      • 2.6 可學習的絕對位置編碼
    • 3. 相對位置編碼 (Relative Position Encoding)
      • 3.1 原理
      • 3.2 數學公式
      • 3.3 Shaw et al. (2018) 相對位置編碼
      • 3.4 代碼與公式的對應關系
      • 3.5 特性與優勢
      • 3.6 帶相對位置的注意力計算
    • 4. RoPE (Rotary Position Embedding)
      • 4.1 原理
      • 4.2 數學公式
      • 4.3 公式的詳細推導
        • 1. 旋轉向量點積的展開
        • 2. 合并第二項和第三項
        • 3. 合并第一項和第四項
        • 4. 驗證相對位置依賴
      • 4.5 代碼實現
      • 4.6 代碼與公式的對應關系
      • 4.7 特性與優勢
      • 4.8 帶RoPE的注意力機制
    • 5. 各種位置編碼對比
      • 5.1 特點對比
      • 5.2 性能測試代碼
    • 6. 總結

1. 位置編碼概述

位置編碼是Transformer架構中的關鍵組件,用于為序列中的每個位置提供位置信息。由于自注意力機制本身是位置無關的,需要額外的位置信息來理解序列中元素的順序。

1.1 為什么需要位置編碼?

輸入序列: "我  愛  中   國"[1] [2] [3] [4]沒有位置編碼:自注意力機制無法區分詞語的順序
有位置編碼:每個位置都有唯一的位置標識

2. 絕對位置編碼 (Absolute Position Encoding)

2.1 原理

絕對位置編碼為序列中的每個位置分配一個唯一的編碼向量。最經典的是Transformer論文中的正弦余弦位置編碼。

2.2 數學公式

正弦位置編碼的數學公式如下:

對于位置 pospospos 和維度 iii

  • iii 為偶數時:
    PE(pos,2i)=sin?(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)PE(pos,2i)?=sin(100002i/dmodel?pos?)

  • iii 為奇數時:
    PE(pos,2i+1)=cos?(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)PE(pos,2i+1)?=cos(100002i/dmodel?pos?)

其中:

  • pospospos 是序列中的位置(范圍:000max_len?1max\_len-1max_len?1
  • iii 是編碼向量中的維度索引(范圍:000dmodel?1d_{\text{model}}-1dmodel??1
  • dmodeld_{\text{model}}dmodel? 是模型的嵌入維度

2.3 代碼實現

import torch
import torch.nn as nn
import math
import numpy as np
import matplotlib.pyplot as pltclass SinusoidalPositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()self.d_model = d_model# 創建位置編碼矩陣pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1).float()# 計算除法項div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))# 應用正弦和余弦函數pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)# 注冊為緩沖區(不參與梯度更新)self.register_buffer('pe', pe.unsqueeze(0))def forward(self, x):# x shape: (batch_size, seq_len, d_model)seq_len = x.size(1)return x + self.pe[:, :seq_len]# 使用示例
d_model = 512
max_len = 100
pos_encoding = SinusoidalPositionalEncoding(d_model, max_len)# 模擬輸入
batch_size, seq_len = 2, 20
x = torch.randn(batch_size, seq_len, d_model)
output = pos_encoding(x)
print(f"輸入形狀: {x.shape}")
print(f"輸出形狀: {output.shape}")

2.4 代碼與公式的對應關系

  1. 創建位置編碼矩陣

    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1).float()
    

    這里生成了一個形狀為 (max_len, d_model) 的零矩陣,并準備好位置索引向量。

  2. 計算分母項

    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
    

    這對應公式中的分母部分 100002i/dmodel10000^{2i/d_{\text{model}}}100002i/dmodel?,通過指數和對數運算轉換為:
    exp?(?log?(10000)?2idmodel)\exp\left(-\frac{\log(10000) \cdot 2i}{d_{\text{model}}}\right)exp(?dmodel?log(10000)?2i?)

  3. 應用正弦和余弦函數

    pe[:, 0::2] = torch.sin(position * div_term)  # 偶數維度使用正弦
    pe[:, 1::2] = torch.cos(position * div_term)  # 奇數維度使用余弦
    

    這直接對應公式中的正弦和余弦部分,分別應用于偶數和奇數維度。

  4. 注冊為緩沖區

    self.register_buffer('pe', pe.unsqueeze(0))
    

    將位置編碼注冊為模型的緩沖區(不參與訓練),并添加批次維度。

  5. 前向傳播

    def forward(self, x):seq_len = x.size(1)return x + self.pe[:, :seq_len]
    

    將位置編碼加到輸入張量上,只取與輸入序列長度匹配的部分。

2.5 特性與優勢

  1. 相對位置表示:正弦位置編碼能夠表達相對位置關系,因為對于任意固定偏移量 kkkPEpos+kPE_{pos+k}PEpos+k? 可以表示為 PEposPE_{pos}PEpos? 的線性函數。

  2. 泛化能力:可以推廣到比訓練期間見過的更長的序列長度。

  3. 計算高效:無需學習參數,在推理時直接生成位置編碼。

  4. 梯度穩定性:由于使用固定函數生成,不會影響模型訓練的梯度流動。

2.6 可學習的絕對位置編碼

class LearnablePositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()self.pos_embedding = nn.Embedding(max_len, d_model)self.max_len = max_lendef forward(self, x):batch_size, seq_len, _ = x.shapepositions = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1)pos_encodings = self.pos_embedding(positions)return x + pos_encodings# 使用示例
learnable_pos = LearnablePositionalEncoding(d_model, max_len)
output_learnable = learnable_pos(x)
print(f"可學習位置編碼輸出形狀: {output_learnable.shape}")

3. 相對位置編碼 (Relative Position Encoding)

3.1 原理

相對位置編碼關注的是位置之間的相對關系,而不是絕對位置。這種方法在處理長序列時表現更好。

3.2 數學公式

相對位置編碼的核心思想是在注意力計算中引入相對位置信息。對于兩個位置 iiijjj,其相對位置為 k=i?jk = i - jk=i?j,編碼公式主要體現在注意力得分的計算中:

  1. 標準自注意力公式(無相對位置):
    Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk??QKT?)V

  2. 加入相對位置編碼的注意力公式
    Attention(Q,K,V)=softmax((Q+Rq)(K+Rk)Tdk)(V+Rv)\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{(Q + R_q)(K + R_k)^T}{\sqrt{d_k}}\right)(V + R_v)Attention(Q,K,V)=softmax(dk??(Q+Rq?)(K+Rk?)T?)(V+Rv?)

    其中:

    • RqR_qRq?RkR_kRk?RvR_vRv? 分別是查詢(Query)、鍵(Key)、值(Value)的相對位置編碼矩陣;
    • RkR_kRk?RvR_vRv? 通常由相對位置索引 kkk 映射得到,即 Rk=Ek(k)R_k = E_k(k)Rk?=Ek?(k)Rv=Ev(k)R_v = E_v(k)Rv?=Ev?(k),其中 EkE_kEk?EvE_vEv? 是可學習的嵌入矩陣。

3.3 Shaw et al. (2018) 相對位置編碼

class RelativePositionEncoding(nn.Module):def __init__(self, d_model, max_relative_position=50):super().__init__()self.d_model = d_modelself.max_relative_position = max_relative_position# 相對位置嵌入vocab_size = 2 * max_relative_position + 1self.relative_position_k = nn.Embedding(vocab_size, d_model)self.relative_position_v = nn.Embedding(vocab_size, d_model)def get_relative_positions(self, seq_len):"""生成相對位置矩陣"""range_vec = torch.arange(seq_len)range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)distance_mat = range_mat - range_mat.transpose(0, 1)# 裁剪到最大相對位置distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)# 轉換為正數索引final_mat = distance_mat_clipped + self.max_relative_positionreturn final_matdef forward(self, query, key, value):seq_len = query.size(1)relative_positions = self.get_relative_positions(seq_len)# 獲取相對位置編碼relative_position_k_emb = self.relative_position_k(relative_positions)relative_position_v_emb = self.relative_position_v(relative_positions)return relative_position_k_emb, relative_position_v_emb# 使用示例
rel_pos_encoding = RelativePositionEncoding(d_model)
q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, seq_len, d_model)
v = torch.randn(batch_size, seq_len, d_model)rel_k, rel_v = rel_pos_encoding(q, k, v)
print(f"相對位置編碼K形狀: {rel_k.shape}")
print(f"相對位置編碼V形狀: {rel_v.shape}")

3.4 代碼與公式的對應關系

  1. 初始化嵌入層

    vocab_size = 2 * max_relative_position + 1
    self.relative_position_k = nn.Embedding(vocab_size, d_model)
    self.relative_position_v = nn.Embedding(vocab_size, d_model)
    
    • vocab_size 對應所有可能的相對位置范圍(從 -max_relative_position+max_relative_position);
    • relative_position_krelative_position_v 分別對應公式中的 EkE_kEk?EvE_vEv?,用于將相對位置索引映射為嵌入向量。
  2. 生成相對位置矩陣

    def get_relative_positions(self, seq_len):range_vec = torch.arange(seq_len)range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)distance_mat = range_mat - range_mat.transpose(0, 1)distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)final_mat = distance_mat_clipped + self.max_relative_positionreturn final_mat
    
    • 生成的 distance_mat 是所有位置對 (i,j)(i,j)(i,j) 的相對距離矩陣(即 k=i?jk = i - jk=i?j);
    • clamp 操作將相對距離限制在預設范圍內,避免過遠的位置影響;
    • final_mat 將相對距離轉換為非負索引(通過加上 max_relative_position),便于嵌入層查找。
  3. 獲取相對位置編碼

    def forward(self, query, key, value):seq_len = query.size(1)relative_positions = self.get_relative_positions(seq_len)relative_position_k_emb = self.relative_position_k(relative_positions)relative_position_v_emb = self.relative_position_v(relative_positions)return relative_position_k_emb, relative_position_v_emb
    
    • relative_position_k_embrelative_position_v_emb 分別對應公式中的 RkR_kRk?RvR_vRv?
    • 它們的形狀均為 (seq_len, seq_len, d_model),表示任意兩個位置之間的相對位置編碼。

3.5 特性與優勢

  1. 捕捉相對位置關系
    相比絕對位置編碼(如Sinusoidal PE),相對位置編碼直接建模token對之間的距離,更適合捕捉序列中的結構信息(如語法依賴關系)。

  2. 參數高效
    只需存儲有限范圍內的相對位置嵌入(通常為 2*max_relative_position+1 個向量),而不是為每個絕對位置存儲一個向量。

  3. 泛化能力
    對于長度超過訓練時所見的序列,仍能通過相對位置編碼處理,而絕對位置編碼可能超出預定義范圍。

  4. 靈活應用
    可選擇性地應用于注意力機制的不同組件(如僅應用于Key,或同時應用于Key和Value),根據任務需求調整。

  5. 提升長序列性能
    在長文本任務(如文檔摘要、長對話生成)中,相對位置編碼能更好地捕捉遠距離依賴關系。

3.6 帶相對位置的注意力計算

class RelativeMultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, max_relative_position=50):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsself.w_q = nn.Linear(d_model, d_model)self.w_k = nn.Linear(d_model, d_model)self.w_v = nn.Linear(d_model, d_model)self.w_o = nn.Linear(d_model, d_model)self.rel_pos_encoding = RelativePositionEncoding(self.d_k, max_relative_position)def forward(self, query, key, value, mask=None):batch_size, seq_len, _ = query.shape# 線性變換Q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)# 計算注意力分數attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)# 添加相對位置編碼rel_k, rel_v = self.rel_pos_encoding(query, key, value)rel_k = rel_k.unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1, 1)# 相對位置注意力rel_attention = torch.matmul(Q.unsqueeze(-2), rel_k.transpose(-2, -1)).squeeze(-2)attention_scores = attention_scores + rel_attention# 應用掩碼if mask is not None:attention_scores.masked_fill_(mask == 0, -1e9)# Softmaxattention_weights = torch.softmax(attention_scores, dim=-1)# 應用注意力權重context = torch.matmul(attention_weights, V)# 重新整形并通過輸出層context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)output = self.w_o(context)return output, attention_weights

4. RoPE (Rotary Position Embedding)

4.1 原理

RoPE(Rotary Positional Encoding)是一種基于旋轉機制的位置編碼方法,通過旋轉向量空間來隱式表達token間的相對位置關系。它在Transformer模型中取得了顯著效果,尤其是在長序列建模和語言理解任務中。

4.2 數學公式

RoPE的核心思想是通過旋轉操作將位置信息直接融入到向量表示中。對于位置 mmm 處的向量 qmq_mqm? 和位置 nnn 處的向量 knk_nkn?,RoPE的計算公式如下:

  1. 旋轉操作
    對于位置 mmm 和維度 ddd,將向量 qmq_mqm? 旋轉 θm\theta_mθm? 角度:
    RoPE(qm,m)d=qm?cos?(mθd)+RotateHalf(qm)?sin?(mθd)\text{RoPE}(q_m, m)_d = q_m \cdot \cos(m\theta_d) + \text{RotateHalf}(q_m) \cdot \sin(m\theta_d) RoPE(qm?,m)d?=qm??cos(mθd?)+RotateHalf(qm?)?sin(mθd?)
    其中:

    • θd=110000ddmodel\theta_d = \frac{1}{10000^{\frac{d}{d_{\text{model}}}}}θd?=10000dmodel?d?1? 是頻率參數;
    • RotateHalf(x)\text{RotateHalf}(x)RotateHalf(x) 表示將向量 xxx 的前半部分與后半部分交換符號后拼接,即 [xd/2+1,xd/2+2,...,xd,?x1,?x2,...,?xd/2][x_{d/2+1}, x_{d/2+2}, ..., x_d, -x_1, -x_2, ..., -x_{d/2}][xd/2+1?,xd/2+2?,...,xd?,?x1?,?x2?,...,?xd/2?]
  2. 在注意力機制中的應用
    RoPE通過以下方式改變注意力得分計算:

    Attention(qm,kn)=RoPE(qm,m)?RoPE(kn,n)\begin{aligned} \text{Attention}(q_m, k_n) &= \text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n) \end{aligned} Attention(qm?,kn?)?=RoPE(qm?,m)?RoPE(kn?,n)?

4.3 公式的詳細推導

1. 旋轉向量點積的展開

RoPE(qm,m)?RoPE(kn,n)=[qmcos?(mθ)+RotateHalf(qm)sin?(mθ)]?[kncos?(nθ)+RotateHalf(kn)sin?(nθ)]=(qmcos?(mθ))?(kncos?(nθ))+(qmcos?(mθ))?(RotateHalf(kn)sin?(nθ))+(RotateHalf(qm)sin?(mθ))?(kncos?(nθ))+(RotateHalf(qm)sin?(mθ))?(RotateHalf(kn)sin?(nθ))\begin{aligned} &\text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n) \\ =& \left[ q_m \cos(m\theta) + \text{RotateHalf}(q_m) \sin(m\theta) \right] \cdot \left[ k_n \cos(n\theta) + \text{RotateHalf}(k_n) \sin(n\theta) \right] \\ =& (q_m \cos(m\theta)) \cdot (k_n \cos(n\theta)) \\ &+ (q_m \cos(m\theta)) \cdot (\text{RotateHalf}(k_n) \sin(n\theta)) \\ &+ (\text{RotateHalf}(q_m) \sin(m\theta)) \cdot (k_n \cos(n\theta)) \\ &+ (\text{RotateHalf}(q_m) \sin(m\theta)) \cdot (\text{RotateHalf}(k_n) \sin(n\theta)) \end{aligned} ==?RoPE(qm?,m)?RoPE(kn?,n)[qm?cos(mθ)+RotateHalf(qm?)sin(mθ)]?[kn?cos(nθ)+RotateHalf(kn?)sin(nθ)](qm?cos(mθ))?(kn?cos(nθ))+(qm?cos(mθ))?(RotateHalf(kn?)sin(nθ))+(RotateHalf(qm?)sin(mθ))?(kn?cos(nθ))+(RotateHalf(qm?)sin(mθ))?(RotateHalf(kn?)sin(nθ))?

2. 合并第二項和第三項

根據 RotateHalf\text{RotateHalf}RotateHalf 的正交性:
q?RotateHalf(k)=?RotateHalf(q)?kq \cdot \text{RotateHalf}(k) = -\text{RotateHalf}(q) \cdot k q?RotateHalf(k)=?RotateHalf(q)?k
以及,三角函數的角度差公式
sin?(a?b)=sin?acos?b?cos?asin?b\sin(a-b) = \sin a \cos b - \cos a \sin b sin(a?b)=sinacosb?cosasinb

將上述展開式中的第二和第三項重寫:
第二項=qm?RotateHalf(kn)?cos?(mθ)sin?(nθ)=?RotateHalf(qm)?kn?cos?(mθ)sin?(nθ)\begin{aligned} \text{第二項} &= q_m \cdot \text{RotateHalf}(k_n) \cdot \cos(m\theta)\sin(n\theta) \\ &= -\text{RotateHalf}(q_m) \cdot k_n \cdot \cos(m\theta)\sin(n\theta) \end{aligned} 第二項?=qm??RotateHalf(kn?)?cos(mθ)sin(nθ)=?RotateHalf(qm?)?kn??cos(mθ)sin(nθ)?

第三項=RotateHalf(qm)?kn?sin?(mθ)cos?(nθ)\begin{aligned} \text{第三項} &= \text{RotateHalf}(q_m) \cdot k_n \cdot \sin(m\theta)\cos(n\theta) \end{aligned} 第三項?=RotateHalf(qm?)?kn??sin(mθ)cos(nθ)?

將第二和第三項合并:
第二項+第三項=RotateHalf(qm)?kn?[sin?(mθ)cos?(nθ)?cos?(mθ)sin?(nθ)]=RotateHalf(qm)?kn?sin?((m?n)θ)\begin{aligned} \text{第二項} + \text{第三項} &= \text{RotateHalf}(q_m) \cdot k_n \cdot \left[ \sin(m\theta)\cos(n\theta) - \cos(m\theta)\sin(n\theta) \right] \\ &= \text{RotateHalf}(q_m) \cdot k_n \cdot \sin\left((m-n)\theta\right) \end{aligned} 第二項+第三項?=RotateHalf(qm?)?kn??[sin(mθ)cos(nθ)?cos(mθ)sin(nθ)]=RotateHalf(qm?)?kn??sin((m?n)θ)?

3. 合并第一項和第四項

第一項=qm?kn?cos?(mθ)cos?(nθ)第四項=RotateHalf(qm)?RotateHalf(kn)?sin?(mθ)sin?(nθ)\begin{aligned} \text{第一項} &= q_m \cdot k_n \cdot \cos(m\theta)\cos(n\theta) \\ \text{第四項} &= \text{RotateHalf}(q_m) \cdot \text{RotateHalf}(k_n) \cdot \sin(m\theta)\sin(n\theta) \end{aligned} 第一項第四項?=qm??kn??cos(mθ)cos(nθ)=RotateHalf(qm?)?RotateHalf(kn?)?sin(mθ)sin(nθ)?

根據 RotateHalf\text{RotateHalf}RotateHalf 的旋轉后點積不變性:
RotateHalf(q)?RotateHalf(k)=q?k\text{RotateHalf}(q) \cdot \text{RotateHalf}(k) = q \cdot k RotateHalf(q)?RotateHalf(k)=q?k
以及,三角函數的角度差公式
cos(a?b)=cos?acos?b+sin?asin?bcos(a-b) = \cos a \cos b + \sin a \sin bcos(a?b)=cosacosb+sinasinb
第四項=RotateHalf(qm)?RotateHalf(kn)?sin?(mθ)sin?(nθ)=qm?kn?sin?(mθ)sin?(nθ)\begin{aligned} \text{第四項} &= \text{RotateHalf}(q_m) \cdot \text{RotateHalf}(k_n) \cdot \sin(m\theta)\sin(n\theta)\\ &= q_m \cdot k_n \cdot \sin(m\theta)\sin(n\theta) \end{aligned} 第四項?=RotateHalf(qm?)?RotateHalf(kn?)?sin(mθ)sin(nθ)=qm??kn??sin(mθ)sin(nθ)?
所以,可以將這兩項合并為:
第一項+第四項=qm?kn?cos?((m?n)θ)\begin{aligned} \text{第一項} + \text{第四項} &= q_m \cdot k_n \cdot \cos\left((m-n)\theta\right) \end{aligned} 第一項+第四項?=qm??kn??cos((m?n)θ)?

最終,ROPE上述表達式可簡化為:
Attention(qm,kn)=RoPE(qm,m)?RoPE(kn,n)=qm?kn?cos?((m?n)θ)+RotateHalf(qm)?kn?sin?((m?n)θ)\begin{aligned} \text{Attention}(q_m, k_n) &= \text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n)\\ &= q_m \cdot k_n \cdot \cos\left((m-n)\theta\right) + \text{RotateHalf}(q_m) \cdot k_n \cdot \sin\left((m-n)\theta\right) \end{aligned} Attention(qm?,kn?)?=RoPE(qm?,m)?RoPE(kn?,n)=qm??kn??cos((m?n)θ)+RotateHalf(qm?)?kn??sin((m?n)θ)?

4. 驗證相對位置依賴

最終表達式中的所有三角函數項均包含 (m?n)θ(m-n)\theta(m?n)θ,即只依賴于位置差 (m-n),而非單獨的 (m) 或 (n)。這表明:

  • 相對位置信息被隱式編碼在注意力得分中
  • 當 (m-n) 固定時,無論 (m) 和 (n) 的絕對位置如何變化,注意力得分保持不變
  • 模型能夠通過這種機制學習到序列中的相對距離關系

4.5 代碼實現

class RoPEPositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000, base=10000):super().__init__()self.d_model = d_modelself.max_len = max_lenself.base = base# 預計算頻率inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))self.register_buffer('inv_freq', inv_freq)# 預計算位置編碼self._build_cache(max_len)def _build_cache(self, max_len):positions = torch.arange(max_len).float()angles = torch.outer(positions, self.inv_freq)# 計算sin和cossin_angles = torch.sin(angles)cos_angles = torch.cos(angles)# 存儲緩存self.register_buffer('sin_cached', sin_angles)self.register_buffer('cos_cached', cos_angles)def rotate_half(self, x):"""旋轉向量的一半維度"""x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]return torch.cat([-x2, x1], dim=-1)def forward(self, x, seq_len=None):if seq_len is None:seq_len = x.shape[-2]# 獲取sin和cos值sin = self.sin_cached[:seq_len, :].unsqueeze(0)cos = self.cos_cached[:seq_len, :].unsqueeze(0)# 擴展維度以匹配輸入if x.dim() == 4:  # (batch, heads, seq_len, dim)sin = sin.unsqueeze(1)cos = cos.unsqueeze(1)# 重復sin和cos以匹配完整維度sin = sin.repeat_interleave(2, dim=-1)cos = cos.repeat_interleave(2, dim=-1)# 應用旋轉return x * cos + self.rotate_half(x) * sin# 使用示例
rope = RoPEPositionalEncoding(d_model)
x_rope = torch.randn(batch_size, seq_len, d_model)
output_rope = rope(x_rope)
print(f"RoPE輸出形狀: {output_rope.shape}")

4.6 代碼與公式的對應關系

  1. 預計算頻率參數

    inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
    self.register_buffer('inv_freq', inv_freq)
    

    這對應公式中的 θd=110000d/dmodel\theta_d = \frac{1}{10000^{d/d_{\text{model}}}}θd?=10000d/dmodel?1?,用于生成不同維度的旋轉頻率。

  2. 預計算位置角度的sin和cos值

    def _build_cache(self, max_len):positions = torch.arange(max_len).float()angles = torch.outer(positions, self.inv_freq)sin_angles = torch.sin(angles)cos_angles = torch.cos(angles)self.register_buffer('sin_cached', sin_angles)self.register_buffer('cos_cached', cos_angles)
    
    • angles 矩陣對應 mθdm\theta_dmθd?,即位置 mmm 在維度 ddd 上的旋轉角度;
    • sin_cachedcos_cached 分別存儲 sin?(mθd)\sin(m\theta_d)sin(mθd?)cos?(mθd)\cos(m\theta_d)cos(mθd?),避免重復計算。
  3. 向量旋轉操作

    def rotate_half(self, x):x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]return torch.cat([-x2, x1], dim=-1)
    

    實現了 RotateHalf(x)\text{RotateHalf}(x)RotateHalf(x) 操作,將向量后半部分取負后與前半部分拼接。

  4. 應用旋轉位置編碼

    def forward(self, x, seq_len=None):# 獲取對應位置的sin和cos值sin = self.sin_cached[:seq_len, :].unsqueeze(0)cos = self.cos_cached[:seq_len, :].unsqueeze(0)# 擴展維度以匹配輸入if x.dim() == 4:  # (batch, heads, seq_len, dim)sin = sin.unsqueeze(1)cos = cos.unsqueeze(1)# 重復以匹配完整維度sin = sin.repeat_interleave(2, dim=-1)cos = cos.repeat_interleave(2, dim=-1)# 應用旋轉:x * cos + rotate_half(x) * sinreturn x * cos + self.rotate_half(x) * sin
    

    這直接對應RoPE的核心公式:
    RoPE(x,m)=x?cos?(mθ)+RotateHalf(x)?sin?(mθ)\text{RoPE}(x, m) = x \cdot \cos(m\theta) + \text{RotateHalf}(x) \cdot \sin(m\theta)RoPE(x,m)=x?cos(mθ)+RotateHalf(x)?sin(mθ)

4.7 特性與優勢

  1. 隱式相對位置編碼
    RoPE通過旋轉操作隱式地將相對位置信息融入注意力計算,使得模型能夠更好地捕捉序列中的相對距離關系,優于傳統的絕對位置編碼。

  2. 旋轉不變性
    RoPE保證了位置編碼的旋轉不變性,即對于任意向量 xxx 和位置偏移 kkk,有:
    RoPE(x,m)?RoPE(y,m+k)=RoPE(x,0)?RoPE(y,k)\text{RoPE}(x, m) \cdot \text{RoPE}(y, m+k) = \text{RoPE}(x, 0) \cdot \text{RoPE}(y, k)RoPE(x,m)?RoPE(y,m+k)=RoPE(x,0)?RoPE(y,k)
    這使得模型在不同位置上具有一致的表示能力。

  3. 無需額外參數
    RoPE不需要像可學習位置編碼那樣引入大量額外參數,只需預計算 sin?\sinsincos?\coscos 值,計算效率高。

  4. 長序列建模能力
    實驗表明,RoPE在長序列任務(如長文本生成、文檔級NLP)中表現優于Sinusoidal PE和絕對位置編碼,能夠更有效地捕捉遠距離依賴關系。

  5. 兼容性強
    可以直接應用于現有的Transformer架構,無需修改模型的整體結構,易于集成到各種NLP系統中。

4.8 帶RoPE的注意力機制

class RoPEMultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, max_len=5000):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsself.w_q = nn.Linear(d_model, d_model)self.w_k = nn.Linear(d_model, d_model)self.w_v = nn.Linear(d_model, d_model)self.w_o = nn.Linear(d_model, d_model)self.rope = RoPEPositionalEncoding(self.d_k, max_len)def forward(self, query, key, value, mask=None):batch_size, seq_len, _ = query.shape# 線性變換Q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)# 應用RoPEQ = self.rope(Q)K = self.rope(K)# 計算注意力attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:attention_scores.masked_fill_(mask == 0, -1e9)attention_weights = torch.softmax(attention_scores, dim=-1)context = torch.matmul(attention_weights, V)# 重新整形context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)output = self.w_o(context)return output, attention_weights

5. 各種位置編碼對比

5.1 特點對比

編碼類型優點缺點適用場景
絕對位置編碼簡單直觀,計算效率高對長序列泛化能力差固定長度序列
相對位置編碼更好的泛化能力計算復雜度高需要處理可變長度序列
RoPE完美的長度外推能力實現相對復雜長序列,語言模型

5.2 性能測試代碼

import timedef benchmark_position_encodings():batch_size, seq_len, d_model = 32, 512, 512num_heads = 8# 創建模型models = {'Sinusoidal': SinusoidalPositionalEncoding(d_model),'Learnable': LearnablePositionalEncoding(d_model),'RoPE': RoPEPositionalEncoding(d_model)}x = torch.randn(batch_size, seq_len, d_model)# 基準測試for name, model in models.items():start_time = time.time()for _ in range(100):with torch.no_grad():output = model(x)end_time = time.time()print(f"{name}: {(end_time - start_time) * 1000:.2f}ms")# 運行基準測試
benchmark_position_encodings()

6. 總結

位置編碼是Transformer架構中的關鍵組件,不同類型的位置編碼各有特點:

  1. 絕對位置編碼:簡單高效,適用于固定長度序列
  2. 相對位置編碼:關注位置關系,泛化能力更強
  3. RoPE:通過旋轉矩陣優雅地處理位置信息,支持長度外推

選擇合適的位置編碼方式需要根據具體應用場景和性能需求來決定。現代大語言模型(如GPT、LLaMA等)普遍采用RoPE,因為它在處理長序列時表現出色。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/88981.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/88981.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/88981.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

網絡安全初級第一次作業

一,docker搭建和掛載vpm 1.安裝 Docker apt-get install docker.io docker-compose 2.創建文件 mkdir /etc/docker.service.d vim /etc/docker.service.d/http-proxy.conf 3.改寫文件配置 [Service] Environment"HTTP_PROXYhttp://192.168.10.103:7890…

交換類排序的C語言實現

交換類排序包括冒泡排序和快速排序兩種。冒泡排序基本介紹冒泡排序是通過重復比較相鄰元素并交換位置實現排序。其核心思想是每一輪遍歷將未排序序列中的最大(或最小)元素"浮動"到正確位置,類似氣泡上升。基本過程是從序列起始位置…

嵌入式 Linux開發環境構建之Source Insight 的安裝和使用

目錄 一、Source Insight 的安裝 二、Source Insight 使用 一、Source Insight 的安裝 這個軟件是代碼編輯和查看軟件,打開開發板光盤軟件,然后右鍵選擇以管理員身份運行這個安裝包。在彈出來的安裝向導里面點擊 next ,如下圖所示。這里選擇…

【字節跳動】數據挖掘面試題0016:解釋AUC的定義,它解決了什么問題,優缺點是什么,并說出工業界如何計算AUC。

文章大綱 AUC(Area Under the Curve)詳解一、定義:AUC是什么?二、解決了什么問題?三、優缺點分析四、工業界大規模計算AUC的方法1. 標準計算(小數據)2. 工業級大規模計算方案3.工業界最佳實踐4.工業界方案選型建議總結:AUC的本質AUC(Area Under the Curve)詳解 一、…

Python后端項目之:我為什么使用pdm+uv

在試用了一段時間的uv和pdm之后,上個月(2025.06)開始,逐步把用了幾年的poetry替換成了pdmuv(pipx install pdm uv && pdm config use_uv true) ## 為什么poetry -> pdm: 1. 通過ssh連接到服務器并使用poetry shell激活虛擬環境之…

鴻蒙Next開發,配置Navigation的Route

1. 通過router_map.json配置文件進行 創建頁面配置router_map.json {"routerMap": [{"name": "StateExamplePage","pageSourceFile": "src/main/ets/pages/state/StateExamplePage.ets","buildFunction": "P…

在 GitHub 上創建私有倉庫

一、在 GitHub 上創建私有倉庫打開 GitHub官網 并登錄。點擊右上角的 “” → 選擇 “New repository”。填寫以下內容: Repository name:倉庫名稱,例如 my-private-repo。Description:可選,倉庫描述。Visibility&…

量產技巧之RK3588 Android12默認移除導航欄狀態欄?

本文介紹使用源碼編譯默認去掉導航欄/狀態欄方法,以觸覺智能EVB3588開發板演示,Android12系統,搭載了瑞芯微RK3588芯片,該開發板是核心板加底板設計,音視頻接口、通信接口等各類接口一應俱全,可幫助企業提高產品開發效…

Conda 安裝與配置詳解及常見問題解決

《Conda 安裝與配置詳解及常見問題解決》 安裝 Conda 有兩種主流方式,分別是安裝 Miniconda(輕量級)和 Anaconda(包含常用數據科學包)。下面為你詳細介紹安裝步驟和注意要點。 一、安裝 Miniconda(推薦&a…

Linux ——lastb定時備份清理

lastb 命令顯示的是系統中 /var/log/btmp 文件中的SSH 登錄失敗記錄。你可以像處理 wtmp 那樣,對 btmp 文件進行備份與清理。? 一、備份 lastb 數據cp /var/log/btmp /var/log/btmp.backup.$(date %F)會保存為如 /var/log/btmp.backup.2025-07-14? 二、清空 lastb…

自定義類型 - 聯合體與枚舉(百度筆試題算法優化)

目錄一、聯合體1.1 聯合體類型的聲明1.2 聯合體的特點1.3 相同成員的結構體和聯合體對比1.4 聯合體大小的計算1.5 聯合練習二、枚舉類型2.1 枚舉類型的聲明2.2 枚舉類型的優點總結一、聯合體 1.1 聯合體類型的聲明 像結構體一樣,聯合體也是由一個或者多個成員構成…

FS820R08A6P2LB——英飛凌高性能IGBT模塊,驅動高效能源未來!

產品概述FS820R08A6P2LB 是英飛凌(Infineon)推出的一款高性能、高可靠性IGBT功率模塊,采用先進的EconoDUAL? 3封裝,專為大功率工業應用設計。該模塊集成了IGBT(絕緣柵雙極型晶體管)和二極管,適…

python學智能算法(十八)|SVM基礎概念-向量點積

引言 前序學習進程中,已經對向量的基礎定義有所了解,已經知曉了向量的值和方向向量的定義,學習鏈接如下: 向量的值和方向 在此基礎上,本文進一步學習向量點積。 向量點積 向量點積運算規則,我們在中學階…

【windows辦公小助手】比文檔編輯器更好用的Notepad++輕量編輯器

Notepad 中文版軟件下載:這個路徑總是顯示有百度無法下載,不推薦 更新:推薦下載路徑 https://github.com/notepad-plus-plus/notepad-plus-plus/releases 參考博主:Notepad的安裝與使用

2025年7月12日全國青少年信息素養大賽圖形化(Scratch)編程小學高年級組復賽真題+答案解析

2025年7月12日全國青少年信息素養大賽圖形化(Scratch)編程小學高年級組復賽真題+答案解析 選擇題 題目一 運行如圖所示的程序,舞臺上一共會出現多少只小貓呢?( ) A. 5 B. 6 C. 7 D. 8 正確答案: B 答案解析: 程序中“當綠旗被點擊”后,角色先移到指定位置,然后“重…

對于獨熱編碼余弦相似度結果為0和詞向量解決了詞之間相似性問題的理解

文章目錄深入理解簡單案例結論詞向量(Word Embedding)簡介詞向量如何解決相似性問題?簡單案例:基于上下文的詞向量訓練總結對于獨熱表示的向量,如果采用余弦相似度計算向量間的相似度,可以明顯的發現任意兩…

數據結構·數狀數組(BIT)

樹狀數組(Binary Index Tree) 英文名:使用二進制下標的樹結構 理解:這個樹實際上用數組來存,二進制下標就是將正常的下標拆為二進制來看。 求x的最低位1的函數lowbit(x) 假設x的二進制表示為x ...10000,…

uniapp video視頻全屏播放后退出,頁面字體變大,樣式混亂問題

uniapp官方的說法是因為頁面使用rpx,但是全屏和退出全屏自動計算屏幕尺寸不支持rpx,建議使用px。但是因為uniapp端的開發都是使用rpx作為屏幕尺寸計算參數,不可能因為video全屏播放功能就整個全部修改,工作量大,耗時耗…

重復頻率較高的廣告為何一直在被使用?

在日常生活中,重復評率較高的洗腦廣告我們時常能夠碰到。廣告的本質是信息傳遞,而重復頻率較高的廣告往往可以通過洗腦式的傳播方式來提升傳播效率。下面就讓我們一同來了解下,為何這類廣告一直受到企業的青睞。一、語義凝練高頻率廣告的內容…

內容管理系統指南:企業內容運營的核心引擎

內容管理看似簡單,實際上隨著內容量的激增,管理難度也逐步提升。尤其是在面對大量頁面、圖文、視頻資料等數字內容時,沒有專業工具的支持,效率與準確性都會受到挑戰。此時,內容管理系統(CMS)應運…