計算機視覺目標檢測-DETR網絡

目錄

  • 摘要
  • abstract
  • DETR目標檢測網絡詳解
    • 二分圖匹配和損失函數
  • DETR總結
  • 總結

摘要

DETR(DEtection TRansformer)是由Facebook AI提出的一種基于Transformer架構的端到端目標檢測方法。它通過將目標檢測建模為集合預測問題,摒棄了錨框設計和非極大值抑制(NMS)等復雜后處理步驟。DETR使用卷積神經網絡提取圖像特征,并將其通過位置編碼轉換為輸入序列,送入Transformer的Encoder-Decoder結構。Decoder通過固定數量的目標查詢(Object Queries),預測類別和邊界框位置。DETR創新性地引入匈牙利算法進行二分圖匹配,確保預測與真實值的唯一對應關系,且采用交叉熵損失和L1-GIoU損失進行優化。在COCO數據集上的實驗表明,DETR在大目標檢測中表現優異,并能靈活遷移到其他任務,如全景分割。

abstract

DETR (DEtection TRansformer) is an end-to-end target detection method based on Transformer architecture proposed by Facebook AI. By modeling object detection as a set prediction problem, it eliminates complex post-processing steps such as anchor frame design and non-maximum suppression (NMS). DETR uses convolutional neural networks to extract image features and convert them via positional encoding into input sequences that feed into Transformer’s Encoder-Decoder structure. Decoder predicts categories and bounding box positions with a fixed number of Object Queries. DETR innovates by introducing the Hungarian algorithm for bipartite graph matching to ensure a unique relationship between the prediction and the true value, and optimizes with cross-entropy losses and L1-GIoU losses. Experiments on the COCO dataset show that DETR performs well in large target detection and can be flexibly migrated to other tasks, such as panoramic segmentation.

下圖是目標檢測中檢測器模型的發展:
在這里插入圖片描述

DETR目標檢測網絡詳解

DETR(DEtection TRansformer)是由Facebook AI在2020年提出的一種基于Transformer架構的端到端目標檢測方法。與傳統的目標檢測方法(如Faster R-CNN、YOLO等)不同,DETR直接將目標檢測建模為一個集合預測問題,擺脫了錨框設計和復雜的后處理(如NMS)。結果在 COCO 數據集上效果與 Faster RCNN 相當,在大目標上效果比 Faster RCNN 好,且可以很容易地將 DETR 遷移到其他任務例如全景分割。
在這里插入圖片描述
簡單來說,就是通過CNN提取圖像特征(通常 Backbone 的輸出通道為 2048,圖像高和寬都變為了 1/32),并經過input embedding+positional encoding操作轉換為圖像序列(如下圖所說,就是類似[N, HW, C]的序列)作為transformer encoder的輸入,得到了編碼后的圖像序列,在圖像序列的幫助下,將object queries(下圖中說的是固定數量的可學習的位置embeddings)轉換/預測為固定數量的類別+bbox預測。相當于Transformer本質上起了一個序列轉換的作用。
在這里插入圖片描述
下圖為DETR的詳細結構:
在這里插入圖片描述
DETR中的encoder-decoder與transformer中的encoder-decoder對比:

  1. spatial positional encoding:新提出的二維空間位置編碼方法,該位置編碼分別被加入到了encoder的self attention的QK和decoder的cross attention的K,同時object queries也被加入到了decoder的兩個attention(第一個加到了QK中,第二個加入了Q)中。而原版的Transformer將位置編碼加到了input和output embedding中。
  2. DETR在計算attention的時候沒有使用masked attention,因為將特征圖展開成一維以后,所有像素都可能是互相關聯的,因此沒必要規定mask。
  3. object queries的轉換過程:object queries是預定義的目標查詢的個數,代碼中默認為100。它的意義是:根據Encoder編碼的特征,Decoder將100個查詢轉化成100個目標,即最終預測這100個目標的類別和bbox位置。最終預測得到的shape應該為[N, 100, C],N為Batch Num,100個目標,C為預測的100個目標的類別數+1(背景類)以及bbox位置(4個值)。
  4. 得到預測結果以后,將object predictions和ground truth box之間通過匈牙利算法進行二分匹配:假如有K個目標,那么100個object predictions中就會有K個能夠匹配到這K個ground truth,其他的都會和“no object”匹配成功,使其在理論上每個object query都有唯一匹配的目標,不會存在重疊,所以DETR不需要nms進行后處理。
  5. 分類loss采用的是交叉熵損失,針對所有predictions;bbox loss采用了L1 loss和giou loss,針對匹配成功的predictions。

