【DL】FocalLoss的PyTorch實現
此篇不介紹FocalLoss的原理,僅展示PyTorch實現FocalLoss的兩種方式。個人認為相關原理已在文章《FocalLoss原理通俗解釋及其二分類和多分類場景下的原理與實現》中講得很清晰,故此篇不再介紹。
方式一
同時計算一個batch中所有樣本關于FocalLoss的損失值(來自文章《FocalLoss原理通俗解釋及其二分類和多分類場景下的原理與實現》,個人補充了一些注釋):
import torch
from torch import nn
import random
class FocalLoss(nn.Module):"""參考 https://github.com/lonePatient/TorchBlocks"""def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None):super(FocalLoss, self).__init__()self.gamma = gammaif isinstance(alpha, list):self.alpha = torch.Tensor(alpha, device=device)else:self.alpha = alphaself.epsilon = epsilon'''batch中所有樣本一起計算loss'''def forward(self, input, target):"""Args:input: model's output, shape of [batch_size, num_cls]target: ground truth labels, shape of [batch_size]Returns:shape of [batch_size]"""num_labels = input.size(-1) # 類別數量idx = target.view(-1, 1).long() # 行向量target變成列向量idxone_hot_key = torch.zeros(idx.size(0), num_labels, dtype=torch.float32, device=idx.device)one_hot_key = one_hot_key.scatter_(1, idx, 1) # one_hot_key矩陣中的每一行對應相應樣本的標簽one_hot向量,利用scatter_方法將樣本的標簽類別標記為1,其余位置為0one_hot_key[:, 0] = 0 # ignore 0 index. 此行需要視具體情況決定是否保留,如果標簽中存在類別0(而不是直接從類別1開始),此行應當注釋、不使用logits = torch.softmax(input, dim=-1)loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() # 計算FocalLossloss = loss.sum(1)return loss.mean()# 固定隨機數種子,方便復現
def setup_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Trueif __name__ == '__main__':loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25])# 設置隨機數種子setup_seed(20) input = torch.randn(3, 5, requires_grad=True) # torch.Size([3, 5]) [sample_num, class_num]target = torch.empty(3, dtype=torch.long).random_(5) # torch.Size([3]) [sample_num]output = loss(input, target)# print(output)output.backward()
方式二
一個batch中逐個樣本計算關于FocalLoss的損失值,將它們求平均,返回一個batch內所有樣本的FocalLoss的平均值:
import torch
from torch import nn
import random
class FocalLoss(nn.Module):"""參考 https://github.com/lonePatient/TorchBlocks"""def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None):super(FocalLoss, self).__init__()self.gamma = gammaif isinstance(alpha, list):self.alpha = torch.Tensor(alpha, device=device)else:self.alpha = alphaself.epsilon = epsilon'''逐個樣本計算loss''' def forward(self, input, target):"""Args:input: model's output, shape of [batch_size, num_cls]target: ground truth labels, shape of [batch_size]Returns:shape of [batch_size]"""num_labels = input.size(-1) # 類別數量loss = []for i, sample in enumerate(input):one_hot_key = torch.zeros(1, num_labels, dtype=torch.float32, device=input.device)one_hot_key.scatter_(1, target[i].view(1, -1), 1)logits = torch.softmax(sample, dim=-1)loss_this_sample = - self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()loss_this_sample = loss_this_sample.sum(1)if i == 0:loss = loss_this_sampleelse:loss = torch.cat((loss, loss_this_sample))return loss.mean()# 固定隨機數種子,方便復現
def setup_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Trueif __name__ == '__main__':loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25])# 設置隨機數種子setup_seed(20) input = torch.randn(3, 5, requires_grad=True) # torch.Size([3, 5]) [sample_num, class_num]target = torch.empty(3, dtype=torch.long).random_(5) # torch.Size([3]) [sample_num]output = loss(input, target)# print(output)output.backward()