YOLOv8訓練流程-原理解析[目標檢測理論篇]

????????? 關于YOLOv8的主干網絡在YOLOv8網絡結構介紹-CSDN博客介紹了,為了更好地學習本章內容,建議先去看預測流程的原理分析YOLOv8原理解析[目標檢測理論篇]-CSDN博客,再次把YOLOv8網絡結構圖放在這里,方便隨時查看。

?

1.前言

??????????YOLOv8訓練流程這一塊內容還是比較復雜的,所以先來談一下訓練流程的思路,一共就兩步:第一步就是從網絡預測的結果中找到正樣本,并且確定正樣本要預測的對象;第二步就是計算預測結果和標簽之間的損失,分別計算預測框的損失以及預測類別的損失。

???????? 如下圖所示,假設一張圖片中只有3個標簽,那么需要從8400個Grid cell中找到這3個標簽對應的正樣本,然后通過計算正樣本的預測值和標簽值之間的損失,最后通過損失的反向傳播更新模型的權值和偏差。

?

????????為了更好地理解YOLOv8或者說是YOLO系列網絡,需要對Grid cell建立概念,如下所示:

????????首先可以看到通過網絡輸出的三個特征圖的分辨率分別為:80*80,40*40,20*20,本文所說的Grid cell即為圖中的紅點、藍點以及黃點,從圖中可以得到以下信息:第一,紅點代表的Grid cell是80*80分辨率中每個像素的中心點,因為紅色Grid cell比較密集并且可以x8將紅點映射回原圖,所以80*80分辨率的特征圖Grid cell是用來訓練小目標的,藍色和黃色Grid cell同理;第二,如果8400個Grid cell全部當成正樣本的話是不實際的,所以必須從這8400個Grid cell中選出一些正樣本;第三,由于YOLOv8是Anchor Free的模型,所以會將這三個尺度的特征圖展開變成長度為8400的一維向量。

2.?Task Aligned Assigner

????????Task Aligned Assigner中文翻譯為任務對齊分配器,是一種正負樣本分配策略,也就是找正樣本的方法,也就是訓練流程中的第一步。

???????? 在正式開始找正樣本之前,需要先把網絡預測值Box和Cls解碼,同時也需要把標簽的Box和Cls解碼,過程如下圖所示:首先是網絡預測結果Pred的Box需要解碼成4維,用來預測LTRB的(解碼過程在預測原理第三章有提到YOLOv8預測流程-原理解析[目標檢測理論篇]-CSDN博客),另外還需要轉換為XYXY格式且預測的坐標值是相對于網絡輸入尺寸的(即640*640);Cls只需要使用Sigmoid()解碼就行。其次是標簽Target的解碼,其實只有Box需要解碼,為了和Pred的解碼格式保持一致,需要將XYWH格式轉換為XYXY,并且標簽值對應的坐標是相對于網絡輸入尺寸(即640*640)。

????????然后正式開始找正樣本了,假設一張圖片上只有一個GT Box,使用紅色框作為標記。由于已經將三個特征圖下的grid cell都轉換到640*640坐標系了,結合GT框的位置和大小,找到合適的中心點作為訓練的正樣本,這就是TaskAlignedAssigner的任務,一共分成三步,即初步篩選,精細篩選,剔除多余三個步驟。

??????? (1)初步篩選:即select_candidates_in_gts,轉換之后的Grid Cell落在GT Box內部,作為初步篩選的正樣本,如圖中所示紅色點為初步篩選的Grid Cell,而落在GT Box外部的點或者落在GT Box角上或者邊上的都需要過濾掉,如藍色點所示;經過初步篩選,圖2中9個紅色的點作為初篩后留下的正樣本點。

?(2)精細篩選:即get_box_metrics,select_topk_candidates,通過公式align_metric=s^α?u^β(s和u分別表示分類得分和CIoU得分, a和b是權重系數,默認值分別0.5和6.0),計算出每個預測框的得分,然后把得分低的預測框給過濾掉,一般會取得分最高的top10個gird cell。

????????其中分類得分,取的在是GT Box內對應的類別的預測值,比如該GT的類別下標為1,那么落在GT box內的點所預測的類別下標為1時的置信度。另外計算IoU使用的是CIoU,計算公式和計算過程如下所示,如何理解CIoU呢,IoU并無法充分表示預測框和標注框之間的關系,需要引入中心點距離,以及最小矩形框斜邊距離,通過這兩者的比值來表示預測框和標注框的相似度。所以會在IoU的基礎上減去該比值,再減去由預測框寬高和標注框寬高組成的式子。

????????(3)剔除多余:保證一個Grid Cell只預測一個GT框,如果一個Grid Cell同時匹配到兩個GT Box,那么將從這兩個GT中,選出與他CIoU值最大的一個作為他要預測的GT Box。如圖所示,Grid Cell A、B、C負責預測GT1,包括預測GT1的類別和Box,而Grid Cell D負責預測GT2,也是預測類別和Box。?

