RT-DETR+Flask實現目標檢測推理案例

今天,帶大家利用RT-DETR(我們可以換成任意一個模型)+Flask來實現一個目標檢測平臺小案例,其實現效果如下:

目標檢測案例

這個案例很簡單,就是讓我們上傳一張圖像,隨后選擇一下置信度,即可檢測出圖像中的目標,那么具體該如何實現呢?

RT-DETR模型推理

在先前的學習過程中,博主對RT-DETR進行來了簡要的介紹,作為百度提出的實時性目標檢測模型,其無論是速度還是精度均取得了較為理想的效果,今天則主要介紹一下RT-DETR的推理過程,與先前使用DETR中使用pth權重與網絡結構相結合的推理方式不同,RT-DETR中使用的是onnx這種權重文件,因此,我們需要先對onnx文件進行一個簡單了解:
在這里插入圖片描述

ONNX模型文件

import onnx
# 加載模型
model = onnx.load('onnx_model.onnx')
# 檢查模型格式是否完整及正確
onnx.checker.check_model(model)
# 獲取輸出層,包含層名稱、維度信息
output = self.model.graph.output
print(output)

在原本的DETR類目標檢測算法中,推理是采用權重文件與模型結構代碼相結合的方式,而在RT-DETR中,則采用onnx模型文件來進行推理,即只需要該模型文件即可。

首先是將pth文件與模型結構進行匹配,從而導出onnx模型文件

"""by lyuwenyu
"""import os 
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))import argparse
import numpy as np from src.core import YAMLConfigimport torch
import torch.nn as nn def main(args, ):"""main"""cfg = YAMLConfig(args.config, resume=args.resume)if args.resume:checkpoint = torch.load(args.resume, map_location='cpu') if 'ema' in checkpoint:state = checkpoint['ema']['module']else:state = checkpoint['model']else:raise AttributeError('only support resume to load model.state_dict by now.')# NOTE load train mode state -> convert to deploy modecfg.model.load_state_dict(state)class Model(nn.Module):def __init__(self, ) -> None:super().__init__()self.model = cfg.model.deploy()self.postprocessor = cfg.postprocessor.deploy()print(self.postprocessor.deploy_mode)def forward(self, images, orig_target_sizes):outputs = self.model(images)return self.postprocessor(outputs, orig_target_sizes)model = Model()dynamic_axes = {'images': {0: 'N', },'orig_target_sizes': {0: 'N'}}data = torch.rand(1, 3, 640, 640)size = torch.tensor([[640, 640]])torch.onnx.export(model, (data, size), args.file_name,input_names=['images', 'orig_target_sizes'],output_names=['labels', 'boxes', 'scores'],dynamic_axes=dynamic_axes,opset_version=16, verbose=False)if args.check:import onnxonnx_model = onnx.load(args.file_name)onnx.checker.check_model(onnx_model)print('Check export onnx model done...')if args.simplify:import onnxsimdynamic = True input_shapes = {'images': data.shape, 'orig_target_sizes': size.shape} if dynamic else Noneonnx_model_simplify, check = onnxsim.simplify(args.file_name, input_shapes=input_shapes, dynamic_input_shape=dynamic)onnx.save(onnx_model_simplify, args.file_name)print(f'Simplify onnx model {check}...')
if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--config', '-c',  default="D:\graduate\programs\RT-DETR-main\RT-DETR-main//rtdetr_pytorch\configs/rtdetr/rtdetr_r18vd_6x_coco.yml",type=str, )parser.add_argument('--resume', '-r', default="D:\graduate\programs\RT-DETR-main\RT-DETR-main/rtdetr_pytorch/tools\output/rtdetr_r18vd_6x_coco\checkpoint0024.pth",type=str, )parser.add_argument('--file-name', '-f', type=str, default='model.onnx')parser.add_argument('--check',  action='store_true', default=False,)parser.add_argument('--simplify',  action='store_true', default=False,)args = parser.parse_args()main(args)

隨后,便是利用onnx模型文件進行目標檢測推理過程了
onnx也有自己的一套流程:

onnx前向InferenceSession的使用

