深度學習中的數值穩定性處理詳解:以SimCLR損失為例

文章目錄

    • 1. 問題背景
      • SimCLR的原始公式
    • 2. 數值溢出問題
      • 為什么會出現數值溢出?
      • 浮點數的表示范圍
    • 3. 數值穩定性處理方法
      • 核心思想
      • 數學推導
    • 4. 代碼實現分解
      • 代碼與公式的對應關系
    • 5. 具體數值示例
      • 示例:相似度矩陣
      • 方法1:直接計算exp(x)
      • 方法2:減去最大值后計算
      • 驗證結果等價性
    • 6. 為什么減去最大值有效?
      • 關鍵原理
    • 7. 實際應用場景
    • 8. 實現建議
    • 總結

在深度學習實現中,特別是涉及指數和對數運算的損失函數計算過程中,數值穩定性是一個核心問題。本文以SimCLR對比學習損失為例,詳細解析數值穩定性處理的原理、實現和重要性。

1. 問題背景

SimCLR是一種自監督學習方法,其核心是InfoNCE損失函數。這個損失函數的計算涉及大量指數運算,容易導致數值溢出或下溢問題。

SimCLR的原始公式

SimCLR的核心損失函數(InfoNCE損失)公式為:

L i = ? log ? exp ? ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N exp ? ( s i m ( z i , z k ) / τ ) ? 1 k ≠ i L_i = -\log \frac{\exp(sim(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau) \cdot \mathbf{1}_{k \neq i}} Li?=?logk=12N?exp(sim(zi?,zk?)/τ)?1k=i?exp(sim(zi?,zj?)/τ)?

其中:

  • z i z_i zi?是錨點特征
  • z j z_j zj?是與 z i z_i zi?對應的正樣本特征
  • τ \tau τ是溫度參數
  • s i m ( ) sim() sim()是相似度函數(通常是點積)
  • 1 k ≠ i \mathbf{1}_{k \neq i} 1k=i?表示排除自身對比的指示函數

2. 數值溢出問題

為什么會出現數值溢出?

當我們計算 exp ? ( x ) \exp(x) exp(x)時:

  • 如果 x x x很大(如 x = 100 x = 100 x=100), exp ? ( 100 ) ≈ 2.7 × 1 0 43 \exp(100) \approx 2.7 \times 10^{43} exp(100)2.7×1043,可能超出浮點數表示范圍
  • 如果 x x x是很小的負數(如 x = ? 100 x = -100 x=?100), exp ? ( ? 100 ) ≈ 3.7 × 1 0 ? 44 \exp(-100) \approx 3.7 \times 10^{-44} exp(?100)3.7×10?44,可能導致下溢為0

在SimCLR中, s i m ( z i , z k ) / τ sim(z_i, z_k)/\tau sim(zi?,zk?)/τ可能很大,特別是當:

  • 特征向量高度相似( s i m sim sim接近1)
  • 溫度參數 τ \tau τ很小(如0.07)

浮點數的表示范圍

浮點數的表示范圍是有限的:

  • 單精度浮點數(32位):約 ± 3.4 × 1 0 38 \pm 3.4 \times 10^{38} ±3.4×1038
  • 雙精度浮點數(64位):約 ± 1.8 × 1 0 308 \pm 1.8 \times 10^{308} ±1.8×10308

3. 數值穩定性處理方法

SimCLR實現中使用了一種簡單而有效的數值穩定性處理技術,代碼如下:

# 數值穩定性處理
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()

核心思想

這種處理的核心思想是:

  1. 找出每行相似度的最大值
  2. 將每行的所有值減去這個最大值
  3. 然后再進行指數計算

數學推導

這種操作是數學等價的。對原始公式進行變換:

L i = ? log ? exp ? ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N exp ? ( s i m ( z i , z k ) / τ ) ? 1 k ≠ i \begin{align} L_i &= -\log \frac{\exp(sim(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau) \cdot \mathbf{1}_{k \neq i}} \\ \end{align} Li??=?logk=12N?exp(sim(zi?,zk?)/τ)?1k=i?exp(sim(zi?,zj?)/τ)???

引入最大值 M i = max ? k ( s i m ( z i , z k ) / τ ) M_i = \max_k (sim(z_i, z_k)/\tau) Mi?=maxk?(sim(zi?,zk?)/τ)

L i = ? log ? exp ? ( s i m ( z i , z j ) / τ ? M i + M i ) ∑ k = 1 2 N exp ? ( s i m ( z i , z k ) / τ ? M i + M i ) ? 1 k ≠ i = ? log ? exp ? ( M i ) ? exp ? ( s i m ( z i , z j ) / τ ? M i ) exp ? ( M i ) ? ∑ k = 1 2 N exp ? ( s i m ( z i , z k ) / τ ? M i ) ? 1 k ≠ i = ? log ? exp ? ( s i m ( z i , z j ) / τ ? M i ) ∑ k = 1 2 N exp ? ( s i m ( z i , z k ) / τ ? M i ) ? 1 k ≠ i \begin{align} L_i &= -\log \frac{\exp(sim(z_i, z_j)/\tau - M_i + M_i)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau - M_i + M_i) \cdot \mathbf{1}_{k \neq i}} \\ &= -\log \frac{\exp(M_i) \cdot \exp(sim(z_i, z_j)/\tau - M_i)}{\exp(M_i) \cdot \sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i}} \\ &= -\log \frac{\exp(sim(z_i, z_j)/\tau - M_i)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i}} \end{align} Li??=?logk=12N?exp(sim(zi?,zk?)/τ?Mi?+Mi?)?1k=i?exp(sim(zi?,zj?)/τ?Mi?+Mi?)?=?logexp(Mi?)?k=12N?exp(sim(zi?,zk?)/τ?Mi?)?1k=i?exp(Mi?)?exp(sim(zi?,zj?)/τ?Mi?)?=?logk=12N?exp(sim(zi?,zk?)/τ?Mi?)?1k=i?exp(sim(zi?,zj?)/τ?Mi?)???

