YOLOv5推理代碼解析

代碼如下

import cv2
import numpy as np
import onnxruntime as ort
import time
import random# 畫一個檢測框
def plot_one_box(x, img, color=None, label=None, line_thickness=None):"""description: 在圖像上繪制一個矩形框。param:x: 框的坐標 [x1, y1, x2, y2]img: 輸入圖像color: 矩形框的顏色,默認為隨機顏色label: 框內顯示的標簽line_thickness: 矩形框的線條寬度return: 無返回值,直接在圖像上繪制"""tl = (line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1)  # line/font thickness,計算線條或字體的粗細color = color or [random.randint(0, 255) for _ in range(3)]  # 如果沒有提供顏色,隨機生成顏色c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))  # 左上角和右下角的坐標cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)  # 繪制矩形框if label:  # 如果提供了標簽,則繪制標簽tf = max(tl - 1, 1)  # 字體的粗細t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]  # 獲取標簽的大小c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3  # 計算標簽背景框的位置cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # 繪制標簽背景框cv2.putText(img,label,(c1[0], c1[1] - 2),0,tl / 3,[225, 255, 255],thickness=tf,lineType=cv2.LINE_AA,)  # 繪制標簽文本# 生成網格坐標
def _make_grid(nx, ny):"""description: 生成網格坐標,用于解碼預測框位置。param:nx, ny: 網格的行數和列數return: 返回網格坐標"""xv, yv = np.meshgrid(np.arange(ny), np.arange(nx))  # 生成網格坐標return np.stack((xv, yv), 2).reshape((-1, 2)).astype(np.float32)  # 轉換為需要的格式# 輸出解碼
def cal_outputs(outs, nl, na, model_w, model_h, anchor_grid, stride):"""description: 對模型輸出的坐標進行解碼,轉換為圖像坐標。param:outs: 模型輸出的框的偏移量nl: 輸出層數量na: 每層的anchor數目model_w, model_h: 模型輸入圖像的尺寸anchor_grid: anchor的尺寸stride: 每個輸出層的縮放步長return: 解碼后的輸出"""row_ind = 0grid = [np.zeros(1)] * nl  # 每個層對應一個網格for i in range(nl):h, w = int(model_w / stride[i]), int(model_h / stride[i])  # 計算該層特征圖的高和寬length = int(na * h * w)  # 當前層的總框數if grid[i].shape[2:4] != (h, w):  # 如果網格的大小不匹配,則重新生成網格grid[i] = _make_grid(w, h)# 解碼每個框的中心坐標和寬高outs[row_ind:row_ind + length, 0:2] = (outs[row_ind:row_ind + length, 0:2] * 2. - 0.5 + np.tile(grid[i], (na, 1))) * int(stride[i])outs[row_ind:row_ind + length, 2:4] = (outs[row_ind:row_ind + length, 2:4] * 2) ** 2 * np.repeat(anchor_grid[i], h * w, axis=0)  # 計算寬高row_ind += lengthreturn outs# 后處理,計算檢測框
def post_process_opencv(outputs, model_h, model_w, img_h, img_w, thred_nms, thred_cond):"""description: 對模型輸出的框進行后處理,得到最終的檢測框。param:outputs: 模型輸出的框model_h, model_w: 模型輸入的高度和寬度img_h, img_w: 原圖的高度和寬度thred_nms: 非極大值抑制的閾值thred_cond: 置信度閾值return: 返回處理后的框、置信度和類別"""conf = outputs[:, 4].tolist()  # 獲取每個框的置信度c_x = outputs[:, 0] / model_w * img_w  # 計算中心點x坐標c_y = outputs[:, 1] / model_h * img_h  # 計算中心點y坐標w = outputs[:, 2] / model_w * img_w  # 計算框的寬度h = outputs[:, 3] / model_h * img_h  # 計算框的高度p_cls = outputs[:, 5:]  # 獲取分類得分if len(p_cls.shape) == 1:  # 如果分類結果只有一維,增加一維p_cls = np.expand_dims(p_cls, 1)cls_id = np.argmax(p_cls, axis=1)  # 獲取類別編號# 計算框的四個角坐標p_x1 = np.expand_dims(c_x - w / 2, -1)p_y1 = np.expand_dims(c_y - h / 2, -1)p_x2 = np.expand_dims(c_x + w / 2, -1)p_y2 = np.expand_dims(c_y + h / 2, -1)areas = np.concatenate((p_x1, p_y1, p_x2, p_y2), axis=-1)  # 合并成框的坐標areas = areas.tolist()  # 轉為列表形式ids = cv2.dnn.NMSBoxes(areas, conf, thred_cond, thred_nms)  # 非極大值抑制if len(ids) > 0:  # 如果有框被保留return np.array(areas)[ids], np.array(conf)[ids], cls_id[ids]else:return [], [], []# 圖像推理
def infer_img(img0, net, model_h, model_w, nl, na, stride, anchor_grid, thred_nms=0.4, thred_cond=0.5):"""description: 對輸入圖像進行推理,輸出檢測框。param:img0: 原始圖像net: 加載的ONNX模型model_h, model_w: 模型的輸入尺寸nl: 輸出層數量na: 每層的anchor數量stride: 每層的縮放步長anchor_grid: 每層的anchor尺寸thred_nms: 非極大值抑制閾值thred_cond: 置信度閾值return: 檢測框、置信度和類別"""# 圖像預處理img = cv2.resize(img0, [model_w, model_h], interpolation=cv2.INTER_AREA)  # 將圖像調整為模型輸入大小img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # 轉換為RGB格式img = img.astype(np.float32) / 255.0  # 歸一化blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)  # 將圖像轉為模型輸入格式# 模型推理outs = net.run(None, {net.get_inputs()[0].name: blob})[0].squeeze(axis=0)  # 推理并去掉batch維度# 輸出坐標矯正outs = cal_outputs(outs, nl, na, model_w, model_h, anchor_grid, stride)# 檢測框計算img_h, img_w, _ = np.shape(img0)  # 獲取原圖的尺寸boxes, confs, ids = post_process_opencv(outs, model_h, model_w, img_h, img_w, thred_nms, thred_cond)return boxes, confs, idsif __name__ == "__main__":# 加載ONNX模型model_pb_path = "a.onnx"  # 模型文件路徑so = ort.SessionOptions()net = ort.InferenceSession(model_pb_path, so)# 類別字典dic_labels = {0: 'jn', 1: 'pill_bag', 2: 'pill_ban', 3: 'yg', 4: 'ys', 5: 'kfy',6: 'pw', 7: 'yanyao_1', 8: 'yanyao_2', 9: 'paper_cup', 10: 'musai',11: 'carrot', 12: 'potato', 13: 'potato_s', 14: 'potato_black',15: 'cizhuan', 16: 'eluanshi_guang', 17: 'stone', 18: 'zhuankuai_bai',19: 'zhuankuai_red', 20: 'empty'}# 模型參數model_h = 320model_w = 320nl = 3na = 3stride = [8., 16., 32.]anchors = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]]anchor_grid = np.asarray(anchors, dtype=np.float32).reshape(nl, -1, 2)# 打開攝像頭video = 1  # 攝像頭設備編號,1表示默認攝像頭cap = cv2.VideoCapture(video)  # 視頻捕獲對象flag_det = False  # 檢測開關while True:success, img0 = cap.read()  # 讀取每一幀if success:if flag_det:  # 如果檢測開啟t1 = time.time()  # 記錄推理前的時間det_boxes, scores, ids = infer_img(img0, net, model_h, model_w, nl, na, stride, anchor_grid,thred_nms=0.4, thred_cond=0.5)  # 推理t2 = time.time()  # 記錄推理后的時間# 繪制檢測框和標簽for box, score, id in zip(det_boxes, scores, ids):label = '%s:%.2f' % (dic_labels[id], score)plot_one_box(box.astype(np.int16), img0, color=(255, 0, 0), label=label, line_thickness=None)# 計算并顯示FPSdelta_time = t2 - t1if delta_time > 0:str_FPS = "FPS: %.2f" % (1. / delta_time)else:str_FPS = "FPS: inf"cv2.putText(img0, str_FPS, (50, 50), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 255, 0), 3)cv2.imshow("video", img0)  # 顯示圖像key = cv2.waitKey(1) & 0xFF  # 等待鍵盤輸入if key == ord('q'):  # 按q鍵退出breakelif key & 0xFF == ord('s'):  # 按s鍵切換檢測開關flag_det = not flag_detprint(flag_det)cap.release()  # 釋放視頻捕獲對象

