物體檢測-系列教程20:YOLOV5 源碼解析10 (Model類前向傳播、forward_once函數、_initialize_biases函數)

😎😎😎物體檢測-系列教程 總目錄

有任何問題歡迎在下面留言
本篇文章的代碼運行界面均在Pycharm中進行
本篇文章配套的代碼資源已經上傳
點我下載源碼

14、Model類

14.2 前向傳播

    def forward(self, x, augment=False, profile=False):if augment:img_size = x.shape[-2:]  # height, widths = [1, 0.83, 0.67]  # scalesf = [None, 3, None]  # flips (2-ud, 3-lr)y = []  # outputsfor si, fi in zip(s, f):xi = scale_img(x.flip(fi) if fi else x, si)yi = self.forward_once(xi)[0]  # forwardyi[..., :4] /= si  # de-scaleif fi == 2:yi[..., 1] = img_size[0] - yi[..., 1]  # de-flip udelif fi == 3:yi[..., 0] = img_size[1] - yi[..., 0]  # de-flip lry.append(yi)return torch.cat(y, 1), None  # augmented inference, trainelse:return self.forward_once(x, profile)  # single-scale inference, train

這段代碼是forward方法的實現,它定義了模型的前向傳播過程,支持正常和增強兩種推理模式:

  1. 前向傳播函數,輸入x,是否進行數據增強augment,是否分析性能profile
  2. 是否使用數據增強
  3. img_size ,獲取輸入圖像的長寬
  4. s,定義縮放尺度
  5. f,定義翻轉模式,這里None表示不翻轉,3表示左右翻轉
  6. y,初始化輸出列表
  7. 使用zip函數將尺度因子列表s和翻轉指示列表f組合起來,然后遍歷每一對尺度因子和翻轉指示
  8. xi,如果fi不為None,先根據fi的值對圖像進行翻轉,然后調用scale_img函數根據si的值縮放處理圖像;否則直接調用scale_img函數根據si的值縮放處理圖像
  9. yi,將xi進行一次前向傳播,取第一個輸出
  10. 對輸出yi的前四個維度進行縮放調整,以恢復到原始的尺度。這通常是對邊界框坐標的調整
  11. 如果使用了上下翻轉
  12. 則調整y的坐標
  13. 如果使用了左右翻轉
  14. 則調整x坐標
  15. 將處理后的輸出添加到列表
  16. 將list y的所有輸出按照第一個維度進行拼接
  17. 如果在當前循環中沒有使用數據增強
  18. 直接進行一次正常的前向傳播

前向傳播方法,包括了一個可選的圖像增強步驟。在增強模式下,通過對輸入圖像應用不同的尺度和翻轉,生成多個變體,對每個變體單獨進行前向傳播,并對輸出進行調整以適應原始圖像的尺寸和方向,最后將所有變體的輸出合并。這種方法可以增加模型的泛化能力,因為它讓模型在訓練時見到更多的數據變化。如果不進行圖像增強,它將執行一次標準的前向傳播。通過這種設計,模型可以更靈活地應對不同的輸入和訓練需求