因為分子和分母中的 exp ? ( M i ) \exp(M_i) exp(Mi?)相互抵消,所以最終結果不變。

4. 代碼實現分解

完整的SimCLR損失計算代碼(包含數值穩定性處理):

# 計算相似度矩陣并除以溫度系數
anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T),self.temperature)# 數值穩定性處理
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()# 創建和應用掩碼
mask = mask.repeat(anchor_count, contrast_count)
logits_mask = torch.scatter(torch.ones_like(mask),1,torch.arange(batch_size * anchor_count).view(-1, 1).to(device),0
)
mask = mask * logits_mask# 計算損失
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()

代碼與公式的對應關系

  1. anchor_dot_contrast s i m ( z i , z k ) / τ sim(z_i, z_k)/\tau sim(zi?,zk?)/τ
  2. logits_max M i = max ? k ( s i m ( z i , z k ) / τ ) M_i = \max_k (sim(z_i, z_k)/\tau) Mi?=maxk?(sim(zi?,zk?)/τ)
  3. logits s i m ( z i , z k ) / τ ? M i sim(z_i, z_k)/\tau - M_i sim(zi?,zk?)/τ?Mi?
  4. exp_logits exp ? ( s i m ( z i , z k ) / τ ? M i ) ? 1 k ≠ i \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i} exp(sim(zi?,zk?)/τ?Mi?)?1k=i?
  5. log_prob log ? exp ? ( s i m ( z i , z k ) / τ ? M i ) ∑ k exp ? ( s i m ( z i , z k ) / τ ? M i ) ? 1 k ≠ i \log \frac{\exp(sim(z_i, z_k)/\tau - M_i)}{\sum_{k} \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i}} logk?exp(sim(zi?,zk?)/τ?Mi?)?1k=i?exp(sim(zi?,zk?)/τ?Mi?)?

5. 具體數值示例

為了直觀理解,我們用一個簡化的例子來說明為什么減去最大值能防止數值溢出。

示例:相似度矩陣

假設有一個計算得到的相似度矩陣(已除以溫度τ=0.07):

sim(z_i, z_k)/τ = [[80, 50, 60, 70, 40],[60, 90, 70, 80, 50],[70, 60, 85, 75, 55],[50, 40, 60, 75, 45]
]

