python基于DETR(DEtection TRansformer)開發構建鋼鐵產業產品智能自動化檢測識別系統

在前文中我們基于經典的YOLOv5開發構建了鋼鐵產業產品智能自動化檢測識別系統,這里本文的主要目的是想要實踐應用DETR這一端到端的檢測模型來開發構建鋼鐵產業產品智能自動化檢測識別系統。

DETR (DEtection TRansformer) 是一種基于Transformer架構的端到端目標檢測模型。與傳統的基于區域提議的目標檢測方法(如Faster R-CNN)不同,DETR采用了全新的思路,將目標檢測問題轉化為一個序列到序列的問題,通過Transformer模型實現目標檢測和目標分類的聯合訓練。

DETR的工作流程如下:

輸入圖像通過卷積神經網絡(CNN)提取特征圖。
特征圖作為編碼器輸入,經過一系列的編碼器層得到圖像特征的表示。
目標檢測問題被建模為一個序列到序列的轉換任務,其中編碼器的輸出作為解碼器的輸入。
解碼器使用自注意力機制(self-attention)對編碼器的輸出進行處理,以獲取目標的位置和類別信息。
最終,DETR通過一個線性層和softmax函數對解碼器的輸出進行分類,并通過一個線性層預測目標框的坐標。
DETR的優點包括:

端到端訓練:DETR模型能夠直接從原始圖像到目標檢測結果進行端到端訓練,避免了傳統目標檢測方法中復雜的區域提議生成和特征對齊的過程,簡化了模型的設計和訓練流程。
不受固定數量的目標限制:DETR可以處理變長的輸入序列,因此不受固定數量目標的限制。這使得DETR能夠同時檢測圖像中的多個目標,并且不需要設置預先確定的目標數量。
全局上下文信息:DETR通過Transformer的自注意力機制,能夠捕捉到圖像中不同位置的目標之間的關系,提供了更大范圍的上下文信息。這有助于提高目標檢測的準確性和魯棒性。
然而,DETR也存在一些缺點:

計算復雜度高:由于DETR采用了Transformer模型,它在處理大尺寸圖像時需要大量的計算資源,導致其訓練和推理速度相對較慢。
對小目標的檢測性能較差:DETR模型在處理小目標時容易出現性能下降的情況。這是因為Transformer模型在處理小尺寸目標時可能會丟失細節信息,導致難以準確地定位和分類小目標。

首先看下實例效果:
?

簡單看下數據集:

PyTorch訓練代碼和DETR(DEDetection-TRansformer)的預訓練模型。我們用Transformer替換了完全復雜的手工制作的對象檢測管道,并將Faster R-CNN與ResNet-50匹配,使用一半的計算能力(FLOP)和相同數量的參數在COCO上獲得42個AP。

官方項目地址在這里,如下所示:

可以看到目前已經收獲了超過1.2w的star量,還是很不錯的了。

DETR整體數據流程示意圖如下所示:

官方也提供了對應的預訓練模型,可以自行使用:

本文選擇的預訓練官方權重是detr-r50-e632da11.pth,首先需要基于官方的預訓練權重開發能夠用于自己的 個性化數據集的權重,如下所示:

pretrained_weights = torch.load("./weights/detr-r50-e632da11.pth")
num_class = 10 + 1
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
torch.save(pretrained_weights,'./weights/detr_r50_%d.pth'%num_class)

因為這里我的類別數量為10,所以num_class修改為:10+1,根據自己的實際情況修改即可。生成后如下所示:

之后按照官方說明準備好數據集即可,啟動訓練模型命令如下所示:

python main.py --dataset_file "coco" --coco_path "/0000" --epoch 100 --lr=1e-4 --batch_size=32 --num_workers=0 --output_dir="outputs" --resume="weights/detr_r50_11.pth"

借助于plot_util.py模塊可以實現對模型的評估和可視化,如下:

