文章目錄
- 1 使用現有網絡
- 2 修改網絡結構
- 2.1 添加新層
- 2.2 替換現有層
- 3 保存網絡模型
- 3.1 完整保存
- 3.2 參數保存(推薦)
- 4 加載網絡模型
- 4.1 加載完整模型文件
- 4.2 加載參數文件
- 5 Checkpoint
- 5.1 保存 Checkpoint
- 5.2 加載 Checkpoint
本文環境:
- Pycharm 2025.1
- Python 3.12.9
- Pytorch 2.6.0+cu124
? PyTorch 通過torchvision.models
提供預訓練模型(如 VGG16)。
? 網址鏈接:https://docs.pytorch.org/vision/stable/models.html。
1 使用現有網絡
? 以 VGG16 為例,進入網址:https://docs.pytorch.org/vision/stable/models/generated/torchvision.models.vgg16.html#torchvision.models.vgg16。

方法一:使用隨機初始化權重
? 將 weights 設置為 None,從 0 開始訓練自己的網絡。
vgg16_false = torchvision.models.vgg16(weights=None) # 權重隨機初始化
方法二:加載預訓練權重
? 也可以使用預訓練好的網絡參數,加載后可直接使用網絡。
這將從官網上下載已訓練好的模型文件。
vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
? 可打印網絡查看其模型結構:
print(vgg16_true)


2 修改網絡結構
2.1 添加新層
? 使用add_module
在分類器(classifier
)后追加全連接層:
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))

2.2 替換現有層
? 直接修改分類器的最后一層(如適配 CIFAR10 的 10 分類任務):
vgg16_false.classifier[6] = nn.Linear(4096, 10) # 替換第6層

3 保存網絡模型
? 使用torch.save()
方法保存網絡模型。文件擴展名推薦使用.pt
或.pth
。
3.1 完整保存
? 將模型類和參數一并保存到文件中。
torch.save(vgg16, 'vgg16_method1.pth') # 包含模型類和參數
- 優點:加載時無需重新定義模型結構。
- 缺點:文件較大,且依賴原始代碼環境(見 4.1 節)。
3.2 參數保存(推薦)
? 僅保存參數字典到文件中。
torch.save(vgg16.state_dict(), 'vgg16_method2.pth') # 僅保存參數字典
- 優點:文件小,靈活性強,適合生產部署。
示例
import torch
import torchvision.models
from torch import nnvgg16 = torchvision.models.vgg16(weights=None)# 保存方式 1,模型結構 + 模型參數
torch.save(vgg16, 'vgg16_method1.pth')# 保存方式 2,模型參數(官方推薦)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')
4 加載網絡模型
? 使用torch.load()
方法加載網絡模型。
4.1 加載完整模型文件
? 加載完整模型時,需將 weights_only 參數設置為 False。
model = torch.load('vgg16_method1.pth', weights_only=False) # 需確保模型類已定義
? 模型打印結果如下:
print(model)

注意
? 若保存自定義模型,加載時必須確保環境中也有該模型的定義,否則會出現報錯。
model_save.py
# model_save.pyimport torch from torch import nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3)def forward(self, x):return self.conv1(x)model = MyModel() torch.save(model, 'my_model_method1.pth')
model_load.py
import torchmodel = torch.load('my_model_method1.pth', weights_only=False) # 報錯,找不到 MyModel 的定義
先運行 model_save.py,再運行 model_load.py,則會出現以下報錯:
![]()
?
4.2 加載參數文件
? 首先,使用torch.load()
方法加載網絡模型。
? 使用模型時,需先創建匹配的網絡結構,再使用model.load_state_dict()
加載參數數據。
vgg16 = torchvision.models.vgg16(weights=None)
model_dict = torch.load('vgg16_method2.pth')
vgg16.load_state_dict(model_dict) # 需結構匹配
? 模型打印結果是參數字典:
print(model_dict)

注意
? 模型保存時若在 GPU 上,加載時需指定 map_location 為 cup。
torch.load('model.pth', map_location=torch.device('cpu'))
? 將參數加載到模型后,手動遷移到 GPU:
model = MyModel() model.load_state_dict(model_dict) model.to('cuda:0')
5 Checkpoint
? 使用 Checkpoint 可以在訓練過程中定期保存模型的狀態,以便在中斷后可以恢復訓練,或者在測試時使用最終的模型。文件擴展名推薦使用.tar
。
5.1 保存 Checkpoint
? 要保存一個模型的 Checkpoint,通常需要保存以下數據:
- 模型的 state_dict(狀態字典);
- 優化器的狀態;
- 額外的信息,如 epoch 等。
import torch# 假設 model 是你的模型,optimizer 是你的優化器
checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss
}# 保存checkpoint
torch.save(checkpoint, 'checkpoint.tar')
5.2 加載 Checkpoint
? 加載 Checkpoint,首先需要加載文件,然后將其內容恢復到模型和優化器的狀態中。
# 假設 model 和 optimizer 是你的模型和優化器實例
checkpoint = torch.load('checkpoint.tar')model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']# 如果需要,可以繼續訓練
model.train() # 確保模型處于訓練模式