🧠 首先搞清楚 LoRA 是怎么做微調的
我們原來要訓練的參數矩陣是 W W W,但 LoRA 說:
別動 W,我在它旁邊加一個低秩矩陣 Δ W = U V \Delta W = UV ΔW=UV,只訓練這個部分!
也就是說,LoRA 用一個新的權重矩陣:
W ′ = W + U V W' = W + UV W′=W+UV
只訓練 U U U 和 V V V, W W W 不動。
📦 所以前向傳播其實用的是:
模型輸入 x ? W ′ x = W x + U V x ? 輸出 ? L \text{模型輸入}x \longrightarrow W'x = Wx + UVx \longrightarrow \text{輸出} \longrightarrow \mathcal{L} 模型輸入x?W′x=Wx+UVx?輸出?L
在這個過程中,損失函數 L \mathcal{L} L 是基于 W + U V W + UV W+UV 來計算的。
🔁 反向傳播的時候怎么求梯度?
LoRA 要訓練的是 U U U 和 V V V,所以我們要算:
? L ? U 和 ? L ? V \frac{\partial \mathcal{L}}{\partial U} \quad \text{和} \quad \frac{\partial \mathcal{L}}{\partial V} ?U?L?和?V?L?
但問題是:損失函數 L \mathcal{L} L 不是直接依賴 U U U 和 V V V,而是依賴 U V UV UV
所以要用鏈式法則,先對 U V UV UV 求導,然后傳播回 U U U、 V V V。而對UV求導等價于對 W W W求導
? 關鍵點來了
我們記:
? L ? W = G \frac{\partial \mathcal{L}}{\partial W} = G ?W?L?=G
這個 G G G 就是“如果我們在做全量微調,該怎么更新 W W W 的梯度”。
LoRA 說:
“雖然我不更新 W W W,但我要更新的是 U V UV UV。所以我也可以用這個 G G G 來指導我怎么更新 U U U 和 V V V。”
于是我們得到:
? L ? U = G V ? , ? L ? V = U ? G \frac{\partial \mathcal{L}}{\partial U} = G V^\top, \quad \frac{\partial \mathcal{L}}{\partial V} = U^\top G ?U?L?=GV?,?V?L?=U?G
LoRA 的梯度建立在 ? L ? W \frac{\partial \mathcal{L}}{\partial W} ?W?L? 上, 是因為它相當于“用低秩矩陣 U V UV UV 來代替全量的參數更新”, 所以梯度傳播也必須從 ? L ? W \frac{\partial \mathcal{L}}{\partial W} ?W?L? 開始。
LoRA 往往只是顯存不足的無奈之選,因為一般情況下全量微調的效果都會優于 LoRA,所以如果算力足夠并且要追求效果最佳時,請優先選擇全量微調。
使用 LoRA 的另一個場景是有大量的微型定制化需求,要存下非常多的微調結果,此時使用 LoRA 能減少儲存成本。
🔍 為什么
為什么 ? L ? W \frac{\partial \mathcal{L}}{\partial W} ?W?L?,就是對 U V UV UV 的梯度?
換句話說:LoRA 中的 W ′ = W + U V W' = W + UV W′=W+UV,那我們訓練時不是更新 W W W,只更新 U V UV UV,那為什么還能用 ? L ? W \frac{\partial \mathcal{L}}{\partial W} ?W?L? 來指導 U U U 和 V V V 的更新呢?
? 答案是:因為前向傳播中 W + U V W + UV W+UV 是一起作為整體參與運算的
所以:
? L ? W = ? L ? ( W + U V ) = ? L ? ( U V ) \frac{\partial \mathcal{L}}{\partial W} = \frac{\partial \mathcal{L}}{\partial (W + UV)} = \frac{\partial \mathcal{L}}{\partial (UV)} ?W?L?=?(W+UV)?L?=?(UV)?L?
這是因為:
- 我們的模型使用的是 W + U V W + UV W+UV
- 所以損失函數 L \mathcal{L} L 是以 W + U V W + UV W+UV 為輸入計算出來的
- 那么對 W W W 求導,其實是對這個整體求導
- 而因為 W W W 是固定的(不訓練,看作常數),所以梯度全部由 U V UV UV 來承接
- 本來我們應該更新 W W W:
W ← W ? η ? L ? W W \leftarrow W - \eta \frac{\partial \mathcal{L}}{\partial W} W←W?η?W?L? - 現在我們不動 W W W,讓 U V UV UV 來“做這個事情”:
W + U V ← W + U V ? η ? ( LoRA方向上的梯度 ) W + UV \leftarrow W + UV - \eta \cdot \left(\text{LoRA方向上的梯度}\right) W+UV←W+UV?η?(LoRA方向上的梯度)
所以如果要算 U V UV UV 的導數,就是算 ? L ? W \frac{\partial \mathcal{L}}{\partial W} ?W?L?