目錄
partial cross entropy loss?
GitHub - LiheYoung/UniMatch: [CVPR 2023] Revisiting Weak-to-Strong Consistency in Semi-Supervised Semantic Segmentation
partial cross entropy loss?
import torch
import torch.nn.functional as Fdef partial_cross_entropy_loss(inputs, targets, ignore_index=-1):"""自定義部分交叉熵損失函數,忽略 ignore_index 指定的標簽。:param inputs: 模型的輸出,形狀應為 (N, C, H, W),其中 N 是批量大小,C 是類別數,H 和 W 是高度和寬度。:param targets: 真實的標簽,形狀應為 (N, H, W)。:param ignore_index: 要忽略的標簽值,默認為 -1。:return: 計算得到的損失。"""# 計算 log softmaxlog_probs = F.log_softmax(inputs, dim=1)# 將 log_probs 和 targets 轉換為適合 gather 的形狀log_probs = log_probs.permute(0, 2, 3, 1) # (N, H, W, C)log_probs = log_probs.reshape(-1, log_probs.shape[-1]) # (N*H*W, C)targets = targets.view(-1) # (N*H*W)# 掩碼未標記的數據點mask = targets != ignore_indexlog_probs = log_probs[mask]targets = targets[mask]# 只計算有標簽的數據點的損失loss = F.nll_loss(log_probs, targets, reduction='mean')return loss
# 假設模型的輸出和真實標簽
outputs = torch.randn(2, 3, 5, 5) # 隨機生成模擬輸出(2個樣本,3個類別,5x5的圖像)
targets = torch.tensor([[[-1, 1, -1, 0, -1], [1, -1, 2, 2, 1], [-1, -1, 1, -1, 0], [2, 2, 2, -1, 1], [-1, 0, -1, 0, 1]], [[1, 0, -1, 1, -1], [2, 2, -1, 0, 0], [-1, 1, 1, 0, -1], [0, 0, 2, -1, 1], [2, -1, 0, -1, -1]]]) # 生成帶有未標記區域的標簽# 計算損失
loss = partial_cross_entropy_loss(outputs, targets)
print(f"Loss: {loss.item()}")