代碼中推理相關的函數逐行詳細中文注釋和解釋,幫助你從初學者的角度完全理解代碼。我們重點講解這幾個核心函數:


1. infer_img() 推理主函數

def infer_img(img0, net, model_h, model_w, nl, na, stride, anchor_grid, thred_nms=0.4, thred_cond=0.5):

這是圖像推理的主函數,完成從原始圖像到預測結果的所有操作。


第一步:圖像預處理

img = cv2.resize(img0, [model_w, model_h], interpolation=cv2.INTER_AREA)
  • 將原始圖像 img0 縮放成模型輸入要求的大小(例如 320×320)。

  • cv2.INTER_AREA 是一種圖像插值方式,適合縮小圖像時使用。

img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  • OpenCV 讀取圖像是 BGR 順序,而深度學習模型通常使用 RGB,因此這里需要轉換顏色通道。

img = img.astype(np.float32) / 255.0
  • 把圖像的數據類型轉為 float32,并將像素值從 [0, 255] 范圍歸一化到 [0, 1],符合模型輸入要求。

blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
  • OpenCV圖像的格式是 (H, W, C),而 PyTorch 模型(如YOLO)的輸入是 (B, C, H, W)

  • np.transpose(img, (2, 0, 1)) 把通道 C 移到第一個維度

  • np.expand_dims(..., axis=0) 增加 batch 維度:變成 (1, 3, 320, 320)