方法1:直接計算exp(x)

直接計算exp(sim(z_i, z_k)/τ)

exp(sim(z_i, z_k)/τ) ≈ [[5.54e+34, 5.18e+21, 1.14e+26, 2.51e+30, 2.35e+17],[1.14e+26, 1.22e+39, 2.51e+30, 5.54e+34, 5.18e+21],[2.51e+30, 1.14e+26, 5.91e+36, 3.58e+32, 1.14e+24],[5.18e+21, 2.35e+17, 1.14e+26, 3.58e+32, 3.49e+19]
]

這些值極其巨大,相加時很容易溢出。例如第一行的和約為5.54e+34,已經接近單精度浮點數的上限。

方法2:減去最大值后計算

找出每行的最大值:

max_values = [80, 90, 85, 75]

減去最大值:

adjusted_logits = [[0, -30, -20, -10, -40],[-30, 0, -20, -10, -40],[-15, -25, 0, -10, -30],[-25, -35, -15, 0, -30]
]

計算exp(adjusted_logits)

exp(adjusted_logits) ≈ [[1.0, 9.36e-14, 2.06e-9, 4.54e-5, 4.25e-18],[9.36e-14, 1.0, 2.06e-9, 4.54e-5, 4.25e-18],[3.06e-7, 1.39e-11, 1.0, 4.54e-5, 9.36e-14],[1.39e-11, 6.31e-16, 3.06e-7, 1.0, 9.36e-14]
]

這些值都在[0,1]范圍內,完全避免了溢出問題。同時,正樣本對和負樣本對之間的相對比例關系保持不變。

驗證結果等價性

例如,對于第一行計算最終的歸一化概率:

原始方法:

P(z_0 -> z_0) = exp(80) / sum(exp(row_0)) ≈ 1.0
P(z_0 -> z_1) = exp(50) / sum(exp(row_0)) ≈ 9.35e-14
...

減去最大值后:

P(z_0 -> z_0) = exp(0) / sum(exp(adjusted_row_0)) ≈ 1.0
P(z_0 -> z_1) = exp(-30) / sum(exp(adjusted_row_0)) ≈ 9.35e-14
...

兩種計算方法得到的概率分布是相同的,但后者避免了數值溢出風險。

6. 為什么減去最大值有效?

關鍵原理

減去最大值的處理之所以有效,是因為:

  1. 將范圍控制在安全區間

    • 減去最大值后,所有值都≤0
    • 因此所有exp(x)的結果都≤1,避免了上溢
    • 同時最大值對應的exp(0)=1,避免了整體下溢為0
  2. 保持相對比例關系

    • 對每行減去相同的常數不改變值之間的相對大小
    • 對于exp()函數來說,這等價于同時除以一個常數因子
    • 在計算Softmax或對數概率時,這個常數因子在分子和分母中抵消
  3. 數學等價性

    • exp(a-b) = exp(a)/exp(b)的性質保證了結果的正確性
    • 這相當于將原始公式的分子和分母同時除以exp(max_value)

7. 實際應用場景

這種數值穩定性技術不僅適用于SimCLR,還廣泛應用于:

  1. Softmax計算:幾乎所有需要計算Softmax的地方都需要
  2. 交叉熵損失:分類任務中常用
  3. 注意力機制:Transformer中的attention計算
  4. 所有對比學習方法:MoCo、BYOL、CLIP等

8. 實現建議

在實現涉及指數計算的函數時,建議:

  1. 始終使用數值穩定性處理
  2. 對每個batch/樣本獨立進行處理(找到每行/每個樣本的最大值)
  3. 使用.detach()阻止梯度通過最大值操作傳播
  4. 注意掩碼操作,確保不包括自身對比或特定的負樣本

總結

數值穩定性處理是深度學習實現中一個看似簡單但至關重要的技術。通過簡單地減去每行的最大值,我們可以有效防止數值溢出/下溢問題,同時保持計算結果的數學等價性。這種技術尤其重要,因為隨著模型和批量大小的增加,數值問題更容易出現,而且往往難以診斷。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/79019.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/79019.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/79019.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SQL(9):創建數據庫,表,簡單