3.Loss

?????????YOLOv8的Loss由三部分組成:Loss_box,Loss_cls,Loss_DFL分別表示回歸框損失,類比損失和DFL損失(其實也是回歸框的損失),下面會詳細介紹這三種損失。

????????還是先來簡單了解下Loss計算的思路,如下圖所示:左邊Target表示標簽值,右邊Pred表示預測值,均需要借助上一章找到的正樣本,然后通過對比同一個Grid cell正樣本的預測值和標簽值,計算對應的Loss。?

????????先來看一下get_targets函數做了哪些處理,GT_Box是經過預處理的,(1,3,4)表示XYXY且相對于640*640尺度的坐標。GT_Cls沒有經過處理,表示GT_Box1、 GT_Box2、 GT_Box3的類別。

??? 假設GT1的box是(x0,y0,X0,Y0),cls是0; GT2的box是(x1,y1,X1,Y1),cls是1;根據找到的正樣本和負樣本來舉個例子,其中負樣本為E,正樣本為A/B/C/D.由圖可以看到Target_Score在這一步已經區分了正負樣本了,其中負樣本使用[0,0]來表示。而Target_Bbox并沒有區分正負樣本,負樣本統統會選擇第1個GT的Box作為其Target_Bbox,換句話說,Target_Bbox值為[x0,y0,X0,Y0]的Grid cell可能為正樣本也可能為負樣本。

????????

3.1Loss_cls?

?????????下面是YOLOv8中計算Loss_cls的代碼:

target_scores_sum = max(target_scores.sum(), 1)loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum   # BCE

????????而主要的部分是Loss_cls采用了BCELoss損失函數,損失計算公式如下(注意:YOLOv8中的Cls使用的是BCEWithLogitsLoss,傳入的預測值是不需要自己進行Sigmoid,損失內部會自動進行sigmoid,但我這里演示使用的是BCELoss):

????????假設當前只有兩個類別,取出其中三個Grid cell的值,其中(0,0)表示負樣本,(0,1)和(1,0)表示正樣本,經過Normalize后得到帶有權重的真實標簽,這里正負樣本均計算Loss.

3.2Loss_box

????????下面是YOLOv8中計算Loss_box的代碼:

iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

?????????而主要的部分是Loss_box采用了CIoULoss損失函數,損失計算公式如下:

????????由前面可知,經過get_targets后的Target_Bbox并沒有區分正負樣本,因此下一步將利用fg_mask來區分正負樣本,從而得到30個正樣本。對Box會求兩個損失,所以有Target_Bbox1和Target_Bbox2,都需要還原到各自的特征圖的比例進行計算(可能這樣數字比較小計算比較方便),并且分別采用XYXY格式和LTRB格式表示。

????????另一方面,Pred_Box1需要通過網絡預測的結果(1,64,8400)解碼成(1,4,8400)并采用XYXY坐標的格式表示,并且找到對應的30個正樣本和Target_Bbox1計算CIoU損失;Pred_Box2則是直接把網絡預測的結果(1,64,8400)取出來,然后找到對應的30個正樣本,和Target_Bbox1計算DFL損失。

?????????這里再說一下為什么會是30個正樣本,因為有3個GT,每個GT取top10個得分最高的grid cell,并且這30個中沒有因重復而被過濾掉的Grid cell。

3.3Loss_DFL?

?????????下面是YOLOv8中計算Loss_DFL的代碼:

loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weightloss_dfl = loss_dfl.sum() / target_scores_sum

????????而主要的部分是Loss_DFL,損失計算公式如下:

????????下面演示一個Grid cell正樣本LTRB的計算過程:

????????首先,Pred_Box2即Pred_dist,是一個(120,16)的矩陣,可以理解為(30*4,16),即共有30個正樣本,每個正樣本需要預測LTRB四個數值,并且這四個數又分別通過0~15來表示。其次,Target_Bbox2即Target,是一個(120,1)的向量,分別對應著30個樣本中每個樣本的LTRB真實值。最后,由于Target一般不會是整數值,所以需要計算相鄰的兩個真實值對應的損失。損失函數使用Cross_entropy損失.

????????前面提到了由于Target一般不會是整數值,所以需要計算相鄰的兩個真實值對應的損失,那么如何選擇呢?這兩個損失之間的權重又是怎么樣的呢?為了加深理解,又單獨舉例演示該Grid cell中的Top_loss是怎么計算的:

