Pytorch筆記一之 cpu模型保存、加載與推理
1.保存模型
首先,在加載模型之前,我們需要了解如何保存模型。PyTorch 提供了兩種保存模型的方法:保存整個模型和僅保存模型的狀態字典(state dict)。推薦使用第二種方式,因為它更靈活且體積較小。
import torch
import torch.nn as nn# 定義一個簡單的神經網絡
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 實例化模型并進行訓練
model = SimpleNN()
# 模型訓練過程(省略)# 保存模型的狀態字典
torch.save(model.state_dict(), 'simple_nn.pth')
2. 加載模型
一旦你保存了模型,接下來就可以加載它。在加載過程中,確保模型的架構與訓練時一致。以下是加載模型的步驟:
- 1.創建一個模型實例
- 2.調用 load_state_dict() 方法加載狀態字典
代碼示例如下:
# 重新定義模型架構
model = SimpleNN()# 加載模型狀態字典
model.load_state_dict(torch.load('simple_nn.pth', map_location=torch.device('cpu')))
3. 在 CPU 上進行推理
完成模型加載后,接下來就可以使用模型進行推理。以下是一個簡單的示例:
# 模擬輸入數據
input_data = torch.randn(1, 10)# 在 CPU 上進行推理
with torch.no_grad(): # 禁用梯度計算,節省內存output = model(input_data)print(output)