14.3 forward_once函數

    def forward_once(self, x, profile=False):y, dt = [], []  # outputsfor m in self.model:if m.f != -1:  # if not from previous layerx = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]if profile:try:import thopo = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2  # FLOPSexcept:o = 0t = time_synchronized()for _ in range(10):_ = m(x)dt.append((time_synchronized() - t) * 100)print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))x = m(x)  # runy.append(x if m.i in self.save else None)  # save outputif profile:print('%.1fms total' % sum(dt))return x
  1. forward_once函數,輸入和forward函一樣
  2. y, dt ,初始化兩個空列表,y用于存儲每一層的輸出,dt用于在性能分析模式下存儲每一層的執行時間
  3. 遍歷模型的每一層
  4. 如果當前層的輸入不是來自上一層的輸出
  5. 如果m.f是整數,則直接從y中獲取對應的層輸出作為輸入。如果m.f是一個列表,則根據列表中的索引從y中選擇輸入,如果索引為-1,則使用原始輸入x
  6. 是否開啟性能分析模式
  7. try
  8. 導入thop庫,用于計算浮點運算數(FLOPS)
  9. o,使用thop.profile計算當前層m的FLOPS,結果除以1E9轉換為GigaFLOPS,并乘以2。這里假設thop.profile返回的是一個元組,其第一個元素是所需的FLOPS
  10. 如果嘗試執行失敗
  11. 則將o(FLOPS)設置為0
  12. t,調用time_synchronized函數,獲取當前精確的時間
  13. 循環10次
  14. 為了穩定測量時間,通過多次執行減少偶然誤差
  15. 調用time_synchronized函數計算執行當前層操作的總時間,并將其添加到dt列表中
  16. 打印當前層的FLOPS、參數數量、執行時間和層類型。為性能分析提供詳細信息
  17. 執行當前層的前向傳播,并更新x為該層的輸出
  18. 如果當前層的索引m.i在保存列表self.save中,則將輸出x保存到y列表中;否則,保存None. 這樣做可以減少內存占用,只保存那些后續步驟中需要的層的輸出
  19. 再次檢查是否開啟了性能分析模式。這個檢查是為了在性能分析完成后打印總的執行時間
  20. 如果開啟了性能分析,計算所有層執行時間的總和并打印。這提供了整個前向傳播過程的總執行時間,幫助了解模型的性能瓶頸
  21. 返回最后一層的輸出

14.4 _initialize_biases函數

    def _initialize_biases(self, cf=None):m = self.model[-1]  # Detect() modulefor mi, s in zip(m.m, m.stride):  # fromb = mi.bias.data.view(m.na, -1).clone()obj_add = math.log(8 / (640 / s) ** 2)  # 計算obj層需要增加的值cls_add = math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum())b[:, 4] = b[:, 4] + obj_addb[:, 5:] = b[:, 5:] + cls_addmi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  1. 初始化偏執的函數,接受一個可選的參數,這個參數用于根據數據集中各類別出現的頻率來調整分類(cls)層的偏置
  2. m,獲取模型中的最后一個模塊,檢測層(Detect模塊),用于目標檢測
  3. 遍歷檢測層中的每個子模塊mi及其對應的步長stride,這里的步長是指輸入圖像被縮減的尺度,對目標尺寸預測非常關鍵
  4. b,獲取子模塊mi的偏置項,并將其重塑(reshape)成(m.na, -1)的形狀,其中m.na是每個特征圖位置預測的錨框數量。.clone()確保在修改b時不會影響原始的偏置值
  5. obj_add ,計算對象(obj)層偏置需要增加的值。這個公式基于假設每640像素的圖像中有8個對象,并根據特征圖的尺度(通過步長s計算)來調整。目的是調整檢測層對于不同尺寸特征圖上對象數量預測的偏置
  6. cls_add ,計算分類(cls)層偏置需要增加的值。如果沒有提供類頻率(cf為None),則使用一個基于類數量m.nc的固定公式。如果提供了類頻率,那么使用類頻率來計算每個類的偏置調整值,以此反映數據集中類別的分布
  7. 將計算出的對象層偏置調整值加到b的第4列上,這是因為在目標檢測中,偏置項通常包括4個坐標偏置和一個對象存在的偏置,后者位于第5個位置(索引為4)
  8. 將計算出的分類層偏置調整值加到b的第5列及之后的所有列上,對應于每個類別的偏置
  9. 將調整后的偏置b重塑回原始形狀并設置為mi的偏置,確保這些偏置在訓練過程中可以被進一步調整(requires_grad=True)

14.5 其他輔助函數

    def _print_biases(self):m = self.model[-1]  # Detect() modulefor mi in m.m:  # fromb = mi.bias.detach().view(m.na, -1).T  # conv.bias(255) to (3,85)print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  1. 獲取模型的最后一個模塊,這里假設是一個目標檢測模塊(Detect模塊)
  2. 遍歷檢測模塊中的每個子模塊mi
  3. 取得當前子模塊mi的偏置,通過.detach()確保不會影響梯度計算,.view(m.na, -1)調整形狀以匹配錨點數量m.na和偏置的其它維度,最后進行轉置以便于處理
  4. 打印當前子模塊卷積層的輸入通道數和偏置的統計信息,包括前五個偏置的平均值和之后所有偏置的平均值

