pytorch小記(七):pytorch中的保存/加載模型操作
- 1. 加載模型參數 (`state_dict`)
- 1.1 保存模型參數
- 1.2 加載模型參數
- 1.3 常見變種
- 1.3.1 指定加載設備
- 1.3.2 非嚴格加載(跳過部分層)
- 1.3.3 打印加載的參數
- 2. 加載整個模型
- 2.1 保存整個模型
- 2.2 加載整個模型
- 2.3 注意事項
- 3. 總結
- 4. 加載模型的完整代碼示例
- 4.1 保存和加載參數
- 4.2 保存和加載整個模型
- 4.3 加載到不同設備
- 4.4 忽略部分參數(非嚴格加載)
- 5. 檢查模型是否加載成功
在 PyTorch 中,加載模型通常分為兩種情況:加載模型參數(state_dict) 和 加載整個模型。以下是加載模型的所有相關操作及其詳細步驟:
1. 加載模型參數 (state_dict
)
當僅保存了模型的參數時(使用 model.state_dict()
保存),加載模型的步驟如下:
1.1 保存模型參數
torch.save(model.state_dict(), 'model.pth')
- 文件內容:只保存模型的參數(權重和偏置)。
- 優點:
- 節省存儲空間。
- 靈活性更高,可以與不同的模型架構配合使用。
- 缺點:
- 需要手動重新定義模型結構。
1.2 加載模型參數
-
重新定義模型架構:
model = MyModel() # 替換為你的模型類
-
加載參數:
state_dict = torch.load('model.pth') # 加載參數字典 model.load_state_dict(state_dict) # 加載參數到模型
-
選擇運行設備:
model.to('cuda') # 如果需要運行在 GPU 上
1.3 常見變種
1.3.1 指定加載設備
- 如果保存時模型在 GPU 上,而加載時在 CPU 環境中,可以使用
map_location
:state_dict = torch.load('model.pth', map_location='cpu')
1.3.2 非嚴格加載(跳過部分層)
- 如果保存的參數與模型結構不完全匹配(例如額外的層或不同的順序),可以使用
strict=False
:model.load_state_dict(state_dict, strict=False)
1.3.3 打印加載的參數
- 可以檢查參數字典的內容:
print(state_dict.keys())
2. 加載整個模型
當模型是通過 torch.save(model)
保存時,文件包含了模型的結構和參數,加載更為簡單。
2.1 保存整個模型
torch.save(model, 'model_full.pth')
- 文件內容:包含模型的架構和參數。
- 優點:
- 無需重新定義模型結構。
- 直接加載并使用。
- 缺點:
- 文件依賴于保存時的代碼版本(如模型定義)。
- 文件體積較大。
2.2 加載整個模型
model = torch.load('model_full.pth')
model.to('cuda') # 如果需要在 GPU 上運行
2.3 注意事項
- 動態定義的模型:
- 如果模型結構是動態定義的(如包含條件邏輯),保存和加載整個模型可能會依賴于代碼的一致性。
- 確保在加載時導入了與保存時相同的模型類。
3. 總結
操作 | 使用場景 | 優點 | 缺點 |
---|---|---|---|
保存參數 (state_dict ) | 推薦大多數情況 | 文件小、靈活性高 | 需要手動定義模型架構 |
保存整個模型 | 模型復雜且固定時 | 不需要重新定義模型,直接加載 | 文件大、依賴保存時的代碼版本 |
4. 加載模型的完整代碼示例
4.1 保存和加載參數
import torch
import torch.nn as nn# 定義模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)# 保存參數
model = MyModel()
torch.save(model.state_dict(), 'model.pth')# 加載參數
model = MyModel() # 重新定義模型
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
model.to('cuda') # 運行在 GPU
4.2 保存和加載整個模型
# 保存整個模型
torch.save(model, 'model_full.pth')# 加載整個模型
model = torch.load('model_full.pth')
model.to('cuda') # 運行在 GPU
4.3 加載到不同設備
# 保存參數
torch.save(model.state_dict(), 'model.pth')# 加載到 CPU
state_dict = torch.load('model.pth', map_location='cpu')
model.load_state_dict(state_dict)# 加載到 GPU
model.to('cuda')
4.4 忽略部分參數(非嚴格加載)
# 保存參數
torch.save(model.state_dict(), 'model.pth')# 加載參數(非嚴格模式)
model = MyModel()
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict, strict=False)
5. 檢查模型是否加載成功
-
驗證權重是否加載
for name, param in model.named_parameters():print(f"{name}: {param.data}")
-
進行推理驗證
x = torch.randn(1, 10).to('cuda') # 假設輸入維度為 10 output = model(x) print(output)
通過以上操作,你可以靈活加載 PyTorch 模型,無論是僅加載參數還是加載整個模型結構和權重。