信息最大化在目標域無標簽的域自適應任務中,它迫使模型在沒有真實標簽的情況下,對未標記數據產生高置信度且類別均衡的預測。此外,這些預測也可以作為偽標簽用于自訓練。
例如,在目標域沒有標簽時,信息最大化損失可以應用于目標域數據,使模型適應目標域并產生有意義的預測,緩解源域和目標域之間的分布偏移。在自訓練或生成模型中,信息最大化通過要求整體預測分布均衡,有效防止模型將所有樣本都預測到少數幾個類別上。
信息最大化損失函數
信息最大化的損失函數可以表達為[1]:
L I M = L e n t + L d i v = H ( Y ∣ X ) ? H ( Y ) \begin{align} \mathcal{L}_{IM} &= \mathcal{L}_{ent}+\mathcal{L}_{div} \\ &= H(Y|X)-H(Y) \end{align} LIM??=Lent?+Ldiv?=H(Y∣X)?H(Y)??
式中, H ( Y ∣ X ) H(Y|X) H(Y∣X)模型預測輸出標簽的信息熵,最小化條件熵讓模型的預測P(Y|X)更加自信。 H ( Y ) H(Y) H(Y)是預測類別標簽的邊緣熵,由于損失前面有個負號,則需要最大化邊緣熵,迫使模型預測的各個類別均勻分布,而不是偏向其中某個類別。
信息最大化本質是最大化預測標簽 Y Y Y 與輸入 X X X 的互信息 I ( X ; Y ) = H ( Y ) ? H ( Y ∣ X ) I(X;Y) = H(Y) - H(Y|X) I(X;Y)=H(Y)?H(Y∣X),因此 L I M = ? I ( X ; Y ) \mathcal{L}_{IM} = -I(X;Y) LIM?=?I(X;Y)。最小化該損失等價于提升輸入與預測標簽之間的互信息。
① 熵最小化損失 (Entropy Minimization):
L e n t = ? E x ∑ c = 1 C δ c ( f ( x ) ) log ? δ c ( f ( x ) ) \begin{align} \mathcal{L}_{ent} = -\mathbb{E}_x \sum_{c=1}^C \delta_c(f(x)) \log \delta_c(f(x)) \end{align} Lent?=?Ex?c=1∑C?δc?(f(x))logδc?(f(x))??
式中, f ( x ) f(x) f(x)表示模型的預測輸出, δ c \delta_c δc?是softmax函數,代表樣本 x x x是類別 c c c的概率值。
② 多樣性最大化損失 (Diversity Regularization):
L d i v = ? ( ? ∑ c = 1 C p ^ c log ? p ^ c ) = ∑ c = 1 C p ^ c log ? p ^ c \begin{align} \mathcal{L}_{div} &=-(- \sum_{c=1}^C \hat{p}_c \log \hat{p}_c)\\ &= \sum_{c=1}^C \hat{p}_c \log \hat{p}_c \end{align} Ldiv??=?(?c=1∑C?p^?c?logp^?c?)=c=1∑C?p^?c?logp^?c???
其中, p ^ c = 1 N ∑ i = 1 N p i c \hat{p}_c=\frac{1}{N}\sum_{i=1}^{N}p_{ic} p^?c?=N1?∑i=1N?pic?表示類別 c c c在整個批次上的平均預測概率,也就是類別 c c c的邊緣分布。注意:在數學上, P ( Y = c ) P(Y=c) P(Y=c)的邊緣概率應該為 P ( Y = c ) = ∑ i = 1 N P ( X = x i , Y = c ) P(Y = c) = \sum_{i=1}^N P(X = x_i, Y = c) P(Y=c)=∑i=1N?P(X=xi?,Y=c),這里采用的均值而非總和,即采用批次均值 p ^ c \hat{p}_c p^?c? 作為無偏估計。
總結:
- 最小化條件熵 H ( Y ∣ X ) H(Y|X) H(Y∣X):迫使模型對每個目標域樣本做出確定性預測。
- 最大化邊緣熵 H ( Y ) H(Y) H(Y):確保模型在整個目標域上的預測類別分布均勻,防止坍縮到少數類。
代碼
import torch
import torch.nn as nn
import torch.nn.functional as Fclass IMLoss(nn.Module):def __init__(self, lambda_div=0.1, eps=1e-8):"""信息最大化損失函數:無監督學習范式Args:lambda_div (float): 多樣性損失的權重系數,默認0.1eps (float): 數值穩定項,防止log(0),默認1e-8"""super().__init__()self.lambda_div = lambda_divself.eps = epsdef forward(self, logits):"""計算信息最大化損失Args:logits (torch.Tensor): 模型輸出的logits張量,形狀為(batch_size, num_classes)Returns:torch.Tensor: 計算得到的總損失值"""# 計算softmax概率probs = F.softmax(logits, dim=1)# 1. L_ent: 熵最小化損失,使預測更確定entropy_per_sample = -torch.sum(probs * torch.log(probs + self.eps), dim=1)entropy_loss = torch.mean(entropy_per_sample)# 2. L_div: 多樣性最大化損失, 使類別分布均勻mean_probs = torch.mean(probs, dim=0) # 邊緣分布,由于樣本是獨立同分布的,這里考慮概率的平均值而非總和diversity_loss = -torch.sum(mean_probs * torch.log(mean_probs + self.eps))# L_IM總損失total_loss = entropy_loss - self.lambda_div * diversity_lossreturn total_lossnum_classes=3
bs=2logits = torch.randn(bs, num_classes) # 模型輸出的邏輯值:model(x)
loss_fn = IMLoss()
loss = loss_fn(logits)
知識點
邊緣分布/邊際分布 (Marginal distribution)定義[2]:Given a known joint distribution of two discrete random variables, say, X X X and Y Y Y, the marginal distribution of either variable – X X X for example – is the probability distribution of X X X when the values of Y Y Y are not taken into consideration.
p X ( x i ) = ∑ j p ( x i , y j ) , p Y ( y j ) = ∑ i p ( x i , y j ) p_X(x_i)=\sum_jp(x_i,y_j),\\ p_Y(y_j)=\sum_i p(x_i,y_j) pX?(xi?)=j∑?p(xi?,yj?),pY?(yj?)=i∑?p(xi?,yj?)
案例:
如下表所示,一個批次有3個樣本,類別為4,對應的隨機變量Y的邊緣分布為最后一行。
y 1 y_1 y1? | y 2 y_2 y2? | y 3 y_3 y3? | y 4 y_4 y4? | p X ( x ) p_X(x) pX?(x) | |
---|---|---|---|---|---|
x 1 x_1 x1? | 4 32 \frac{4}{32} 324? | 2 32 \frac{2}{32} 322? | 1 32 \frac{1}{32} 321? | 1 32 \frac{1}{32} 321? | 8 32 \frac{8}{32} 328? |
x 2 x_2 x2? | 3 32 \frac{3}{32} 323? | 6 32 \frac{6}{32} 326? | 3 32 \frac{3}{32} 323? | 3 32 \frac{3}{32} 323? | 15 32 \frac{15}{32} 3215? |
x 3 x_3 x3? | 9 32 \frac{9}{32} 329? | 0 0 0 | 0 0 0 | 0 0 0 | 9 32 \frac{9}{32} 329? |
p Y ( y ) p_Y(y) pY?(y) | 16 32 \frac{16}{32} 3216? | 8 32 \frac{8}{32} 328? | 4 32 \frac{4}{32} 324? | 4 32 \frac{4}{32} 324? | 32 32 \frac{32}{32} 3232? |
參考:
[1] [2002.08546] Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation
[2] Marginal distribution - Wikipedia