????????該正樣本需要對GT對應的LTRB中的T為例,該正樣本的中心點距離上邊框是7.29像素,因為網絡預測只能是0~15的整數,那么只能選擇7和8這兩個相鄰的值作為標簽值,即yi=7和yi+1=8。接下來是選擇這兩個損失的權重,遵循一個原則:離得越近權重越大,所以當計算標簽為7的時候,選擇權重0.71,即yi+1-y;而計算標簽為8的時候,選擇權重0.21,即y-yi+1。

??? 前面也提到了損失函數使用Cross_entropy損失,和BCE損失有兩點區別,第一是把網絡預測的每個正樣本的LTRB值都需要進行SoftMax(),使得∑value=1,這和預測的時候是一樣的;第二是只選取標簽值對應的值作為損失,比如在該正樣本預測Top的損失計算中,有7和8兩個標簽值,那么7對應的損失值為1.4676,即-Log(Si),8對應的損失值為1.6825,即-log(Si+1)。最后該正樣本的Loss_Top為1.53,該正樣本的總損失為(Loss_Left+ Loss_Top+ Loss_Right+ Loss_Bottom)/4.

? ? ? ? 訓練過程的原理會稍微復雜,先整理成這樣子,后面我再優化下表達,爭取每個人都可以看得懂。

?????????

???????

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/10613.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/10613.shtml
英文地址,請注明出處:http://en.pswp.cn/web/10613.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Map中KEY去除下劃線并首字母轉換為大寫工具類

在運維舊項目時候&#xff0c;碰上sql查詢結果只能返回List<Map>&#xff0c;key為表單字段名&#xff0c;value為獲取到的結果數據。 懶得一個一個敲出來&#xff0c;就直接寫個方法轉換&#xff0c;并賦值到相應實體對象里去。 Map中KEY去除下劃線并首字母轉換為大寫&…

算法提高之矩陣距離

算法提高之矩陣距離 核心思想&#xff1a;多源bfs 從多個源頭做bfs&#xff0c;求距離 先把所有1的坐標存入隊列 再把所有1連接的位置存入 一層一層求 #include <iostream>#include <cstring>#include <algorithm>using namespace std;const int N 1…

Kafka 面試題(八)

1. Kafka&#xff1a;硬件配置選擇和調優的建議 &#xff1f; Kafka的硬件配置選擇和調優是確保Kafka集群高效穩定運行的關鍵環節。以下是一些建議&#xff1a; 硬件配置選擇&#xff1a; 內存&#xff08;RAM&#xff09;&#xff1a;建議至少使用32GB內存的服務器。為Kafk…

Web3Tools - 助記詞生成

Web3Tools - 助記詞生成工具 本文介紹了一個簡單的助記詞生成工具&#xff0c;使用 React 和 Material-UI 構建。用戶可以選擇助記詞的語言和長度&#xff0c;然后生成隨機的助記詞并顯示在頁面上 功能介紹 選擇語言和長度&#xff1a; 用戶可以在下拉菜單中選擇助記詞的語言&…

