0. 背景
假設我們有兩個矩陣:
- 矩陣 A,尺寸為
(n, d_k)
- 矩陣 B,尺寸為
(d_k, n)
我們要計算它們的乘積 C = A * B。
那么這個過程所需的計算量是多少?
1. 結果矩陣的尺寸
首先,結果矩陣 C 的尺寸是由第一個矩陣的行數和第二個矩陣的列數決定的。
- C 的行數 = A 的行數 =
n
- C 的列數 = B 的列數 =
n
所以,結果矩陣 C 的尺寸為(n, n)
。
2. 單個元素的計算量
接下來,我們看如何計算結果矩陣 C 中的任意一個元素 C_ij
(第 i 行,第 j 列的元素)。
根據矩陣乘法的定義,C_ij
是由 A 的第 i 行和 B 的第 j 列的點積(dot product)得到的。
- A 的第 i 行是一個有
d_k
個元素的行向量。 - B 的第 j 列是一個有
d_k
個元素的列向量。
計算過程如下:
C_ij = A_i1 * B_1j + A_i2 * B_2j + ... + A_id_k * B_d_kj
為了計算這一個 C_ij
元素,我們需要:
d_k
次乘法 (每個A_ik
乘以B_kj
)d_k - 1
次加法 (將d_k
個乘積相加)
3. 總計算量
現在我們來計算整個矩陣 C 的總計算量。
結果矩陣 C 是一個 (n, n)
的矩陣,所以它總共有 n * n = n^2
個元素。
我們將單個元素的計算量乘以總元素數量:
-
總乘法次數 = (每個元素的乘法次數) × (總元素個數)
=d_k * n^2
-
總加法次數 = (每個元素的加法次數) × (總元素個數)
=(d_k - 1) * n^2
4. 結論
將一個 (n, d_k)
矩陣與一個 (d_k, n)
矩陣相乘:
- 總乘法運算量為
n2 * d_k
次。 - 總加法運算量為
n2 * (d_k - 1)
次。
在計算機科學和機器學習領域,我們通常使用浮點運算次數 (FLOPs, Floating Point Operations) 來衡量計算量。一次乘法和一次加法通常被打包看作一次操作(特別是在現代硬件的FMA指令中)。
總FLOPs ≈ 總乘法次數 + 總加法次數
= (n2 * d_k) + (n2 * (d_k - 1))
= n2 * (d_k + d_k - 1)
= n2 * (2d_k - 1)
當 d_k
比較大時,我們通常近似為 2 * n2 * d_k
FLOPs。
5. 應用背景(非常重要)
這個計算 (n, d_k) * (d_k, n)
在 Transformer模型 的自注意力(Self-Attention)機制中非常核心。
n
通常代表序列長度(Sequence Length)。d_k
代表Query和Key向量的維度。
這個計算對應的是 Query (Q) 矩陣 和 Key (K) 矩陣的轉置 (K?) 相乘,以得到注意力分數矩陣(Attention Score Matrix)。
- Q 的尺寸是
(n, d_k)
- K 的尺寸是
(n, d_k)
,所以 K? 的尺寸是(d_k, n)
- Q * K? 的結果是一個
(n, n)
的矩陣,其計算復雜度就是 O(n2 * d_k) 。
這也解釋了為什么標準Transformer模型的計算量和內存占用會隨著序列長度 n
的增加而呈平方級增長,這是限制其處理非常長序列的主要瓶頸之一。
6. 附錄(泛化)
補充一種更加泛化的計算方式。我們來分析一下將一個 (a, b)
矩陣與一個 (b, c)
矩陣相乘的計算量。
假設我們有兩個矩陣:
- 矩陣 A,尺寸為
(a, b)
(a
行,b
列) - 矩陣 B,尺寸為
(b, c)
(b
行,c
列)
我們要計算它們的乘積 C = A * B。
6.1 結果矩陣的尺寸
首先,結果矩陣 C 的尺寸由 A 的行數和 B 的列數決定。
- C 的行數 = A 的行數 =
a
- C 的列數 = B 的列數 =
c
所以,結果矩陣 C 的尺寸為(a, c)
。
6.2 單個元素的計算量
接下來,我們計算結果矩陣 C 中的任意一個元素 C_ij
(第 i
行, 第 j
列)。
C_ij
是由 A 的第 i
行和 B 的第 j
列的點積(dot product)得到的。
- A 的第
i
行是一個長度為b
的行向量。 - B 的第
j
列是一個長度為b
的列向量。
計算公式為:
C_ij = A_i1 * B_1j + A_i2 * B_2j + ... + A_ib * B_bj
為了計算這一個 C_ij
元素,我們需要:
b
次乘法b - 1
次加法
6.3 總計算量
結果矩陣 C 是一個 (a, c)
的矩陣,它總共有 a * c
個元素。
我們將單個元素的計算量乘以總元素數量,得到整個矩陣的計算量:
-
總乘法次數 = (每個元素的乘法次數) × (總元素個數)
=b * (a * c)
=a * b * c
-
總加法次數 = (每個元素的加法次數) × (總元素個數)
=(b - 1) * (a * c)
=a * c * (b - 1)
6.4 結論與總結
對于一個 (a, b)
矩陣和一個 (b, c)
矩陣的乘法:
- 總乘法運算量為
a * b * c
次。 - 總加法運算量為
a * c * (b - 1)
次。
在衡量算法復雜度時,我們通常使用 Big O 表示法,或者計算總的 浮點運算次數 (FLOPs)。
-
總FLOPs ≈ 總乘法次數 + 總加法次數
=(a * b * c) + (a * c * (b - 1))
=a * c * (b + b - 1)
=a * c * (2b - 1)
-
時間復雜度 (Time Complexity):
當a
,b
,c
都很大時,常數2
和-1
可以忽略。因此,計算復雜度為 O(abc)。
6.5 驗證一下之前的問題
讓我們用這個通用公式來驗證你之前的問題:一個 (n, d_k)
矩陣乘以一個 (d_k, n)
矩陣。
這里:
a = n
b = d_k
c = n
代入通用公式:
- 總乘法次數 =
a * b * c
=n * d_k * n
=n2 * d_k
- 總加法次數 =
a * c * (b - 1)
=n * n * (d_k - 1)
=n2 * (d_k - 1)
這與我們之前得到的結論完全一致。這個 O(abc)
的公式是矩陣乘法計算量分析的基礎。