YOLOv8 剪枝模型加載踩坑記:解決 YAML 覆蓋剪枝結構的問題

1. 問題背景

模型剪枝是實現模型輕量化、加速推理的關鍵步驟。然而,在 Ultralytics YOLOv8 的生態中,在成功剪枝后,進行微調(Fine-tuning)時會遇到一個令人困惑的現象:明明加載的是剪枝后的模型(例如 20M 參數),但訓練啟動時打印的日志卻顯示為標準版模型的參數(例如 25M)。并且經過驗證,微調后的模型參數就是標準的yolo模型。

加載代碼如下:

    model = YOLO("pruned.pt")     # load a pretrained model (recommended for training)model.train(data=name_yaml, device=0, imgsz=640, epochs=50, batch=32, workers=16, name=path_fineturn)  # train the model

原因是Ultralytics 的 Trainer 仍會先依據 原始 YAML 構建標準結構(約 25M 參數)。隨后僅將 .pt 文件中的權重加載到這張標準結構中。


2. 代碼觸發點與根本原因

問題的根源在于 Ultralytics 的 Trainer 在初始化模型時(get_model 方法)的執行順序。

ultralytics/engine/model.py中的Model類的train()方法中,原始代碼如下:

self.trainer.get_model 方法的執行流程如下:

  • 優先使用 cfg 參數構建模型:該參數接收 cfg=self.model.yaml。由于 pruned.pt 在保存時不會自動更新其內部的 YAML 配置(?model = YOLO("pruned.pt")會構造出一個實例,里面的self.model有很多屬性,其中self.model.model是模型網絡,這是真正的、由網絡層構成的可執行實體。我們的剪枝操作直接修改了這個對象,比如減少了某些卷積層的通道數,從而改變了它的實際結構self.model.yaml是配置文件,剪枝時只修改了self.model.model,沒有更新原始的self.model.yaml),所以這里的 self.model.yaml 仍然是標準版 YOLOv8m 的網絡結構

  • 創建標準結構并打印摘要get_model?會立即執行?model = DetectionModel(cfg)?通過self.model.yaml來構建一個完整的未剪枝模型(25.8M)。隨后調用?model.info()?方法,這就是日志中顯示"標準版"摘要的原因。完成標準結構創建后,get_model?才會處理 weights 參數,將 pruned.pt 中的權重加載到剛創建的標準結構中。PyTorch 的?load_state_dict?會按照名稱和形狀匹配的原則加載對應層的權重,跳過不匹配的層,此時模型仍保持標準骨架結構。


3. 改進寫法(實際切換到剪枝后結構)

為了解決這個問題,我們必須在 Trainer 開始訓練前,確保其內部持有的模型對象是我們剪枝后的那一個。

將代碼調整為:

        if not args.get("resume"):  # manually set model only if not resumingself.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)# ★ 關鍵修正:用我們剪枝后的模型對象,替換掉 Trainer 內部剛剛由 YAML 創建的模型self.trainer.model.model = self.model.modelprint("\n--- Verifying model after swapping in Trainer ---")# 打印替換后的模型參數量params_after_swap = sum(p.numel() for p in self.trainer.model.model.parameters()) / 1e6print(f"Parameters inside trainer: {params_after_swap:.2f}M\n")  # 應顯示約 20.8Mself.model = self.trainer.modelif SETTINGS["hub"] is True and not self.session:

  • 依然允許 get_model 按部就班地完成它的初始化流程(包括打印那條“誤導性”的日志)。

  • 但在這之后,立即通過 self.trainer.model.model = self.model.model 這行代碼,強行將 Trainer 內部的 nn.Module 對象替換為我們真正的、剪枝后的模型 (self.model.model)

  • 啟動階段的日志已打印過標準版結構,因此顯示上仍是標準參數量,但通過打印替換后的模型對象的參數量可以看到已經替換為剪枝后的模型

