#d 兩種保存方式比較
僅保存模型參數
優點:
- 更加靈活,只保存模型的參數,不保存模型的結構,可以在不同的模型結構中加載參數(只要參數匹配)。
- 文件大小通常比保存整個模型小。
- 安全性更高,因為不直接執行pickle內容。
缺點:
- 加載模型前需要先定義模型的結構,增加了代碼量。
保存整個模型
優點:
- 保存簡單,一行代碼完成。
- 加載模型時不需要再定義模型的結構。
缺點:
- 保存的模型依賴于具體的類定義,如果模型的結構有所改變(例如類名、層的結構等),加載時可能會出現問題。
- 文件通常比僅保存狀態字典的方式大。
- 可能存在安全風險,因為
torch.load
會加載任何pickle內容。
總結:
僅保存模型的參數(狀態字典)是更加推薦的方式,因為它更加靈活和安全。但是,如果你想要快速保存和加載整個模型,不擔心模型結構變化或安全問題,保存整個模型也是一個可行的選擇。
1 僅保存模型參數
#c 說明 保存加載方式
PyTorch保存模型的「學習參數」是通過state_dict
的一個內部狀態字典,使用torch.save
來保存模型的學習參數。
#e 模型保存方式一
model = models.vgg16(weights='IMAGENET1K_V1')
'''
vgg16是一個非常流行的卷積神經網絡,經過了大量的訓練,可以識別1000個不同的對象。
weights='IMAGENET1K_V1'表示加載了在ImageNet數據集上預訓練的權重。
'''
torch.save(model.state_dict(), 'model_weights.pth')#狀態字典與保存路徑
#e 模型加載方式一
加載模型權重,首先需要創建一個與「原始模型相同的模型實例」,然后使用load_state_dict
方法加載參數。
注意:需要使用model.eval()
方法將模型設置為評估模式,這將關閉Dropout和BatchNorm層。否則將會導致不一致的推理結果。
model = models.vgg16()#加載模型
model.load_state_dict(torch.load('model_weights.pth'))#加載模型權重
model.eval()#設置模型為評估模式
2 保存整個模型
#c 說明 保存整個模型
在加載模型權重時,需要首先實例化模型類,因為模型類定義了網絡的結構。如果希望將模型類的架構與模型一起保存,那么可以傳遞模型本身(而不是模型的狀態字典model.state_dict())給保存函數。
#e 模型保存方式二
torch.save(model, 'model.pth')#保存模型
#e 模型加載方式二
model = torch.load('model.pth')#加載模型