【DL】FocalLoss的PyTorch實現

【DL】FocalLoss的PyTorch實現

此篇不介紹FocalLoss的原理,僅展示PyTorch實現FocalLoss的兩種方式。個人認為相關原理已在文章《FocalLoss原理通俗解釋及其二分類和多分類場景下的原理與實現》中講得很清晰,故此篇不再介紹。

方式一

同時計算一個batch中所有樣本關于FocalLoss的損失值(來自文章《FocalLoss原理通俗解釋及其二分類和多分類場景下的原理與實現》,個人補充了一些注釋):

import torch
from torch import nn
import random
class FocalLoss(nn.Module):"""參考 https://github.com/lonePatient/TorchBlocks"""def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None):super(FocalLoss, self).__init__()self.gamma = gammaif isinstance(alpha, list):self.alpha = torch.Tensor(alpha, device=device)else:self.alpha = alphaself.epsilon = epsilon'''batch中所有樣本一起計算loss'''def forward(self, input, target):"""Args:input: model's output, shape of [batch_size, num_cls]target: ground truth labels, shape of [batch_size]Returns:shape of [batch_size]"""num_labels = input.size(-1) # 類別數量idx = target.view(-1, 1).long() # 行向量target變成列向量idxone_hot_key = torch.zeros(idx.size(0), num_labels, dtype=torch.float32, device=idx.device)one_hot_key = one_hot_key.scatter_(1, idx, 1) # one_hot_key矩陣中的每一行對應相應樣本的標簽one_hot向量,利用scatter_方法將樣本的標簽類別標記為1,其余位置為0one_hot_key[:, 0] = 0  # ignore 0 index. 此行需要視具體情況決定是否保留,如果標簽中存在類別0(而不是直接從類別1開始),此行應當注釋、不使用logits = torch.softmax(input, dim=-1)loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() # 計算FocalLossloss = loss.sum(1)return loss.mean()# 固定隨機數種子,方便復現
def setup_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Trueif __name__ == '__main__':loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25])# 設置隨機數種子setup_seed(20) input = torch.randn(3, 5, requires_grad=True) # torch.Size([3, 5]) [sample_num, class_num]target = torch.empty(3, dtype=torch.long).random_(5) # torch.Size([3]) [sample_num]output = loss(input, target)# print(output)output.backward()

方式二

一個batch中逐個樣本計算關于FocalLoss的損失值,將它們求平均,返回一個batch內所有樣本的FocalLoss的平均值:

import torch
from torch import nn
import random
class FocalLoss(nn.Module):"""參考 https://github.com/lonePatient/TorchBlocks"""def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None):super(FocalLoss, self).__init__()self.gamma = gammaif isinstance(alpha, list):self.alpha = torch.Tensor(alpha, device=device)else:self.alpha = alphaself.epsilon = epsilon'''逐個樣本計算loss'''    def forward(self, input, target):"""Args:input: model's output, shape of [batch_size, num_cls]target: ground truth labels, shape of [batch_size]Returns:shape of [batch_size]"""num_labels = input.size(-1) # 類別數量loss = []for i, sample in enumerate(input):one_hot_key = torch.zeros(1, num_labels, dtype=torch.float32, device=input.device)one_hot_key.scatter_(1, target[i].view(1, -1), 1)logits = torch.softmax(sample, dim=-1)loss_this_sample = - self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()loss_this_sample = loss_this_sample.sum(1)if i == 0:loss = loss_this_sampleelse:loss = torch.cat((loss, loss_this_sample))return loss.mean()# 固定隨機數種子,方便復現
def setup_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Trueif __name__ == '__main__':loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25])# 設置隨機數種子setup_seed(20) input = torch.randn(3, 5, requires_grad=True) # torch.Size([3, 5]) [sample_num, class_num]target = torch.empty(3, dtype=torch.long).random_(5) # torch.Size([3]) [sample_num]output = loss(input, target)# print(output)output.backward()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/9955.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/9955.shtml
英文地址,請注明出處:http://en.pswp.cn/web/9955.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【iOS】frame與bounds區別

