文章目錄
- 1. 問題背景
- SimCLR的原始公式
- 2. 數值溢出問題
- 為什么會出現數值溢出?
- 浮點數的表示范圍
- 3. 數值穩定性處理方法
- 核心思想
- 數學推導
- 4. 代碼實現分解
- 代碼與公式的對應關系
- 5. 具體數值示例
- 示例:相似度矩陣
- 方法1:直接計算exp(x)
- 方法2:減去最大值后計算
- 驗證結果等價性
- 6. 為什么減去最大值有效?
- 關鍵原理
- 7. 實際應用場景
- 8. 實現建議
- 總結
在深度學習實現中,特別是涉及指數和對數運算的損失函數計算過程中,數值穩定性是一個核心問題。本文以SimCLR對比學習損失為例,詳細解析數值穩定性處理的原理、實現和重要性。
1. 問題背景
SimCLR是一種自監督學習方法,其核心是InfoNCE損失函數。這個損失函數的計算涉及大量指數運算,容易導致數值溢出或下溢問題。
SimCLR的原始公式
SimCLR的核心損失函數(InfoNCE損失)公式為:
L i = ? log ? exp ? ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N exp ? ( s i m ( z i , z k ) / τ ) ? 1 k ≠ i L_i = -\log \frac{\exp(sim(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau) \cdot \mathbf{1}_{k \neq i}} Li?=?log∑k=12N?exp(sim(zi?,zk?)/τ)?1k=i?exp(sim(zi?,zj?)/τ)?
其中:
- z i z_i zi?是錨點特征
- z j z_j zj?是與 z i z_i zi?對應的正樣本特征
- τ \tau τ是溫度參數
- s i m ( ) sim() sim()是相似度函數(通常是點積)
- 1 k ≠ i \mathbf{1}_{k \neq i} 1k=i?表示排除自身對比的指示函數
2. 數值溢出問題
為什么會出現數值溢出?
當我們計算 exp ? ( x ) \exp(x) exp(x)時:
- 如果 x x x很大(如 x = 100 x = 100 x=100), exp ? ( 100 ) ≈ 2.7 × 1 0 43 \exp(100) \approx 2.7 \times 10^{43} exp(100)≈2.7×1043,可能超出浮點數表示范圍
- 如果 x x x是很小的負數(如 x = ? 100 x = -100 x=?100), exp ? ( ? 100 ) ≈ 3.7 × 1 0 ? 44 \exp(-100) \approx 3.7 \times 10^{-44} exp(?100)≈3.7×10?44,可能導致下溢為0
在SimCLR中, s i m ( z i , z k ) / τ sim(z_i, z_k)/\tau sim(zi?,zk?)/τ可能很大,特別是當:
- 特征向量高度相似( s i m sim sim接近1)
- 溫度參數 τ \tau τ很小(如0.07)
浮點數的表示范圍
浮點數的表示范圍是有限的:
- 單精度浮點數(32位):約 ± 3.4 × 1 0 38 \pm 3.4 \times 10^{38} ±3.4×1038
- 雙精度浮點數(64位):約 ± 1.8 × 1 0 308 \pm 1.8 \times 10^{308} ±1.8×10308
3. 數值穩定性處理方法
SimCLR實現中使用了一種簡單而有效的數值穩定性處理技術,代碼如下:
# 數值穩定性處理
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
核心思想
這種處理的核心思想是:
- 找出每行相似度的最大值
- 將每行的所有值減去這個最大值
- 然后再進行指數計算
數學推導
這種操作是數學等價的。對原始公式進行變換:
L i = ? log ? exp ? ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N exp ? ( s i m ( z i , z k ) / τ ) ? 1 k ≠ i \begin{align} L_i &= -\log \frac{\exp(sim(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau) \cdot \mathbf{1}_{k \neq i}} \\ \end{align} Li??=?log∑k=12N?exp(sim(zi?,zk?)/τ)?1k=i?exp(sim(zi?,zj?)/τ)???
引入最大值 M i = max ? k ( s i m ( z i , z k ) / τ ) M_i = \max_k (sim(z_i, z_k)/\tau) Mi?=maxk?(sim(zi?,zk?)/τ):
L i = ? log ? exp ? ( s i m ( z i , z j ) / τ ? M i + M i ) ∑ k = 1 2 N exp ? ( s i m ( z i , z k ) / τ ? M i + M i ) ? 1 k ≠ i = ? log ? exp ? ( M i ) ? exp ? ( s i m ( z i , z j ) / τ ? M i ) exp ? ( M i ) ? ∑ k = 1 2 N exp ? ( s i m ( z i , z k ) / τ ? M i ) ? 1 k ≠ i = ? log ? exp ? ( s i m ( z i , z j ) / τ ? M i ) ∑ k = 1 2 N exp ? ( s i m ( z i , z k ) / τ ? M i ) ? 1 k ≠ i \begin{align} L_i &= -\log \frac{\exp(sim(z_i, z_j)/\tau - M_i + M_i)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau - M_i + M_i) \cdot \mathbf{1}_{k \neq i}} \\ &= -\log \frac{\exp(M_i) \cdot \exp(sim(z_i, z_j)/\tau - M_i)}{\exp(M_i) \cdot \sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i}} \\ &= -\log \frac{\exp(sim(z_i, z_j)/\tau - M_i)}{\sum_{k=1}^{2N} \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i}} \end{align} Li??=?log∑k=12N?exp(sim(zi?,zk?)/τ?Mi?+Mi?)?1k=i?exp(sim(zi?,zj?)/τ?Mi?+Mi?)?=?logexp(Mi?)?∑k=12N?exp(sim(zi?,zk?)/τ?Mi?)?1k=i?exp(Mi?)?exp(sim(zi?,zj?)/τ?Mi?)?=?log∑k=12N?exp(sim(zi?,zk?)/τ?Mi?)?1k=i?exp(sim(zi?,zj?)/τ?Mi?)???
因為分子和分母中的 exp ? ( M i ) \exp(M_i) exp(Mi?)相互抵消,所以最終結果不變。
4. 代碼實現分解
完整的SimCLR損失計算代碼(包含數值穩定性處理):
# 計算相似度矩陣并除以溫度系數
anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T),self.temperature)# 數值穩定性處理
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()# 創建和應用掩碼
mask = mask.repeat(anchor_count, contrast_count)
logits_mask = torch.scatter(torch.ones_like(mask),1,torch.arange(batch_size * anchor_count).view(-1, 1).to(device),0
)
mask = mask * logits_mask# 計算損失
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
代碼與公式的對應關系
anchor_dot_contrast
→ s i m ( z i , z k ) / τ sim(z_i, z_k)/\tau sim(zi?,zk?)/τlogits_max
→ M i = max ? k ( s i m ( z i , z k ) / τ ) M_i = \max_k (sim(z_i, z_k)/\tau) Mi?=maxk?(sim(zi?,zk?)/τ)logits
→ s i m ( z i , z k ) / τ ? M i sim(z_i, z_k)/\tau - M_i sim(zi?,zk?)/τ?Mi?exp_logits
→ exp ? ( s i m ( z i , z k ) / τ ? M i ) ? 1 k ≠ i \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i} exp(sim(zi?,zk?)/τ?Mi?)?1k=i?log_prob
→ log ? exp ? ( s i m ( z i , z k ) / τ ? M i ) ∑ k exp ? ( s i m ( z i , z k ) / τ ? M i ) ? 1 k ≠ i \log \frac{\exp(sim(z_i, z_k)/\tau - M_i)}{\sum_{k} \exp(sim(z_i, z_k)/\tau - M_i) \cdot \mathbf{1}_{k \neq i}} log∑k?exp(sim(zi?,zk?)/τ?Mi?)?1k=i?exp(sim(zi?,zk?)/τ?Mi?)?
5. 具體數值示例
為了直觀理解,我們用一個簡化的例子來說明為什么減去最大值能防止數值溢出。
示例:相似度矩陣
假設有一個計算得到的相似度矩陣(已除以溫度τ=0.07):
sim(z_i, z_k)/τ = [[80, 50, 60, 70, 40],[60, 90, 70, 80, 50],[70, 60, 85, 75, 55],[50, 40, 60, 75, 45]
]
方法1:直接計算exp(x)
直接計算exp(sim(z_i, z_k)/τ)
:
exp(sim(z_i, z_k)/τ) ≈ [[5.54e+34, 5.18e+21, 1.14e+26, 2.51e+30, 2.35e+17],[1.14e+26, 1.22e+39, 2.51e+30, 5.54e+34, 5.18e+21],[2.51e+30, 1.14e+26, 5.91e+36, 3.58e+32, 1.14e+24],[5.18e+21, 2.35e+17, 1.14e+26, 3.58e+32, 3.49e+19]
]
這些值極其巨大,相加時很容易溢出。例如第一行的和約為5.54e+34,已經接近單精度浮點數的上限。
方法2:減去最大值后計算
找出每行的最大值:
max_values = [80, 90, 85, 75]
減去最大值:
adjusted_logits = [[0, -30, -20, -10, -40],[-30, 0, -20, -10, -40],[-15, -25, 0, -10, -30],[-25, -35, -15, 0, -30]
]
計算exp(adjusted_logits)
:
exp(adjusted_logits) ≈ [[1.0, 9.36e-14, 2.06e-9, 4.54e-5, 4.25e-18],[9.36e-14, 1.0, 2.06e-9, 4.54e-5, 4.25e-18],[3.06e-7, 1.39e-11, 1.0, 4.54e-5, 9.36e-14],[1.39e-11, 6.31e-16, 3.06e-7, 1.0, 9.36e-14]
]
這些值都在[0,1]范圍內,完全避免了溢出問題。同時,正樣本對和負樣本對之間的相對比例關系保持不變。
驗證結果等價性
例如,對于第一行計算最終的歸一化概率:
原始方法:
P(z_0 -> z_0) = exp(80) / sum(exp(row_0)) ≈ 1.0
P(z_0 -> z_1) = exp(50) / sum(exp(row_0)) ≈ 9.35e-14
...
減去最大值后:
P(z_0 -> z_0) = exp(0) / sum(exp(adjusted_row_0)) ≈ 1.0
P(z_0 -> z_1) = exp(-30) / sum(exp(adjusted_row_0)) ≈ 9.35e-14
...
兩種計算方法得到的概率分布是相同的,但后者避免了數值溢出風險。
6. 為什么減去最大值有效?
關鍵原理
減去最大值的處理之所以有效,是因為:
-
將范圍控制在安全區間:
- 減去最大值后,所有值都≤0
- 因此所有
exp(x)
的結果都≤1,避免了上溢 - 同時最大值對應的
exp(0)=1
,避免了整體下溢為0
-
保持相對比例關系:
- 對每行減去相同的常數不改變值之間的相對大小
- 對于
exp()
函數來說,這等價于同時除以一個常數因子 - 在計算Softmax或對數概率時,這個常數因子在分子和分母中抵消
-
數學等價性:
exp(a-b) = exp(a)/exp(b)
的性質保證了結果的正確性- 這相當于將原始公式的分子和分母同時除以
exp(max_value)
7. 實際應用場景
這種數值穩定性技術不僅適用于SimCLR,還廣泛應用于:
- Softmax計算:幾乎所有需要計算Softmax的地方都需要
- 交叉熵損失:分類任務中常用
- 注意力機制:Transformer中的attention計算
- 所有對比學習方法:MoCo、BYOL、CLIP等
8. 實現建議
在實現涉及指數計算的函數時,建議:
- 始終使用數值穩定性處理
- 對每個batch/樣本獨立進行處理(找到每行/每個樣本的最大值)
- 使用
.detach()
阻止梯度通過最大值操作傳播 - 注意掩碼操作,確保不包括自身對比或特定的負樣本
總結
數值穩定性處理是深度學習實現中一個看似簡單但至關重要的技術。通過簡單地減去每行的最大值,我們可以有效防止數值溢出/下溢問題,同時保持計算結果的數學等價性。這種技術尤其重要,因為隨著模型和批量大小的增加,數值問題更容易出現,而且往往難以診斷。