目錄
- 摘要
- abstract
- DETR目標檢測網絡詳解
- 二分圖匹配和損失函數
- DETR總結
- 總結
摘要
DETR(DEtection TRansformer)是由Facebook AI提出的一種基于Transformer架構的端到端目標檢測方法。它通過將目標檢測建模為集合預測問題,摒棄了錨框設計和非極大值抑制(NMS)等復雜后處理步驟。DETR使用卷積神經網絡提取圖像特征,并將其通過位置編碼轉換為輸入序列,送入Transformer的Encoder-Decoder結構。Decoder通過固定數量的目標查詢(Object Queries),預測類別和邊界框位置。DETR創新性地引入匈牙利算法進行二分圖匹配,確保預測與真實值的唯一對應關系,且采用交叉熵損失和L1-GIoU損失進行優化。在COCO數據集上的實驗表明,DETR在大目標檢測中表現優異,并能靈活遷移到其他任務,如全景分割。
abstract
DETR (DEtection TRansformer) is an end-to-end target detection method based on Transformer architecture proposed by Facebook AI. By modeling object detection as a set prediction problem, it eliminates complex post-processing steps such as anchor frame design and non-maximum suppression (NMS). DETR uses convolutional neural networks to extract image features and convert them via positional encoding into input sequences that feed into Transformer’s Encoder-Decoder structure. Decoder predicts categories and bounding box positions with a fixed number of Object Queries. DETR innovates by introducing the Hungarian algorithm for bipartite graph matching to ensure a unique relationship between the prediction and the true value, and optimizes with cross-entropy losses and L1-GIoU losses. Experiments on the COCO dataset show that DETR performs well in large target detection and can be flexibly migrated to other tasks, such as panoramic segmentation.
下圖是目標檢測中檢測器模型的發展:
DETR目標檢測網絡詳解
DETR(DEtection TRansformer)是由Facebook AI在2020年提出的一種基于Transformer架構的端到端目標檢測方法。與傳統的目標檢測方法(如Faster R-CNN、YOLO等)不同,DETR直接將目標檢測建模為一個集合預測問題,擺脫了錨框設計和復雜的后處理(如NMS)。結果在 COCO 數據集上效果與 Faster RCNN 相當,在大目標上效果比 Faster RCNN 好,且可以很容易地將 DETR 遷移到其他任務例如全景分割。
簡單來說,就是通過CNN提取圖像特征(通常 Backbone 的輸出通道為 2048,圖像高和寬都變為了 1/32),并經過input embedding+positional encoding操作轉換為圖像序列(如下圖所說,就是類似[N, HW, C]的序列)作為transformer encoder的輸入,得到了編碼后的圖像序列,在圖像序列的幫助下,將object queries(下圖中說的是固定數量的可學習的位置embeddings)轉換/預測為固定數量的類別+bbox預測。相當于Transformer本質上起了一個序列轉換的作用。
下圖為DETR的詳細結構:
DETR中的encoder-decoder與transformer中的encoder-decoder對比:
- spatial positional encoding:新提出的二維空間位置編碼方法,該位置編碼分別被加入到了encoder的self attention的QK和decoder的cross attention的K,同時object queries也被加入到了decoder的兩個attention(第一個加到了QK中,第二個加入了Q)中。而原版的Transformer將位置編碼加到了input和output embedding中。
- DETR在計算attention的時候沒有使用masked attention,因為將特征圖展開成一維以后,所有像素都可能是互相關聯的,因此沒必要規定mask。
- object queries的轉換過程:object queries是預定義的目標查詢的個數,代碼中默認為100。它的意義是:根據Encoder編碼的特征,Decoder將100個查詢轉化成100個目標,即最終預測這100個目標的類別和bbox位置。最終預測得到的shape應該為[N, 100, C],N為Batch Num,100個目標,C為預測的100個目標的類別數+1(背景類)以及bbox位置(4個值)。
- 得到預測結果以后,將object predictions和ground truth box之間通過匈牙利算法進行二分匹配:假如有K個目標,那么100個object predictions中就會有K個能夠匹配到這K個ground truth,其他的都會和“no object”匹配成功,使其在理論上每個object query都有唯一匹配的目標,不會存在重疊,所以DETR不需要nms進行后處理。
- 分類loss采用的是交叉熵損失,針對所有predictions;bbox loss采用了L1 loss和giou loss,針對匹配成功的predictions。
匈牙利算法是用于解決二分圖匹配的問題,即將Ground Truth的K個bbox和預測出的100個bbox作為二分圖的兩個集合,匈牙利算法的目標就是找到最大匹配,即在二分圖中最多能找到多少條沒有公共端點的邊。匈牙利算法的輸入就是每條邊的cost 矩陣
二分圖匹配和損失函數
思考:
DETR 預測了一組固定大小的 N = 100 個邊界框,這比圖像中感興趣的對象的實際數量大得多。怎么樣來計算損失呢?或者說預測出來的框我們怎么知道對應哪一個 ground-truth 的框呢?
為了解決這個問題,第一步是將 ground-truth 也擴展成 N = 100 個檢測框。使用了一個額外的特殊類標簽 ? \phi? 來表示在未檢測到任何對象,或者認為是背景類別。這樣預測和真實都是兩個100 個元素的集合了。這時候采用匈牙利算法進行二分圖匹配,即對預測集合和真實集合的元素進行一一對應,使得匹配損失最小。
σ ^ = arg ? min ? G ∈ G N ∑ i N L m a t c h ( y i , y ^ σ ( i ) ) \hat{\sigma}=\arg\min_{\mathrm{G\in G_N}}\sum_{\mathrm{i}}^{\mathrm{N}}\mathcal{L}_{\mathrm{match}}\left(\mathrm{y_i},\hat{\mathrm{y}}_{\mathrm{\sigma(i)}}\right) σ^=argG∈GN?min?i∑N?Lmatch?(yi?,y^?σ(i)?)
L m a t c h ( y i , y ^ σ ( i ) ) = ? 1 { c i ≠ ? } p ^ σ ( i ) ( c i ) + 1 { c i ≠ ? } L b o x ( b i , b ^ σ ( i ) ) \mathcal{L}_{\mathrm{match}}\left(\mathrm{y_i},\hat{\mathrm{y}}_{\mathrm{\sigma(i)}}\right)=-1_{\{\mathrm{c_i}\neq\varnothing\}}\hat{\mathrm{p}}_{\mathrm{\sigma(i)}}\left(\mathrm{c_i}\right)+1_{\{\mathrm{c_i}\neq\varnothing\}}\mathcal{L}_{\mathrm{box}}\left(\mathrm{b_i},\hat{\mathrm{b}}_{\mathrm{\sigma(i)}}\right) Lmatch?(yi?,y^?σ(i)?)=?1{ci?=?}?p^?σ(i)?(ci?)+1{ci?=?}?Lbox?(bi?,b^σ(i)?)
對于那些不是背景的,獲得其對應的預測是目標類別的概率,然后用框損失減去預測類別概率。這也就是說不僅框要近,類別也要基本一致,是最好的。經過匈牙利算法之后,我們就得到了 ground truth 和預測目標框之間的一一對應關系。然后就可以計算損失函數了。
下面是利用pytorch實現DETR的代碼:
位置編碼部分:
class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):return x + self.pe[:x.size(0), :]
用于為序列數據(如Transformer中的輸入)添加位置信息。位置編碼幫助模型保留序列中元素的位置信息,這是因為Transformer模型本身不具備位置信息感知能力。
使用正弦和余弦函數優點:
優點:
正弦和余弦具有周期性和平滑性;
不同維度具有不同頻率,編碼了多尺度的位置信息。
作用:保留序列的位置信息,使模型能夠感知數據的順序。
編碼可視化結果:
import matplotlib.pyplot as pltimport torch
import torch.nn as nn# 位置編碼
class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):return x + self.pe[:x.size(0), :]pe = PositionalEncoding(d_model=16, max_len=100)
x = torch.zeros(100, 1, 16)
encoded = pe(x).squeeze(1).detach().numpy()plt.figure(figsize=(10, 5))
plt.imshow(encoded, aspect='auto', cmap='viridis')
plt.colorbar(label='Encoding Value')
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.title('Positional Encoding Visualization')
plt.show()
上圖反應以下幾點變化
不同維度的變化
- 低頻維度(如 d=0,1):顏色變化緩慢,代表位置之間編碼的相似性較高,捕捉全局信息。
- 高頻維度(如 d=14,15):顏色變化迅速,代表位置之間編碼差異較大,捕捉局部信息。
同一位置的編碼:
值的分布(正弦和余弦的相互作用)保證了每個位置在多維空間中具有唯一性。
時間步的相對差異:
相鄰位置(如第1和第2位置)在高維上的值差異較大,這為模型提供了感知時間步變化的能力。
encoder-decoder:
class Transformer(nn.Module):def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6):super().__init__()self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_encoder_layers)self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_decoder_layers)def forward(self, src, tgt, src_mask=None, tgt_mask=None):memory = self.encoder(src, mask=src_mask)output = self.decoder(tgt, memory, tgt_mask=tgt_mask)return output
DETR模型:
# DETR模型
class DETR(nn.Module):def __init__(self, num_classes, num_queries, backbone='resnet50'):super().__init__()self.num_queries = num_queries# Backboneself.backbone = models.resnet50(pretrained=True)self.conv = nn.Conv2d(2048, 256, kernel_size=1)# Transformerself.transformer = Transformer(d_model=256)self.query_embed = nn.Embedding(num_queries, 256)self.positional_encoding = PositionalEncoding(256)# Prediction headsself.class_embed = nn.Linear(256, num_classes + 1) # +1 for no-object classself.bbox_embed = nn.Linear(256, 4)def forward(self, images):# Feature extractionfeatures = self.backbone(images)features = self.conv(features)h, w = features.shape[-2:]# Flatten and add positional encodingsrc = features.flatten(2).permute(2, 0, 1) # (HW, N, C)src = self.positional_encoding(src)# Query embeddingquery_embed = self.query_embed.weight.unsqueeze(1).repeat(1, images.size(0), 1) # (num_queries, N, C)# Transformerhs = self.transformer(src, query_embed)# Predictionoutputs_class = self.class_embed(hs)outputs_coord = self.bbox_embed(hs).sigmoid() # Normalized to [0, 1]return {'pred_logits': outputs_class, 'pred_boxes': outputs_coord}
DETR總結
DETR通過Transformer實現端到端的目標檢測,無需(如NMS)復雜的后處理。相比傳統檢測器,DETR具有簡潔的架構和強大的全局建模能力,但訓練時對數據和計算資源的需求較高。
總結
DETR簡化了目標檢測的流程,摒棄了傳統檢測器中繁瑣的錨框設計和后處理步驟,架構更簡潔,且依托于Transformer的全局建模能力,在捕捉長距離特征關系方面表現出色。相比傳統方法,DETR在目標數量固定的場景下,能夠更高效地處理目標檢測任務。其優點包括易遷移、多任務適用性和端到端優化能力,但其劣勢在于訓練時間較長、計算資源消耗較大,尤其是在小目標檢測和訓練數據量不足的情況下效果略顯不足。