詳解序數回歸損失函數ordinal_regression_loss:原理與實現

在醫療 AI 領域,很多分類任務具有有序類別的特性,如疾病嚴重程度(輕度→中度→重度)、腫瘤分級(G1→G2→G3)等。這類任務被稱為序數回歸(Ordinal Regression),需要特殊的損失函數設計。本文將深入解析序數回歸損失函數的原理及其實現代碼。

一、序數回歸與傳統分類的區別

傳統分類任務(如疾病類型識別)假設類別之間是無序的,而序數回歸的類別具有自然順序。例如:

  • 疾病嚴重程度:0(正常)→1(輕度)→2(中度)→3(重度)
  • 影像評分:1 分→2 分→3 分→4 分→5 分

對于這類任務,傳統的交叉熵損失存在局限性:它只關注類別預測的正確性,而忽略了類別之間的順序關系。例如,將真實標簽為 "中度"(2)的樣本預測為 "重度"(3),與預測為 "輕度"(1),在交叉熵損失中被視為同等錯誤,但實際上前者的錯誤程度更小。

二、序數回歸損失函數的核心思想

序數回歸損失函數的設計目標是:不僅要正確分類,還要保持類別之間的順序關系。常見的實現方法有以下幾種:

  1. 累積概率模型:將序數分類轉化為一系列二分類問題
  2. 相鄰類別比較:比較相鄰類別的預測概率
  3. 距離敏感損失:懲罰與真實類別距離更遠的錯誤預測

代碼中實現的是累積概率模型,這是最常用的序數回歸方法之一。

三、累積概率模型的數學原理

累積概率模型的核心思想是:將序數類別轉化為一系列累積概率。對于有K個類別的問題,定義K-1個閾值cutspoints,,則樣本屬于類別k的概率為:,其中:

四、代碼實現解析

下面詳細解析序數回歸損失函數的實現代碼:

def ordinal_regression_loss(self, pred, label, num_classes, train_cutpoints=False, scale=20.0):# 1. 計算閾值(cutpoints)num_cutpoints = num_classes - 1#計算閾值數量cutpoints = torch.arange(num_cutpoints, device=pred.device).float() * scale / (num_classes - 2) - scale / 2cutpoints = nn.Parameter(cutpoints, requires_grad=train_cutpoints)# 2. 計算累積概率sigmoids = torch.sigmoid(cutpoints - pred)# 3. 構建概率矩陣:將累積概率轉換為每個類別的概率link_mat = sigmoids[:, 1:] - sigmoids[:, :-1]  # 中間類別的概率link_mat = torch.cat((sigmoids[:, [0]],         # 第一個類別的概率link_mat,                 # 中間類別的概率(1 - sigmoids[:, [-1]])   # 最后一個類別的概率), dim=1)# 4. 數值穩定性處理:防止對數計算時出現NaNeps = 1e-15likelihoods = torch.clamp(link_mat, eps, 1 - eps)# 5. 計算負對數似然損失neg_log_likelihood = torch.log(likelihoods)if label is None:loss = 0else:loss = -torch.gather(neg_log_likelihood, 1, label).mean()return loss, likelihoods

五、關鍵步驟詳解

1. 閾值(Cutpoints)計算
cutpoints = torch.arange(num_cutpoints, device=pred.device).float() * scale / (num_classes - 2) - scale / 2
  • 作用:生成均勻分布的閾值點,將連續空間劃分為多個區間

例如:

  • 參數
    • scale:控制閾值的范圍,默認 20.0
    • train_cutpoints:是否將閾值作為可訓練參數(默認為 False)
  • 基礎序列torch.arange(num_cutpoints):對于K個類別,生成序列[0,1,2,...,K-2]
  • 縮放因子scale / (num_classes - 2)調整閾值之間的間隔
  • 線性變換* scale / (num_classes - 2) - scale / 2:將基礎序列映射到?[-scale/2, scale/2]?區間。

