文章目錄
- PyTorch 學習率調度器(LR Scheduler)
- 1. 一句話定義
- 2. 通用使用套路
- 3. 內置調度器對比速覽
- 4. 各調度器最小模板
- ① LambdaLR(線性 warmup)
- ② StepLR
- ③ MultiStepLR
- ④ CosineAnnealingLR
- ⑤ ReduceLROnPlateau(必須傳指標)
- 5. 常用調試 API
- 6. 易踩坑 Top-3
- 7. 速記口訣
PyTorch 學習率調度器(LR Scheduler)
1. 一句話定義
每過一段時間 / 滿足某條件,自動按規則修改優化器學習率的工具。
2. 通用使用套路
optimizer = torch.optim.Adam(model.parameters(), lr=初始LR)
scheduler = XXXLR(optimizer, ...) # 選下面任意一種
for epoch in range(EPOCH):train(...)val_loss = validate(...)optimizer.step() # ① 先更新參數scheduler.step(val_loss) # ② 再調度LR(ReduceLROnPlateau需傳loss)
順序:先 optimizer.step()
→ 再 scheduler.step()
,否則報警告。
3. 內置調度器對比速覽
調度器 | 觸發規則 | 主要參數 | 參數解釋 | 典型場景 |
---|---|---|---|---|
LambdaLR | 自定義函數 f(epoch) 返回乘數 | lr_lambda , last_epoch | lr_lambda : 接收 epoch,返回 LR 乘數;last_epoch : 重啟訓練時設為上次 epoch | warmup、分段線性 |
StepLR | 固定每 step_size epoch 降一次 | step_size , gamma , last_epoch | step_size : 隔多少 epoch 降;gamma : 乘性衰減系數 | 常規“等間隔”下降 |
MultiStepLR | 指定里程碑 epoch 列表降 | milestones , gamma , last_epoch | milestones : List,到這些 epoch 就 ×gamma | 訓練中期多段下降 |
CosineAnnealingLR | 余弦曲線從初始→η_min | T_max , eta_min , last_epoch | T_max : 半個余弦周期長度;eta_min : 最小 LR | 退火、cosine 重啟 |
ReduceLROnPlateau | 監控指標停止改善時降 | mode , factor , patience , threshold , cooldown , min_lr | 見下方詳注 | 驗證 loss/acc 卡住時 |
ReduceLROnPlateau 參數詳注
mode='min'
或'max'
:指標越小/越大越好factor=0.1
:新 LR = 舊 LR × factorpatience=3
:連續 3 次 epoch 無改善才降threshold=0.01
:改善幅度小于閾值視為無改善cooldown=1
:降 LR 后凍結監控的 epoch 數min_lr=1e-6
:下限,降到此值不再降
4. 各調度器最小模板
① LambdaLR(線性 warmup)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: epoch / 5 if epoch < 5 else 1)
② StepLR
scheduler = StepLR(optimizer, step_size=2, gamma=0.1) # 每 2 epoch ×0.1
③ MultiStepLR
scheduler = MultiStepLR(optimizer, milestones=[2, 6], gamma=0.1)
④ CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
⑤ ReduceLROnPlateau(必須傳指標)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3,threshold=0.01, cooldown=1, min_lr=1e-6)
val_loss = validate(...)
scheduler.step(val_loss) # ← 記得傳指標
5. 常用調試 API
scheduler.get_last_lr() # 當前實際 LR 列表(每個 param_group)
scheduler.last_epoch # 已完成的 epoch 計數(從 0 開始)
6. 易踩坑 Top-3
- 先
optimizer.step()
再scheduler.step()
否則報警告 “Detected call oflr_scheduler.step()
beforeoptimizer.step()
”。 - ReduceLROnPlateau 必須傳監控值
不傳 → RuntimeError。 - Lambda/MultiStep 等無需監控值,傳了 → TypeError。
7. 速記口訣
“優化先邁步,調度再跟進;Plateau 傳 loss,其余不用問。”