關于onnx的前向推理,onnx使用了onnxruntime計算引擎。
onnx runtime是一個用于onnx模型的推理引擎。微軟聯合Facebook等在2017年搞了個深度學習以及機器學習模型的格式標準–ONNX,順路提供了一個專門用于ONNX模型推理的引擎(onnxruntime)。

import onnxruntime
# 創建一個InferenceSession的實例,并將模型的地址傳遞給該實例
sess = onnxruntime.InferenceSession('onnxmodel.onnx')
# 調用實例sess的潤方法進行推理
outputs = sess.run(output_layers_name, {input_layers_name: x})

推理詳細代碼

推理代碼如下:

import torch
import onnxruntime as ort
from PIL import Image, ImageDraw
from torchvision.transforms import ToTensorif __name__ == "__main__":##################classes = ['car','truck',"bus"]################### print(onnx.helper.printable_graph(mm.graph))#############img_path = "1.jpg"#############im = Image.open(img_path).convert('RGB')im = im.resize((640, 640))im_data = ToTensor()(im)[None]print(im_data.shape)size = torch.tensor([[640, 640]])sess = ort.InferenceSession("model.onnx")import timestart = time.time()output = sess.run(output_names=['labels', 'boxes', 'scores'],#output_names=None,input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()})end = time.time()fps = 1.0 / (end - start)print(fps)# print(type(output))# print([out.shape for out in output])labels, boxes, scores = outputdraw = ImageDraw.Draw(im)thrh = 0.6for i in range(im_data.shape[0]):scr = scores[i]lab = labels[i][scr > thrh]box = boxes[i][scr > thrh]print(i, sum(scr > thrh))#print(lab)print(f'box:{box}')for l, b in zip(lab, box):draw.rectangle(list(b), outline='red',)print(l.item())draw.text((b[0], b[1] - 10), text=str(classes[l.item()]), fill='blue', )#############im.save('2.jpg')#############

前端代碼

前端代碼包含兩部分,一個是上傳頁面,一個是顯示頁面

上傳頁面如下:

<!DOCTYPE html>
<html lang="en">
<head><meta charset="UTF-8"><meta name="viewport" content="initial-scale=1.0, maximum-scale=1.0, user-scalable=no" /><title></title><script src="http://www.jq22.com/jquery/jquery-1.10.2.js"></script><style>#addCommodityIndex {text-align: center;width: 300px;height: 340px;position: absolute;left: 50%;top: 50%;margin: -200px 0 0 -200px;border: solid #ccc 1px;padding: 35px;}#imghead {cursor: pointer;}.btn {width: 100%;height: 40px;text-align: center;}</style><link rel="stylesheet" href="../static/css/bootstrap.min.css"  crossorigin="anonymous">
</head><body><div id="addCommodityIndex"><h2>目標檢測</h2><div class="form-group row"><form id="upload"  action="/upload" enctype="multipart/form-data" method="POST"><img src=""><div class="form-group row"><label>上傳圖像</label><input type="file" class="form-control"  name='file'></div><div class="form-group row"><label>選擇置信度</label><select class="form-control" name="score" id="exampleFormControlSelect1"><option value="0.5">0.5</option><option value="0.6">0.6</option><option value="0.7">0.7</option><option value="0.8">0.8</option><option value="0.9">0.9</option></select></div><div class="form-group row"><div class="btn"><input type="submit" class="btn btn-success" value="提交圖像" /></div></div></form></div></div></body>
</html>

顯示頁面:

<!DOCTYPE html>
<html lang="en"><head><meta charset="UTF-8"><meta name="viewport" content="initial-scale=1.0, maximum-scale=1.0, user-scalable=no" /><title></title><script src="http://www.jq22.com/jquery/jquery-1.10.2.js"></script><style>#addCommodityIndex {text-align: center;position: absolute;left: 40%;top: 50%;margin: -200px 0 0 -200px;border: solid #ccc 1px;}#imghead {cursor: pointer;}.result {width: 100%;height: 100%;text-align: center;}</style><link rel="stylesheet" href="../static/css/bootstrap.min.css"  crossorigin="anonymous">
</head><body><div id="addCommodityIndex">
<div class="card mb-3" style="max-width: 680px;"><div class="row no-gutters"><div class="col-md-5"><img src="../static/img/result.jpg" class="result"></div><div class="col-md-5"><div class="card-body"><h5 class="card-title">檢測結果</h5><p class="card-text">目標數量:{{num}}</p><p class="card-text">檢測速度:{{fps}}/</p><a  href="/home" class="btn btn-success">繼續提交</a></div></div></div>
</div>
</div>
</body>
</html>