深度解析:為什么是替換?.model.model 而不是 .model
  1. yolo.model 對象 (DetectionModelBaseModel 的實例)
    它是一個“功能完備的檢測器”,不僅包含了網絡結構,還封裝了與之相關的元數據和方法(如 .train(), .info(), .yaml 等)。把它理解為一個高級接口

  2. yolo.model.model 對象 (純 nn.Module 實例)
    這才是我們通常意義上所說的PyTorch 模型網絡。它是一個純粹的 torch.nn.Module 子類,由各種網絡層搭建而成。我們的剪枝操作,直接修改的就是這個對象。

為什么不寫成 self.trainer.model = self.model

  • 源(Source)self.model.model 是我們從加載的 pruned.pt 中取出的、那個已經被剪枝過的純粹網絡結構

  • 目標(Destination)self.trainer.model.modelTrainer 內部那個標準結構的純粹網絡

self.trainer.model 是一個高級的 BaseModel 對象,Trainer 在初始化時已經對其進行了一些配置(如設備分配等)。如果我們用self.trainer.model = self.model整個地替換掉它,可能會破壞這些已經完成的設置,存在潛在風險。只替換最底層的 nn.Module,既能保證網絡結構正確,又不會干擾 Trainer 的其他工作流程。
注意替換模型必須在self.trainer.model構建好之后,如果直接使用self.trainer.model.model = self.model.model會顯示self.trainer.model是個str,還不是對象。

4. 顯示不一致的原因

  • Summary 打印時機get_model 在構建標準結構后立即輸出層數與參數量。

  • 結構替換發生在 summary 之后:沒有重新打印,因此日志沒有更新為剪枝后的參數量

  • 保存階段:調用 model.save()torch.save({'model': ...}) 時,寫入的是替換后的剪枝模型對象,所以最終 .pt 文件尺寸/參數量正確

5. 驗證流程建議

為了確保操作是正確的,最好進行驗證。

步驟 1:驗證初始剪枝模型
在開始微調訓練前,先確認?pruned.pt 是真的被剪枝了。

from ultralytics import YOLO
initial_model = YOLO("pruned.pt")
print("--- Verifying initial pruned model ---")
initial_model.model.info(verbose=False)  # 應顯示約 20.8M 參數

步驟 2:在替換后立即驗證
在修正代碼的核心行之后,立刻加入打印驗證,就是之前的代碼。

# ...
self.trainer.model.model = self.model.model
print("\n--- Verifying model after swapping in Trainer ---")
# 打印替換后的模型參數量
params_after_swap = sum(p.numel() for p in self.trainer.model.model.parameters()) / 1e6
print(f"Parameters inside trainer: {params_after_swap:.2f}M\n") # 應顯示約 20.8M

步驟 3:驗證最終保存的模型
訓練結束后,加載最終生成的權重文件,再次確認。

final_model = YOLO("runs/train/exp/weights/last.pt")
print("--- Verifying final saved model ---")
final_model.model.info() # 應顯示約 20.8M 參數

結果如圖:

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/90937.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/90937.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/90937.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

js的學習1

1.數組 數組方法 push()數組尾部添加unshift()數組頭部添加pop()數組尾部刪除shift()數組頭部刪除splice(起始位置,刪除幾個元素,要替換的元素)刪除指定的元素,改變了原數組,返回值是被刪除的元素indexOf()第一次查到的索引&#…

LeetCode 2563.統計公平數對的數目

給你一個下標從 0 開始、長度為 n 的整數數組 nums &#xff0c;和兩個整數 lower 和 upper &#xff0c;返回 公平數對的數目 。 如果 (i, j) 數對滿足以下情況&#xff0c;則認為它是一個 公平數對 &#xff1a; 0 < i < j < n&#xff0c;且 lower < nums[i] n…

ZABBIX配置自動發現與自動注冊,網易郵箱告警和釘釘告警

