pytorch:加載預訓練模型(多卡加載單卡預訓練模型,多GPU,單GPU)

在pytorch加載預訓練模型時,可能遇到以下幾種情況。

分為以下幾種

  • 在pytorch加載預訓練模型時,可能遇到以下幾種情況。
    • 1.多卡訓練模型加載單卡預訓練模型
    • 2. 多卡訓練模型加載多卡預訓練模型
    • 3. 單卡訓練模型加載單卡預訓練模型
    • 4. 單卡訓練模型加載多卡預訓練模型
    • 5.直接刪除預訓練模型中不匹配的鍵
    • 6. 新版torch的模型加載torch<0.4 版本模型
    • 7.在加載的參數模型中增加缺失的鍵,然后賦予隨機參數

問題分為幾種情況:

1.多卡訓練模型加載單卡預訓練模型

if isinstance(self.netG, torch.nn.DataParallel):self.netG = self.netG.module
self.netG.load_state_dict(torch.load(path))

在這里插入圖片描述
這是多卡訓練的模型加載單卡訓練的模型出現的問題。

2. 多卡訓練模型加載多卡預訓練模型

self.netG.load_state_dict(torch.load(path))

3. 單卡訓練模型加載單卡預訓練模型

self.netG.load_state_dict(torch.load(path))

4. 單卡訓練模型加載多卡預訓練模型

對預訓練模型創建新的字典,去掉key值前面的’module.’

