模型網絡文件是深度學習模型的存儲形式,保存了模型的架構、參數等信息。
讀寫模型網絡文件是深度學習流程中的關鍵環節,方便模型的訓練、測試、部署與共享。
1. 主流框架讀寫方法
(一)TensorFlow
保存模型
可以使用
tf.saved_model.save
方法保存整個模型,包括架構、參數、編譯信息等。例如:model.save('model_dir', save_format='tf')
,將模型保存在文件夾 'model_dir' 中。
加載模型
使用
tf.keras.models.load_model
加載保存的模型。如:loaded_model = tf.keras.models.load_model('model_dir')
,即可加載之前保存的模型進行預測、繼續訓練等操作。
(二)PyTorch
使用 torch.save 和 torch.load 來保存和加載 張量。
保存模型
通常有兩種方式:一種是保存整個模型對象,使用
torch.save(model, 'model.pth')
,將模型結構和參數都保存下來。另一種是僅保存模型的參數狀態字典,即torch.save(model.state_dict(), 'model_state_dict.pth')
,這種方式更常見,因為當模型架構修改時,只要能正確加載參數,就無需重新訓練整個模型。
加載模型
對于保存整個模型的情況,直接使用
model = torch.load('model.pth')
。對于僅保存參數的情況,先定義好模型架構,再用model.load_state_dict(torch.load('model_state_dict.pth'))
加載參數,使模型具備相應的能力。
對于深度學習模型而言,通常只需保存其權重參數即可滿足需求。在 PyTorch 框架中,可以使用 torch.save()?函數來保存網絡的 state_dict?參數,這是保存模型權重的一種高效方式。
而在加載模型權重時,可以借助網絡的 load_state_dict() 方法,搭配 torch.load()?函數來實現對網絡參數的讀取,從而恢復模型的訓練狀態和性能表現。
2. 模型保存示例
torch.save(model.state_dict(), path)
只保存“參數”(一個純字典),文件小、加載靈活。
torch.save(model.state_dict(), "best_model.pt")
1. 加載時必須先重新建網絡,再把參數填進去:
new_model = MyNet() # 重新建圖
new_model.load_state_dict(torch.load("best_model.pt"))
new_model.eval() # 記得切到推理模式
2. 優點
- 文件 ≈ 僅參數大小,磁盤占用小
- 不關心原始類定義,跨代碼版本更穩
3. 缺點
????????需要手動重建網絡結構才能用
torch.save(model, path)
:把整個模型(結構+參數)序列化為一個 Pickle 對象,一步到位。
torch.save(model, "full_model.pt")
1. 加載極其簡單:
model = torch.load("full_model.pt") # 結構+參數全回來
model.eval()
2. 優點
????????一行代碼即可復現模型,適合快速分享、斷點繼續訓練
3. 缺點
Pickle 會硬編碼類定義路徑,代碼位置/類名一變就加載失敗
文件更大(含結構+參數)
選用建議
生產/長期維護 → 用 state_dict(穩妥、小、可遷移)。
臨時 checkpoint / 本地快速實驗 → 用 完整模型(省事)。