在pytorch加載預訓練模型時,可能遇到以下幾種情況。
分為以下幾種
- 在pytorch加載預訓練模型時,可能遇到以下幾種情況。
- 1.多卡訓練模型加載單卡預訓練模型
- 2. 多卡訓練模型加載多卡預訓練模型
- 3. 單卡訓練模型加載單卡預訓練模型
- 4. 單卡訓練模型加載多卡預訓練模型
- 5.直接刪除預訓練模型中不匹配的鍵
- 6. 新版torch的模型加載torch<0.4 版本模型
- 7.在加載的參數模型中增加缺失的鍵,然后賦予隨機參數
問題分為幾種情況:
1.多卡訓練模型加載單卡預訓練模型
if isinstance(self.netG, torch.nn.DataParallel):self.netG = self.netG.module
self.netG.load_state_dict(torch.load(path))
這是多卡訓練的模型加載單卡訓練的模型出現的問題。
2. 多卡訓練模型加載多卡預訓練模型
self.netG.load_state_dict(torch.load(path))
3. 單卡訓練模型加載單卡預訓練模型
self.netG.load_state_dict(torch.load(path))
4. 單卡訓練模型加載多卡預訓練模型
對預訓練模型創建新的字典,去掉key值前面的’module.’
state_dict = torch.load('checkpoint.pt’)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k,v in state_dict.items():name = k[7:]new_state_dict[name] =v
self.netG.load_state_dict(new_state_dict)
5.直接刪除預訓練模型中不匹配的鍵
model = DPN(num_init_features=64, k_R=96, G=32, k_sec=(3,4,20,3), inc_sec=(16,32,24,128), num_classes=1,decoder=args.decoder)http = {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'}pretrained_dict=model_zoo.load_url(http['url'])model_dict = model.state_dict()pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys model_dict.update(pretrained_dict)model.load_state_dict(model_dict)model = torch.nn.DataParallel(model).cuda()
6. 新版torch的模型加載torch<0.4 版本模型
baol
7.在加載的參數模型中增加缺失的鍵,然后賦予隨機參數
在state_dict 參數模型中增加開頭是conv1一些鍵
state_dict = torch.load(path, map_location=self.device)
model_dict = self.netG_A.state_dict()for k,v in model_dict.items():if k.startswith('conv11') or k.startswith('conv21') or k.startswith('conv31'):state_dict[k] = vself.netG_A.load_state_dict(state_dict)