文章目錄
- 1. 推理用: 保存 & 加載權重 (最常見)
- 2. 繼續訓練用: 保存 & 加載完整狀態
- 3. 微調用: 部分加載 (分類頭不同等情況)
1. 推理用: 保存 & 加載權重 (最常見)
import torch
import torch.nn as nnmodel = nn.Linear(10, 2)# 保存權重
torch.save(model.state_dict(), "model.pt")# 加載權重 (推理/評估)
model2 = nn.Linear(10, 2)
state = torch.load("model.pt", map_location="cpu")
model2.load_state_dict(state)
model2.eval() # 推理時別忘了
2. 繼續訓練用: 保存 & 加載完整狀態
# ===== 保存 =====
torch.save({"epoch": epoch,"model": model.state_dict(),"optimizer": optimizer.state_dict(),"scheduler": scheduler.state_dict(),
}, "ckpt.pt")# ===== 加載 =====
ckpt = torch.load("ckpt.pt", map_location="cpu")
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
start_epoch = ckpt["epoch"] + 1
model.train() # 繼續訓練前別忘了
3. 微調用: 部分加載 (分類頭不同等情況)
state = torch.load("pretrain.pt", map_location="cpu")
# 只加載匹配的層, 其余保持初始化
missing, unexpected = model.load_state_dict(state, strict=False)
print("未加載:", missing) # 模型需要,但 checkpoint 里沒有的
print("多余:", unexpected) # checkpoint 里有,但模型不需要的
? 速記
- 推理: 保存/加載
model.state_dict()
- 繼續訓練: 把 optimizer/scheduler 一并保存
- 微調:
strict=False
部分加載 - 安全:
map_location="cpu"
加載更通用 - 模式: 推理用
model.eval()
,訓練用model.train()