BatchSize
顯存占用:與batch_size呈線性關系,可理解為 M t o t a l = M f i x e d + B a t c h S i z e ? M p e r ? s a m p l e M_{total}=M_{fixed}+BatchSize*M_{per-sample} Mtotal?=Mfixed?+BatchSize?Mper?sample?,其中 M f i x e d M_{fixed} Mfixed?指的是模型本身固定占用的顯存(由參數數量決定)和優化器狀態(也由參數數量決定)
總訓練時間:理論上與BatchSize無關(總數不變,單步訓練時間增加,總步數減少),但實際中隨BatchSize越大,總時間可能減少(硬件并行效率提升),直到顯存或硬件并行能力達到瓶頸。
截斷長度(輸入序列分詞后的最大長度,即每條樣本被大模型讀取的最大長度)
1. 顯存占用
在大型語言模型(如 Transformer)中,顯存占用主要與模型的激活值(Activations)有關,而激活值的大小受到輸入序列長度(即截斷長度)的直接影響。以下是逐步分析:
激活值的定義
激活值是指模型在正向傳播過程中每一層計算出的中間結果,通常存儲在顯存中,以便反向傳播時計算梯度。對于 Transformer 模型,激活值主要與注意力機制(Self-Attention)和前饋網絡(Feed-Forward Network, FFN)的計算相關。
顯存占用的組成
顯存占用主要包括:
- 模型參數(權重和偏置):與模型規模(層數、隱藏維度)相關,與截斷長度無關。
- 激活值:與輸入序列長度(截斷長度 L L L)、批次大小(batch size B B B)、隱藏維度(hidden size H H H)和層數( N N N)成正比。
- 梯度(訓練時):與參數量和激活值大小相關。
對于激活值部分,顯存占用主要來源于:
- 注意力機制:計算 Q ? K T Q \cdot K^T Q?KT的注意力分數矩陣,尺寸為 ( B , L , L ) (B, L, L) (B,L,L),每層需要存儲。
- 中間張量:如 V V V的加權和、前饋層的輸出等。
數學表達式
假設: L L L:截斷長度(序列長度), B B B:批次大小, H H H:隱藏維度, N N N:模型層數, P P P:浮點數精度(如 FP32 為 4 字節,FP16 為 2 字節)
激活值的顯存占用近似為:
顯存 激活值 ≈ N ? B ? L ? H ? P + N ? B ? L 2 ? P \text{顯存}_{\text{激活值}} \approx N \cdot B \cdot L \cdot H \cdot P + N \cdot B \cdot L^2 \cdot P 顯存激活值?≈N?B?L?H?P+N?B?L2?P
- 第一項 N ? B ? L ? H ? P N \cdot B \cdot L \cdot H \cdot P N?B?L?H?P:表示每層的線性張量(如 Q , K , V Q, K, V Q,K,V或 FFN 輸出)的顯存占用。
- 第二項 N ? B ? L 2 ? P N \cdot B \cdot L^2 \cdot P N?B?L2?P:表示注意力分數矩陣的顯存占用(僅在標準注意力機制中顯著,若使用優化如 FlashAttention,則可能減少)。
結論:顯存占用與截斷長度 L L L呈線性( O ( L ) O(L) O(L))到二次方( O ( L 2 ) O(L^2) O(L2))的關系,具體取決于注意力機制的實現方式。
2. 訓練時間
訓練時間主要與計算量(FLOPs,浮點運算次數)和硬件并行能力有關,而截斷長度會影響計算量。
計算量的組成
- 注意力機制:每層的計算量與 L 2 L^2 L2相關,因為需要計算 L × L L \times L L×L的注意力矩陣。
- 前饋網絡:每層的計算量與 L L L線性相關,因為對每個 token 獨立計算。
總計算量(FLOPs)近似為:
FLOPs ≈ N ? B ? ( 2 ? L 2 ? H + 4 ? L ? H 2 ) \text{FLOPs} \approx N \cdot B \cdot (2 \cdot L^2 \cdot H + 4 \cdot L \cdot H^2) FLOPs≈N?B?(2?L2?H+4?L?H2)
- 2 ? L 2 ? H 2 \cdot L^2 \cdot H 2?L2?H:注意力機制的矩陣乘法(如 Q ? K T Q \cdot K^T Q?KT和 softmax ? V \text{softmax} \cdot V softmax?V),
- 4 ? L ? H 2 4 \cdot L \cdot H^2 4?L?H2:前饋網絡的計算(假設 FFN 隱藏層維度為 4 H 4H 4H)。
訓練時間
訓練時間與 FLOPs 成正比,同時受硬件并行能力(如 GPU 的計算核心數)影響。假設每秒浮點運算能力為 F GPU F_{\text{GPU}} FGPU?(單位:FLOPs/s),則單次前向+反向傳播的訓練時間為:
時間 ≈ FLOPs F GPU ≈ N ? B ? ( 2 ? L 2 ? H + 4 ? L ? H 2 ) F GPU \text{時間} \approx \frac{\text{FLOPs}}{F_{\text{GPU}}} \approx \frac{N \cdot B \cdot (2 \cdot L^2 \cdot H + 4 \cdot L \cdot H^2)}{F_{\text{GPU}}} 時間≈FGPU?FLOPs?≈FGPU?N?B?(2?L2?H+4?L?H2)?
結論:訓練時間與截斷長度 L L L呈線性( O ( L ) O(L) O(L))到二次方( O ( L 2 ) O(L^2) O(L2))的關系,具體取決于注意力機制的計算占比。
3. 總結
- 顯存占用:與 L L L呈 O ( L ) O(L) O(L)或 O ( L 2 ) O(L^2) O(L2)關系,取決于是否存儲完整的注意力矩陣。
- 訓練時間:與 L L L呈 O ( L ) O(L) O(L)到 O ( L 2 ) O(L^2) O(L2)關系,注意力機制的二次項通常更顯著。
1
假設某模型大小為5GB,推理所需顯存也為5GB,普通Lora微調(FP16)所需顯存為5GB*2=10GB,8bit的QLora量化為5GB/2=2.5GB,4bit的QLora量化為5GB/4=1.25GB