匈牙利算法是用于解決二分圖匹配的問題,即將Ground Truth的K個bbox和預測出的100個bbox作為二分圖的兩個集合,匈牙利算法的目標就是找到最大匹配,即在二分圖中最多能找到多少條沒有公共端點的邊。匈牙利算法的輸入就是每條邊的cost 矩陣
在這里插入圖片描述

二分圖匹配和損失函數

思考
DETR 預測了一組固定大小的 N = 100 個邊界框,這比圖像中感興趣的對象的實際數量大得多。怎么樣來計算損失呢?或者說預測出來的框我們怎么知道對應哪一個 ground-truth 的框呢?

為了解決這個問題,第一步是將 ground-truth 也擴展成 N = 100 個檢測框。使用了一個額外的特殊類標簽 ? \phi? 來表示在未檢測到任何對象,或者認為是背景類別。這樣預測和真實都是兩個100 個元素的集合了。這時候采用匈牙利算法進行二分圖匹配,即對預測集合和真實集合的元素進行一一對應,使得匹配損失最小。
σ ^ = arg ? min ? G ∈ G N ∑ i N L m a t c h ( y i , y ^ σ ( i ) ) \hat{\sigma}=\arg\min_{\mathrm{G\in G_N}}\sum_{\mathrm{i}}^{\mathrm{N}}\mathcal{L}_{\mathrm{match}}\left(\mathrm{y_i},\hat{\mathrm{y}}_{\mathrm{\sigma(i)}}\right) σ^=argGGN?min?iN?Lmatch?(yi?,y^?σ(i)?)
L m a t c h ( y i , y ^ σ ( i ) ) = ? 1 { c i ≠ ? } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ? } L b o x ( b i , b ^ σ ( i ) ) \mathcal{L}_{\mathrm{match}}\left(\mathrm{y_i},\hat{\mathrm{y}}_{\mathrm{\sigma(i)}}\right)=-1_{\{\mathrm{c_i}\neq\varnothing\}}\hat{\mathrm{p}}_{\mathrm{\sigma(i)}}\left(\mathrm{c_i}\right)+1_{\{\mathrm{c_i}\neq\varnothing\}}\mathcal{L}_{\mathrm{box}}\left(\mathrm{b_i},\hat{\mathrm{b}}_{\mathrm{\sigma(i)}}\right) Lmatch?(yi?,y^?σ(i)?)=?1{ci?=?}?p^?σ(i)?(ci?)+1{ci?=?}?Lbox?(bi?,b^σ(i)?)
對于那些不是背景的,獲得其對應的預測是目標類別的概率,然后用框損失減去預測類別概率。這也就是說不僅框要近,類別也要基本一致,是最好的。經過匈牙利算法之后,我們就得到了 ground truth 和預測目標框之間的一一對應關系。然后就可以計算損失函數了。

下面是利用pytorch實現DETR的代碼:
位置編碼部分:

class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):return x + self.pe[:x.size(0), :]

用于為序列數據(如Transformer中的輸入)添加位置信息。位置編碼幫助模型保留序列中元素的位置信息,這是因為Transformer模型本身不具備位置信息感知能力。
使用正弦和余弦函數優點
優點:
正弦和余弦具有周期性和平滑性;
不同維度具有不同頻率,編碼了多尺度的位置信息。
作用:保留序列的位置信息,使模型能夠感知數據的順序。

