最近看到幾個項目都用mask2former做圖像分割,雖然是1年前的論文,但是其attention的設計還是很有借鑒意義,同時,mask2former參考了detr的query設計,實現了語義和實例分割任務的統一。
1.背景
1.1 detr簡介
detr算是第一個嘗試用transformer實現目標檢測的框架,其設計思路也很簡單,就是定義object queries,用來查詢是否存在目標以及目標位置的,類似cnn檢測中的rpn,產生候選框。在detr中,object queries為(100,b,256)的可學習的參數,其中每個256維的向量代表了檢測的box信息,這個信息是由類別和空間信息(box坐標)組成,其中類別信息用于區別類別,而空間信息則描述了目標在圖像中的位置。
通過設置query,則不需要像傳統cnn檢測時預設anchor,最后通過匈牙利匹配算法將query到的目標和gt進行匹配,計算loss。
decoder過程中,query object先初始化為0,然后經過self attention,再和encoder的輸出進行cross attention。
?1.2?Deformable-DETR簡介
Deformable-Detr是在detr的基礎上了主要做了2個改進,Deformable attention(可變形注意力)和多尺度特征,通過可變性注意力降低了顯存,多尺度特征對小目標檢測效果比較好。
(1)Deformable attention(可變形注意力)
這個設計參考了可變性卷積(DCN),后續很多設計都參考了這個。先看下DCN,就是在標準卷積(a)的3 * 3的卷積核上,每個點上增加一個偏移量(dx,dy),讓卷積核不規則,可以適應目標的形狀和尺度。
對于一般的attention,query與key的每個值都要計算注意力,這樣的問題就是耗顯存;另外,對圖像來說,假設其中有一個目標,一般只有離圖像比較近的像素才有用,離比較遠的像素,對目標的貢獻很少,甚至還有負向的干擾。
Defromable attention的設計思路就是query不與全局的key進行計算,而是至于其周圍的key進行計算。至于這個周圍要選哪幾個位置,就類似DCN,讓模型自己去學。
- 單尺度的可變性注意力機制
DeformAttn的公式如下:
- 多尺度的可變性注意力機制
多尺度即類似fpn,提取不同尺度的特征,但由于特征的尺寸不一樣,需要將不同尺度的特征連接起來。
可變性注意力機制公式如下:
相比單尺度的,多尺度多了一個l,代表第幾個尺度,一般取4個層級。
對于一個query,在其參考點(reference point)對應的所有層都采用K個點,然后將每層的K個點特征融合(相加)。
整個deformable atten的流程如下:
2.mask2former
mask2former的設計上使用了deformable detr的可變形注意力。
主要計算過程用下圖表示:
2.1 模型改進
(1)masked attention
一般計算過程中,計算atten時只用前景部分計算,減少顯存占用。
(2) 多分辨率特征
如上圖,圖像經過backbone得到4層特征,然后經過Pixel Decoder得到O1,O2,O3,O4,注意O1,O2,O3經過Linear+Deform atten Layer,O4只通過Linear+卷積得到,具體可以區別看上圖。
(3) decoder優化
在transformer decoder(這個過程用的是標準attention)計算過程中,query剛開始都是隨機初始化的,沒有圖像特征,如果按常規直接self attention可能學不到充分的信息,所以將ca和sa兩個模塊反過來,先和pixdecoder得到的圖像O1,O2,O3計算ca,再繼續計算sa。
2.2 類別和mask分開預測
class和mask預測獨立開來,mask只預測是背景還是前景,class負責預測類別,這部分保留了maskformer的設計。
如上圖,class通過query加上Linear直接將維度轉到(n,k+1),其中k為類別數目。
mask通過decoder和最后一層的mask做外積運算,得到(k,h,w)的tensor,每個k代表一個前景。
采用這種query的方式,既可以做instance也可以做語義分割,query的數量N和類別K數量無關。
2.3 loss優化
mask decoder過程中,主要用最后一層的輸出計算loss;同時為了輔助訓練,默認開啟了auxiliary loss(輔助loss),其他層的輸出也去計算loss。
還有一個trick,mask計算loss時,不是mask上的所有點都去計算,而是隨機采樣一定數目的點去計算loss。默認設置K?= 12544,?i.e., 112?×?112 points,這樣可以節省顯存。
3.擴展
3.1 DAT:另一個Deform atten設計
另一篇deform atten的論文DAT,和deform attention思路類似,也是學習offset。只不過在偏移量設計上有區別,如下圖所示,DAT在當前特征圖F上學習offset時,進行了上采樣2倍,在得到offset后需要插值回F的尺寸,增加了相對位置的bias。
對比幾種查詢的注意力結果,vit是全查,swin固定窗口大小,有可能限制查到的key,DCN為可變性卷積,DAT學到的key更好。
?
模型設計上,參考swin-transformer,只將最后2層替換Deformable attention,效果最好。
?
3.2 視頻實例分割跟蹤
mask2former用于視頻分割,結構如下
模型結構上和圖像的分割基本一致。
修改主要在transformer decoder,包含以下3個地方:
(1)增加時間編碼t
主要在Transformer decoder過程,圖像的位置編碼為(x,y),對于視頻,由于考慮了多幀數據,增加時間t進行編碼,位置編碼為(x,y,t)。
# b, t, c, h, wassert x.dim() == 5, f"{x.shape} should be a 5-dimensional Tensor, got {x.dim()}-dimensional Tensor instead"if mask is None:mask = torch.zeros((x.size(0), x.size(1), x.size(3), x.size(4)), device=x.device, dtype=torch.bool)not_mask = ~maskz_embed = not_mask.cumsum(1, dtype=torch.float32) # not_mask【bath,t,h,w】1代表時間列的索引,cumsum累加計算,得到位置idy_embed = not_mask.cumsum(2, dtype=torch.float32) # hx_embed = not_mask.cumsum(3, dtype=torch.float32) # wif self.normalize:eps = 1e-6z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scaley_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scalex_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scaledim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)dim_t_z = torch.arange((self.num_pos_feats * 2), dtype=torch.float32, device=x.device)dim_t_z = self.temperature ** (2 * (dim_t_z // 2) / (self.num_pos_feats * 2))pos_x = x_embed[:, :, :, :, None] / dim_t # [b,t,h,w]->[b,t,h,w,d] xy編碼的d長度是位置編碼向量長度的一半pos_y = y_embed[:, :, :, :, None] / dim_tpos_z = z_embed[:, :, :, :, None] / dim_t_z # z用編碼向量長度,然后和xy編碼相加pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4)pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4)pos = (torch.cat((pos_y, pos_x), dim=4) + pos_z).permute(0, 1, 4, 2, 3) # b, t, c, h, w
(2) query和多幀數據進行atten計算
for i in range(self.num_feature_levels):size_list.append(x[i].shape[-2:])pos.append(self.pe_layer(x[i].view(bs, t, -1, size_list[-1][0], size_list[-1][1]), None).flatten(3))src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) #level_embed size [level_num,d],level embed和輸入相加# NTxCxHW => NxTxCxHW => (TxHW)xNxC # 多幀數據融合_, c, hw = src[-1].shapepos[-1] = pos[-1].view(bs, t, c, hw).permute(1, 3, 0, 2).flatten(0, 1)# 其中src是Pixel decoder的輸出src[-1] = src[-1].view(bs, t, c, hw).permute(1, 3, 0, 2).flatten(0, 1)
(3)query和mask計算優化
如代碼所示,query和mask 外積計算,從q外積mask得到mask的shape為[b,q,t,h,w],也就是得到(b,q,t)個instance mask,然后query的instance mask和每幀的gt計算loss。
def forward_prediction_heads(self, output, mask_features, attn_mask_target_size):decoder_output = self.decoder_norm(output)decoder_output = decoder_output.transpose(0, 1)outputs_class = self.class_embed(decoder_output)mask_embed = self.mask_embed(decoder_output)# query和mask 外積計算,從q外積mask得到[b,q,t,h,w]個maskoutputs_mask = torch.einsum("bqc,btchw->bqthw", mask_embed, mask_features)b, q, t, _, _ = outputs_mask.shape# NOTE: prediction is of higher-resolution# [B, Q, T, H, W] -> [B, Q, T*H*W] -> [B, h, Q, T*H*W] -> [B*h, Q, T*HW]attn_mask = F.interpolate(outputs_mask.flatten(0, 1), size=attn_mask_target_size, mode="bilinear", align_corners=False).view(b, q, t, attn_mask_target_size[0], attn_mask_target_size[1])# must use bool type# If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()attn_mask = attn_mask.detach()return outputs_class, outputs_mask, attn_mask
訓練時是以instance作為一個基礎單元,假設有t幀圖像,有n個instance(實例),instance和frame的關系如下圖表示:
?
instance在每幀上都可能存在或者不存在。對于每個instance,初始化t個mask,初始化為0,所以instace的shape是[b,n,t,h,w],如果這個instance在某幀上存在,即賦真值mask,用于匹配計算loss;不存在,即為0。
instance在每幀上都是同一個物體(形態可能變化,但是instance id是相同的),所以預測instance的類別時,每個instance只需要預測一個類別即可,所以類別的shape為[b,n]。
3.3 思考
sam(segment anything model)可以通過prompt進行分割,但是缺乏類別信息,可以參考mask2former的思想,mask和類別是獨立的,可以添加分類的query,接一個分類的分支,然后在coco等數據集上單獨訓練這個分支,讓sam分割后增加類別信息。
4.參考資料
- mask2former論文
- mask2former代碼
附贈
【一】上千篇CVPR、ICCV頂會論文
【二】動手學習深度學習、花書、西瓜書等AI必讀書籍
【三】機器學習算法+深度學習神經網絡基礎教程
【四】OpenCV、Pytorch、YOLO等主流框架算法實戰教程
? 在助理處自取:
?
? 還可咨詢論文輔導?【畢業論文、SCI、CCF、中文核心、El會議】評職稱、研博升學、本升海外學府!