LoRA是大模型微調方法的一種,它的特點是只在模型的 部分權重(如 QKV 矩陣) 上 添加可訓練參數
通過 低秩矩陣(A×B) 來優化參數更新
優點:
極大降低顯存消耗(deepseek 7B 只需 10GB)
適用于多任務 LoRA 適配器切換
訓練速度快
例如在 Transformer 里,自注意力(Self-Attention)計算:
Y=XW,
其中 X 是input, W是原始模型的權重矩陣(全連接層).
傳統的Fine-tuning就是直接對 W 進行梯度更新,導致需要存儲整個 W 的更新版本,顯存占用極大。
LoRA 關鍵思想:
不直接更新 W,而是 用兩個小矩陣 A A A 和 B B B 近似建模 W 的變化:
W ′ = W + Δ W W' = W + \Delta W W′=W+ΔW
Δ W = A B \Delta W = AB ΔW=AB
其中:
A ∈ R d × r A \in \mathbb{R}^{d \times r} A∈Rd×r
B ∈ R r × d B \in \mathbb{R}^{r \times d} B∈Rr×d
r ? d r \ll d r?d(低秩),一般 r=4, 8, 16,遠小于 d。
所以只需要訓練A 和 B,大幅減少訓練參數量,用 A B AB AB近似 Δ W \Delta W ΔW, 使得最終 W ′ W' W′仍然能適應新任務。
訓練時,只更新A和B, W保持凍結。
推理時,計算 W + A B W+AB W+AB得到微調后的完整模型, 但A,B遠小于W,開銷極小。
代碼簡單演示一下如何在transformer的q_proj里加入LoRA
在 Transformer 里,q_proj 是 nn.Linear 層
import torch
import torch.nn as nn
import mathclass LoRAQProj(nn.Module):def __init__(self, hidden_size, r=16, lora_alpha=16):super().__init__()self.hidden_size = hidden_sizeself.r = rself.lora_alpha = lora_alphaself.scaling = lora_alpha / r # LoRA 影響力# 原始 Q 投影層(凍結)self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)# LoRA 適配器:A 和 Bself.lora_A = nn.Linear(hidden_size, r, bias=False) # 低秩 Aself.lora_B = nn.Linear(r, hidden_size, bias=False) # 低秩 B# 初始化 LoRA 參數nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))nn.init.zeros_(self.lora_B.weight) # B 矩陣初始化為 0def forward(self, x):"""計算 Self-Attention 里的 Query 矩陣:Q = X * (W_q + AB)"""base_output = self.q_proj(x) # 原始投影lora_output = self.lora_B(self.lora_A(x)) * self.scaling # LoRA 適配器return base_output + lora_output # 總輸出# 測試模型
hidden_size = 512
batch_size = 4
seq_len = 10x = torch.randn(batch_size, seq_len, hidden_size) # 輸入數據
model = LoRAQProj(hidden_size)
output = model(x)print("LoRA Q-Projection Output Shape:", output.shape) # (4, 10, 512)
訓練LoRA適配器
訓練時,凍結self.q_proj, 只訓練lora_A 和 lora_B
# 訓練 LoRA
optimizer = torch.optim.AdamW([p for n, p in model.named_parameters() if "lora" in n], lr=1e-4
)for epoch in range(10):for batch in dataloader: # 假設 dataloader 提供訓練數據optimizer.zero_grad()output = model(batch["input_ids"])loss = loss_function(output, batch["labels"]) # 計算損失loss.backward()optimizer.step()
推理時合并LoRA
LoRA 訓練完成后,我們需要合并 A, B 到 q_proj
計算 W q ′ = W q + A B W_{q}' = W_{q} + AB Wq′?=Wq?+AB,
這樣,可以移除A,B,只保留 W q ′ W_{q}' Wq′?, 加速推理
def merge_lora(model):"""合并 LoRA 適配器到原始權重:W_q' = W_q + AB"""with torch.no_grad():model.q_proj.weight += (model.lora_B.weight @ model.lora_A.weight) * model.scaling# 移除 LoRA 適配器del model.lora_Adel model.lora_Breturn model# 進行推理時合并 LoRA
merged_model = merge_lora(model)
不過實際中,不需要我們自己去寫這些代碼,可以用unsloth, LLaMA-Factory 等框架來實現。