編碼可視化結果:

import matplotlib.pyplot as pltimport torch
import torch.nn as nn# 位置編碼
class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):return x + self.pe[:x.size(0), :]pe = PositionalEncoding(d_model=16, max_len=100)
x = torch.zeros(100, 1, 16)
encoded = pe(x).squeeze(1).detach().numpy()plt.figure(figsize=(10, 5))
plt.imshow(encoded, aspect='auto', cmap='viridis')
plt.colorbar(label='Encoding Value')
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.title('Positional Encoding Visualization')
plt.show()

在這里插入圖片描述
上圖反應以下幾點變化
不同維度的變化

  1. 低頻維度(如 d=0,1):顏色變化緩慢,代表位置之間編碼的相似性較高,捕捉全局信息。
  2. 高頻維度(如 d=14,15):顏色變化迅速,代表位置之間編碼差異較大,捕捉局部信息。

同一位置的編碼:
值的分布(正弦和余弦的相互作用)保證了每個位置在多維空間中具有唯一性。

時間步的相對差異:
相鄰位置(如第1和第2位置)在高維上的值差異較大,這為模型提供了感知時間步變化的能力。

encoder-decoder:

class Transformer(nn.Module):def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6):super().__init__()self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_encoder_layers)self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_decoder_layers)def forward(self, src, tgt, src_mask=None, tgt_mask=None):memory = self.encoder(src, mask=src_mask)output = self.decoder(tgt, memory, tgt_mask=tgt_mask)return output

DETR模型:

# DETR模型
class DETR(nn.Module):def __init__(self, num_classes, num_queries, backbone='resnet50'):super().__init__()self.num_queries = num_queries# Backboneself.backbone = models.resnet50(pretrained=True)self.conv = nn.Conv2d(2048, 256, kernel_size=1)# Transformerself.transformer = Transformer(d_model=256)self.query_embed = nn.Embedding(num_queries, 256)self.positional_encoding = PositionalEncoding(256)# Prediction headsself.class_embed = nn.Linear(256, num_classes + 1)  # +1 for no-object classself.bbox_embed = nn.Linear(256, 4)def forward(self, images):# Feature extractionfeatures = self.backbone(images)features = self.conv(features)h, w = features.shape[-2:]# Flatten and add positional encodingsrc = features.flatten(2).permute(2, 0, 1)  # (HW, N, C)src = self.positional_encoding(src)# Query embeddingquery_embed = self.query_embed.weight.unsqueeze(1).repeat(1, images.size(0), 1)  # (num_queries, N, C)# Transformerhs = self.transformer(src, query_embed)# Predictionoutputs_class = self.class_embed(hs)outputs_coord = self.bbox_embed(hs).sigmoid()  # Normalized to [0, 1]return {'pred_logits': outputs_class, 'pred_boxes': outputs_coord}

DETR總結

DETR通過Transformer實現端到端的目標檢測,無需(如NMS)復雜的后處理。相比傳統檢測器,DETR具有簡潔的架構和強大的全局建模能力,但訓練時對數據和計算資源的需求較高。

總結

DETR簡化了目標檢測的流程,摒棄了傳統檢測器中繁瑣的錨框設計和后處理步驟,架構更簡潔,且依托于Transformer的全局建模能力,在捕捉長距離特征關系方面表現出色。相比傳統方法,DETR在目標數量固定的場景下,能夠更高效地處理目標檢測任務。其優點包括易遷移、多任務適用性和端到端優化能力,但其劣勢在于訓練時間較長、計算資源消耗較大,尤其是在小目標檢測和訓練數據量不足的情況下效果略顯不足。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/65645.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/65645.shtml
英文地址,請注明出處:http://en.pswp.cn/web/65645.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Vim Masterclass 筆記09】S06L22:Vim 核心操作訓練之 —— 文本的搜索、查找與替換操作(第一部分)

