公式 (10.6.2) 描述了位置編碼的具體計算方式,這種位置編碼基于正弦和余弦函數,用于在自注意力機制中引入位置信息。下面我們詳細解釋公式和代碼。
公式 (10.6.2)
公式 (10.6.2) 的目的是為輸入序列中的每個詞元添加一個位置編碼,以保留序列的位置信息:
[
\begin{split}
\begin{aligned}
p_{i, 2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right), \
p_{i, 2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right).
\end{aligned}
\end{split}
]
這里:
- ( p_{i, 2j} ) 是位置編碼矩陣 (\mathbf{P}) 的第 (i) 行、第 (2j) 列的元素。
- ( p_{i, 2j+1} ) 是位置編碼矩陣 (\mathbf{P}) 的第 (i) 行、第 (2j+1) 列的元素。
- ( i ) 表示詞元在序列中的位置。
- ( j ) 表示編碼維度的索引。
- ( d ) 是詞元向量的維度。
這些位置編碼使用不同頻率的正弦和余弦函數,較小的頻率用于較低的維度,較大的頻率用于較高的維度。
讓我們詳細解釋一下為什么在公式 (10.6.2) 中使用 ( i ) 和 ( 2j ),為什么是 ( 10000^{2j/d} ),以及為什么選擇正弦和余弦函數。
1. 為什么是 ( i ) 和 ( 2j )
- ( i ): 表示詞元在序列中的位置。
- ( 2j ) 和 ( 2j+1 ): 表示編碼維度的索引。位置編碼矩陣的每個詞元的每個維度都有兩個值,一個是正弦函數值,另一個是余弦函數值。
在位置編碼矩陣中,維度 ( 2j ) 存儲正弦函數值,維度 ( 2j+1 ) 存儲余弦函數值。這種交替存儲方式允許位置編碼同時捕捉到不同頻率的周期信息。
2. 為什么是 ( 10000^{2j/d} )
-
( 10000^{2j/d} ): 這是一個縮放因子,確保不同維度的頻率不同。具體來說,隨著 ( j ) 的增加,頻率會指數級地增加。
- 當 ( j ) 較小時, ( \frac{2j}{d} ) 也較小,這意味著 ( 10000^{2j/d} ) 較小,從而使 ( \frac{i}{10000^{2j/d}} ) 較大,結果是低頻率。
- 當 ( j ) 較大時, ( \frac{2j}{d} ) 也較大,這意味著 ( 10000^{2j/d} ) 較大,從而使 ( \frac{i}{10000^{2j/d}} ) 較小,結果是高頻率。
這種設計保證了不同維度上位置編碼的頻率不同,從而捕捉到多種粒度的位置信息。
3. 為什么選擇正弦和余弦函數
選擇正弦和余弦函數的主要原因是它們的周期性和相位特性。這些函數可以捕捉到序列中的相對位置關系:
-
正弦函數和余弦函數的周期性: 位置編碼利用了正弦和余弦函數的周期性,能夠捕捉到詞元在序列中的相對位置。因為這些函數是周期性的,模型可以通過這些位置編碼了解詞元之間的相對距離。
-
正弦和余弦的互補性: 正弦函數和余弦函數是相位差90度的互補函數,組合在一起可以更全面地描述位置信息。
總結
結合以上幾點,公式 (10.6.2) 的位置編碼設計利用了正弦和余弦函數的周期性特性,通過不同的頻率和相位捕捉序列中詞元的相對位置,從而增強了模型對序列順序信息的理解。
這就是為什么公式 (10.6.2) 被設計成這個樣子:通過 ( i ) 來表示位置,通過 ( 10000^{2j/d} ) 來控制頻率,通過正弦和余弦函數來捕捉不同頻率的位置信息。