這兩行代碼的核心是將連續的預測空間均勻劃分為多個有序區間,每個區間對應一個類別。通過調整?scale?參數,可以控制區間的寬度,適應不同的任務需求。當?train_cutpoints=True?時,模型會在訓練過程中自動學習最優的閾值位置,進一步提升序數回歸的性能。

2. 累積概率計算
sigmoids = torch.sigmoid(cutpoints - pred)
  • 作用:將模型預測值與閾值的差值通過 sigmoid 函數轉換為累積概率
  • 示例:對于 3 個類別(2 個閾值),累積概率為:

將模型輸出的抽象分數?pred,通過與閾值?cutpoints?的比較,轉換為 “屬于某個類別或更低等級” 的概率。這個概率越接近 1,說明?pred?越可能落在該類別或更低等級的區間里。

3. 類別概率矩陣構建
link_mat = sigmoids[:, 1:] - sigmoids[:, :-1]
link_mat = torch.cat((sigmoids[:, [0]], link_mat, 1 - sigmoids[:, [-1]]), dim=1)

  • sigmoids[:, 1:]?→ 取所有樣本的第二個及以后的累積概率
  • sigmoids[:, :-1]?→ 取所有樣本的第一個及以前的累積概率
4.數值穩定性處理:防止對數計算時出現NaN