文章目錄 S06L22 Search, Find, and Replace - Part One1 從光標位置起,正向定位到當前行的首個字符 b2 從光標位置起,反向查找某個字符3 重復上一次字符查找操作4 定位到目標字符的前一個字符5 單字符查找與 Vim 命令的組合6 跨行查找某字符串7 Vim 的增…

Python3 JSON

JSON(JavaScript Object Notation)是一種輕量級的數據交換格式,易于人閱讀和編寫,同時也易于機器解析和生成。它基于JavaScript編程語言的一個子集,但JSON是獨立于語言的,很多編程語言都支持JSON格式數據的…

202406 青少年軟件編程等級考試C/C++ 二級真題答案及解析(電子學會)

第 1 題 冠軍魔術 2018年FISM(世界魔術大會)近景總冠軍簡綸廷的表演中有一個情節:以桌面上一根帶子為界,當他將紙牌從帶子的一邊推到另一邊時,紙牌會變成硬幣;把硬幣推回另一邊會變成紙牌。 這里我們假設紙牌會變成等量的硬幣,而硬幣變成紙牌時,紙牌的數量會加倍。那么…

springboot 默認的 mysql 驅動版本

本案例以 springboot 3.1.12 版本為例 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>3.1.12</version><relativePath/> </parent> 點擊 spring-…

計算機網絡(二)——物理層和數據鏈路層

一、物理層 1.作用 實現相信計算機節點之間比特流的透明傳輸&#xff0c;盡可能屏蔽具體傳輸介質和物理設備的差異。 2.數據傳輸單位 比特。 3.相關通信概念 ①信源和信宿&#xff1a;即信號的發送方和接收方。 ②數據&#xff1a;即信息的實體&#xff0c;比如圖像、視頻等&am…

sql server cdc漏掃數據

SQL Server的CDC指的是“變更數據捕獲”&#xff08;Change Data Capture&#xff09;。這是SQL Server數據庫提供的一項功能&#xff0c;能夠跟蹤并記錄對數據庫表中數據所做的更改。這些更改包括插入、更新和刪除操作。CDC可以捕獲這些變更的詳細信息&#xff0c;并使這些信息…

AI數字人+文旅:打造數字文旅新名片

在數字化浪潮的推動下&#xff0c;人工智能技術正以前所未有的速度滲透到我們生活的每一個角落。特別是在文化和旅游領域&#xff0c;AI數字人的出現&#xff0c;不僅為傳統文旅產業注入了新的活力&#xff0c;也為游客帶來了全新的體驗。 肇慶AI數字人——星湖 “星湖”是肇…

做一個 簡單的Django 《股票自選助手》顯示 用akshare 庫(A股數據獲取)

圖&#xff1a; 股票自選助手 這是一個基于 Django 開發的 A 股自選股票信息查看系統。系統使用 akshare 庫獲取實時股票數據&#xff0c;支持添加、刪除和更新股票信息。 功能特點 支持添加自選股票實時顯示股票價格和漲跌幅一鍵更新所有股票數據支持刪除不需要的股票使用中…

Protobuf編碼規則詳解

Protobuf編碼規則詳解 1 Message 結構1.1 tag1.1.1 字段編號(field_num)1.1.2 傳輸類型(wire_type) 1.2 字段順序1.3 默認值 2 編碼2.1 Varint編碼2.1.1 Varint編碼過程2.1.2解碼過程2.1.3 存儲2.1.4 小結2.2 有符號整數(sint32和sint64)編碼的問題與zigzag優化 3 編碼實踐3.1測…

系統思考與因果智慧

“眾生畏果&#xff0c;菩薩畏因”&#xff0c;這句話蘊藏著深厚的因果智慧&#xff0c;與系統思考不謀而合。 眾生畏果&#xff0c;體現了大多數人的行為模式&#xff1a;關注的是眼前的問題與結果&#xff0c;比如失敗、沖突、痛苦。正如在系統思考中&#xff0c;我們稱之為…

【docker】exec /entrypoint.sh: no such file or directory

