計算Dice損失的函數
def Dice_loss(inputs, target, beta=1, smooth = 1e-5):n,c, h, w = inputs.size() #nt,ht, wt, ct = target.size() #nt,if h != ht and w != wt:inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)temp_target = target.view(n, -1, ct)#--------------------------------------------## 計算dice loss#--------------------------------------------#tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])fp = torch.sum(temp_inputs , axis=[0,1]) - tpfn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tpscore = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)dice_loss = 1 - torch.mean(score)return dice_loss
這段代碼是用于計算二分類問題的混淆矩陣(Confusion Matrix)中的True Positives(TP),False Positives(FP)和False Negatives(FN)。在混淆矩陣中,TP表示模型正確預測為正類的數量,FP表示模型錯誤地預測為正類的數量,FN表示實際為正類但模型沒有預測為正類的數量。
讓我們分解這段代碼來理解每個部分的作用:
-
temp_target[..., :-1] * temp_inputs
:temp_target[..., :-1]
獲取temp_target
張量中除了最后一維之外的所有元素。:-1
是一個切片操作,它表示從開始到倒數第二個元素。temp_inputs
是模型的預測輸出。- 這兩個張量進行元素相乘,只有當
temp_target
的最后一維等于 1 時,才會乘以temp_inputs
對應的位置的值。這模擬了只有當預測和真實標簽都為正類(1)時,才認為是真正的正類檢測。
-
torch.sum(..., axis=[0,1])
:- 這是一個求和操作,計算在指定維度上(這里是第0維和第1維)的總和。
axis=[0,1]
表示在第0維和第1維上進行求和。通常,第0維代表批量大小(batch size),第1維代表序列長度(sequence length)。- 這樣做的效果是將所有正類預測的和(TP)匯總起來,無論它們在批量中的哪個位置或序列中。
-
tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
:- 最終,
tp
保存了所有正類預測的數量。
- 最終,
-
fp = torch.sum(temp_inputs, axis=[0,1]) - tp
:torch.sum(temp_inputs, axis=[0,1])
計算了所有預測為正類的數量,無論它們是否真的是正類。- 然后從中減去
tp
,得到假正類的數量(FP),即模型錯誤地預測為正類的數量。
-
fn = torch.sum(temp_target[...,:-1], axis=[0,1]) - tp
:torch.sum(temp_target[...,:-1], axis=[0,1])
計算了實際為正類的數量,無論模型是否預測它們為正類。- 然后從中減去
tp
,得到假負類的數量(FN),即實際為正類但模型沒有預測為正類的數量。
綜上所述,這段代碼通過計算TP、FP和FN,來評估模型在二分類任務中的性能。這些值是計算精確度(Precision)、召回率(Recall)和F1得分的關鍵。