2025-05-31 Python深度學習9——網絡模型的加載與保存

文章目錄

  • 1 使用現有網絡
  • 2 修改網絡結構
    • 2.1 添加新層
    • 2.2 替換現有層
  • 3 保存網絡模型
    • 3.1 完整保存
    • 3.2 參數保存(推薦)
  • 4 加載網絡模型
    • 4.1 加載完整模型文件
    • 4.2 加載參數文件
  • 5 Checkpoint
    • 5.1 保存 Checkpoint
    • 5.2 加載 Checkpoint

本文環境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

? PyTorch 通過torchvision.models提供預訓練模型(如 VGG16)。

? 網址鏈接:https://docs.pytorch.org/vision/stable/models.html。

1 使用現有網絡

? 以 VGG16 為例,進入網址:https://docs.pytorch.org/vision/stable/models/generated/torchvision.models.vgg16.html#torchvision.models.vgg16。

image-20250531103635500

方法一:使用隨機初始化權重

? 將 weights 設置為 None,從 0 開始訓練自己的網絡。

vgg16_false = torchvision.models.vgg16(weights=None)  # 權重隨機初始化

方法二:加載預訓練權重

? 也可以使用預訓練好的網絡參數,加載后可直接使用網絡。
這將從官網上下載已訓練好的模型文件。

vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)

? 可打印網絡查看其模型結構:

print(vgg16_true)
image-20250531104433678
...
image-20250531104447912

2 修改網絡結構

2.1 添加新層

? 使用add_module在分類器(classifier)后追加全連接層:

vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
image-20250531104536554

2.2 替換現有層

? 直接修改分類器的最后一層(如適配 CIFAR10 的 10 分類任務):

vgg16_false.classifier[6] = nn.Linear(4096, 10)  # 替換第6層
image-20250531104551228

3 保存網絡模型

? 使用torch.save()方法保存網絡模型。文件擴展名推薦使用.pt.pth

3.1 完整保存

? 將模型類和參數一并保存到文件中。

torch.save(vgg16, 'vgg16_method1.pth')  # 包含模型類和參數
  • 優點:加載時無需重新定義模型結構。
  • 缺點:文件較大,且依賴原始代碼環境(見 4.1 節)。

3.2 參數保存(推薦)

? 僅保存參數字典到文件中。

torch.save(vgg16.state_dict(), 'vgg16_method2.pth')  # 僅保存參數字典
  • 優點:文件小,靈活性強,適合生產部署。

示例

import torch
import torchvision.models
from torch import nnvgg16 = torchvision.models.vgg16(weights=None)# 保存方式 1,模型結構 + 模型參數
torch.save(vgg16, 'vgg16_method1.pth')# 保存方式 2,模型參數(官方推薦)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')

4 加載網絡模型

? 使用torch.load()方法加載網絡模型。

4.1 加載完整模型文件

? 加載完整模型時,需將 weights_only 參數設置為 False。

model = torch.load('vgg16_method1.pth', weights_only=False)  # 需確保模型類已定義

? 模型打印結果如下:

print(model)
image-20250531111142517

注意

? 若保存自定義模型,加載時必須確保環境中也有該模型的定義,否則會出現報錯。

  • model_save.py

    # model_save.pyimport torch
    from torch import nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3)def forward(self, x):return self.conv1(x)model = MyModel()
    torch.save(model, 'my_model_method1.pth')
    
  • model_load.py

    import torchmodel = torch.load('my_model_method1.pth', weights_only=False)  # 報錯,找不到 MyModel 的定義
    

    先運行 model_save.py,再運行 model_load.py,則會出現以下報錯:

image-20250531110244566

?

4.2 加載參數文件

? 首先,使用torch.load()方法加載網絡模型。

? 使用模型時,需先創建匹配的網絡結構,再使用model.load_state_dict()加載參數數據。

vgg16 = torchvision.models.vgg16(weights=None)
model_dict = torch.load('vgg16_method2.pth')
vgg16.load_state_dict(model_dict)  # 需結構匹配

? 模型打印結果是參數字典:

print(model_dict)
image-20250531111411199

注意

? 模型保存時若在 GPU 上,加載時需指定 map_location 為 cup。

torch.load('model.pth', map_location=torch.device('cpu'))

? 將參數加載到模型后,手動遷移到 GPU:

model = MyModel()
model.load_state_dict(model_dict)
model.to('cuda:0')

5 Checkpoint

? 使用 Checkpoint 可以在訓練過程中定期保存模型的狀態,以便在中斷后可以恢復訓練,或者在測試時使用最終的模型。文件擴展名推薦使用.tar