一、自動發現zabbix server 主動的去發現所有的客戶端&#xff0c;然后將客戶端的信息登記在服務端上。缺點是如果定義的網段中的主機數量多&#xff0c;zabbix server 登記耗時較久&#xff0c;且壓力會較大。1、部署準備準備三臺虛擬機192.168.80.151&#xff1b;192.168.80.…

QT(五)常用類

1. QString字符串類(掌握) QString是Qt的字符串類&#xff0c;與C的string相比&#xff0c;不再使用ASCII編碼&#xff0c;QString使用的是Unicode編碼。 QString中每個字符都是一個16位的QChar&#xff0c;而不是8位的char。 QString完全支持中文&#xff0c;但是由于不同的技…

EXCEL怎么提取表名

錯誤的方法&#xff1a;使用以下方法提取表名的時候&#xff0c;會存在1個問題&#xff0c;公式只在當前工作表生效&#xff0c;換工作表會出現表名覆蓋的情況。RIGHT(CELL("filename"),LEN(CELL("filename"))-FIND("]",CELL("filename&quo…

springboot校園外賣配送系統

目 錄 第一章 緒 論 1.1背景及意義 1.2國內外研究概況 1.3 研究的內容 第二章 關鍵技術的研究 2.1開發技術 2.2 Springboot框架介紹 2.3 Vue.js 主要功能 2.4 MVVM模式介紹 2.4 B/S體系工作原理 2.5 MySQL數據庫 第三章 系統分析 3.1 系統設計目標 3.2 系統可行性…

【智慧物聯網平臺】安裝部署教程——仙盟創夢IDE

一、部署前準備1. 環境要求基礎環境&#xff1a;JDK 1.8、MySQL 5.7/8.0、Maven 3.6、Redis&#xff08;用于緩存&#xff09;、Node.js&#xff08;用于前端構建&#xff0c;可選&#xff09;。依賴服務&#xff1a;若需對接門禁、道閘等硬件設備&#xff0c;需確保設備網絡可…

【安全漏洞】防范未然:如何有效關閉不必要的HTTP請求方法,保護你的Web應用

在構建和維護Web應用的過程中&#xff0c;安全問題總是我們最關心的話題之一。今天&#xff0c;我們要探討的是一個經常被忽視的Web漏洞——未關閉或限制不必要的HTTP請求方法。 雖然我們在日常開發中主要使用 GET 和 POST 這兩種請求方法&#xff0c;但像 PUT、DELETE、HEAD、…

嵌入式Linux裸機開發筆記8(IMX6ULL)主頻和時鐘配置實驗(1)

引言在前幾章實驗中我們都沒有涉及到 I.MX6U 的時鐘和主頻配置操作&#xff0c;全部使用的默認配置&#xff0c; 默認配置下 I.MX6U 工作頻率為 396MHz。但是 I.MX6U 系列標準的工作頻率為 528MHz&#xff0c;有些 型號甚至可以工作到 696MHz。本章學習 I.MX6U 的時鐘系統&…

設計模式(四)創建型:生成器模式詳解

設計模式&#xff08;四&#xff09;創建型&#xff1a;生成器模式詳解生成器模式&#xff08;Builder Pattern&#xff09;是 GoF 23 種設計模式中的核心創建型模式之一&#xff0c;其核心價值在于將一個復雜對象的構建過程與其表示分離&#xff0c;使得同樣的構建過程可以創建…

《Angular+Spring Boot:ERP前端采購銷售庫存協同架構解析》

基于Angular與Spring Boot構建的全棧ERP前端&#xff0c;絕非技術的簡單疊加&#xff0c;而是通過深度融合兩者特性&#xff0c;打造出兼具穩定性與靈活性的業務載體。Angular的組件化架構將復雜界面拆解為可復用的獨立單元&#xff0c;依賴注入機制則讓服務調用與數據流轉條理…

Java 排序

