pytorch小記(七):pytorch中的保存/加載模型操作

pytorch小記(七):pytorch中的保存/加載模型操作

  • 1. 加載模型參數 (`state_dict`)
    • 1.1 保存模型參數
    • 1.2 加載模型參數
    • 1.3 常見變種
      • 1.3.1 指定加載設備
      • 1.3.2 非嚴格加載(跳過部分層)
      • 1.3.3 打印加載的參數
  • 2. 加載整個模型
    • 2.1 保存整個模型
    • 2.2 加載整個模型
    • 2.3 注意事項
  • 3. 總結
  • 4. 加載模型的完整代碼示例
    • 4.1 保存和加載參數
    • 4.2 保存和加載整個模型
    • 4.3 加載到不同設備
    • 4.4 忽略部分參數(非嚴格加載)
    • 5. 檢查模型是否加載成功


在 PyTorch 中,加載模型通常分為兩種情況:加載模型參數(state_dict)加載整個模型。以下是加載模型的所有相關操作及其詳細步驟:


1. 加載模型參數 (state_dict)

當僅保存了模型的參數時(使用 model.state_dict() 保存),加載模型的步驟如下:

1.1 保存模型參數

torch.save(model.state_dict(), 'model.pth')
  • 文件內容:只保存模型的參數(權重和偏置)。
  • 優點
    • 節省存儲空間。
    • 靈活性更高,可以與不同的模型架構配合使用。
  • 缺點
    • 需要手動重新定義模型結構。

1.2 加載模型參數

  1. 重新定義模型架構:

    model = MyModel()  # 替換為你的模型類
    
  2. 加載參數:

    state_dict = torch.load('model.pth')  # 加載參數字典
    model.load_state_dict(state_dict)    # 加載參數到模型
    
  3. 選擇運行設備:

    model.to('cuda')  # 如果需要運行在 GPU 上
    

1.3 常見變種

1.3.1 指定加載設備

  • 如果保存時模型在 GPU 上,而加載時在 CPU 環境中,可以使用 map_location
    state_dict = torch.load('model.pth', map_location='cpu')
    

1.3.2 非嚴格加載(跳過部分層)

  • 如果保存的參數與模型結構不完全匹配(例如額外的層或不同的順序),可以使用 strict=False
    model.load_state_dict(state_dict, strict=False)
    

1.3.3 打印加載的參數

  • 可以檢查參數字典的內容:
    print(state_dict.keys())
    

2. 加載整個模型

當模型是通過 torch.save(model) 保存時,文件包含了模型的結構和參數,加載更為簡單。

2.1 保存整個模型

torch.save(model, 'model_full.pth')
  • 文件內容:包含模型的架構和參數。
  • 優點
    • 無需重新定義模型結構。
    • 直接加載并使用。
  • 缺點
    • 文件依賴于保存時的代碼版本(如模型定義)。
    • 文件體積較大。

2.2 加載整個模型

model = torch.load('model_full.pth')
model.to('cuda')  # 如果需要在 GPU 上運行

2.3 注意事項

  • 動態定義的模型
    • 如果模型結構是動態定義的(如包含條件邏輯),保存和加載整個模型可能會依賴于代碼的一致性。
    • 確保在加載時導入了與保存時相同的模型類。

3. 總結

操作使用場景優點缺點
保存參數 (state_dict)推薦大多數情況文件小、靈活性高需要手動定義模型架構
保存整個模型模型復雜且固定時不需要重新定義模型,直接加載文件大、依賴保存時的代碼版本

4. 加載模型的完整代碼示例

4.1 保存和加載參數

import torch
import torch.nn as nn# 定義模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)# 保存參數
model = MyModel()
torch.save(model.state_dict(), 'model.pth')# 加載參數
model = MyModel()  # 重新定義模型
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
model.to('cuda')  # 運行在 GPU

4.2 保存和加載整個模型

# 保存整個模型
torch.save(model, 'model_full.pth')# 加載整個模型
model = torch.load('model_full.pth')
model.to('cuda')  # 運行在 GPU

4.3 加載到不同設備

# 保存參數
torch.save(model.state_dict(), 'model.pth')# 加載到 CPU
state_dict = torch.load('model.pth', map_location='cpu')
model.load_state_dict(state_dict)# 加載到 GPU
model.to('cuda')

4.4 忽略部分參數(非嚴格加載)

# 保存參數
torch.save(model.state_dict(), 'model.pth')# 加載參數(非嚴格模式)
model = MyModel()
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict, strict=False)