在深度學習中,當計算概率的對數時(如交叉熵損失中的?log(p)),如果概率?p?非常接近 0(如 1e-20),會導致以下問題:

  1. 數值下溢:計算機無法精確表示極小數,可能返回 0
  2. 對數計算錯誤log(0)?會返回負無窮(-inf
  3. 梯度爆炸:反向傳播時,-inf?的梯度會導致參數更新異常

同樣,當概率?p?接近 1 時,1-p?接近 0,也會引發類似問題。

  • torch.clamp(input, min, max)?將輸入張量的每個元素限制在?[min, max]?范圍內
  • 確保所有概率值在?[1e-15, 1-1e-15]?之間,避免過于接近 0 或 1

5. 負對數似然損失計算
neg_log_likelihood = torch.log(likelihoods)
loss = -torch.gather(neg_log_likelihood, 1, label).mean()
  • 作用:計算每個樣本的真實類別對應的負對數概率,并取平均

通過最大似然估計,讓模型預測的真實類別概率最大化。具體步驟為:

  1. 計算對數似然:將概率轉換為對數空間
  2. 按標簽選擇:提取真實類別對應的對數似然
  3. 取負平均:轉換為損失(越小越好)

六、為什么選擇序數回歸損失?

在醫療分類任務中,序數回歸損失有以下優勢:

  1. 利用順序信息:充分利用類別之間的順序關系,提高模型對程度差異的敏感性
  2. 減少信息損失:相比將序數問題簡單視為分類問題,保留了更多結構信息
  3. 更好的校準:輸出的概率具有更明確的臨床意義(如疾病嚴重程度的概率)
  4. 提升性能:在序數分類任務中,通常比傳統分類損失取得更好的性能

七、實踐建議

  1. 閾值初始化

    • 代碼中的線性初始化是常用方法,但對于特定任務,可根據先驗知識自定義閾值
    • train_cutpoints=True時,模型會學習最優閾值位置
  2. 模型輸出設計

    • 模型最后一層應輸出單個連續值(而非類別概率),作為序數回歸的預測值
    • 可通過全連接層實現:nn.Linear(input_dim, 1)
  3. 超參數調整

    • scale參數影響閾值的分布范圍,需根據具體任務調整
    • 對于嚴重不平衡的序數類別,可考慮加權損失
  4. 評估指標

    • 除準確率外,建議使用 Kendall's τ 或 Spearman 相關性等評估順序一致性
    • 醫學場景中,還需關注不同嚴重程度類別的敏感性和特異性

八、總結

序數回歸損失函數為具有順序關系的醫療分類任務提供了更合適的優化目標。通過將類別轉化為累積概率,它不僅能正確分類,還能保持類別之間的順序關系,特別適合疾病嚴重程度分級、影像評分等醫療場景。

在實際應用中,可根據任務特點調整閾值初始化方式和損失函數參數,結合適當的評估指標,構建更符合臨床需求的醫療 AI 模型。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/914933.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/914933.shtml
英文地址,請注明出處:http://en.pswp.cn/news/914933.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SQL增查

建完庫與建完表后后:1.分別查詢student表和score表的所有記錄student表:score表:2.查詢student表的第2條到5條記錄SELECT * FROM student LIMIT 1,4;3.從student表中查詢計算機系和英語系的學生的信息SELECT * FROM student-> WHERE department IN (計算機系, 英…

二分答案之最大化最小值

參考資料來源靈神在力扣所發的題單,僅供分享學習筆記和記錄,無商業用途。 核心思路:本質上是求最大 應用場景:在滿足條件的最小值區間內使最大化 檢查函數:保證數據都要大于等于答案 補充:為什么需要滿…

OCR 賦能檔案數字化:讓沉睡的檔案 “活” 起來

添加圖片注釋,不超過 140 字(可選)企業產品檔案包含設計圖紙、檢測報告、生產記錄等,傳統數字化僅靠掃描存檔,后續檢索需人工逐份翻閱,效率極低。?OCR 產品檔案解決方案直擊痛點:通過智能識別技…

力扣118.楊輝三角

思路1.新建一個vector的vector2.先把空間開出來,然后再把里面的值給一個個修改開空間的手段:new、構造函數、reserve、resize因為我們之后要修改里面的數據,這就意味著我們需要去讀取這個數據并修改,如果用reserve的話&#xff0c…

Python 網絡爬蟲 —— 提交信息到網頁

一、模塊核心邏輯“提交信息到網頁” 是網絡交互關鍵環節,借助 requests 庫的 post() 函數,能模擬瀏覽器向網頁發數據(如表單、文件 ),實現信息上傳,讓我們能與網頁背后的服務器 “溝通”,像改密…

SpringMVC4

一、SpringMVC 注解與項目開發流程1.1注解的生命周期- Target、Retention 等元注解:- Target(ElementType.TYPE) :說明這個注解只能用在類、接口上。- Retention(RetentionPolicy.RUNTIME) :說明注解在運行時保留,能通過反射獲取…

數據結構排序算法總結(C語言實現)

以下是常見排序算法的總結及C語言實現,包含時間復雜度、空間復雜度和穩定性分析:1. 冒泡排序 (Bubble Sort)思想:重復比較相鄰元素,將較大元素向后移動。 時間復雜度:O(n)(最好O(n),最壞O(n)) 空…

嵌入式學習-PyTorch(2)-day19

很久沒有學了,期間打點滴打了一個多星期,太累了,再加上學了一下Python語法基礎,再終于開始重新學習pytorchtensorboard 的使用import torch from torch.utils.tensorboard import SummaryWriter writer SummaryWriter("logs…

Prompt Engineering 快速入門+實戰案例

資料來源:火山引擎-開發者社區 引言 什么是 prompt A prompt is an input to a Generative AI model, that is used to guide its output. Prompt engineering is the process of writing effective instructions for a model, such that it consistently generat…

「源力覺醒 創作者計劃」_文心開源模型(ERNIE-4.5-VL-28B-A3B-PT)使用心得

文章目錄背景操作流程開源模型選擇算力服務器平臺開通部署一個算力服務器登錄GPU算力服務器進行模型的部署FastDeploy 快速部署服務安裝paddlepaddle-gpu1. 降級沖突的庫版本安裝fastdeploy直接部署模型(此處大約花費15分鐘時間)放行服務端口供公網訪問最…

P10719 [GESP202406 五級] 黑白格

題目傳送門 前言:不是這樣例有點過分了哈: 這是我沒考慮到無解的情況的得分: 這是我考慮了的得分: 總而言之,就是一個Subtask 你沒考慮無解的情況(除了Subtask #0),就會WA一大片,然后這個Subt…

AWS RDS PostgreSQL可觀測性最佳實踐

AWS RDS PostgreSQL 介紹AWS RDS PostgreSQL 是亞馬遜云服務(AWS)提供的托管型 PostgreSQL 數據庫服務。托管服務:AWS 管理數據庫的底層基礎設施,包括硬件、操作系統、數據庫引擎等,用戶無需自行維護。高性能&#xff…

C++——set,map的模擬實現

文章目錄前言紅黑樹的改變set的模擬實現基本框架迭代器插入源碼map模擬實現基礎框架迭代器插入賦值重載源碼測試代碼前言 set,map底層使用紅黑樹這種平衡二叉搜索樹來組織元素 ,這使得set, map能夠提供對數時間復雜度的查找、插入和刪除操作。 下面都是基…

LabVIEW液壓機智能監控

?基于LabVIEW平臺,結合西門子、研華等硬件,構建液壓機實時監控系統。通過 OPC 通信技術實現上位機與 PLC 的數據交互,解決傳統監控系統數據采集滯后、存儲有限、參數調控不便等問題,可精準采集沖壓過程中的位置、速度、壓力等參數…

15. 什么是 xss 攻擊?怎么防護

總結 跨站腳本攻擊&#xff0c;注入惡意腳本敏感字符轉義&#xff1a;“<”,“/”前端可以抓包篡改主要后臺處理&#xff0c;轉義什么是 XSS 攻擊&#xff1f;怎么防護 概述 XSS&#xff08;Cross-Site Scripting&#xff0c;跨站腳本攻擊&#xff09;是一種常見的 Web 安全…

更換docker工作目錄

使用環境 由于默認系統盤比較小docker鏡像很容易就占滿&#xff0c;需要掛載新的磁盤修改docker的默認工作目錄 環境&#xff1a;centos7 docker默認工作目錄: /var/lib/docker/ 新的工作目錄&#xff1a;/home/docker-data【自己手動創建&#xff0c;一般掛在新加的磁盤下面】…

算法學習筆記:26.二叉搜索樹(生日限定版)——從原理到實戰,涵蓋 LeetCode 與考研 408 例題

二叉搜索樹&#xff08;Binary Search Tree&#xff0c;簡稱 BST&#xff09;是一種特殊的二叉樹&#xff0c;因其高效的查找、插入和刪除操作&#xff0c;成為計算機科學中最重要的數據結構之一。BST 的核心特性是 “左小右大”&#xff0c;這一特性使其在數據檢索、排序和索引…

共生型企業:駕馭AI自動化(事+AI)與人類增強(人+AI)的雙重前沿

目錄 引言&#xff1a;人工智能的雙重前沿 第一部分&#xff1a;自動化范式&#xff08;事AI&#xff09;——重新定義卓越運營 第一章&#xff1a;智能自動化的機制 第二章&#xff1a;自動化驅動的行業轉型 第三章&#xff1a;自動化的經濟演算 第二部分&#xff1a;協…

TypeScript的export用法

在 TypeScript 中&#xff0c;export 用于將模塊中的變量、函數、類、類型等暴露給外部使用。export 語法允許將模塊化的代碼分割并在其他文件中導入。 1. 命名導出&#xff08;Named Export&#xff09; 命名導出是 TypeScript 中最常見的一種導出方式&#xff0c;它允許你導出…

數據結構-2(鏈表)

一、思維導圖二、鏈表的反轉def reverse(self):"""思路&#xff1a;1、設置previous_node、current、next_node三個變量,目標是將current和previous_node逐步向后循環并逐步進行反轉,知道所有元素都被反轉2、但唯一的問題是&#xff1a;一旦current.next反轉為向…