def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):'''Function to plot specific fields from training log(s). Plots both training and test results.:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file- fields = which results to plot from each log file - plots both training and test for each field.- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots- log_name = optional, name of log file if different than default 'log.txt'.:: Outputs - matplotlib plots of results in fields, color coded for each log file.- solid lines are training results, dashed lines are test results.'''func_name = "plot_utils.py::plot_logs"# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,# convert single Path to list to avoid 'not iterable' errorif not isinstance(logs, list):if isinstance(logs, PurePath):logs = [logs]print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")else:raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \Expect list[Path] or single Path obj, received {type(logs)}")# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dirfor i, dir in enumerate(logs):if not isinstance(dir, PurePath):raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")if not dir.exists():raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")# verify log_name existsfn = Path(dir / log_name)if not fn.exists():print(f"-> missing {log_name}.  Have you gotten to Epoch 1 in training?")print(f"--> full path of missing log file: {fn}")return# load log file(s) and plotdfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):for j, field in enumerate(fields):if field == 'mAP':coco_eval = pd.DataFrame(np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]).ewm(com=ewm_col).mean()axs[j].plot(coco_eval, c=color)else:df.interpolate().ewm(com=ewm_col).mean().plot(y=[f'train_{field}', f'test_{field}'],ax=axs[j],color=[color] * 2,style=['-', '--'])for ax, field in zip(axs, fields):ax.legend([Path(p).name for p in logs])ax.set_title(field)def plot_precision_recall(files, naming_scheme='iter'):if naming_scheme == 'exp_id':# name becomes exp_idnames = [f.parts[-3] for f in files]elif naming_scheme == 'iter':names = [f.stem for f in files]else:raise ValueError(f'not supported {naming_scheme}')fig, axs = plt.subplots(ncols=2, figsize=(16, 5))for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):data = torch.load(f)# precision is n_iou, n_points, n_cat, n_area, max_detprecision = data['precision']recall = data['params'].recThrsscores = data['scores']# take precision for all classes, all areas and 100 detectionsprecision = precision[0, :, :, 0, -1].mean(1)scores = scores[0, :, :, 0, -1].mean(1)prec = precision.mean()rec = data['recall'][0, :, 0, -1].mean()print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +f'score={scores.mean():0.3f}, ' +f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}')axs[0].plot(recall, precision, c=color)axs[1].plot(recall, scores, c=color)axs[0].set_title('Precision / Recall')axs[0].legend(names)axs[1].set_title('Scores / Recall')axs[1].legend(names)return fig, axs

結果如下所示:

iter 000: mAP@50= 24.0, score=0.317, f1=0.341
iter 050: mAP@50= 27.7, score=0.339, f1=0.400
iter latest: mAP@50= 26.4, score=0.348, f1=0.393
iter 000: mAP@50= 24.0, score=0.317, f1=0.341
iter 050: mAP@50= 27.7, score=0.339, f1=0.400
iter latest: mAP@50= 26.4, score=0.348, f1=0.393

可視化如下所示:

【Precision曲線】
精確率曲線(Precision-Recall Curve)是一種用于評估二分類模型在不同閾值下的精確率性能的可視化工具。它通過繪制不同閾值下的精確率和召回率之間的關系圖來幫助我們了解模型在不同閾值下的表現。
精確率(Precision)是指被正確預測為正例的樣本數占所有預測為正例的樣本數的比例。召回率(Recall)是指被正確預測為正例的樣本數占所有實際為正例的樣本數的比例。
繪制精確率曲線的步驟如下:
使用不同的閾值將預測概率轉換為二進制類別標簽。通常,當預測概率大于閾值時,樣本被分類為正例,否則分類為負例。
對于每個閾值,計算相應的精確率和召回率。
將每個閾值下的精確率和召回率繪制在同一個圖表上,形成精確率曲線。
根據精確率曲線的形狀和變化趨勢,可以選擇適當的閾值以達到所需的性能要求。
通過觀察精確率曲線,我們可以根據需求確定最佳的閾值,以平衡精確率和召回率。較高的精確率意味著較少的誤報,而較高的召回率則表示較少的漏報。根據具體的業務需求和成本權衡,可以在曲線上選擇合適的操作點或閾值。
精確率曲線通常與召回率曲線(Recall Curve)一起使用,以提供更全面的分類器性能分析,并幫助評估和比較不同模型的性能。
【Recall曲線】
召回率曲線(Recall Curve)是一種用于評估二分類模型在不同閾值下的召回率性能的可視化工具。它通過繪制不同閾值下的召回率和對應的精確率之間的關系圖來幫助我們了解模型在不同閾值下的表現。
召回率(Recall)是指被正確預測為正例的樣本數占所有實際為正例的樣本數的比例。召回率也被稱為靈敏度(Sensitivity)或真正例率(True Positive Rate)。
繪制召回率曲線的步驟如下:
使用不同的閾值將預測概率轉換為二進制類別標簽。通常,當預測概率大于閾值時,樣本被分類為正例,否則分類為負例。
對于每個閾值,計算相應的召回率和對應的精確率。
將每個閾值下的召回率和精確率繪制在同一個圖表上,形成召回率曲線。
根據召回率曲線的形狀和變化趨勢,可以選擇適當的閾值以達到所需的性能要求。
通過觀察召回率曲線,我們可以根據需求確定最佳的閾值,以平衡召回率和精確率。較高的召回率表示較少的漏報,而較高的精確率意味著較少的誤報。根據具體的業務需求和成本權衡,可以在曲線上選擇合適的操作點或閾值。
召回率曲線通常與精確率曲線(Precision Curve)一起使用,以提供更全面的分類器性能分析,并幫助評估和比較不同模型的性能。

