Pytorch(4)-模型保存-載入-eval()

模型保存與提取

  • 1. 整個模型 保存-載入
  • 2. 僅模型參數 保存-載入
  • 3. GPU/CPU模型保存與導入
  • 4. net.eval()--固定模型隨機項

神經網絡模型在線訓練完之后需要保存下來,以便下次使用時可以直接導入已經訓練好的模型。pytorch 提供兩種方式保存模型:

方式1:保存整個網絡,載入時直接載入整個網絡,優點:代碼簡單,缺點需要的存儲空間大

方式2:只保存網絡參數,載入時需要先建立與原來網絡一樣結構的網絡,然后將網絡參數導入到該網絡中,方式2的優缺點與方式1相反。

1. 整個模型 保存-載入

模型的結構參數都保存下來了

# 保存模型:設置 保存目錄 和 保存文件名.擴展名,常用擴展名: .pkl .pth (擴展名只要好辨識就即可)
PATH="./model/mynet1.pkl"
# 導入官方提供的預訓練模型
net1=torchvision.models.alexnet(pretrainend=True)
# 用數據集訓練網絡
.....
# 保存訓練好的網絡
torch.save(net1, PATH)
-----------------------------------------------------------
# 載入模型:設置載入路徑,即模型保存的路徑
PATH="./model/mynet1.pkl"
net1_1=torch.load(PATH)

2. 僅模型參數 保存-載入

保存時–只保存網絡中的參數 (速度快, 占內存少), 載入時–需要提前創建好結構和net2是一樣的

# 保存模型:設置 保存目錄 和 保存文件名.擴展名,常用擴展名: .pkl .pth (擴展名只要好辨識就即可)
PATH="./model/mynet2.pkl"
# 導入官方提供的預訓練模型
net2=torchvision.models.alexnet(pretrainend=True)
# 用數據集訓練網絡
.....
# 保存訓練好的網絡
torch.save(net1.state_dict(), PATH)
-----------------------------------------------------------
# 載入模型:設置載入路徑,即模型保存的路徑
PATH="./model/net2.pkl"
# 新建一個網絡
net2_2=torchvision.models.alexnet(pretrained=True)
# 載入模型參數
net2_2.load_state_dict(torch.load(PATH))

迷糊的現象

在使用莫煩的文檔做實驗時,保存的兩個文件:net.pkl,net_params.pkl大小差異比較大。保證在導入模型是比較快。
在這里插入圖片描述
但是使用torchvision.models.模塊中的一系列網絡時,因為網絡的參數很大,所以實驗過程中用兩種方法保存模型的文件大小是一致的。(猜測是內置模型使用torch.save(net1, ‘net.pkl’)時默認保存的是模型參數)

提供一個神經網絡模型占用空間大小的計算方法:
在這里插入圖片描述
參考文檔:經典CNN模型計算量與內存需求分析

3. GPU/CPU模型保存與導入

在訓練是模型是GPU/CPU,決定了模型載入時的模型原型。可以分為下面三種情況
(只展示導入整個網絡模型的情況,具體實驗還沒做過):

1.CPU(原型)->CPU, GPU(原型)->GPU

torch.load( ‘net.pkl’)

2.GPU(原型)->CPU

torch.load(‘model_dict.pkl’, map_location=lambda storage, loc: storage)

3.CPU(模型文件)->GPU

torch.load(‘model_dic.pkl’, map_location=lambda storage, loc: storage.cuda)

參考文檔:https://blog.csdn.net/u012135425/article/details/85217542

4. net.eval()–固定模型隨機項

兩種模型載入方式、.eval() 作用實驗demo

step1: 載入模型

# 20191204 pytorch 模型載入測試
import torchvision as tvt
import torch
net1=tvt.models.alexnet(pretrained=True)  # 1.自動從網上下載的預先訓練模型
net2=torch.load("./model/mynet1.pkl")     # 2.導入事先訓練好的保存的整個網絡net3=tvt.models.alexnet(pretrained=True)  # 3.導入只保存模型參數的網絡,需要新建一個網絡
net3.load_state_dict(torch.load("../model/mynet2.pkl"))
net3.eval()                              #   固定dropout和歸一化層,否則每次推理會生成不同的結果。

step2:輸出三個網絡同一層參數的和,net2 和net3 對應參數相等。可以看出來,兩種模型保存和導入方式是等價的。

net1 tensor(-21257.7656, grad_fn=<SumBackward0>)
net2 tensor(-21253.9473, device='cuda:0', grad_fn=<SumBackward0>)
net3 tensor(-21253.9551, grad_fn=<SumBackward0>)

step3: 產生一個隨機輸入a,輸入到網絡1,2,3,打印輸出結果。