第二步:模型推理

outs = net.run(None, {net.get_inputs()[0].name: blob})[0].squeeze(axis=0)
  • 用 ONNX Runtime 推理:輸入是 blob

  • net.get_inputs()[0].name 得到模型輸入的名字

  • squeeze(axis=0) 把 batch 維度去掉,形狀變成 (N, 85),N 是預測框數量,85 是每個框的信息(x, y, w, h, conf, + 80類)


第三步:輸出坐標解碼

outs = cal_outputs(outs, nl, na, model_w, model_h, anchor_grid, stride)
  • YOLO 的輸出是相對 anchor + grid 編碼的,需要轉換為圖像上的真實位置

  • cal_outputs() 就是做這個解碼變換的函數(后面詳細講)


第四步:后處理,獲取檢測框信息

img_h, img_w, _ = np.shape(img0)
boxes, confs, ids = post_process_opencv(outs, model_h, model_w, img_h, img_w, thred_nms, thred_cond)
  • 將模型輸出映射回原始圖像尺寸

  • 使用置信度閾值和 NMS 非極大值抑制刪除重復框

  • 得到最終的:

    • boxes: 框坐標

    • confs: 置信度

    • ids: 類別編號


2. cal_outputs() 坐標解碼函數

def cal_outputs(outs, nl, na, model_w, model_h, anchor_grid, stride):

含義解釋:

  • outs: 模型輸出,形狀大致是 (N, 85),前4列是框的位置

  • nl: YOLO使用的輸出層數量(3個:大中小目標)

  • na: 每個特征層使用的 anchor 數(通常為 3)

  • anchor_grid: 每層 anchor 的寬高尺寸

  • stride: 每層特征圖相對于原圖的縮放倍數

grid = [np.zeros(1)] * nl
  • 每一層都要生成網格坐標 grid,初始化為占位

for i in range(nl):h, w = int(model_w / stride[i]), int(model_h / stride[i])
  • 計算第 i 層的特征圖尺寸(如:320/8=40)

    length = int(na * h * w)
  • 該層有多少個預測框

    if grid[i].shape[2:4] != (h, w):grid[i] = _make_grid(w, h)
  • 如果還沒有生成 grid,就調用 _make_grid() 創建形狀為 (h*w, 2) 的網格點

    outs[row_ind:row_ind + length, 0:2] = ...outs[row_ind:row_ind + length, 2:4] = ...
  • 對該層的所有框做位置矯正(中心點解碼 + 寬高縮放)

  • 用 grid 和 anchor 反算出真實坐標


