relaxed_madd
這條指令到底做了什么
核心:relaxed_madd
是一個分量級別 (Component-wise) 的操作
首先,最重要的一點是:v128.relaxed_madd<f32>(a, b, c)
不是矩陣乘法。它是一個在三個向量 a
, b
, c
之間進行的、逐個分量的、并行的融合乘加操作。
這三個向量 a
, b
, c
都是 v128
類型,我們可以把它們看作包含四個 32位浮點數 (f32) 的數組。
1. 定義我們的輸入向量
讓我們用數學形式來表示這三個輸入向量 a
, b
, c
:
- 向量 a?=(a0a1a2a3)\vec{a} = \begin{pmatrix} a_0 \\ a_1 \\ a_2 \\ a_3 \end{pmatrix}a=?a0?a1?a2?a3???
- 向量 b?=(b0b1b2b3)\vec{b} = \begin{pmatrix} b_0 \\ b_1 \\ b_2 \\ b_3 \end{pmatrix}b=?b0?b1?b2?b3???
- 向量 c?=(c0c1c2c3)\vec{c} = \begin{pmatrix} c_0 \\ c_1 \\ c_2 \\ c_3 \end{pmatrix}c=?c0?c1?c2?c3???
這里的 a0,a1,…,c3a_0, a_1, \dots, c_3a0?,a1?,…,c3? 都是普通的 32 位浮點數。
2. relaxed_madd
的數學公式
v128.relaxed_madd<f32>(a, b, c)
這條指令執行的計算,其結果是一個新的向量,我們稱之為 $ \vec{r} $ (result)。這個計算的數學公式是:
r?=(a?⊙b?)+c?\vec{r} = (\vec{a} \odot \vec{b}) + \vec{c}r=(a⊙b)+c
這里的 ⊙\odot⊙ 符號代表 哈達瑪積 (Hadamard Product),也就是分量相乘 (component-wise multiplication)。
3. 展開公式,看清細節
把上面的公式展開到每一個分量上,就能看清它到底發生了什么。結果向量 r?\vec{r}r 的四個分量 r0,r1,r2,r3r_0, r_1, r_2, r_3r0?,r1?,r2?,r3? 是這樣并行計算出來的:
r?=(r0r1r2r3)=(a0?b0+c0a1?b1+c1a2?b2+c2a3?b3+c3)\vec{r} = \begin{pmatrix} r_0 \\ r_1 \\ r_2 \\ r_3 \end{pmatrix} = \begin{pmatrix} a_0 \cdot b_0 + c_0 \\ a_1 \cdot b_1 + c_1 \\ a_2 \cdot b_2 + c_2 \\ a_3 \cdot b_3 + c_3 \end{pmatrix}r=?r0?r1?r2?r3???=?a0??b0?+c0?a1??b1?+c1?a2??b2?+c2?a3??b3?+c3???
這就是 relaxed_madd
的全部真相:它在一條指令里,同時并行地完成了這四個獨立的融合乘加運算。
它在你的矩陣乘法代碼中是如何被應用的?
現在我們把你代碼中的一行拿出來,用這個公式來解釋:
// 代碼行
res0 = v128.relaxed_madd<f32>(sA0, rB0, res0_prev);
// (我把之前的 res0 重命名為 res0_prev 以便區分)
這里的輸入是什么?
-
sA0
: 這是一個由矩陣 A 的元素A[0][0]
廣播 (splat) 而來的向量。
sA0?=(A[0][0]A[0][0]A[0][0]A[0][0])\vec{sA0} = \begin{pmatrix} A[0][0] \\ A[0][0] \\ A[0][0] \\ A[0][0] \end{pmatrix}sA0=?A[0][0]A[0][0]A[0][0]A[0][0]?? -
rB0
: 這是矩陣 B 的第一行向量。
rB0?=(B[0][0]B[0][1]B[0][2]B[0][3])\vec{rB0} = \begin{pmatrix} B[0][0] \\ B[0][1] \\ B[0][2] \\ B[0][3] \end{pmatrix}rB0=?B[0][0]B[0][1]B[0][2]B[0][3]?? -
res0_prev
: 這是上一步計算的結果(累加值)。
那么,v128.relaxed_madd<f32>(sA0, rB0, res0_prev)
這一步到底計算了什么?我們套用上面的公式:
res0?=(sA0?⊙rB0?)+res0prev?\vec{res0} = (\vec{sA0} \odot \vec{rB0}) + \vec{res0_{prev}}res0=(sA0⊙rB0)+res0prev??
展開來看就是:
res0?=(A[0][0]?B[0][0]+res0prev,0A[0][0]?B[0][1]+res0prev,1A[0][0]?B[0][2]+res0prev,2A[0][0]?B[0][3]+res0prev,3)\vec{res0} = \begin{pmatrix} A[0][0] \cdot B[0][0] + res0_{prev,0} \\ A[0][0] \cdot B[0][1] + res0_{prev,1} \\ A[0][0] \cdot B[0][2] + res0_{prev,2} \\ A[0][0] \cdot B[0][3] + res0_{prev,3} \end{pmatrix}res0=?A[0][0]?B[0][0]+res0prev,0?A[0][0]?B[0][1]+res0prev,1?A[0][0]?B[0][2]+res0prev,2?A[0][0]?B[0][3]+res0prev,3???
這完美地對應了我們矩陣乘法思想:用 A 的一個標量元素,去數乘 B 的一整個行向量,然后加到累加器上。
當你把四次 relaxed_madd
調用鏈接起來后,最終的結果 res0
的第一個分量就是:
r0=(A[0][0]?B[0][0])+(A[0][1]?B[1][0])+(A[0][2]?B[2][0])+(A[0][3]?B[3][0])r_0 = (A[0][0] \cdot B[0][0]) + (A[0][1] \cdot B[1][0]) + (A[0][2] \cdot B[2][0]) + (A[0][3] \cdot B[3][0])r0?=(A[0][0]?B[0][0])+(A[0][1]?B[1][0])+(A[0][2]?B[2][0])+(A[0][3]?B[3][0])
這正好是結果矩陣 C 的 C[0][0]
元素!其他分量同理。
FMA 的“融合”體現在哪里?
“融合” (Fused) 的意思是,在計算 ai?bi+cia_i \cdot b_i + c_iai??bi?+ci? 時:
- 計算 ai?bia_i \cdot b_iai??bi? 的乘積,得到一個內部的、高精度的中間結果(比如 80 位浮點數)。
- 不進行舍入,直接用這個高精度結果與 cic_ici? 相加。
- 對最終的和只進行一次舍入,得到 32 位的浮點數結果。
相比之下,非融合的 mul
+ add
會進行兩次舍入,可能會損失精度。更重要的是,FMA 是一條硬件指令,吞吐量更高。