a=torch.randn([1,3,224,224])
y1=net1(a)
y2=net2(a)
y3=net3(a)
# 第二次輸入
y11=net1(a)
y22=net2(a)
y33=net3(a)
# 打印y1,y2,y3,y11,y22,y33(1000維的和)
y1: tensor(-5.2689, grad_fn=<SumBackward0>)
y2: tensor(-1.6695, device='cuda:0', grad_fn=<SumBackward0>)
y3: tensor(-4.4349, device='cuda:0', grad_fn=<SumBackward0>)y11: tensor(-4.4205, grad_fn=<SumBackward0>)
y22: tensor(-5.9475, device='cuda:0', grad_fn=<SumBackward0>)
y33: tensor(-4.4349, device='cuda:0', grad_fn=<SumBackward0>)

只有net3的輸出是固定的,因為在模型導入的時候執行了net3.eval().
結論:無論采用 方式1 還是 方式2 導入的模型, 在模型測試時,都需要用.eval()方法固定一下網絡在訓練過程中的隨機項目,如dropout 等,避免網絡在同一個輸入下產生不一樣的結果。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/445237.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/445237.shtml
英文地址,請注明出處:http://en.pswp.cn/news/445237.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

大數據學習(08)--Hadoop中的數據倉庫Hive

文章目錄目錄1.什么是數據倉庫&#xff1f;1.1數據倉庫概念1.2傳統數據倉庫面臨的挑戰1.3 Hive介紹1.4 Hive與傳統數據庫的對比1.5 Hive在企業中的部署與應用2.Hive系統架構3.Hive工作原理3.1 SQL轉換為MapReduce作業的基本原理3.2 Hive中SQL查詢轉換MapReduce作業的過程4.Hive…

dubbo知識點總結 持續更新

Dubbo 支持哪些協議&#xff0c;每種協議的應用場景&#xff0c;優缺點&#xff1f; ? dubbo&#xff1a; 單一長連接和 NIO 異步通訊&#xff0c;適合大并發小數據量的服務調用&#xff0c; 以及消費者遠大于提供者。傳輸協議 TCP&#xff0c;異步&#xff0c;Hessian 序列化…

使用Linux auto Makefile自動生成的運行步驟

首先創建一個 Linux Makefile.am.這一步是創建Linux Makefile很重要的一步&#xff0c;automake要用的腳本配置文件是Linux Makefile.am&#xff0c;用戶需要自己創建相應的文件。之后&#xff0c;automake工具轉換成Linux Makefile.in。AD&#xff1a; 在向大家詳細介紹Linux …

無限踩坑系列(6)-mySQL數據庫鏈接錯誤

mySQL數據庫鏈接錯誤錯誤1錯誤2長鏈接短連接應用場景需要一直訪問mySQL數據庫&#xff0c;遇到如下錯誤&#xff1a;錯誤1 釋放已經釋放的數據庫鏈接conn.&#xff0c;或者&#xff0c;操作已經釋放的數據庫鏈接conn.或者失去鏈接后再操作數據庫都可能會報這個錯誤 aise err.I…

初探函數式編程和面對對象式編程

文章目錄目錄1.函數式編程和面向對象編程概念1.1 函數式編程1.2 面向對象編程2.函數式編程和面向對象編程的優缺點2.1 函數式編程優點缺點2.2 面對對象編程優點缺點3.為什么在并行計算中函數式編程比較好3.1 什么是并行計算3.2 函數式編程興起原因目錄 1.函數式編程和面向對象…

linux常用解壓和壓縮文件的命令

linux常用解壓和壓縮文件的命令 .tar 解包&#xff1a;tar xvf FileName.tar打包&#xff1a;tar cvf FileName.tar DirName&#xff08;注&#xff1a;tar是打包&#xff0c;不是壓縮&#xff01;&#xff09;———————————————.gz解壓1&#xff1a;gunzip FileN…

Python外(4)-讀寫mat文件

讀寫mat文件1.讀取2.寫入.mat 是matlab中數據存儲的標準格式&#xff0c;Python中能夠通過庫scipy讀取和保存。導入scipy庫 from scipy import io 1.讀取 io.loadmat(file_name, mdictNone, appendmatTrue, **kwargs) 簡便方式&#xff1a; io.loadmat(file_name) append mat–…

Linux下的xml文件的創建

創建一個xml文檔流程如下&#xff1a; l 用xmlNewDoc函數創建一個文檔指針doc&#xff1b; l 用xmlNewNode函數創建一個節點指針root_node&#xff1b; l 用xmlDocSetRootElement將root_node設置為doc的根結點&#xff1b; l 給root_node添加一系列的子節點&#x…

壓力測試http_load 通過修改配置測試https協議成功了。