Flask框架代碼:

# -*- coding: utf-8 -*-
from flask import Flask,request,render_template
import json
import os
import time
app = Flask(__name__)
import infer
@app.route('/home',methods=['GET'])
def home():return render_template('upload.html')@app.route('/upload',methods=['GET','POST'])
def upload():if request.method == 'POST':f = request.files['file'] #獲取數據流rootPath = os.path.dirname(os.path.abspath(__file__)) #根目錄路徑#創建存儲文件的文件夾,使用時間戳防止重名覆蓋file_path = 'static/upload/' + str(int(time.time()))absolute_path = os.path.join(rootPath,file_path).replace('\\','/') #存儲文件的絕對路徑,window路徑顯示\\要轉化/if not os.path.exists(absolute_path): #不存在改目錄則會自動創建os.makedirs(absolute_path)save_file_name = os.path.join(absolute_path,f.filename).replace('\\','/') #文件存儲路徑(包含文件名)f.save(save_file_name)score=request.values.to_dict().get("score")num,fps=infer.inference(save_file_name,score)#return json.dumps({'code':200,'url':url_path},ensure_ascii=False)return render_template("show.html",num=num,fps=fps)app.run(port='5000',debug=True)

上述項目博主已經上傳到github上

git init
git add README.md
git commit -m "first commit"
git branch -M main
git remote add origin https://github.com/pengxiang1998/rt-detr.git
git push -u origin main

項目地址

在使用onnx時,安裝了onnxruntime后,出現了下面的錯誤:

ImportError: cannot import name 'create_and_register_allocator_v2' from 'onnxruntime.capi._pybind_state'

這是由于onnxruntime-gpu版本與CUDA、CuDNN版本不匹配導致的,可以查看下面的網址來查看匹配版本

https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html

在這里插入圖片描述
隨后又出現錯誤:

