【學習記錄】pytorch載入模型的部分參數

需要從PointNet網絡框架中提取encoder部分的參數,然后賦予自己的模型。因此,需要從一個已有的.pth文件讀取部分參數,加載到自定義模型上面。做了一些嘗試,記錄如下。

關于模型保存與載入

torch.save(): 使用Python的pickle實用程序將對象進行序列化,然后將序列化的對象保存到disk,可以保存各種對象,包括模型、張量和字典等。
torch.load(): 使用pickle unpickle工具將pickle的對象文件反序列化為內存。
可以看出,pth文件本質上是一個序列化的dict。

我們在save時,代碼如下:

state = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),
}

然后以下代碼load進來:

checkpoint = torch.load(args.model_file,  map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

查看checkpoint,可以看到包含的就是自己保存時的3個dict,分別是epoch,model_state_dict,和optimizer信息。
在這里插入圖片描述

這里我們重點關注 model_state_dict,數據類型是一個 OrderedDict,有序字典。展開如下:
在這里插入圖片描述
可以看到里面包含了自己定義的encoder,bn1-3,mlp 1-4層,以及每個層對應的參數(權重、bias,對于bn層還有mean, var等)。
這個Dict的順序就是在Model中我們定義的順序,這個和模型是一致的。
因此,如果載入時的模型和保存模型完全一致,直接用load_state_dict()就可以按順序把數據載入進來。但如,如果定義不同怎么辦?這就需要手動載入。

方法1:手動載入指定層的參數

從debug的斷點可以看到,每個參數就是存在dict中的一個tensor。因此,我們只要讀取對應的dict即可。
例如,encoder的conv1的權重,就是 checkpoint['model_state_dict']['encoder.conv1.weight'],那么我們在自己的模型對應的位置讀取這個dict即可。
具體載入方式如下:

# 定義模型
model = MyPointNetSegmentation(channel=3, get_feature=True, batch_size=1)
model.to('cpu')# 載入其他模型的參數
checkpoint = torch.load(model_file, map_location='cpu')
model_dict = checkpoint['model_state_dict']# 將其他模型的參數,賦值給自己模型對應參數
model.encoder.conv1.weight.data.copy_(model_dict['encoder.conv1.weight'])
model.encoder.conv1.bias.data.copy_(model_dict['encoder.conv1.bias'])

把所有有用的參數都賦值過來就好,但要注意參數對應的tensor維度是一樣的。
在這里插入圖片描述

方法2:一次性載入key值相同的參數

如果說兩個model的某些key值相同,可以用python的字典推導方式,將名稱相關的參數提取出來。例如:

def load_dict_from_pointnet(model : Point2VoxelNet, checkpoint):my_model_dict = model.state_dict()pretrained_dict =  checkpoint['model_state_dict']# 只將pretraind_dict中那些在model_dict中的參數,提取出來state_dict = {k:v for k,v in pretrained_dict.items() if k in my_model_dict .keys()}my_model_dict.update(state_dict)		# 注意要更新state的變量,如果直接賦值,會出現某些key沒有定義,導致運行失敗model.load_state_dict(my_model_dict)# 對比參數是否一致print(f"{checkpoint['model_state_dict']['feat.stn.conv1.weight'][1]}")print(f"{model.feat.stn.conv1.weight[1]}")return model

看到這里,可以知道如果自己的模型改了名稱,例如.pth的參數是:feat.stn.conv1,我這邊叫做了 encoder.stn.conv1,那么是無法直接賦值的。可以用方法1,一個個載入,但是太慢了。另一種方式,是做一個鍵值映射,如果讀到的是 feat.xxx,則賦予自定義模型中的 encoder.xxx ,簡單處理即可。

注意事項

  • conv層需要載入的參數有:weight 和 bias
  • BN層涉及的參數有:
    1. weight,bias
    2. running_mean,running_var:這兩個參數用于歸一化的均值和方差, 因此也需要載入
    3. num_batches_tracked:在訓練時需要載入,在test時不需要載入
  • 載入參數后,如果用于測試,需要調用 eval()。注意不能在載入參數前調用 eval。eval 會將 bn 層的training參數設置為 false ,這樣在測試時 batch_size 時如果是 1 也能夠正常運行。

測試

用默認方式載入參數,以及手動方式載入后的兩個模型,預測結果一致。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/74186.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/74186.shtml
英文地址,請注明出處:http://en.pswp.cn/web/74186.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【藍橋杯14天沖刺課題單】Day 8

1.題目鏈接:19714 數字詩意 這道題是一道數學題。 先考慮奇數,已知奇數都可以表示為兩個相鄰的數字之和,2k1k(k1) ,那么所有的奇數都不會被計入。 那么就需要考慮偶數什么情況需要被統計。根據打表,其實可以發現除了…

鴻蒙ArkTS開發:微信/系統來電通話監聽功能實現

本文將介紹如何在鴻蒙應用中使用ArkTS實現通話監聽和錄音功能,利用harmony-utils工具庫簡化開發流程。 工具庫地址 一、功能概述 本實現包含以下核心功能: 通話狀態監聽:檢測來電、去電和通話中狀態 音頻流監控:通過麥克風使用…

NFS 重傳次數速率監控

這張圖展示的是 NFS 重傳次數速率監控,具體解釋如下: 1. 指標含義 監控指標 node_nfs_rpc_retransmissions_total 統計 NFS(網絡文件系統)通信中 RPC(遠程過程調用)的重傳次數,rate(node_nfs_…

【 <二> 丹方改良:Spring 時代的 JavaWeb】之 Spring Boot 中的國際化:支持多語言的 RESTful API

<前文回顧> 點擊此處查看 合集 https://blog.csdn.net/foyodesigner/category_12907601.html?fromshareblogcolumn&sharetypeblogcolumn&sharerId12907601&sharereferPC&sharesourceFoyoDesigner&sharefromfrom_link <今日更新> 一、開篇整…

黑帽SEO之搜索引擎劫持-域名劫持原理分析

問題起源 這是在《Web安全深度剖析》的第二章“深入HTTP請求流程”的2.3章節“黑帽SEO之搜索引擎劫持”提到的內容&#xff0c;但是書中描述并不詳細&#xff0c;沒有講如何攻擊達到域名劫持的效果。 書中對SEO搜索引擎劫持的現象描述如下&#xff1a;直接輸入網站的域名可以進…

theos工具來編譯xcode的swiftUI項目為ipa文件

Theos 是一個開源的開發工具套件&#xff0c;主要用于為 iOS/macOS 平臺開發和編譯 越獄插件&#xff08;Tweaks&#xff09;、動態庫、命令行工具等。它由 Dustin Howett 創建&#xff0c;并被廣泛用于越獄社區的開發中。但這里我主要使用它的打包ipa功能&#xff0c;因為我的…

25.4.1學習總結【Java】

動態規劃題 2140. 解決智力問題https://leetcode.cn/problems/solving-questions-with-brainpower/ 給你一個下標從 0 開始的二維整數數組 questions &#xff0c;其中 questions[i] [pointsi, brainpoweri] 。 這個數組表示一場考試里的一系列題目&#xff0c;你需要 按順…

計算機網絡知識點匯總與復習——(二)物理層

Preface 計算機網絡是考研408基礎綜合中的一門課程&#xff0c;它的重要性不言而喻。然而&#xff0c;計算機網絡的知識體系龐大且復雜&#xff0c;各類概念、協議和技術相互關聯&#xff0c;讓人在學習時容易迷失方向。在進行復習時&#xff0c;面對龐雜的的知識點&#xff0c…

string的底層原理

一.構造函數 我們來看一下&#xff0c;string的底層就是一個字符型指針和一個size來表示string的大小&#xff0c;capacity來表示分配的內存大小。 我們來看我們注釋掉的第一個構造函數&#xff0c;我們是通過初始化列表來初始化size的大小&#xff0c;再通過size的大小來初始化…

Python FastAPI + Celery + RabbitMQ 分布式圖片水印處理系統

FastAPI 服務器Celery 任務隊列RabbitMQ 作為消息代理定時任務處理 首先創建項目結構&#xff1a; c:\Users\Administrator\Desktop\meitu\ ├── app/ │ ├── __init__.py │ ├── main.py │ ├── celery_app.py │ ├── tasks.py │ └── config.py…

【藍橋杯】每日練習 Day18

目錄 前言 動態求連續區間和 分析 代碼 數星星 分析 代碼 星空之夜 分析 代碼 前言 接下來是今天的題目&#xff08;本來是有四道題的但是有一道題是前面講過&#xff08;逆序數的&#xff0c;感興趣的小伙伴可以去看我歸并排序的那一篇&#xff09;的我就不再過多贅…

基于銀河麒麟桌面服務器操作系統的 DeepSeek本地化部署方法【詳細自用版】

一、3種方式使用DeepSeek 1.本地部署 服務器操作系統環境進行,具體流程如下(桌面環境步驟相同): 本例所使用銀河麒麟高級服務器操作系統版本信息: (1)安裝ollama 方式一:按照ollama官網的下載指南,執行如下命令: curl -fsSL https://ollama.com/install.sh | sh方…

Python入門(7):Python序列結構-字典

字典Dictionary 字典(dictionary)和列表類似&#xff0c;也是可變序列&#xff0c;不過與列表不同&#xff0c;它是無序的可變序列&#xff0c;保存的為容是以“鍵-值對”的形式存放的。 Python 中的字典相當于 Java 或者 C中的 Map 對象。在C#中,就是Dictionary<TKey,TVa…

Flutter項目之構建打包分析

目錄&#xff1a; 1、準備部分2、構建Android包2.1、配置修改部分2.2、編譯打包 3、構建ios包3.1、配置修改部分3.2、編譯打包 1、準備部分 2、構建Android包 2.1、配置修改部分 2.2、編譯打包 執行flutter build apk命令進行打包。 3、構建ios包 3.1、配置修改部分 3.2、編譯…

不用再付費~全網書源一鍵下載,實現閱讀自由!!!

現在市面上有許多免費你看書的軟件&#xff0c;但都軟件內太多廣告彈窗&#xff0c;這無疑是很煩&#xff0c;有事一不小心點進去就下載了軟件&#xff0c;簡直讓人頭大&#xff01; 如果你遇到這樣的難題那么就應該看下本文~ 這是一款能一鍵將在線連載小說整合下載成標準格式&…

GCC RISCV 后端 -- GIMPLE IR 表示的一些理解

C/C源代碼經過 GCC 解析&#xff08;Parse&#xff09;及轉換后&#xff0c;通過 GIMPLE IR 予以表示&#xff08;Representation&#xff09;。其中&#xff0c;一個C/C源文件&#xff0c;通過 宏處理后&#xff0c;形成一個 轉譯單元&#xff08;Translation Unit&#xff09…

JAVA設計模式之適配器模式《太白金星有點煩》

太白金星握著月光凝成的鼠標&#xff0c;第108次檢查南天門服務器的運行日志。這個剛從天樞院調來的三等仙官&#xff0c;此刻正盯著瑤池主機房里的青銅鼎發愁——鼎身上"天地同壽"的云紋間&#xff0c;漂浮著三界香火系統每分鐘吞吐的十萬條功德數據。看著居高不下的…

以太坊DApp開發腳手架:Scaffold-ETH 2 詳細介紹與搭建教程

一、什么是Scaffold-ETH 2 Scaffold-ETH 2是一個開源的最新工具包&#xff0c;類似于腳手架。用于在以太坊區塊鏈上構建去中心化應用程序 &#xff08;DApp&#xff09;。它旨在使開發人員更容易創建和部署智能合約&#xff0c;并構建與這些合約交互的用戶界面。 Scaffold-ETH…

畢業設計:實現一個基于Python、Flask和OpenCV的人臉打卡Web系統(六)

畢業設計:實現一個基于Python、Flask和OpenCV的人臉打卡Web系統(六) Flask Flask是一個使用 Python 編寫的輕量級 Web 應用框架。其 WSGI 工具箱采用 Werkzeug ,模板引擎則使用 Jinja2 。Flask使用 BSD 授權。 Flask也被稱為 “microframework” ,因為它使用簡單的核心,…

第十一章 VGA顯示圖片(還不會)

FPGA至簡設計實例 前言 一、項目背景 1. IP核概述 IP 核(Intellectual Property core)指的是知識產權核或知識產權模塊,其是具有特定電路功能的硬件描述語言程序,在EDA技術開發中具有十分重要的地位。美國著名的Dataquest咨詢公司將 半導體產業的IP定義為“用于ASIC或FPGA…