【PL 基礎】如何啟用早停機制
- 摘要
- 1. on_train_batch_start()
- 2. EarlyStopping Callback
摘要
??本文介紹了兩種在 PyTorch Lightning 中實現早停機制的方法。第一種是通過重寫on_train_batch_start()
方法手動控制訓練流程;第二種是使用內置的EarlyStopping
回調,可以監控驗證指標并在指標停止改善時自動停止訓練。文章詳細說明了EarlyStopping
的參數設置,包括監控指標、模式選擇、耐心值等核心參數,以及停止閾值、發散閾值等進階參數。同時介紹了如何通過子類化修改早停觸發時機,并提醒注意驗證頻率與耐心值的配合使用。文末提供了完整的代碼示例,展示了如何在實際訓練中配置和使用早停機制。
1. on_train_batch_start()
??通過重寫 on_train_batch_start()
方法,在滿足特定條件時提前返回,從而停止并跳過當前epoch的剩余訓練批次。
??如果對于最初要求的每個epoch重復這樣做,將停止整個訓練。
2. EarlyStopping Callback
??EarlyStopping
回調可用于監控指標,并在沒有觀察到改善時停止訓練。
要啟用此功能,請執行以下操作:
-
導入
EarlyStopping
回調模塊; -
使用
log()
方法記錄需要監控的指標; -
初始化回調并設置要監控的指標名稱(
monitor
參數); -
根據指標特性設置監控模式(
mode
參數); -
將
EarlyStopping
回調傳遞給Trainer
的callbacks
參數。
from lightning.pytorch.callbacks.early_stopping import EarlyStoppingclass LitModel(LightningModule):def validation_step(self, batch, batch_idx):loss = ...self.log("val_loss", loss)model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)
可以通過更改其參數來自定義回調行為。
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])
用于在極值點停止訓練的附加參數:
-
stopping_threshold
(停止閾值):當監控指標達到該閾值時立即終止訓練。適用于已知超過特定最優值后模型不再提升的場景。 -
divergence_threshold
(發散閾值):當監控指標劣化至該閾值時即刻停止訓練。當指標惡化至此程度時,我們認為模型已無法恢復,此時應提前終止并嘗試不同初始條件。 -
check_finite
(有限值檢測):啟用后,若監控指標出現NaN(非數值)或無窮大時終止訓練。 -
check_on_train_epoch_end
(訓練周期結束檢測):啟用后,在訓練周期結束時檢查指標。僅當監控指標通過周期級訓練鉤子記錄時才需啟用此功能。
若需在訓練過程的其他階段啟用早停機制,請通過創建子類繼承 EarlyStopping
類并修改其調用位置:
class MyEarlyStopping(EarlyStopping):def on_validation_end(self, trainer, pl_module):# override this to disable early stopping at the end of val looppassdef on_train_end(self, trainer, pl_module):# instead, do it at the end of training loopself._run_early_stopping_check(trainer)
默認情況下,EarlyStopping
回調會在每個驗證周期結束時觸發。但驗證頻率可通過 Trainer
中的參數調節,例如通過設置 check_val_every_n_epoch
(每N個訓練周期驗證一次)和 val_check_interval
(驗證間隔)。需特別注意:patience
(耐心值)統計的是驗證結果未提升的次數,而非訓練周期數。因此當設置 check_val_every_n_epoch=10
且 patience=3
時,訓練器需經歷至少 40個訓練周期才會停止。