3. post_process_opencv() 后處理函數

def post_process_opencv(outputs, model_h, model_w, img_h, img_w, thred_nms, thred_cond):

功能:

  • 將模型輸出映射回原始圖像尺寸

  • 提取類別信息

  • 使用 OpenCV 的 cv2.dnn.NMSBoxes() 進行非極大值抑制,保留重要框

步驟:

conf = outputs[:, 4].tolist()         # 提取每個框的置信度
c_x = outputs[:, 0] / model_w * img_w
c_y = outputs[:, 1] / model_h * img_h
w = outputs[:, 2] / model_w * img_w
h = outputs[:, 3] / model_h * img_h
  • 將中心點和尺寸從模型尺寸映射回原始圖像尺寸

p_cls = outputs[:, 5:]
cls_id = np.argmax(p_cls, axis=1)
  • 取得每個框的類別分數最大值(即分類結果)

p_x1 = c_x - w/2
p_y1 = c_y - h/2
p_x2 = c_x + w/2
p_y2 = c_y + h/2
  • 把中心點轉為左上角和右下角坐標 [x1, y1, x2, y2]

areas = np.concatenate((p_x1, p_y1, p_x2, p_y2), axis=-1)
ids = cv2.dnn.NMSBoxes(areas, conf, thred_cond, thred_nms)
  • 用 NMS 去除重疊預測框


本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/80499.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/80499.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/80499.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

CATIA高效工作指南——常規配置篇(二)

一、結構樹(Specification Tree)操作技巧精講 結構樹是CATIA設計中記錄模型歷史與邏輯關系的核心模塊,其高效管理直接影響設計效率。本節從基礎操作到高級技巧進行系統梳理。 1.1 結構樹激活與移動 ??激活方式??: ??白線…

批量重命名bat

作為一名程序員,怎么可以自己一個個改文件名呢! Windows的批量重命名會自動加上括號和空格,看著很不爽,寫一個bat處理吧!?(ゝω???) 功能:將當前目錄下的所有文件名里面當括號和空格都去掉。 用法&…

嵌入式軟件開發常見warning之 warning: implicit declaration of function

文章目錄 🧩 1. C 編譯流程回顧(背景)📍 2. 出現 warning 的具體階段:**編譯階段(Compilation)**🧬 2.1 詞法分析(Lexical Analysis)🌲 2.2 語法分…

【人工智能-agent】--Dify中MCP工具存數據到MySQL

本文記錄的工作如下: 自定義MCP工具,爬取我的鋼鐵網數據爬取的數據插值處理自定義MCP工具,把爬取到的數據(str)存入本地excel表格中自定義MCP工具,把爬取到的數據(str)存入本地MySQ…

Golang 應用的 CI/CD 與 K8S 自動化部署全流程指南

一、CI/CD 流程設計與工具選擇 1. 技術棧選擇 版本控制:Git(推薦 GitHub/GitLab)CI 工具:Jenkins/GitLab CI/GitHub Actions(本文以 GitHub Actions 為例)容器化:Docker Docker Compose制品庫…

網絡基礎1(應用層、傳輸層)

目錄 一、應用層 1.1 序列化和反序列化 1.2 HTTP協議 1.2.1 URL 1.2.2 HTTP協議格式 1.2.3 HTTP服務器示例 二、傳輸層 2.1 端口號 2.1.1 netstat 2.1.2 pidof 2.2 UDP協議 2.2.1 UDP的特點 2.2.2 基于UDP的應用層…

基于大模型預測的吉蘭 - 巴雷綜合征綜合診療方案研究報告大綱

目錄 一、引言(一)研究背景(二)研究目的與意義二、大模型預測吉蘭 - 巴雷綜合征的理論基礎與技術架構(一)大模型原理概述(二)技術架構設計三、術前預測與手術方案制定(一)術前預測內容(二)手術方案制定依據與策略四、術中監測與麻醉方案調整(一)術中監測指標與數…

【言語】刷題2