dockerfile生成的image 報錯內容&#xff1a; exec /entrypoint.sh: no such file or directory查看文件正常在此路徑&#xff0c;但是就是報錯沒找到。 可能是因為sh文件的換行符使用了win的。

計算機的錯誤計算(二百零七)

摘要 利用兩個數學大模型計算 arccot(0.125664e2)的值&#xff0c;結果保留16位有效數字。 實驗表明&#xff0c;它們的輸出中分別僅含有3位和1位正確數字。 例1. 計算 arccot(0.125664e2)的值&#xff0c;結果保留16位有效數字。 下面是與一個數學解題器的對話。 以上為與…

MCANet: 基于多模態字幕感知的大語言模型訓練無關視頻異常檢測

目錄 摘要01 引言02 相關工作2.1 視頻異常檢測2.2 基于視頻的大語言模型&#xff08;VLLMs&#xff09; 03 方法論3.1 問題定義3.2 MCANet3.3 圖像字幕分支3.4 音頻字幕分支3.5 基于LLM的異常評分3.6 視頻-文本分數優化 04 實驗4.1 數據集和評估指標4.2 實現細節4.3 定性結果4.…

WMS倉庫管理系統,Vue前端開發,Java后端技術源碼(源碼學習)

一、項目背景和建設目標 隨著企業業務的不斷擴展&#xff0c;倉庫管理成為影響生產效率、成本控制及客戶滿意度的重要環節。為了提升倉庫作業的透明度、準確性和效率&#xff0c;本方案旨在構建一套全面、高效、易用的倉庫管理系統&#xff08;WMS&#xff09;。該系統將涵蓋庫…

【Uniapp-Vue3】創建自定義頁面模板

大多數情況下我們都使用的是默認模板&#xff0c;但是默認模板是Vue2格式的&#xff0c;如果我們想要定義一個Vue3模板的頁面就需要自定義。 一、我們先復制下面的模板代碼&#xff08;可根據自身需要進行修改&#xff09;&#xff1a; <template><view class"…

【Go】:圖片上添加水印的全面指南——從基礎到高級特性

前言 在數字內容日益重要的今天&#xff0c;保護版權和標識來源變得關鍵。為圖片添加水印有助于聲明所有權、提升品牌認知度&#xff0c;并防止未經授權的使用。本文將介紹如何用Go語言實現圖片水印&#xff0c;包括靜態圖片和帶旋轉、傾斜效果的文字水印&#xff0c;幫助您有…

springCloudGateWay使用總結

1、什么是網關 功能: ①身份認證、權限驗證 ②服務器路由、負載均衡 ③請求限流 2、gateway搭建 2.1、創建一個空項目 2.2、引入依賴 2.3、加配置 3、斷言工廠 4、過濾工廠 5、全局過濾器 6、跨域問題

zig 安裝,Hello World 示例

1. 安裝 Zig 首先&#xff0c;你需要在你的計算機上安裝 Zig 編譯器。你可以從 Zig 官方網站 下載適合你操作系統的版本。 安裝完成后&#xff0c;你可以在終端中運行以下命令來檢查 Zig 是否安裝成功&#xff1a; zig version如果一切正常&#xff0c;它會顯示 Zig 的版本信…

【Docker】Docker與Docker compose離線安裝

文章目錄 一. 離線安裝1. 下載docker2. 安裝 二. 相關命令三. 配置docker-compose 一. 離線安裝 1. 下載docker wget https://download.docker.com/linux/static/stable/x86_64/docker-27.1.2.tgz wget https://download.docker.com/linux/static/stable/aarch64/docker-27.1…

【UE5 C++課程系列筆記】22——多線程基礎——FRunnable和FRunnableThread

目錄 1、FRunnable 1.1 概念 1.2 主要成員函數 &#xff08;1&#xff09;Init 函數 &#xff08;2&#xff09;Run 函數 &#xff08;3&#xff09;Stop 函數 &#xff08;4&#xff09;Exit 函數 2、FRunnableThread 2.1 概念 2.2 主要操作 &#xff08;1&#xff…