fuse函數,用于融合模型中的卷積層(Conv2d)和批歸一化層(BatchNorm2d)

    def fuse(self):  # fuse model Conv2d() + BatchNorm2d() layersprint('Fusing layers... ')for m in self.model.modules():if type(m) is Conv:m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatabilitym.conv = fuse_conv_and_bn(m.conv, m.bn)  # update convm.bn = None  # remove batchnormm.forward = m.fuseforward  # update forwardself.info()return self
  1. 遍歷模型中的所有模塊
  2. 檢查當前模塊是否為卷積層
  3. 為了兼容PyTorch 1.6.0,清空非持久性緩沖區集合
  4. 使用fuse_conv_and_bn函數來融合當前卷積層和其后的批歸一化層
  5. 將批歸一化層設為None,表示移除批歸一化層
  6. 更新模塊的前向傳播函數為融合后的版本
  7. 在完成融合后,調用info方法打印模型信息
  8. 返回更新后的模型實例
    def info(self):  # print model informationmodel_info(self)

調用一個model_info函數,傳入當前模型實例,用于收集和打印模型的詳細信息,如參數數量、層的類型等

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/716134.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/716134.shtml
英文地址,請注明出處:http://en.pswp.cn/news/716134.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

MySQL 8.0 架構 之錯誤日志文件(Error Log)(2)

文章目錄 MySQL 8.0 架構 之錯誤日志文件(Error Log)(2)MySQL錯誤日志文件(Error Log)錯誤日志相關參數log_errorlog_error_services過濾器(Filter Error Log Components)寫入/接收器…

Vue+SpringBoot打造大學計算機課程管理平臺

目錄 一、摘要1.1 項目介紹1.2 項目錄屏 二、功能模塊2.1 實驗課程檔案模塊2.2 實驗資源模塊2.3 學生實驗模塊 三、系統設計3.1 用例設計3.2 數據庫設計3.2.1 實驗課程檔案表3.2.2 實驗資源表3.2.3 學生實驗表 四、系統展示五、核心代碼5.1 一鍵生成實驗5.2 提交實驗5.3 批閱實…

131. 分割回文串(力扣LeetCode)

文章目錄 131. 分割回文串題目描述回溯代碼 131. 分割回文串 題目描述 給你一個字符串 s,請你將 s 分割成一些子串,使每個子串都是 回文串 。返回 s 所有可能的分割方案。 回文串 是正著讀和反著讀都一樣的字符串。 示例 1: 輸入&#xf…

Android 架構MVI、MVVM、MVC、MVP

目錄 一、MVC(Model-View-Controller) 二、 MVP(Model-View-Presenter) 三. MVVM(Model-View-ViewModel) 四. MVI(Model-View-Intent) 五.MVI簡單實現 先簡單了解一下MVC、MVP和…

索引使用規則6——單列索引聯合索引

1、單列索引 單列索引:即一個索引只包含單個列 舉個例子 1.1、給phone和那么建立索引 create index index_name on tb_qianzhui(name); create index index_phone on tb_qianzhui(phone);1.2、查詢發現可能的索引有好幾個,但是最終選擇了phone的索引…

軟考 系統分析師系列知識點之詳細調查(2)

接前一篇文章:軟考 系統分析師系列知識點之詳細調查(1) 所屬章節: 第10章. 系統分析 第2節. 詳細調查 在系統規劃階段,通過初步調查,系統分析師已經對企業的組織結構、系統功能等有了大致的了解。但是&…

蘿卜大雜燴 | 提高數據科學工作效率的 8 個 Python 庫

本文來源公眾號“蘿卜大雜燴”,僅用于學術分享,侵權刪,干貨滿滿。 原文鏈接:提高數據科學工作效率的 8 個 Python 庫 在進行數據科學時,可能會浪費大量時間編碼并等待計算機運行某些東西。所以我選擇了一些 Python 庫…

Vue3中的Hooks詳解

vue3帶來了Composition API,其中Hooks是其重要組成部分。之前我寫過一篇關于vue3 hooks的文章比較簡單 Vue3從入門到刪庫 第十一章(自定義hooks) 所以本文將深入探討Vue3中Hooks,幫助你在Vue3開發中更加得心應手。 一、Vue3 Hoo…