state_dict = torch.load('checkpoint.pt’)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k,v in state_dict.items():name = k[7:]new_state_dict[name]  =v 
self.netG.load_state_dict(new_state_dict)

5.直接刪除預訓練模型中不匹配的鍵

 model = DPN(num_init_features=64, k_R=96, G=32, k_sec=(3,4,20,3), inc_sec=(16,32,24,128), num_classes=1,decoder=args.decoder)http = {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'}pretrained_dict=model_zoo.load_url(http['url'])model_dict = model.state_dict()pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys model_dict.update(pretrained_dict)model.load_state_dict(model_dict)model = torch.nn.DataParallel(model).cuda()

6. 新版torch的模型加載torch<0.4 版本模型

baol

7.在加載的參數模型中增加缺失的鍵,然后賦予隨機參數

在state_dict 參數模型中增加開頭是conv1一些鍵

state_dict = torch.load(path, map_location=self.device)
model_dict = self.netG_A.state_dict()for k,v in model_dict.items():if k.startswith('conv11') or k.startswith('conv21') or k.startswith('conv31'):state_dict[k] = vself.netG_A.load_state_dict(state_dict)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/535199.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/535199.shtml
英文地址,請注明出處:http://en.pswp.cn/news/535199.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

python中 numpy轉list list 轉numpy

list to numpy import numpy as np a [1,2] b np.array(a)numpy to list a np.zero(1,1) a.tolist()

知識蒸餾 knowledge distill 相關論文理解

Knowledge Distil 相關文章1.FitNets : Hints For Thin Deep Nets &#xff08;ICLR2015&#xff09;2.A Gift from Knowledge Distillation&#xff1a;Fast Optimization, Network Minimization and Transfer Learning (CVPR 2017)3.Matching Guided Distillation&#xff08…

模型壓縮 相關文章解讀

模型壓縮相關文章Learning both Weights and Connections for Efficient Neural Networks (NIPS2015)Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffman Coding&#xff08;ICLR2016&#xff09;Learning both Weights and …

Linux 殺死進程

kill -9 進程名 殺死進程

計算圖片相似度的方法

文章目錄1.余弦相似度計算2.哈希算法計算圖片的相似度3.直方圖計算圖片的相似度4.SSIM&#xff08;結構相似度度量&#xff09;計算圖片的相似度5.基于互信息&#xff08;Mutual Information&#xff09;計算圖片的相似度1.余弦相似度計算 把圖片表示成一個向量&#xff0c;通…

.size .shape .size() type的運用

.size ndarray.size 數組元素的總個數&#xff0c;相當于 .shape 中 n*m 的值 a np.array([2,2]) print(a.size)2.shap ndarray.shape 數組的維度&#xff0c;對于矩陣&#xff0c;n 行 m 列 a np.array([2,2]) print(a.shape) (1,2)torch.tensor 數組的維度 x torch.r…

矩陣相加

tensor 類型 a torch.randn(1,3,3) b torch.randn(1,3,3) c a b numpy.array 類型 a np.array([2,2]) b np.array([2,2]) print(type(a)) print(ab)[4,4]

Latex 生成的PDF增加行號 左右兩邊

增加行號 \usepackage[switch]{lineno}\linenumbers \begin{document} \end{document}

pytorh .to(device) 和.cuda()的區別

原理 .to(device) 可以指定CPU 或者GPU device torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 單GPU或者CPU model.to(device) #如果是多GPU if torch.cuda.device_count() > 1:model nn.DataParallel(model&#xff0c;devic…

Linux 修改用戶名的主目錄 家目錄

首先root 登陸 sudo -i 輸入密碼然后 vim /etc/passwd 找到用戶名 然后修改后面的路徑即可

ubunt16.04 安裝3090顯卡驅動 cuda cudnn pytorch

安裝驅動 需要的安裝包 30系列顯卡是新一代架構&#xff0c;新驅動不支持cuda9以及cuda10&#xff0c;所以必須安裝cuda11、而pytorch現在穩定版為1.6&#xff0c;最高僅支持到cud10.2。所以唯一的辦法就是使用上處于beta測試的1.7或1.8。這也是為啥一開始就強調本文的寫作時…

3090顯卡 torch.cuda.is_available()返回false的解決辦法

問題 1.執行Nvidia-smi 命令沒有報錯&#xff0c;能夠顯示驅動信息&#xff1b; 2.執行 torch.backends.cudnn.enabled is TRUE 3.torch.cuda.is_available()一直返回False 解決 把torch&#xff0c;torchvision等相關安裝包全部刪除&#xff0c;安裝適合版本的torch。 30系…

測試項目:車牌檢測,行人檢測,紅綠燈檢測,人流檢測,目標識別

本項目為2020年中國軟件杯&#xff21;組第一批賽題&#xff02;基于計算機視覺的交通場景智能應用&#xff02;&#xff0e;項目用python實現&#xff0c;主要使用YOLO模型實現道路目標如人、車、交通燈等物體的識別&#xff0c;使用開源的&#xff02;中文車牌識別HyperLPR&a…

pytorch: 在訓練中保存模型,加載模型

文章目錄1. 保存整個模型2.僅保存和加載模型參數(推薦使用)3. 保存其他參數到模型中&#xff0c;比如optimizer,epoch 等1. 保存整個模型 torch.save(model, model.pkl) model torch.load(model.pkl)2.僅保存和加載模型參數(推薦使用) torch.save(model_object.state_dict()…

Pytorch:保存圖片

1. 直接保存Tensor #!/usr/bin/env python # _*_ coding:utf-8 _*_ import torch from torchvision import utils as vutilsdef save_image_tensor(input_tensor: torch.Tensor, filename):"""將tensor保存為圖片:param input_tensor: 要保存的tensor:param fi…

Python List:合并多個list,listd的合并

第一種方法 a [1,3,3] b [2,3,3] a a b print(a) [1,3,3,2,3,3]第二種方法 a [1,3,3] b [2,3,3] a.extend(b) print(a) [1,3,3,2,3,3]

掛載硬盤問題:mount: wrong fs type, bad option, bad superblock on /dev/sdb,

mount: wrong fs type, bad option, bad superblock on /dev/sdb,missing codepage or helper program, or other error 解決方案&#xff1a; # create mount dir sudo mkdir /hdd6T# new file system sudo mkfs.ext4 /dev/sdc# mount drive sudo mount /dev/sdc /hdd6T/# c…

linux 安裝python3.8的幾種方法

1.命令行搞定 git clone https://github.com/waketzheng/carstino cd carstino python3 upgrade_py.py2.離線安裝 自己在官網下載安裝包 https://www.python.org/ftp/python/3.8.0/ 解壓&#xff1a; tar -zvf Python-3.8.0.tgz安裝 cd Python-3.8.0 ./configure --prefix/u…

面試題目:欠擬合、過擬合及如何防止過擬合

對于深度學習或機器學習模型而言&#xff0c;我們不僅要求它對訓練數據集有很好的擬合&#xff08;訓練誤差&#xff09;&#xff0c;同時也希望它可以對未知數據集&#xff08;測試集&#xff09;有很好的擬合結果&#xff08;泛化能力&#xff09;&#xff0c;所產生的測試誤…

LaTeX:equation, aligned 書寫公式換行,頂部對齊

使用aligined 函數&#xff0c;其中aligned就是用來公式對齊的&#xff0c;在中間公式中&#xff0c;\ 表示換行&#xff0c; & 表示對齊。在公式中等號之前加&&#xff0c;等號介紹要換行的地方加\就可以了。 \begin{equation} \begin{aligned} L_{task} &\lamb…