5.1 保存 Checkpoint

? 要保存一個模型的 Checkpoint,通常需要保存以下數據:

  • 模型的 state_dict(狀態字典);
  • 優化器的狀態;
  • 額外的信息,如 epoch 等。
import torch# 假設 model 是你的模型,optimizer 是你的優化器
checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss
}# 保存checkpoint
torch.save(checkpoint, 'checkpoint.tar')

5.2 加載 Checkpoint

? 加載 Checkpoint,首先需要加載文件,然后將其內容恢復到模型和優化器的狀態中。

# 假設 model 和 optimizer 是你的模型和優化器實例
checkpoint = torch.load('checkpoint.tar')model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']# 如果需要,可以繼續訓練
model.train()  # 確保模型處于訓練模式

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/907924.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/907924.shtml
英文地址,請注明出處:http://en.pswp.cn/news/907924.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

批量導出CAD屬性塊信息生成到excel——CAD C#二次開發(插件實現)

本插件可實現批量導出文件夾內大量dwg文件的指定塊名的屬性信息到excel,效果如下: 插件界面: dll插件如下: 使用方法: 1、獲取此dll插件。 2、cad命令行輸入netload ,加載此dll(要求AutoCAD&…

在Linux環境里面,Python調用C#寫的動態庫,如何實現?

在Linux環境中,Python可以通過pythonnet(CLR的Python綁定)或subprocess調用C#動態庫。以下是兩種方法的示例: 方法1:使用pythonnet(推薦) 前提條件 安裝Mono或.NET Core運行時安裝pythonnet包…

小程序跳轉H5或者其他小程序

1. h5跳轉小程序有兩種情況 &#xff08;1&#xff09;從普通瀏覽器打開的h5頁面跳轉小程序使用wx-open-launch-weapp可以實現h5跳轉小程序 <wx-open-launch-weappstyle"display:block;"v-elseid"launch-btn":username"wechatYsAppid":path…

性能優化 - 案例篇:緩沖區

文章目錄 Pre1. 引言2. 緩沖概念與類比3. Java I/O 中的緩沖實現3.1 FileReader vs BufferedReader&#xff1a;裝飾者模式設計3.2 BufferedInputStream 源碼剖析3.2.1 緩沖區大小的權衡與默認值 4. 異步日志中的緩沖&#xff1a;Logback 異步日志原理與配置要點4.1 Logback 異…

文檔整合自動化

主要功能是按照JSON文件&#xff08;Sort.json&#xff09;中指定的順序合并多個Word文檔&#xff08;.docx&#xff09;&#xff0c;并清除文檔中的所有超鏈接。最終輸出合并后的文檔名為"sorted_按章節順序.docx"。 主要分為幾個部分&#xff1a; 初始化配置 定…

嵌入式(C語言篇)Day13

嵌入式Day13 一段話總結 文檔主要介紹帶有頭指針和尾指針的單鏈表的實現及操作&#xff0c;涵蓋創建、銷毀、頭插、尾插、按索引/數據增刪查、遍歷等核心操作&#xff0c;強調頭插/尾插時間復雜度為O(1)&#xff0c;按索引/數據操作需遍歷鏈表、時間復雜度為O(n)&#xff0c;并…

【ASR】基于分塊非自回歸模型的流式端到端語音識別

論文地址:https://arxiv.org/abs/2107.09428 摘要 非自回歸 (NAR) 模型在語音處理中越來越受到關注。 憑借最新的基于注意力的自動語音識別 (ASR) 結構,與自回歸 (AR) 模型相比,NAR 可以在僅精度略有下降的情況下實現有前景的實時因子 (RTF) 提升。 然而,識別推理需要等待…

RNN循環網絡:給AI裝上“記憶“(superior哥AI系列第5期)

&#x1f504; RNN循環網絡&#xff1a;給AI裝上"記憶"&#xff08;superior哥AI系列第5期&#xff09; 嘿&#xff01;小伙伴們&#xff0c;又見面啦&#xff01;&#x1f44b; 上期我們學會了讓AI"看懂"圖片&#xff0c;今天要給AI裝上一個更酷的技能——…

DAY41 CNN

可以看到即使在深度神經網絡情況下&#xff0c;準確率仍舊較差&#xff0c;這是因為特征沒有被有效提取----真正重要的是特征的提取和加工過程。MLP把所有的像素全部展平了&#xff08;這是全局的信息&#xff09;&#xff0c;無法布置到局部的信息&#xff0c;所以引入了卷積神…

【仿生系統】愛麗絲機器人的設想(可行性優先級較高)

非程序化、能夠根據環境和交互動態產生情感和思想&#xff0c;并以微妙、高級的方式表達出來的能力 我們不想要一個“假”的智能&#xff0c;一個僅僅通過if-else邏輯或者簡單prompt來模擬情感的機器人。您追求的是一種更深層次的、能夠學習、成長&#xff0c;并形成獨特“個性…

面向連接的運輸:TCP

目錄 TCP連接 TCP報文段結構 往返時間估計與超時 可靠數據傳輸 回退N步or超時重傳 超時間隔加倍 快速重傳 流量控制 TCP連接管理 三次握手 1. 客戶端 → 服務器&#xff1a;SYN 包 2. 服務器 → 客戶端&#xff1a;SYNACK 包 3. 客戶端 → 服務器&#xff1a;AC…

SpringAI系列 - 升級1.0.0

目錄 一、調整pom二、MessageChatMemoryAdvisor調整三、ChatMemory get方法刪除lastN參數四、QuestionAnswerAdvisor調整Spring AI發布1.0.0正式版了?? ,搞起… 一、調整pom <properties><java.version>17</java.version><spring-ai.version>

前端高頻面試題2:JavaScript/TypeScript

1.什么是類數組對象 一個擁有 length 屬性和若干索引屬性的對象就可以被稱為類數組對象&#xff0c;類數組對象和數組類似&#xff0c;但是不能調用數組的方法。常見的類數組對象有 arguments 和 DOM 方法的返回結果&#xff0c;還有一個函數也可以被看作是類數組對象&#xff…

Spring Security入門:創建第一個安全REST端點項目

項目初始化與基礎配置 創建基礎Spring Boot項目 我們首先創建一個名為ssia-ch2-ex1的空項目(該名稱與配套源碼中的示例項目保持一致)。項目需要添加以下兩個核心依賴: org.springframework.bootspring-boot-starter-weborg.springframework.bootspring-boot-starter-secur…

秋招Day12 - 計算機網絡 - UDP

說說TCP和UDP的區別&#xff1f; TCP使用無邊界的字節流傳輸&#xff0c;可能發生拆包和粘包&#xff0c;接收方并不知道數據邊界&#xff1b;UDP采用數據報傳輸&#xff0c;數據報之間相互獨立&#xff0c;有邊界。 應用場景方面&#xff0c;TCP適合對數據的可靠性要求高于速…

【QQ音樂】sign簽名| data參數加密 | AES-GCM加密 | webpack (下)

1.目標 網址&#xff1a;https://y.qq.com/n/ryqq/toplist/26 我們知道了 sign P(n.data)&#xff0c;其中n.data是明文的請求參數 2.webpack生成data加密參數 那么 L(n.data)就是密文的請求參數。返回一個Promise {<pending>}&#xff0c;所以L(n.data) 是一個異步函數…

Codeforces Round 1028 (Div. 2)(A-D)

題面鏈接&#xff1a;Dashboard - Codeforces Round 1028 (Div. 2) - Codeforces A. Gellyfish and Tricolor Pansy 思路 要知道騎士如果沒了那么這個人就失去了攻擊手段&#xff0c;貪心的來說我們只需要攻擊血量少的即可&#xff0c;那么取min比較一下即可 代碼 void so…

【存儲基礎】存儲設備和服務器的關系和區別

文章目錄 1. 存儲設備和服務器的區別2. 客戶端訪問數據路徑場景1&#xff1a;經過服務器處理場景2&#xff1a;客戶端直連 3. 服務器作為"中轉站"的作用 剛開始接觸存儲的時候&#xff0c;以為數據都是存放在服務器上的&#xff0c;服務器和存儲設備是一個東西&#…

macOS 安裝 Grafana + Prometheus + Node Exporter

macOS 安裝指南&#xff1a;Grafana Prometheus Node Exporter 目錄簡介&#x1f680; 快速開始 安裝 Homebrew1. 安裝 Homebrew2. 更新 Homebrew 安裝 Node Exporter使用 Homebrew 安裝驗證 Node Exporter 安裝 Prometheus使用 Homebrew 安裝驗證安裝 安裝 Grafana使用 Home…

不可變集合類型轉換異常

記錄一個異常&#xff1a;class java.util.ImmutableCollections$ListN cannot be cast to class java.util.ArrayList (java.util.ImmutableCollections$ListN and java.util.ArrayList 文章目錄 1、原因2、解決方式一3、解決方式二4、關于不可變集合的補充4.1 JDK8和9的對比4…