【PR曲線】
精確率-召回率曲線(Precision-Recall Curve)是一種用于評估二分類模型性能的可視化工具。它通過繪制不同閾值下的精確率(Precision)和召回率(Recall)之間的關系圖來幫助我們了解模型在不同閾值下的表現。
精確率是指被正確預測為正例的樣本數占所有預測為正例的樣本數的比例。召回率是指被正確預測為正例的樣本數占所有實際為正例的樣本數的比例。
繪制精確率-召回率曲線的步驟如下:
使用不同的閾值將預測概率轉換為二進制類別標簽。通常,當預測概率大于閾值時,樣本被分類為正例,否則分類為負例。
對于每個閾值,計算相應的精確率和召回率。
將每個閾值下的精確率和召回率繪制在同一個圖表上,形成精確率-召回率曲線。
根據曲線的形狀和變化趨勢,可以選擇適當的閾值以達到所需的性能要求。
精確率-召回率曲線提供了更全面的模型性能分析,特別適用于處理不平衡數據集和關注正例預測的場景。曲線下面積(Area Under the Curve, AUC)可以作為評估模型性能的指標,AUC值越高表示模型的性能越好。
通過觀察精確率-召回率曲線,我們可以根據需求選擇合適的閾值來權衡精確率和召回率之間的平衡點。根據具體的業務需求和成本權衡,可以在曲線上選擇合適的操作點或閾值。?

感興趣的話可以自行動手實踐嘗試下!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/166443.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/166443.shtml
英文地址,請注明出處:http://en.pswp.cn/news/166443.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

springboot項目修改項目名稱

參考該文章正確修改項目名稱:SpringBoot項目怎么重命名_springboot修改項目名稱-CSDN博客

【Lodash】 Filter 與Map 的結合使用

用Filter過濾數據之后,想給某個字段重新賦值 在使用 filter() 方法過濾數據后,如果你想給某個字段賦值,你可以使用 map() 方法來修改數組中的元素。map() 方法可以對數組中的每個元素應用一個函數,并返回一個新的數組。 以下是一…

【Django使用】10大章31模塊md文檔,第5篇:Django模板和數據庫使用

當你考慮開發現代化、高效且可擴展的網站和Web應用時,Django是一個強大的選擇。Django是一個流行的開源Python Web框架,它提供了一個堅實的基礎,幫助開發者快速構建功能豐富且高度定制的Web應用 全套Django筆記直接地址: 請移步這…

外匯天眼:多名投資者賬戶被惡意清空,遠離volofinance!

最近,外匯平臺volofinance因有多名投資者投訴,“榮幸”成為外匯天眼黑平臺榜單中的一員,那么volofinance到底做了什么導致投資者前來投訴曝光呢? 起底volofinace 在網絡搜索中,關于volofinance的信息少之又少&#xf…

成為AI產品經理——模型評估指標

目錄 一、模型評估分類 1.在線評估 2.離線評估 二、離線模型評估 1.特征評估 ① 特征自身穩定性 ② 特征來源穩定性 ③ 特征成本 2.模型評估 ① 統計性評估 覆蓋度 最大值、最小值 分布形態 ② 模型性能指標 分類問題 回歸問題 ③ 模型的穩定性 模型評估指標分…

配置mvn打包參數,不同環境使用不同的配置文件

方法一: 首先在/resource目錄下創建各自環境的配置 要在不同的環境中使用不同的配置文件進行Maven打包,可以使用Maven的profiles特性和資源過濾功能。下面是配置Maven打包參數的步驟: 在項目的pom.xml文件中,添加profiles配置…

python 負數 處理

num_negative -4 print(num_negative) num_dec_to_hex hex(num_negative) print(負數轉十六進制: num_dec_to_hex) /---------------------------------------------------------/ -4 負數轉十六進制:-0x4通過上面代碼片段可以看到,python…

第一個Mybatis項目