front:刷題1 ? 前對策的說理類 題干 新時代是轉型關口,要創新和開放(前對策)創新和開放不能一蹴而就,但是對于現代化很重要 BC片面,排除 A雖然表達出了創新和開放很重要,體現了現代化&#xf…

Blueprints - Gameplay Message Subsystem

一些學習筆記歸檔; Gameplay Message是C插件,安裝方式是把插件文件夾拷貝到Plugins中(沒有的話需要新建該文件夾),然后再刷新源碼,運行項目; 安裝后還需要在插件中激活: 這樣&#…

火山云網站搭建

使用火山引擎的 **火山云(Volcano Engine Cloud)** 搭建網站,主要涉及云服務器、存儲、網絡等核心云服務的配置。以下是搭建網站的基本步驟和關鍵點: --- ### **一、準備工作** 1. **注冊火山引擎賬號** - 訪問火山引擎官網&…

嵌入式開發學習(第二階段 C語言基礎)

直到型循環的實現 特點:先執行,后判斷,不管條件是否滿足,至少執行一次。 **代表:**do…while,goto(已經淘汰,不推薦使用) do…while 語法: 循環變量; do {循環體; }…

Nginx +Nginx-http-flv-module 推流拉流

這兩天為了利用云服務器實現 Nginx 進行OBS Rtmp推流,Flv拉流時發生了諸多情況,記錄實現過程。 環境 OS:阿里云CentOS 7.9 64位Nginx:nginx-1.28.0Nginx-http-flv-module:nginx-http-flv-module-1.2.12 安裝Nginx編…

射頻ADRV9026驅動

參考: ADRV9026 & ADRV9029 Prototyping Platform User Guide [Analog Devices Wiki] 基于ADRV9026的四通道射頻收發FMC子卡-CSDN博客 adrv9026 spi 接口驗證代碼-CSDN博客

使用本地部署的 LLaMA 3 模型進行中文對話生成

以下程序調用本地部署的 LLaMA3 模型進行多輪對話生成,通過 Hugging Face Transformers API 加載、預處理、生成并輸出最終回答。 程序用的是 Chat 模型格式(如 LLaMA3 Instruct 模型),遵循 ChatML 模板,并使用 apply…

Oracle19c中的全局臨時表

應用程序通常使用某種形式的臨時數據存儲來處理過于復雜而無法一次性完成的流程。通常,這些臨時存儲被定義為數據庫表或 PL/SQL 表。從 Oracle 8i 開始,可以使用全局臨時表將臨時表的維護和管理委托給服務器。 一、臨時表分類 Oracle 支持兩種類型的臨…

Windows 安裝 Milvus

說明 操作系統:Window 中間件:docker desktop Milvus:Milvus Standalone(單機版) 安裝 docker desktop 參考:Window、CentOs、Ubuntu 安裝 docker-CSDN博客 安裝 Milvus 參考鏈接:Run Mil…

24、DeepSeek-V3論文筆記

DeepSeek-V3論文筆記 **一、概述****二、核心架構與創新技術**0.匯總:1. **基礎架構**2. **創新策略** 1.DeepSeekMoE無輔助損失負載均衡DeepSeekMoE基礎架構無輔助損失負載均衡互補序列級輔助損失 2.多令牌預測(MTP)1.概念2、原理2.1BPD2.2M…

1.8 梯度

(知識體系演進邏輯樹) 一元導數(1.5) │ ├─→ 多元偏導數(1.6核心突破) │ │ │ └─解決:多變量耦合時的單變量影響分析 │ │ │ ├─幾何:坐標軸切片切線斜率…

274、H指數

題目 給你一個整數數組 citations ,其中 citations[i] 表示研究者的第 i 篇論文被引用的次數。計算并返回該研究者的 h 指數。 根據維基百科上 h 指數的定義:h 代表“高引用次數” ,一名科研人員的 h 指數 是指他(她&#xff09…

【C++11】異常

前言 上文我們學習到了C11中類的新功能【C11】類的新功能-CSDN博客 本文我們來學習C下一個新語法:異常 1.異常的概念 異常的處理機制允許程序在運行時就出現的問題進行相應的處理。異常可以使得我們將問題的發現和問題的解決分開,程序的一部分負…