5. 檢查模型是否加載成功

  1. 驗證權重是否加載

    for name, param in model.named_parameters():print(f"{name}: {param.data}")
    
  2. 進行推理驗證

    x = torch.randn(1, 10).to('cuda')  # 假設輸入維度為 10
    output = model(x)
    print(output)
    

通過以上操作,你可以靈活加載 PyTorch 模型,無論是僅加載參數還是加載整個模型結構和權重。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/65829.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/65829.shtml
英文地址,請注明出處:http://en.pswp.cn/web/65829.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Mysql--運維篇--主從復制和集群(主從復制I/O線程,SQL線程,二進制日志,中繼日志,集群NDB)

一、主從復制 MySQL的主從復制(Master-Slave Replication)是一種數據冗余和高可用性的解決方案,它通過將一個或多個從服務器(Slave)與主服務器(Master)同步來實現。主從復制的基本原理是&#…

【EI會議征稿通知】第十一屆機械工程、材料和自動化技術國際會議(MMEAT 2025)

本次大會旨在匯聚全球機械工程、材料科學及自動化技術的創新學者和行業專家,為他們提供一個卓越的交流與合作平臺。隨著全球對可持續技術和智能制造需求的不斷增加,MMEAT 2025將重點關注這些領域的最新發展趨勢和未來前景。此次大會的主要目標是推動機械…

OpenCV基礎:視頻的采集、讀取與錄制

從攝像頭采集視頻 相關接口 - VideoCapture VideoCapture 用于從視頻文件、攝像頭或其他視頻流設備中讀取視頻幀。它可以捕捉來自多種源的視頻。 主要參數: cv2.VideoCapture(source): source: 這是一個整數或字符串,表示視頻的來源。 如果是整數&a…

解讀Linux Bridge中的東西流向與南北流向

解讀Linux Bridge中的東西流向與南北流向 在現代云計算和虛擬化環境中,網絡流量的管理和優化變得越來越重要。Linux Bridge作為Linux內核提供的一個強大的二層交換機工具,在虛擬化和容器化應用中扮演著至關重要的角色。本文將深入探討Linux Bridge中的兩…

在線實用工具 json格式化,base64轉碼,正則表達式測試工具

1、在線json格式化工具: https://json.openai2025.com/ 2、在線base64轉碼工具 https://base64.openai2025.com/ 3、在線正則表達式測試工具 https://reg.openai2025.com/ 4、在線去水印工具 https://watermark.openai2025.com

java 中 main 方法使用 KafkaConsumer 拉取 kafka 消息如何禁止輸出 debug 日志

pom 依賴&#xff1a; <dependency><groupId>org.springframework.kafka</groupId><artifactId>spring-kafka</artifactId><version>2.5.14.RELEASE</version> </dependency> 或者 <dependency><groupId>org.ap…

車聯網安全--TLS握手過程詳解

目錄 1. TLS協議概述 2. 為什么要握手 2.1 Hello 2.2 協商 2.3 同意 3.總共握了幾次手&#xff1f; 1. TLS協議概述 車內各ECU間基于CAN的安全通訊--SecOC&#xff0c;想必現目前多數通信工程師們都已經搞的差不多了&#xff08;不要再問FvM了&#xff09;&#xff1b;…

RuoYi Cloud項目解讀【四、項目配置與啟動】

四、項目配置與啟動 當上面環境全部準備好之后&#xff0c;接下來就是項目配置。需要將項目相關配置修改成當前相關環境。 1 后端配置 1.1 數據庫 創建數據庫ry-cloud并導入數據腳本ry_2024xxxx.sql&#xff08;必須&#xff09;&#xff0c;quartz.sql&#xff08;可選&…

C#對象池

一、資源管理的困境與破局 在軟件開發的征程中&#xff0c;我們時常陷入資源管理的泥沼。以一個繁忙餐廳為例&#xff0c;每個顧客都急需一個盤子盛美食&#xff0c;可盤子數量有限&#xff0c;如果每次顧客用完盤子后&#xff0c;都不假思索地去清洗一個全新的盤子來供下一位…

Vue.js組件開發-如何使用moment.js

在Vue.js組件開發中&#xff0c;需要處理日期和時間&#xff0c;moment.js 是一個非常有用的庫。moment.js 提供了豐富的API來解析、驗證、操作和顯示日期和時間。 步驟&#xff1a; 1. 安裝moment.js 首先&#xff0c;需要通過npm或yarn安裝moment.js。在項目根目錄下運行以…

