文章目錄
- CMUNeXt: An Efficient Medical Image Segmentation Network based on Large Kernel and Skip Fusion
- 摘要
- 本文方法
- 實驗結果
- Boundary Difference Over Union Loss For Medical Image Segmentation(損失函數)
- 摘要
- 本文方法
- 實驗結果
CMUNeXt: An Efficient Medical Image Segmentation Network based on Large Kernel and Skip Fusion
摘要
u型結構已成為醫學圖像分割網絡設計的一個重要范例。然而,由于卷積固有的局部局限性,具有u型結構的全卷積分割網絡難以有效地提取全局上下文信息,而這對于病灶的精確定位至關重要。雖然結合cnn和transformer的混合架構可以解決這些問題,但由于環境和邊緣設備施加的計算資源限制,它們在真實醫療場景中的應用受到限制。此外,輕量級網絡中的卷積感應偏置能很好地擬合稀缺的醫療數據,這是基于Transformer的網絡所缺乏的。為了在利用歸納偏置的同時提取全局上下文信息,我們提出了一種高效的全卷積輕量級醫學圖像分割網絡CMUNeXt,該網絡能夠在真實場景場景中實現快速準確的輔助診斷。
CMUNeXt利用大內核和倒瓶頸設計,將遠距離空間和位置信息徹底混合,高效提取全局上下文信息。我們還介紹了Skip-Fusion模塊,旨在實現平滑的跳過連接,并確保充分的特征融合。在多個醫學圖像數據集上的實驗結果表明,CMUNeXt在分割性能上優于現有的重量級和輕量級醫學圖像分割網絡,同時具有更快的推理速度、更輕的權重和更低的計算成本。
代碼地址
本文方法
CMUNEXT模塊比較簡單,總的來說不是很復雜
實驗結果
Boundary Difference Over Union Loss For Medical Image Segmentation(損失函數)
摘要
醫學圖像分割對臨床診斷至關重要。然而,目前醫學圖像分割的損失主要集中在整體分割結果上,較少提出用于指導邊界分割的損失。那些確實存在的損失往往需要與其他損失結合使用,并且產生無效效果。為了解決這個問題,我們開發了一種簡單有效的損失,稱為邊界差分聯合損失(邊界DoU損失)來指導邊界區域分割。它是通過計算預測與真值的差集與差集與部分交集集并集的比值得到的。我們的損失只依賴于區域計算,使得它易于實現并且訓練穩定,不需要任何額外的損失。此外,我們使用目標大小來自適應調整應用于邊界區域的注意力。
代碼地址
本文方法
實驗結果
class BoundaryDoULoss(nn.Module):def __init__(self, n_classes):super(BoundaryDoULoss, self).__init__()self.n_classes = n_classesdef _one_hot_encoder(self, input_tensor):tensor_list = []for i in range(self.n_classes):temp_prob = input_tensor == itensor_list.append(temp_prob.unsqueeze(1))output_tensor = torch.cat(tensor_list, dim=1)return output_tensor.float()def _adaptive_size(self, score, target):kernel = torch.Tensor([[0,1,0], [1,1,1], [0,1,0]])padding_out = torch.zeros((target.shape[0], target.shape[-2]+2, target.shape[-1]+2))padding_out[:, 1:-1, 1:-1] = targeth, w = 3, 3Y = torch.zeros((padding_out.shape[0], padding_out.shape[1] - h + 1, padding_out.shape[2] - w + 1)).cuda()for i in range(Y.shape[0]):Y[i, :, :] = torch.conv2d(target[i].unsqueeze(0).unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0).cuda(), padding=1)Y = Y * targetY[Y == 5] = 0C = torch.count_nonzero(Y)S = torch.count_nonzero(target)smooth = 1e-5alpha = 1 - (C + smooth) / (S + smooth)alpha = 2 * alpha - 1intersect = torch.sum(score * target)y_sum = torch.sum(target * target)z_sum = torch.sum(score * score)alpha = min(alpha, 0.8) ## We recommend using a truncated alpha of 0.8, as using truncation gives better results on some datasets and has rarely effect on others.loss = (z_sum + y_sum - 2 * intersect + smooth) / (z_sum + y_sum - (1 + alpha) * intersect + smooth)return lossdef forward(self, inputs, target):inputs = torch.softmax(inputs, dim=1)target = self._one_hot_encoder(target)assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())loss = 0.0for i in range(0, self.n_classes):loss += self._adaptive_size(inputs[:, i], target[:, i])return loss / self.n_classes