SAVE AND LOAD THE MODEL (保存和加載模型)
PyTorch
?模型存儲學習到的參數在內部狀態字典中,稱為?state_dict
, 他們的持久化通過?torch.save
?方法。
model = models.shufflenet_v2_x0_5(pretrained=True)
torch.save(model, "../../data/ShuffleNetV2_X0.5.pth")
如果要加載模型的話,首先需要實例化一個同類型的模型對象,然后用 load_state_dict() 方法加載參數。
model = models.shufflenet_v2_x0_5()
model.load_state_dict(torch.load("../../data/ShuffleNetV2_X0.5.pth"))
model.eval()
Output exceeds the size limit. Open the full output data in a text editor
ShuffleNetV2((conv1): Sequential((0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(stage2): Sequential((0): InvertedResidual((branch1): Sequential((0): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)(3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(4): ReLU(inplace=True))(branch2): Sequential((0): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)(4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)(6): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): ReLU(inplace=True)
...(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(fc): Linear(in_features=1024, out_features=1000, bias=True)
)
Saving and Loading Models with Shapes
當加載模型權重時,我們需要首先實例化模型類,因為類定義了網絡的結構。我們可能想要保存類的結構以及模型,在這種情況下,我們可以將 model (而不是 model.state_dict() ) 傳遞給保存函數:
?
torch.save(model, "../../data/ShuffleNetV2_X0.5_eval2.pth")
加載模型如這樣:
model = torch.load("../../data/ShuffleNetV2_X0.5_eval2.pth")
print(model)
這種方法在序列化模型時使用 Python?pickle?模塊,因此它依賴于加載模型時可用的實際類定義。
Lnton羚通專注于音視頻算法、算力、云平臺的高科技人工智能企業。 公司基于視頻分析技術、視頻智能傳輸技術、遠程監測技術以及智能語音融合技術等, 擁有多款可支持ONVIF、RTSP、GB/T28181等多協議、多路數的音視頻智能分析服務器/云平臺。