(一)為什么要用Mybatis? (1)Mybatis對比JDBC而言,sql(單獨寫在xml的配置文件中)和java編碼分開,功能邊界清晰,一個專注業務,一個專注數據。 (2&…

【C++】:多態

朋友們、伙計們,我們又見面了,本期來給大家解讀一下有關多態的知識點,如果看完之后對你有一定的啟發,那么請留下你的三連,祝大家心想事成! C 語 言 專 欄:C語言:從入門到精通 數據結…

Linux(CentOS7)上安裝mysql

在CentOS中默認安裝有MariaDB(MySQL的一個分支),可先移除/卸載MariaDB。 yum remove mariadb // 查看是否存在mariadb rpm -qa|grep -i mariadb // 卸載 mariadb rpm -e --nodeps rpm -qa|grep mariadb yum安裝 下載rpm // 5.6版本 wge…

XML映射文件

<?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE mapperPUBLIC "-//mybatis.org//DTD Mapper 3.0//EN""http://mybatis.org/dtd/mybatis-3-mapper.dtd"> <mapper namespace"org.mybatis.example.BlogMapper&q…

conan 入門(三十二):package_info中配置禁用CMakeDeps生成使用項目自己生成的config.cmake

conanfile.py中定義的package_info()方法用于向package的調用者(conumer)提供包庫名&#xff0c;編譯/連接選項&#xff0c;文件夾等等信息&#xff0c;有了這些信息構建工具的generator就可以根據它們生成對應的文件&#xff0c;用于調用者引用package. 比如基于cmake的CMakeD…

線索二叉樹:C++實現

引言&#xff1a; 線索二叉樹是一種特殊的二叉樹&#xff0c;它可以通過線索&#xff08;線索是指在二叉樹中將空指針改為指向前驅或后繼的指針&#xff09;的方式將二叉樹轉化為一個線性結構&#xff0c;從而方便對二叉樹進行遍歷。本文將介紹如何使用C實現線索二叉樹。 技術…

安全地公網訪問樹莓派等設備的服務 內網穿透--frp 23年11月方法

如果想要樹莓派可以被公網訪問&#xff0c;可以選擇直接網上搜內網穿透提供商&#xff0c;一個月大概10塊錢&#xff0c;也有免費的&#xff0c;但是免費的速度就不要希望很好了。 也可以選擇接下來介紹的frp&#xff0c;這種方式不需要付費&#xff0c;但是需要你有一臺有著公…

vue3自定義拖拽指令

<template><div v-move class"box"></div> </template><script setup lang"ts"> import { Directive } from vue const vMove:Directive (el:HTMLElement) >{const mousedown (e:MouseEvent) >{// 鼠標按下const s…

【Golang】解決使用interface{}解析json數字會變成科學計數法的問題

在使用解析json結構體的時候&#xff0c;使用interface{}接數字會發現變成了科學計數法格式的數字&#xff0c;不符合實際場景的使用要求。 舉例代碼如下&#xff1a; type JsonUnmStruct struct {Id interface{} json:"id"Name string json:"name"…

Linux 的性能調優的思路

Linux操作系統是一個開源產品&#xff0c;也是一個開源軟件的實踐和應用平臺&#xff0c;在這個平臺下有無數的開源軟件支撐&#xff0c;我們常見的apache、tomcat、mysql等。 開源軟件的最大理念是自由、開放&#xff0c;那么Linux作為一個開源平臺&#xff0c;最終要實現的是…

Java反射調用kotlin中的類,Object類,Companion對

Java反射調用kotlin中的類&#xff0c;Object類&#xff0c;Companion對象 1. Java反射調用kotlin中的普通類 kotlin普通類&#xff1a; package com.common; class TestNormal {fun get():String{return "Nolmal abc"}fun showNum(v:Int){println("Nolmal s…

uniApp微信支付實現

后端&#xff1a;小程序下單 - 小程序支付 | 微信支付商戶文檔中心 服務端需要請求&#xff1a;https://api.mch.weixin.qq.com該地址獲取微信支付Api接口需要的參數。 服務端請求接口需要的Body參數&#xff1a; 客戶端&#xff08;前端&#xff09;需要調用&#xff1a;wx.…

12V降3.3V100mA穩壓芯片WT7133

12V降3.3V100mA穩壓芯片WT7133 WT71XX系列是一款采用CMOS工藝實現的三端高輸入電壓、低壓差、小輸出電流電壓穩壓器。 它的輸出電流可達到100mA&#xff0c;輸入電壓可達到18V。其固定輸出電壓的范圍是2.5V&#xff5e;8.0V&#xff0c;用戶 也可通過外圍應用電路來實現可變電壓…