文章目錄 前言framebounds兩者區別size的區別總結 前言 在學習響應者鏈的過程中用到了frame與bounds的混用,這兩個屬性經常出現在我們的開發中,特別撰寫一篇博客分析區別 首先,我們來看一下iOS特有的坐標系,在iOS坐標系中以左上…

C語言如何查看進程中環境變量中所有的值

示例代碼&#xff1a;查看進程中環境變量中所有的值。 #include <stdio.h>int main(){extern char** environ;for (char** pp environ; *pp; pp){printf("%s\n", *pp);}return 0; }輸出結果&#xff1a; SHELL/bin/bash WSL2_GUI_APPS_ENABLED1 WSL_DISTRO_…

【debug】如何使用pycharm對代碼調試

后續會將所有debug中遇到的知識放入&#xff0c;建議關注收藏 本站友情鏈接&#xff1a; 基本理論專欄&#xff08;當前更新好的debug所有內容都在這里&#xff09; 【debug】報錯解決方法&#xff08;CondaHTTPError&#xff1a;HTTP 000 connection failed for url&#xff…

【回溯 狀態壓縮 深度優先】37. 解數獨

本文涉及知識點 回溯 狀態壓縮 深度優先 LeetCode37. 解數獨 編寫一個程序&#xff0c;通過填充空格來解決數獨問題。 數獨的解法需 遵循如下規則&#xff1a; 數字 1-9 在每一行只能出現一次。 數字 1-9 在每一列只能出現一次。 數字 1-9 在每一個以粗實線分隔的 3x3 宮內只…

leetCode刷題記錄4-面試經典150題-2

文章目錄 不要擺&#xff0c;沒事干就刷題&#xff0c;只有好處&#xff0c;沒有壞處&#xff0c;實在不行&#xff0c;看看競賽題面試經典 150 題 - 2210. 課程表 II 不要擺&#xff0c;沒事干就刷題&#xff0c;只有好處&#xff0c;沒有壞處&#xff0c;實在不行&#xff0c…

[C++核心編程-06]----C++類和對象之對象模型和this指針

&#x1f3a9; 歡迎來到技術探索的奇幻世界&#x1f468;?&#x1f4bb; &#x1f4dc; 個人主頁&#xff1a;一倫明悅-CSDN博客 ?&#x1f3fb; 作者簡介&#xff1a; C軟件開發、Python機器學習愛好者 &#x1f5e3;? 互動與支持&#xff1a;&#x1f4ac;評論 &…

Microsoft 365 for Mac v16.84 office365全套辦公軟件

Microsoft 365 for Mac是一款功能豐富的辦公軟件套件&#xff0c;為Mac用戶提供了豐富的功能和工具&#xff0c;提高了工作效率和協作能力。Microsoft 365 for Mac是一款專為Mac用戶設計的訂閱式辦公軟件套件&#xff0c;旨在提高生產力和效率。 Microsoft 365 for Mac v16.84正…

數據賦能(83)——數據要素:數據要素管理與數據管理

數據要素管理則更關注數據作為生產性資源在創造經濟價值中的作用&#xff1b;數據管理更側重于數據在整個生命周期中的控制、保護和價值提升。 數據要素管理是對數據作為關鍵生產要素進行系統性管理的過程。它聚焦于數據在經濟和社會活動中的價值創造和貢獻&#xff0c;將數據…

ubantu安裝docker以及docker-compose

ubantu安裝docker以及docker-compose 安裝docker1、從官方存儲庫中安裝Docker2、啟動Docker服務3、驗證 安裝docker compose使用docker部署服務1、需要再opt文件夾下創建以下文件夾&#xff0c;/opt文件夾目錄說明2、可將已備份對應文件夾拷至對應文件夾下3、在/opt/compose目錄…

python集合

集合是一個無序的不重復元素序列&#xff0c;集合中的元素必須是不可變類型 集合的創建與刪除 用{}直接創建 用集合推導式創建 用ser&#xff08;&#xff09;函數將列表&#xff0c;元組&#xff0c;range對象轉換成集合 numset1{1,2,3,4,5}numset2{x**2 for x in range(…