文章目錄排序插入排序分析希爾排序分析選擇排序分析堆排序分析冒泡排序分析快速排序霍爾法分析挖坑法找基準前后指針法題目快排的優化三數取中法非遞歸實現快排歸并排序分析非遞歸實現歸并排序海量數據的排序非比較的排序計數排序分析基數排序桶排序排序 穩定的排序&#xff1…

日本IT就職面試|儀容禮儀篇分享建議

日系企業で好印象を與える「身だしなみ」と「面接マナー」ガイドこんにちは。 日系企業への就職?転職活動をされている方にとって、「第一印象」は合否を左右する大切なポイントですよね。実は、面接の評価は入室の瞬間から始まっていると言っても過言ではありません。 今回は…

英語聽力口語詞匯-8.美食類

1.crispy,crisp adj.酥脆的&#xff0c;易碎的 2.sweet adj.甜的 比如說chocolate is so sweet and delicious 3.chewy adj.難嚼的&#xff0c;難咽的 4.oatmeal n.燕麥粉 5.pickle n.泡菜 7.stir-fry v.炒菜 8.bacon n.咸肉&#xff0c;熏肉 9.yummy adj.美味可口的 1…

力扣7:整數反轉

力扣7:整數反轉題目思路代碼題目 給你一個 32 位的有符號整數 x &#xff0c;返回將 x 中的數字部分反轉后的結果。 如果反轉后整數超過 32 位的有符號整數的范圍 [?2^31, 2^31 ? 1] &#xff0c;就返回 0。 思路 這道題我們可以分成兩部分來做&#xff0c;一是完成反轉二…

PWM信號控制電機

1&#xff1a;環境 STM32F103C8T6 KEIL5.38 2個電機 2個輪子 1個L298N STLINKV2 CH340 1個4位獨立按鍵 杜邦線若干 2&#xff1a;代碼 key.h #ifndef __KEY_H #define __KEY_H#include "stm32f10x.h"extern volatile uint8_t key_t ; extern volatile uint8_t …

開源賦能產業,生態共筑未來 | 開源科學計算與系統建模(openSCS)分論壇圓滿舉行

2025開放原子開源生態大會于7月23日-24日在北京國家會議中心召開。本屆大會以“開源賦能產業&#xff0c;生態共筑未來”為主題&#xff0c;匯聚政、產、學、研、用、金、創、投等各領域開源力量&#xff0c;聚焦開源政策導向、生態發展趨勢、開源產業實踐&#xff0c;共探中國…

Android廣播機制體系初識

Android廣播機制體系大白話把Android的廣播機制想象成小區里的“大喇叭”誰在喊話&#xff1f;任何App或系統都能當“大喇叭”&#xff0c;比如喊一嗓子“電量不足啦&#xff01;”&#xff08;這就是發送廣播&#xff09;誰在聽&#xff1f;其他App只要“豎起耳朵”&#xff0…

微信小程序點擊輸入框時,頂部導航欄被遮擋問題如何解決?

前言 不知道大家開發微信小程序的時候有沒有遇到這么一個問題&#xff0c;就是在表單頁面中&#xff0c;點擊輸入框后&#xff0c;輸入框頂起會把頂部欄給遮擋住&#xff0c;如下圖所示&#xff1a;遇到這種情況有沒有解決的辦法呢&#xff1f;能不能既將頁面頂起&#xff0c;同…

通過具有一致性嵌入的大語言模型(LMMs)實現端到端乳腺癌放射治療計劃制定|文獻速遞-醫學影像算法文獻分享

Title題目End-to-end breast cancer radiotherapy planning via LMMs with consistencyembedding通過具有一致性嵌入的大語言模型&#xff08;LMMs&#xff09;實現端到端乳腺癌放射治療計劃制定01文獻速遞介紹近年來&#xff0c;受大型語言模型&#xff08;LLM&#xff09;啟發…