uniapp 圖片添加水印代碼封裝(優化版、圖片上傳壓縮、生成文字根據頁面自適應比例、增加文字背景色

uniapp 圖片添加水印代碼封裝(優化版、圖片上傳壓縮、生成文字根據頁面自適應比例、增加文字背景色 多張照片上傳封裝 <template><view class"image-picker"><uni-file-picker v-model"imageValue" :auto-upload"false" :title…

關于服務端接口知識的匯總

大家好&#xff0c;今天給大家分享一下之前整理的關于接口知識的匯總&#xff0c;對于測試人員來說&#xff0c;深入了解接口知識能帶來諸多顯著的好處。 一、為什么要了解接口知識&#xff1f; 接口是系統不同模塊之間交互的關鍵通道。只有充分掌握接口知識&#xff0c;才能…

http-server實現本地服務器

要實現一個本地服務器&#xff0c;你可以使用Node.js的http-server模塊。首先&#xff0c;確保你已經安裝了Node.js和npm。然后&#xff0c;按照以下步驟操作&#xff1a; 打開終端或命令提示符&#xff0c;進入你想要作為服務器根目錄的文件夾&#xff1b;運行以下命令安裝ht…

Axure PR 10 制作頂部下拉三級菜單和側邊三級菜單教程和源碼

在線預覽地址&#xff1a;Untitled Document 2.側邊三級下拉菜單 在線預覽地址&#xff1a;Untitled Document 文件包和教程下載地址&#xff1a;https://pan.quark.cn/s/77e55945bfa4 程序員必備資源網站&#xff1a;天夢星服務平臺 (tmxkj.top)

Linux x86_64 dump_stack()函數基于FP棧回溯

文章目錄 前言一、dump_stack函數使用二、dump_stack函數源碼解析2.1 show_stack2.2 show_stack_log_lvl2.3 show_trace_log_lvl2.4 dump_trace2.5 print_context_stack 參考資料 前言 Linux x86_64 centos7 Linux&#xff1a;3.10.0 一、dump_stack函數使用 dump_stack函數…

Unity開發中導彈路徑散射的原理與實現

Unity開發中導彈路徑散射的原理與實現 前言邏輯原理代碼實現導彈自身腳本外部控制腳本 應用效果結語 前言 前面我們學習了導彈的追蹤的效果&#xff0c;但是在動畫或游戲中&#xff0c;我們經常可以看到導彈發射后的彈道是不規則的&#xff0c;扭扭曲曲的飛行&#xff0c;然后擊…

數字生態系統的演進與企業API管理的關鍵之路

數字生態系統的演進與企業API管理的關鍵之路 在數字化時代&#xff0c;企業正經歷著一場轉型的浪潮&#xff0c;而API&#xff08;應用程序編程接口&#xff09;扮演著至關重要的角色。API如同一座橋梁&#xff0c;將組織內部的價值轉化為可市場化的產品&#xff0c;從而增強企…

韓國站群服務器在全球網絡架構中的重要作用?

韓國站群服務器在全球網絡架構中的重要作用? 在全球互聯網的蓬勃發展中&#xff0c;站群服務器作為網絡架構的核心組成部分之一&#xff0c;扮演著至關重要的角色。韓國站群服務器以其卓越的技術實力、優越的地理位置、穩定的網絡基礎設施和強大的安全保障能力&#xff0c;成…

LeetCode 題目 118:楊輝三角

題目描述 給定一個非負整數 numRows&#xff0c;生成楊輝三角的前 numRows 行。在楊輝三角中&#xff0c;每個數是它左上方和右上方的數的和。 楊輝三角解析 在這個詳解中&#xff0c;我們將使用 ASCII 圖形來說明楊輝三角的構建過程&#xff0c;包括逐行添加新的行的過程。…

250 基于matlab的5種時頻分析方法((短時傅里葉變換)STFT

基于matlab的5種時頻分析方法&#xff08;(短時傅里葉變換)STFT,Gabor展開和小波變換,Wigner-Ville&#xff08;WVD&#xff09;,偽Wigner-Ville分布(PWVD),平滑偽Wigner-Ville分布&#xff08;SPWVD&#xff09;,每條程序都有詳細的說明&#xff0c;設置仿真信號進行時頻輸出。…

Parted分區大容量磁盤

創建了新的虛擬磁盤10T , 掛載后分區格式化一.fdisk無法創建大容量的分區 Fileserver:~ # fdisk /dev/sdb Welcome to fdisk (util-linux 2.29.2). Changes will remain in memory only, until you decide to write them. Be careful before using the write command. Device …

使用html和css實現個人簡歷表單的制作

根據下列要求&#xff0c;做出下圖所示的個人簡歷&#xff08;表單&#xff09; 表單要求 Ⅰ、表格整體的邊框為1像素&#xff0c;單元格間距為0&#xff0c;表格中前六列列寬均為100像素&#xff0c;第七列 為200像素&#xff0c;表格整體在頁面上居中顯示&#xff1b; Ⅱ、前…

git提交代碼異常報錯error:bad signature 0x00000000

報錯信息 error:bad signature 0x00000000 異常原因 git 提交過程中異常關機或重啟&#xff0c;造成當前項目工程中的.git/index 文件損壞&#xff0c;無法提交 解決步驟 刪除.git/index文件 rm -f .git/index 重啟git git reset

Java 【數據結構】 哈希(Hash超詳解)HashSetHashMap【神裝】

登神長階 第十神裝 HashSet 第十一神裝 HashMap 目錄 &#x1f454;一.哈希 &#x1f9e5;1.概念 &#x1fa73;2.Object類的hashCode()方法: &#x1f45a;3.String類的哈希碼: &#x1f460;4.注意事項: &#x1f3b7;二.哈希桶 &#x1fa97;1.哈希桶原理 &#x…

Bert基礎(二十二)--Bert實戰:對話機器人

一 、概念簡介 1.1 生成式對話機器人 1.1.1什么是生成式對話機器人? 生成式對話機器人是一種能夠通過自然語言交互來理解和生成響應的人工智能系統。它們能夠進行開放域的對話,即在對話過程中,機器人可以根據用戶的需求和上下文信息,自主地生成新的、連貫的回復,而不僅…

如何使用CertCrunchy從SSL證書中發現和識別潛在的主機名稱

關于CertCrunchy CertCrunchy是一款功能強大的網絡偵查工具&#xff0c;該工具基于純Python開發&#xff0c;廣大研究人員可以利用該工具輕松從SSL證書中發現和識別潛在的主機信息。 支持的在線源 該工具支持從在線源或給定IP地址范圍獲取SSL證書的相關數據&#xff0c;并檢索…