模型的保存和加載
僅保存模型參數
- 原理:保存模型的權重參數,不保存模型結構代碼。加載時需提前定義與訓練時一致的模型類。
- 優點:文件體積小(僅含參數),跨框架兼容性強(需自行定義模型結構)。
# 保存模型參數
torch.save(model.state_dict(), "model_weights.pth")# 加載參數(需先定義模型結構)
model = MLP() # 初始化與訓練時相同的模型結構
model.load_state_dict(torch.load("model_weights.pth"))
# model.eval() # 切換至推理模式(可選)
保存模型+權重
- 原理:保存模型結構及參數
- 優點:加載時無需提前定義模型類
- 缺點:文件體積大,依賴訓練時的代碼環境(如自定義層可能報錯)。
# 保存整個模型
torch.save(model, "full_model.pth")# 加載模型(無需提前定義類,但需確保環境一致)
model = torch.load("full_model.pth")
model.eval() # 切換至推理模式(可選)
保存訓練狀態(斷點續訓)
- 原理:保存模型參數、優化器狀態(學習率、動量)、訓練輪次、損失值等完整訓練狀態,用于中斷后繼續訓練。
- 適用場景:長時間訓練任務(如分布式訓練、算力中斷)。
# 保存訓練狀態checkpoint = {"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch,"loss": best_loss,}torch.save(checkpoint, "checkpoint.pth")# 加載并續訓model = MLP()optimizer = torch.optim.Adam(model.parameters())checkpoint = torch.load("checkpoint.pth")model.load_state_dict(checkpoint["model_state_dict"])optimizer.load_state_dict(checkpoint["optimizer_state_dict"])start_epoch = checkpoint["epoch"] + 1 # 從下一輪開始訓練best_loss = checkpoint["loss"]# 繼續訓練循環for epoch in range(start_epoch, num_epochs):train(model, optimizer, ...)
早停法(early stop)
- 正常情況:訓練集和測試集損失同步下降,最終趨于穩定。
- 過擬合:訓練集損失持續下降,但測試集損失在某一時刻開始上升(或不再下降)。
如果可以監控驗證集的指標不再變好,此時提前終止訓練,避免模型對訓練集過度擬合。----監控的對象是驗證集的指標。這種策略叫早停法。
if test_loss.item() < best_test_loss: # 如果當前測試集損失小于最佳損失best_test_loss = test_loss.item() # 更新最佳損失best_epoch = epoch + 1 # 更新最佳epochcounter = 0 # 重置計數器# 保存最佳模型torch.save(model.state_dict(), 'best_model.pth')else:counter += 1if counter >= patience:print(f"早停觸發!在第{epoch+1}輪,測試集損失已有{patience}輪未改善。")print(f"最佳測試集損失出現在第{best_epoch}輪,損失值為{best_test_loss:.4f}")early_stopped = Truebreak # 終止訓練循環
邏輯:
- 首先初始一個計數器counter。
- 每 200 輪訓練執行一次判斷:比較當前損失與歷史最佳損失。
? - 若當前損失更低,保存模型參數。
? - 若當前損失更高或相等,計數器加 1。
? ? - 若計數器達到最大容許的閾值patience,則停止訓練。
@浙大疏錦行