1、創建數據庫,一句SQL語句搞定 CREATE DATDBASE 數據庫名 CREATE DATABASE my_db;2、創建表 CREATE TABLE 表名(字段名 類型) CREATE TABLE Persons ( PersonID int, LastName varchar(255), FirstName varchar(255), Address varchar(255), City varchar(255)…

QT Sqlite數據庫-教程002 查詢數據-下

【1】數據庫查詢的優化:prepare prepare語句是一種在執行之前將SQL語句編譯為字節碼的機制,可以提高執行效率并防止SQL注入攻擊。 【2】使用prepare查詢一張表 QString myTable "myTable" ; QString cmd QString("SELECT * FROM %1…

cline 提示詞工程指南-架構篇

cline 提示詞工程指南-架構篇 本篇是 cline 提示詞工程指南的學習和擴展,可以參閱: https://docs.cline.bot/improving-your-prompting-skills/prompting 前言 cline 是 vscode 的插件,用來在 vscode 里實現 ai 編程。 它使得你可以接入…

算法---子序列[動態規劃解決](最長遞增子序列)

最長遞增子序列 子序列包含子數組&#xff01; 說白了&#xff0c;要用到雙層循環&#xff01; 用雙層循環中的dp[i]和dp[j]把所有子序列情況考慮到位 class Solution { public:int lengthOfLIS(vector<int>& nums) {vector<int> dp(nums.size(),1);for(int i …

kubectl命令補全以及oc命令補全

kubectl命令補全 1.安裝bash-completion 如果你用的是Bash(默認情況下是)&#xff0c;先安裝補全功能支持包 sudo apt update sudo apt install bash-completion -y2.為kubectl 啟用補全功能 會話中臨時&#xff1a; source <(kubectl completion bash)持久化配置&#x…

48、Spring Boot 詳細講義(五)

3、集成MyBatis 3.1 MyBatis 概述 3.1.1 核心功能和優勢 MyBatis 是一個 Java 持久層框架,它通過 XML 或注解配置 SQL 語句,將 Java 方法與 SQL 語句映射起來,消除了大量的 JDBC 代碼,簡化了數據庫操作。MyBatis 的核心功能和優勢包括: ORM(對象關系映射):通過 XML …

BERT - Bert模型框架復現

本節將實現一個基于Transformer架構的BERT模型。 1. MultiHeadAttention 類 這個類實現了多頭自注意力機制&#xff08;Multi-Head Self-Attention&#xff09;&#xff0c;是Transformer架構的核心部分。 在前幾篇文章中均有講解&#xff0c;直接上代碼 class MultiHeadAtt…

解決 Spring Boot 啟動報錯:數據源配置引發的啟動失敗

啟動項目時&#xff0c;控制臺輸出了如下錯誤信息&#xff1a; Error starting ApplicationContext. To display the condition evaluation report re-run your application with debug enabled. 2025-04-14 21:13:33.005 [main] ERROR o.s.b.d.LoggingFailureAnalysisReporte…

履帶小車+六軸機械臂(2)

本次介紹原理圖部分 開發板部分&#xff0c;電源供電部分&#xff0c;六路舵機&#xff0c;PS2手柄接收器&#xff0c;HC-05藍牙模塊&#xff0c;蜂鳴器&#xff0c;串口&#xff0c;TB6612電機驅動模塊&#xff0c;LDO線性穩壓電路&#xff0c;按鍵部分 1、開發板部分 需要注…

【開發記錄】服務外包大賽記錄

參加服務外包大賽的A07賽道中&#xff0c;最近因為頻繁的DEBUG&#xff0c;心態爆炸 記錄錯誤 以防止再次出現錯誤浪費時間。。。 2025.4.13 項目在上傳圖片之后 會自動刷新 沒有等待后端返回 Network中的fetch /upload顯示canceled. 然而這是使用了VS的live Server插件才這樣&…

基于FreeRTOS和LVGL的多功能低功耗智能手表(硬件篇)

目錄 一、簡介 二、板子構成 三、核心板 3.1 MCU最小系統板電路 3.2 電源電路 3.3 LCD電路 3.4 EEPROM電路 3.5 硬件看門狗電路 四、背板 4.1 傳感器電路 4.2 充電盤 4.3 藍牙模塊電路 五、總結 一、簡介 本篇開始介紹這個項目的硬件部分&#xff0c;從最小電路設…

為 Kubernetes 提供智能的 LLM 推理路由:Gateway API Inference Extension 深度解析

現代生成式 AI 和大語言模型&#xff08;LLM&#xff09;服務給 Kubernetes 帶來了獨特的流量路由挑戰。與典型的短時、無狀態 Web 請求不同&#xff0c;LLM 推理會話通常是長時運行、資源密集且部分有狀態的。例如&#xff0c;一個基于 GPU 的模型服務器可能同時維護多個活躍的…

MacOs下解決遠程終端內容復制并到本地粘貼板

常常需要在服務器上搗鼓東西&#xff0c;同時需要將內容復制到本地的需求。 1-內容是在遠程終端用vim打開&#xff0c;如何用vim的類似指令達到快速復制到本地呢&#xff1f; 假設待復制的內容&#xff1a; #include <iostream> #include <cstring> using names…

STM32 vs ESP32:如何選擇最適合你的單片機?

引言 在嵌入式開發中&#xff0c;STM32 和 ESP32 是兩種最熱門的微控制器方案。但許多開發者面對項目選型時仍會感到困惑&#xff1a;到底是選擇功能強大的 STM32&#xff0c;還是集成無線的 ESP32&#xff1f; 本文將通過 硬件資源、開發場景、成本分析 等多維度對比&#xf…

【blender小技巧】Blender導出帶貼圖的FBX模型,并在unity中提取材質模型使用

前言 這其實是我之前做過的操作&#xff0c;我只是單獨提取出來了而已。感興趣可以去看看&#xff1a;【blender小技巧】使用Blender將VRM或者其他模型轉化為FBX模型&#xff0c;并在unity使用&#xff0c;導出帶貼圖的FBX模型&#xff0c;貼圖材質問題修復 一、導出帶貼圖的…

如何保證本地緩存和redis的一致性

1. Cache Aside Pattern&#xff08;旁路緩存模式&#xff09;?? ?核心思想?&#xff1a;應用代碼直接管理緩存與數據的同步&#xff0c;分為讀寫兩個流程&#xff1a; ?讀取數據?&#xff1a; 先查本地緩存&#xff08;如 Guava Cache&#xff09;。若本地未命中&…

k8s通過service標簽實現藍綠發布

k8s通過service標簽實現藍綠發布 通過k8s service label標簽實現藍綠發布方法1:使用kubelet完成藍綠切換1. 創建綠色版本1.1 創建綠色版本 Deployment1.2 創建綠色版本 Service 2. 創建藍色版本2.1 創建藍色版本 Deployment2.2 創建藍色版本 Service 3. 創建藍綠切換SVC (用于外…

智慧酒店企業站官網-前端靜態網站模板【前端練習項目】

最近又寫了一個靜態網站&#xff0c;智慧酒店宣傳官網。 使用的技術 html css js 。 特別適合編程學習者進行網頁制作和前端開發的實踐。 項目包含七個核心模塊&#xff1a;首頁、整體解決方案、優勢、全國案例、行業觀點、合作加盟、關于我們。 通過該項目&#xff0c;小伙伴們…

Epplus 8+ 許可證設置

Epplus 8 之后非商業許可證的設置變了如果還用普通的方法會報錯 Unhandled exception. OfficeOpenXml.LicenseContextPropertyObsoleteException: Please use the static ‘ExcelPackage.License’ property to set the required license information from EPPlus 8 and later …

CST1016.基于Spring Boot+Vue高校競賽管理系統

計算機/JAVA畢業設計 【CST1016.基于Spring BootVue高校競賽管理系統】 【項目介紹】 高校競賽管理系統&#xff0c;基于 DeepSeek Spring AI Spring Boot Vue 實現&#xff0c;功能豐富、界面精美 【業務模塊】 系統共有兩類用戶&#xff0c;分別是學生用戶和管理員用戶&a…