文章目錄
- 什么是 Warmup?
- 實現示例
- 科學設置 Warmup 的黃金法則
- 直觀例子
什么是 Warmup?
Warmup 是一種學習率調度策略,在訓練初期逐步增加學習率(LR),而不是直接使用目標學習率。它解決了兩個關鍵問題:
- 避免早期震蕩:模型參數初始化為隨機值,直接高LR會導致不穩定更新。
- 穩定Adam優化器:Adam的動量估計在初始階段不準確,需要漸進調整。
實現示例
def get_lr(step, warmup_steps, d_model):# 1. 預熱階段:線性增長if step < warmup_steps:return base_lr * (step / warmup_steps)# 2. 衰減階段:反平方根衰減scale = (warmup_steps ** 0.5) * min(step ** (-0.5), step * (warmup_steps ** (-1.5))return base_lr * scale
科學設置 Warmup 的黃金法則
- NNN = 總樣本數(條)
- BBB = 每次 forward 的原始 batch(每卡)
- AAA = 梯度累積步數(
accum_grad
) - EEE = epoch 數
- WWW = 卡數(
world_size
,單卡 = 1)
- 每 epoch 的 forward 次數(向上取整):
I=?N/(B×W)?I = \lceil N / (B \times W) \rceil I=?N/(B×W)?
- 每 epoch 的 optimizer 更新次數(每 AAA 次 forward 做一次 update,向上取整):
S=?I/A?S = \lceil I / A \rceil S=?I/A?
- 總 optimizer step(也就是 scheduler 用的 total steps):
T=S×ET = S \times E T=S×E
- 推薦的 warmup 步數:
warmup={max?{?0.10×T?,10},T<4000clamp?(?0.05×T?,4000,20000),T≥4000\text{warmup} = \begin{cases} \max\{\lceil 0.10 \times T \rceil, 10\}, & T < 4000\\[4pt] \operatorname{clamp}(\, \lfloor 0.05 \times T \rceil,\; 4000,\; 20000 \,), & T \ge 4000 \end{cases} warmup={max{?0.10×T?,10},clamp(?0.05×T?,4000,20000),?T<4000T≥4000?
并且最終確保 warmup≤T?1\text{warmup} \le T-1warmup≤T?1。
解釋:小訓練用 10%,大訓練用 5%,并在 4k–20k 之間限制
直觀例子
假設 B=16,A=8,E=120,W=1B=16, A=8, E=120, W=1B=16,A=8,E=120,W=1:
- 若 N=100,000N = 100{,}000N=100,000:
- I=?100000/16?=6250I=\lceil100000/16\rceil=6250I=?100000/16?=6250
- S=?6250/8?=782S=\lceil6250/8\rceil=782S=?6250/8?=782
- T=782×120=93,840T=782\times120=93{,}840T=782×120=93,840
- warmup ≈ ?0.05×93840?=4,692\lfloor0.05\times93840\rceil=4{,}692?0.05×93840?=4,692(取 4000–20000 區間內 → 4692)
- 若 N=1,000,000N = 1{,}000{,}000N=1,000,000:
- I=?1000000/16?=62500I=\lceil1000000/16\rceil=62500I=?1000000/16?=62500
- S=?62500/8?=7813S=\lceil62500/8\rceil=7813S=?62500/8?=7813
- T=7813×120=937,560T=7813\times120=937{,}560T=7813×120=937,560
- 0.05×T ? 20000 → clamp → 20000
- 若 N=100N = 100N=100(極小樣本、僅作示例):
- I=?100/16?=7I=\lceil100/16\rceil=7I=?100/16?=7
- S=?7/8?=1S=\lceil7/8\rceil=1S=?7/8?=1
- T=1×120=120T=1\times120=120T=1×120=120
- 因為 T<4000T<4000T<4000,warmup = max(ceil(0.1×120),10) = 12
import mathdef suggest_warmup(N,B,A,E,W=1):I = math.ceil(N / (B*W))S = math.ceil(I / A)T = S * Eif T < 4000:w = max(math.ceil(0.10*T), 10)else:w = round(0.05*T)w = max(4000, min(w, 20000))w = min(w, T-1)return {"iters_per_epoch":I, "opt_steps_per_epoch":S, "total_steps":T, "warmup":w}print(suggest_warmup(403733, 16, 8, 120, 4))
# {'iters_per_epoch': 6309, 'opt_steps_per_epoch': 789, 'total_steps': 94680, 'warmup': 4734}print(suggest_warmup(403733, 16, 4, 100, 4))
# {'iters_per_epoch': 6309, 'opt_steps_per_epoch': 1578, 'total_steps': 157800, 'warmup': 7890}