> This ORT build has ['TensorrtExecutionProvider',
> 'CUDAExecutionProvider', 'CPUExecutionProvider'] enabled. Since ORT
> 1.9, you are required to explicitly set the providers parameter when instantiating InferenceSession. For example,
> onnxruntime.InferenceSession(...,
> providers=['TensorrtExecutionProvider',

這是由于InferenceSession中沒有提供對應的provider,修改代碼如下:

if torch.cuda.is_available():print("GPU")sess = ort.InferenceSession("model.onnx", None, providers=["CUDAExecutionProvider"])else:print("CPU")sess= ort.InferenceSession("model.onnx", None)

隨后運行,發現安裝了onnxruntime-gpu后的速度竟然滿了下來,fps僅為0.2,而原本使用onnxruntime的fps則為7左右,這到底是怎么回事呢?

在這里插入圖片描述

YOLO集成推理

而在YOLO集成的RT-DETR項目中,訓練得到的權重 文件為.pt,在推理時需要與RT-DETR搭配使用,從而實現推理過程:
需要注意的是,由于YOLO里面集成了多種模型,因此為了具有適配性,其代碼都具有通用性

from ultralytics.models import RTDETR
if __name__ == '__main__':model=RTDETR("weights/best.pt")model.predict(source="images/1.mp4",save=True,conf=0.6)

隨后執行predict,代碼如下:

def predict(self,source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,stream: bool = False,predictor=None,**kwargs,) -> list:if source is None:source = ASSETSLOGGER.warning(f"WARNING ?? 'source' is missing. Using 'source={source}'.")is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any(x in ARGV for x in ("predict", "track", "mode=predict", "mode=track"))custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"}  # method defaultsargs = {**self.overrides, **custom, **kwargs}  # highest priority args on the rightprompts = args.pop("prompts", None)  # for SAM-type modelsif not self.predictor:self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)self.predictor.setup_model(model=self.model, verbose=is_cli)else:  # only update args if predictor is already setupself.predictor.args = get_cfg(self.predictor.args, args)if "project" in args or "name" in args:self.predictor.save_dir = get_save_dir(self.predictor.args)if prompts and hasattr(self.predictor, "set_prompts"):  # for SAM-type modelsself.predictor.set_prompts(prompts)return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)

這部分代碼在功能上具有復用性,因此在理解上存在一定難度。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/46367.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/46367.shtml
英文地址,請注明出處:http://en.pswp.cn/web/46367.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

GPT LangChain experimental agent - allow dangerous code

題意&#xff1a;GPT LangChain 實驗性代理 - 允許危險代碼 問題背景&#xff1a; Im creating a chatbot in VS Code where it will receive csv file through a prompt on Streamlit interface. However from the moment that file is loaded, it is showing a message with…

第12章 結構化命令《Linux命令行與Shell腳本編程大全筆記》

12.1 if-then命令 不同于其他語言&#xff0c;if后面不是一個等式&#xff0c;而是命令&#xff0c;如果命令運行成功返回狀態碼0則運行then語句部分把分號&#xff08;;&#xff09;放到命令尾部&#xff0c;可以將then語句寫在同一行 12.4 test命令 格式&#xff1a;if te…

激活pytorch遇到報錯usage: conda-script.py [-h] [--no-plugins] [-V] COMMAND ...

問題 今天初次嘗試在pycharm上創建與激活虛擬環境&#xff0c;創建結束后&#xff0c;使用命令conda activate pytorch激活虛擬環境時出現以下報錯&#xff1a; usage: conda-script.py [-h] [–no-plugins] [-V] COMMAND … conda-script.py: error: argument COMMAND: inval…

Selenium原理深度解析

在自動化測試領域&#xff0c;Selenium無疑是最受歡迎和廣泛使用的工具之一。它支持多種瀏覽器和操作系統&#xff0c;為開發人員和測試人員提供了強大的自動化測試解決方案。本文將深入探討Selenium的工作原理&#xff0c;包括其架構、核心組件、執行流程以及它在自動化測試中…

獨立開發者系列(26)——域名與解析

域名&#xff08;英語&#xff1a;Domain Name&#xff09;&#xff0c;又稱網域&#xff0c;是由一串用點分隔的名字組成的互聯網上某一臺計算機或計算機組的名稱&#xff0c;用于在數據傳輸時對計算機的定位標識&#xff08;有時也指地理位置&#xff09;。 由于IP地址不方便…

postMessageXss續2

原文地址如下:https://research.securitum.com/art-of-bug-bounty-a-way-from-js-file-analysis-to-xss/ 在19年我寫了一篇文章&#xff0c;是基于postMessageXss漏洞的入門教學:https://www.cnblogs.com/piaomiaohongchen/p/14727871.html 這幾天瀏覽mXss技術的時候&#xff…

模型蒸餾、量化、裁剪的概念和區別

模型壓縮概述 1.1 模型壓縮的重要性 隨著深度學習技術的快速發展&#xff0c;神經網絡模型在各種任務中取得了顯著的成功。然而&#xff0c;這些模型通常具有大量的參數和復雜的結構&#xff0c;導致模型體積龐大、計算資源消耗高和推理時間長。這些問題限制了深度學習模型在…

車載音視頻App框架設計

簡介 統一播放器提供媒體播放一致性的交互和視覺體驗&#xff0c;減少各個媒體應用和場景獨自開發的重復工作量&#xff0c;實現媒體播放鏈路的一致性&#xff0c;減少碎片化的Bug。本文面向應用開發者介紹如何快速接入媒體播放器。 主要功能&#xff1a; 新設計的統一播放U…

新版本cesium編譯1.103之后的版本

cesium1.1之后的版本文件結構域1.1之前的版本有了很大的差別&#xff0c;源碼也全部移到了packages目錄中。有很多依賴包沒有寫在根目錄的package.json文件中。npm i 后直接編譯會保持。 cesium源碼git https://github.com/CesiumGS/cesium 1、添加缺少的包&#xff0c;缺少的…

4. 雙端口ram設計

1. 設計要求 設計一個位寬8bit&#xff0c;地址深度為128&#xff0c;可以同時讀寫的雙端口RAM 要求&#xff1a;模塊名字為RAM_DUAL 輸入端口&#xff1a;ADDR_W&#xff0c;ADDR_R CLK_R&#xff0c;CLK_W&#xff0c;RSTn ADDR_R[6:0]&#xff0c;ADDR_W[6:0] DATA_WR…

k8s學習——創建測試鏡像

創建一個安裝了ifconfig、telnet、curl、nc、traceroute、ping、nslookup等網絡工具的鏡像&#xff0c;便于集群中的測試。 創建一個Dockerfile文件 # 使用代理下載 Ubuntu 鏡像作為基礎 FROM docker.m.daocloud.io/library/ubuntu:latest# 設置環境變量 DEBIAN_FRONTEND 為 …

學習測試9-接口測試 2-抓包工具Fiddler

Fiddler 抓包工具的使用 怎么找接口信息&#xff0c;可以通過瀏覽器的開發者工具 Fiddler 是一個 HTTP 協議調試代理工具 File 菜單&#xff1a; Capture Traffic&#xff08;或 F12&#xff09;&#xff1a;是個開關&#xff0c;可以控制是否把 Fiddler 注冊為系統代理。當把…

淺談Open.Json.pickle.Os

一、Open函數使用 open函數是 Python 中用于打開文件的內置函數&#xff0c;它返回一個文件對象&#xff0c;該文件對象提供了對文件進行讀寫操作的方法。使用 open 函數時&#xff0c;通常需要指定至少兩個參數&#xff1a;文件名&#xff08;file&#xff09;和模式&#xf…

【網絡工具】Charles 介紹及環境配置

?個人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;專欄地址&#xff1a;http://t.csdnimg.cn/iAmAo &#x1f4da;專欄簡介&#xff1a;在這個專欄中&#xff0c;我將會整理一些工作或學習中用到的工具介紹給大家~ &#x1f4d8;Charles 系列其它文章&#xff1a;【網絡…

Git操縱本地倉庫和遠程倉庫

git是一個代碼托管的平臺&#xff0c;我們可以對我們的代碼進行分支 推送提交 打標簽等等操作&#xff0c;而且git使用過程中也是支持一些linux語言的 比如cd呀 touch mkdir啊等等等 git的具體安裝過程就不再贅述 我個人認為 好多東西就是 代碼也好 文字 文檔 也好&…

【C語言】結構體,枚舉,聯合超詳解!!!

目錄 結構體 結構體聲明 結構體成員的訪問 結構體自引用 結構體變量定義&#xff0c;初始化&#xff0c;傳參 結構體內存對齊 位段 枚舉 聯合(共用體) 結構體 結構體聲明 1. 概念 1. 結構體是一些值的集合&#xff0c;這些值稱為成員變量。 2. 結構體的每個成員可…

長難句打卡7.15

The trend was naturally most obvious in those areas of science based especially on a mathematical or laboratory training, and can be illustrated in terms of the development of geology in the United Kingdom 這一趨勢自然在以數學或實驗室訓練為基礎的科學領域里…

Unlink

Unlink 原理 我們在利用 unlink 所造成的漏洞時&#xff0c;其實就是對 chunk 進行內存布局&#xff0c;然后借助 unlink 操作來達成修改指針的效果。簡單回顧一下 unlink 的目的與過程&#xff0c;其目的是把一個雙向鏈表中的空閑塊拿出來&#xff08;例如 free 時和目前物理…

Leetcode二分搜索法淺析

文章目錄 1.二分搜索法1.1什么是二分搜索法&#xff1f;1.2解法思路 1.二分搜索法 題目原文&#xff1a; 給定一個 n 個元素有序的&#xff08;升序&#xff09;整型數組 nums 和一個目標值 target &#xff0c;寫一個函數搜索 nums 中的 target&#xff0c;如果目標值存在返…

從PyTorch官方的一篇教程說開去(1 - 初心)

原文在此&#xff0c;喜歡讀原汁原味的可以自行去跟&#xff0c;這是一個非常經典和有學習意義的例子&#xff0c;在此向老爺子們致敬 - https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html 開源文化好是好&#xff0c;但是“公地的悲哀”這點避不開…