【代碼】Mysql 查詢近一個月各類型設備新增數量

錯誤示例 SELECT COUNT(*) AS count, p.type, d.active_date FROM device d LEFT JOIN product p ON d.product_id p.pid WHERE MONTH (active_date) MONTH (CURRENT_DATE - INTERVAL 1 MONTH) AND YEAR (active_date) YEAR (CURRENT_DATE - INTERVAL 1 MONTH) group by p.…

mysql高可用集群MGR組復制的介紹、部署及配置說明

前言 MGR全稱MySQL Group Replication(Mysql組復制),是MySQL官方于2016年12月推出的一個全新的高可用與高擴展的解決方案。MGR提供了高可用、高擴展、高可靠的MySQL集群服務。 高一致性:基于分布式paxos協議實現組復制,保證數據一致性; 高容錯性:自動檢測機制,只要不…

霍金《時間簡史 A Brief History of Time》書后索引(A--D)

圖源&#xff1a;Wikipedia INDEX A Abacus Absolute position Absolute time Absolute zero Acceleration Age of the universe Air resistance Albrecht, Andreas Alpha Centauri Alpher, Ralph Anthropic principle Antigravity Antiparticles Aristotle Arrows of time …

基于Vant UI的微信小程序開發(隨時更新的寫手)

基于Vant UI的微信小程序開發? &#xff08;一&#xff09;懸浮浮動1、效果圖&#xff1a;只要無腦引用樣式就可以了2、頁面代碼3、js代碼4、樣式代碼 &#xff08;二&#xff09;底部跳轉1、效果圖&#xff1a;點擊我要發布跳轉到發布的頁面2、js代碼3、頁面代碼4、app.json代…

vue項目設置主題色

在vue開發過程中&#xff0c;很多頁面為了保持主題顏色統一&#xff0c;且方便后期管理&#xff0c;通常會設有主題色&#xff0c;通過主題色可以使得頁面上的按鈕單選框等控件保持顏色統一。 接下來介紹其中一種方法&#xff1a; 1.先建立一個js文件用于存放主題色&#xff…

我覺得POC應該貼近實際

今天我看到一位老師給我一份測試數據。 這是三個國產數據庫。算是分布式的。其中有兩個和我比較熟悉&#xff0c;但是這個數據看上去并不好。看上去第一個黃色的數據庫數據是這里最好的了。但是即使如此&#xff0c;我相信大部分做數據庫的人都知道。MySQL和PostgreSQL平時拿出…

Spark Streaming筆記總結(保姆級)

萬字長文警告&#xff01;&#xff01;&#xff01; 目錄 一、離線計算與流式計算 1.1 離線計算 1.1.1 離線計算的特點 1.1.2 離線計算的應用場景 1.1.3 離線計算代表技術 1.2 流式計算 1.2.1 流式計算的特點 1.2.2 流式計算的應用場景 1.2.3 流式計算的代表技術 二…

最小生成樹刷題筆記

算法基礎&#xff1a; 首先是prim算法三部曲&#xff1a; &#xff08;1&#xff09;找到距離最小生成樹最近的節點。 &#xff08;2&#xff09;將距離最小生成樹最近的節點加入到最小生成樹中。 &#xff08;3&#xff09;更新非最小生成樹節點到最小生成樹的距離。 實現…

HTML批量文件上傳3—Servlet批量文件處理FileUpLoad

作者:私語茶館 1.開源的文件上傳組件介紹 本文使用的是Apache Commons下面的一個子項目FileUpload,另外一個常見組件是SmartUpload。FileUpload遵循RFC 1897,即“Form-based File Upload in HTML”,對于請求需要滿足:HTTP協議,Post請求,content Type=“multipart/form-d…

Kafka 面試題(五)

1. kafka的消費者是pull(拉)還是push(推)模式&#xff0c;這種模式有什么好處&#xff1f; Kafka的消費者是pull&#xff08;拉&#xff09;模式。在這種模式下&#xff0c;消費者主動從Kafka的broker中拉取數據來進行消費。 這種pull模式的好處主要體現在以下幾個方面&#…