微信小程序mp3音頻播放組件,僅需傳入url即可

// index.js // packageChat/components/audio-player/index.js Component({/*** 組件的屬性列表*/properties: {/*** MP3 文件的 URL*/src: {type: String,value: ,observer(newVal, oldVal) {if (newVal ! oldVal && newVal) {// 如果 InnerAudioContext 已存在&…

要避免除數絕對值遠遠小于被除數絕對值的除法

要避免除數絕對值遠遠小于被除數絕對值的除法 用絕對值小的數作除數&#xff0c;舍人誤差會增大&#xff0c;如計算 x y \frac xy yx?,若 0 < ∣ y ∣ < ∣ x ∣ 0<|y|<|x| 0<∣y∣<∣x∣&#xff0c;則可能對計算結果帶來嚴重影響&#xff0c;應盡量避免…

深入了解OpenStack中的隧道網絡

在OpenStack環境中&#xff0c;隧道網絡是一項關鍵技術&#xff0c;它確保了虛擬機之間以及虛擬機與外部網絡之間的安全通信。通過隧道機制&#xff0c;我們可以有效地隔離不同租戶的流量&#xff0c;并支持多租戶環境下的復雜網絡需求。之前我們介紹了隧道網絡&#xff0c;下面…

4. scala高階之隱式轉換與泛型

背景 上一節&#xff0c;我介紹了scala中的面向對象相關概念&#xff0c;還有一個特色功能&#xff1a;模式匹配。本文&#xff0c;我會介紹另外一個特別強大的功能隱式轉換&#xff0c;并在最后介紹scala中泛型的使用 1. 隱式轉換 Scala提供的隱式轉換和隱式參數功能&#…

pandas與sql對應關系【幫助sql使用者快速上手pandas】

本頁旨在提供一些如何使用pandas執行各種SQL操作的示例&#xff0c;來幫助SQL使用者快速上手使用pandas。 目錄 SQL語法一、選擇SELECT1、選擇2、添加計算列 二、連接JOIN ON1、內連接2、左外連接3、右外連接4、全外連接 三、過濾WHERE1、AND2、OR3、IS NULL4、IS NOT NULL5、B…

第432場周賽:跳過交替單元格的之字形遍歷、機器人可以獲得的最大金幣數、圖的最大邊權的最小值、統計 K 次操作以內得到非遞減子數組的數目

Q1、跳過交替單元格的之字形遍歷 1、題目描述 給你一個 m x n 的二維數組 grid&#xff0c;數組由 正整數 組成。 你的任務是以 之字形 遍歷 grid&#xff0c;同時跳過每個 交替 的單元格。 之字形遍歷的定義如下&#xff1a; 從左上角的單元格 (0, 0) 開始。在當前行中向…

《探索鴻蒙Next上開發人工智能游戲應用的技術難點》

在科技飛速發展的當下&#xff0c;鴻蒙Next系統為應用開發帶來了新的機遇與挑戰&#xff0c;開發一款運行在鴻蒙Next上的人工智能游戲應用更是備受關注。以下是在開發過程中可能會遇到的一些技術難點&#xff1a; 鴻蒙Next系統適配性 多設備協同&#xff1a;鴻蒙Next的一大特色…

Harry技術添加存儲(minio、aliyun oss)、短信sms(aliyun、模擬)、郵件發送等功能

Harry技術添加存儲&#xff08;minio、aliyun oss&#xff09;、短信sms&#xff08;aliyun、模擬&#xff09;、郵件發送等功能 基于SpringBoot3Vue3前后端分離的Java快速開發框架 項目簡介&#xff1a;基于 JDK 17、Spring Boot 3、Spring Security 6、JWT、Redis、Mybatis-P…

Vue2: el-table為每一行添加超鏈接,并實現光標移至文字上時改變形狀

為表格中的某一列添加超鏈接 一個表格通常有許多列,網上許多教程都可以實現為某一列添加超鏈接,如下,實現了當光標懸浮在“姓名”上時,改變為手形,點擊可實現跳轉。 <el-table :data="tableData"><el-table-column label="姓名" prop=&quo…

R數據分析:多分類問題預測模型的ROC做法及解釋

有同學做了個多分類的預測模型,結局有三個類別,做的模型包括多分類邏輯回歸、隨機森林和決策樹,多分類邏輯回歸是用ROC曲線并報告AUC作為模型評估的,后面兩種模型報告了混淆矩陣,審稿人就提出要統一模型評估指標。那么肯定是統一成ROC了,剛好借這個機會給大家講講ROC在多…