貪吃蛇(C語言)步驟講解

一:文章大概 使用C語言在windows環境的控制臺中模擬實現經典小游戲 實現基本功能: 1.貪吃蛇地圖繪制 2.蛇吃食物的功能(上,下,左,右方向控制蛇的動作) 3.蛇撞墻死亡 4.計算得分 5.蛇身加…

[C語言]——C語言常見概念(1)

目錄 一.C語言是什么、 二.C語言的歷史和輝煌 三.編譯器的選擇(VS2022為例) 1.編譯和鏈接 2.編譯器的對比 3.VS2022 的優缺點 四.VS項目和源文件、頭文件介紹 五.第?個C語言程序 ??????? 一.C語言是什么、 ?和?交流使?的是?然語?&…

【python】爬取鏈家二手房數據做數據分析【附源碼】

一、前言、 在數據分析和挖掘領域中,網絡爬蟲是一種常見的工具,用于從網頁上收集數據。本文將介紹如何使用 Python 編寫簡單的網絡爬蟲程序,從鏈家網上海二手房頁面獲取房屋信息,并將數據保存到 Excel 文件中。 二、效果圖&#…

【JS】解構賦值注意點,解構賦值報錯

報錯代碼 const 小明 { email: 6, pwd: 66 } const 小剛 { email: 9, pwd: 99 }const { email } 小明 const { email } 小剛 報錯圖 原因 2個常量重復,重復在同一個作用域內是不能重復的,例如大括號內{const a 1; const a 2} 小伙伴A提問 問&…

Redis-基礎篇

Redis是一個開源、高性能、內存鍵值存儲數據庫,由 Salvatore Sanfilippo(網名antirez)創建,并在BSD許可下發布。它不僅可以用作緩存系統來加速數據訪問,還可以作為持久化的主數據存儲系統或消息中間件使用。Redis因其數…

leetcode:37.解數獨

題目理解:本題中棋盤的每一個位置都要放一個數字(而N皇后是一行只放一個皇后),并檢查數字是否合法,解數獨的樹形結構要比N皇后更寬更深。 代碼實現:

SpringBoot+Redis 解決海量重復提交問題,yyds!

在實際的開發項目中,一個對外暴露的接口往往會面臨很多次請求,我們來解釋一下冪等的概念:任意多次執行所產生的影響均與一次執行的影響相同。按照這個含義,最終的含義就是 對數據庫的影響只能是一次性的,不能重復處理。如何保證其…

?動類型轉換、強制類型轉換

為何short s1 1;是對的,而float f3.4;是錯的? 整數直接量,默認是int型。所以int a 4L; 會報錯,但是long l 4; 這樣不會,因為這樣會形成一個自動類型的轉換,int類型自動轉換為long類型 小數直接量&#…

JetBrains Gateway Github Copilot 客戶端插件和主機插件

JetBrains Gateway可以通過插件支持Github Copilot(需另行注冊)。 需要安裝插件 客戶端,而非插件 主機,如圖所示: 大概是因為代碼顯示在客戶端(運行在本地的IDE)?

NOC2023軟件創意編程(學而思賽道)python初中組復賽真題

目錄 下載打印原文檔做題: 軟件創意編程 一、參賽范圍 1.參賽組別:小學低年級組(1-3 年級)、小學高年級組(4-6 年級)、初中組。 2.參賽人數:1 人。 3.指導教師:1 人(可空缺)。 4.每人限參加 1 個賽項。 組別確定:以地方教育行政主管部門(教委、教育廳、教育局) 認…

Python 潮流周刊#40:白宮建議使用 Python 等內存安全的語言

△△請給“Python貓”加星標 ,以免錯過文章推送 你好,我是貓哥。這里每周分享優質的 Python、AI 及通用技術內容,大部分為英文。本周刊開源,歡迎投稿[1]。另有電報頻道[2]作為副刊,補充發布更加豐富的資訊,…

三層靶機靶場之環境搭建

下載: 鏈接:百度網盤 請輸入提取碼 提取碼:f4as 簡介 2019某CTF線下賽真題內網結合WEB攻防題庫,涉 及WEB攻擊,內網代理路由等技術,每臺服務器存在一個 Flag,獲取每一 個Flag對應一個積分&…