到http://www.acme.com/software/http_load/ 下載http_load &#xff0c;安裝也很簡單直接make;make instlall 就行。 如果你需要測試https&#xff0c;你必須將 Makefile中 # CONFIGURE: If you want to compile in support for https, uncomment these # definitions. You w…

面向對象設計與分析40講(16)靜態工廠方法模式

前面我們介紹了簡單工廠模式&#xff0c;在創建對象前&#xff0c;我們需要先創建工廠&#xff0c;然后再通過工廠去創建產品。 如果將工廠的創建方法static化&#xff0c;那么無需創建工廠即可通過靜態方法直接調用的方式創建產品&#xff1a; // 工廠類&#xff0c;定義了靜…

搜索詳解

搜索 一.dfs和bfs簡介 深度優先遍歷(dfs) 本質&#xff1a; 遍歷每一個點。 遍歷流程&#xff1a; 從起點開始&#xff0c;在其一條分支上一條路走到黑&#xff0c;走不通了就往回走&#xff0c;只要當前有分支就繼續往下走&#xff0c;直到將所有的點遍歷一遍。 剪枝&a…

Python外(5)-for-enumerate()-zip()

for循環小技巧技巧1&#xff1a;enumerate()技巧2&#xff1a;打包兩個可遍歷數據&#xff0c;一起循環-zip()技巧1&#xff1a;enumerate() 在使用pytorch訓練網絡的過程中&#xff0c;官方教程給出了 for i, data in enumerate(trainloader, 0): 這涉及到enumerate函數的使用…

特征工程總結

目錄1 特征工程是什么&#xff1f; 2 數據預處理   2.1 無量綱化     2.1.1 標準化     2.1.2 區間縮放法     2.1.3 標準化與歸一化的區別   2.2 對定量特征二值化   2.3 對定性特征啞編碼   2.4 缺失值計算   2.5 數據變換 3 特征選擇   3.1 Filter …

Jmeter測試并發https請求成功了

Jmeter2.4 如何測試多個并發https請求&#xff0c;終于成功了借此機會分享給大家 首先要安裝jmeter2.4版本的&#xff0c;而且不建議大家使用badboy&#xff0c;因為這存在兼容性問題。對于安裝&#xff0c;我就不講了&#xff0c;我就說說如何測試https&#xff0c;想必大家都…

關系數據庫——sql基礎1定義

關系數據庫標準語言SQL 基本概念 SQL語言是一個功能極強的關系數據庫語言。同時也是一種介于關系代數與關系演算之間的結構化查詢語言&#xff08;Structured Query Language&#xff09;&#xff0c;其功能包括數據定義、數據查詢、數據操縱和數據控制。 SQL的特點&#xff…

libcurl編程

一、curl簡介 curl是一個利用URL語法在命令行方式下工作的文件傳輸工具。它支持的協議有&#xff1a;FTP, FTPS, HTTP, HTTPS, GOPHER, TELNET, DICT, FILE 以及 LDAP。curl同樣支持HTTPS認證&#xff0c;HTTP POST方法, HTTP PUT方法, FTP上傳, kerberos認證, HTTP上傳, 代理服…

大數據學習(09)--Hadoop2.0介紹

文章目錄目錄1.Hadoop的發展與優化1.1 Hadoop1.0 的不足與局限1.2 Hadoop2.0 的改進與提升2.HDFS2.0 的新特性2.1 HDFS HA2.2 HDFS Federation3. 新一代的資源管理器YARN3.1 MapReduce1.0 缺陷3.2 YARN的設計思路3.3 YARN 體系結構3.4 YARN工作流程3.5 YARN框架與MapReduce1.0框…

Java多線程常用方法

start()與run() start() 啟動線程并執行相應的run()方法 run() 子線程要執行的代碼放入run()方法 getName()和setName() getName() 獲取此線程的名字 setName() 設置此線程的名字 isAlive() 是判斷當前線程是否處于活動狀態。活動狀態就是已經啟動尚未終止。 curren…

MachineLearning(2)-圖像分類常用數據集

圖像分類常用數據集1 CIFAR-102.MNIST3.STL_104.Imagenet5.L-Sun6.caltech-101在訓練神經網絡進行圖像識別分類時&#xff0c;常會用到一些通用的數據集合。利用這些數據集合可以對比不同模型的性能差異。下文整理常用的圖片數據集合&#xff08;持續更新中)。基本信息對比表格…

Linux網絡編程實例詳解

本文介紹了在Linux環境下的socket編程常用函數用法及socket編程的一般規則和客戶/服務器模型的編程應注意的事項和常遇問題的解決方法&#xff0c;并舉了具體代 碼實例。要理解本文所談的技術問題需要讀者具有一定C語言的編程經驗和TCP/IP方面的基本知識。要實習本文的示例&…