對比損失的PyTorch實現詳解
本文以SiT代碼中對比損失的實現為例作介紹。
論文:https://arxiv.org/abs/2104.03602
代碼:https://github.com/Sara-Ahmed/SiT
對比損失簡介
作為一種經典的自監督損失,對比損失就是對一張原圖像做不同的圖像擴增方法,得到來自同一原圖的兩張輸入圖像,由于圖像擴增不會改變圖像本身的語義,因此,認為這兩張來自同一原圖的輸入圖像的特征表示應該越相似越好(通常用余弦相似度來進行距離測度),而來自不同原圖像的輸入圖像應該越遠離越好。來自同一原圖的輸入圖像可做正樣本,同一個batch內的不同輸入圖像可用作負樣本。如下圖所示(粗箭頭向上表示相似度越高越好,向下表示越低越好)。
論文中的公式
lcontrxi,xj(W)=esim(SiTcontr(xi),SiTcontr(xj))/τ∑k=1,k≠i2Nesim(SiTcontr(xi),SiTcontr(xk))/τ(1)l^{x_i,x_j}_{contr}(W)=\frac{e^{sim(SiT_{contr}(x_i),SiT_{contr}(x_j))/\tau}}{\sum_{k=1,k\ne i}^{2N}e^{sim(SiT_{contr}(x_i),SiT_{contr}(x_k))/\tau}} \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (1) lcontrxi?,xj??(W)=∑k=1,k?=i2N?esim(SiTcontr?(xi?),SiTcontr?(xk?))/τesim(SiTcontr?(xi?),SiTcontr?(xj?))/τ???????????????????(1)
L=?1N∑j=1Nloglxj,xjˉ(W)(2)\mathcal{L}=-\frac{1}{N}\sum_{j=1}^Nlogl^{x_j,x_{\bar{j}}}(W) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2) L=?N1?j=1∑N?loglxj?,xjˉ??(W)??????????????????(2)
SiT論文中的對比損失公式如上所示。其中xix_ixi?,xjx_jxj?分別表示兩個不同的輸入圖像,sim(?,?)sim(\cdot,\cdot)sim(?,?)表示余弦相似度,即歸一化之后的點積,τ\tauτ是超參數溫度,xjx_jxj?和xjˉx_{\bar{j}}xjˉ??是來自同一原圖的兩種不同數據增強的輸入圖像, SiTcontr(?)SiT_{contr}(\cdot)SiTcontr?(?) 表示從對比頭中得到的圖像表示,沒看過原文的話,就直接理解為輸入圖像經過一系列神經網絡,得到一個dimdimdim 維度的特征向量作為圖像的特征表示,網絡不是本文的重點,重點是怎樣根據得到的特征向量計算對比損失。
與最近很火的infoNCE對比損失基本一樣,只是寫法不同。
代碼實現
class ContrastiveLoss(nn.Module):def __init__(self, batch_size, device='cuda', temperature=0.5):super().__init__()self.batch_size = batch_sizeself.register_buffer("temperature", torch.tensor(temperature).to(device)) # 超參數 溫度self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool).to(device)).float()) # 主對角線為0,其余位置全為1的mask矩陣def forward(self, emb_i, emb_j): # emb_i, emb_j 是來自同一圖像的兩種不同的預處理方法得到z_i = F.normalize(emb_i, dim=1) # (bs, dim) ---> (bs, dim)z_j = F.normalize(emb_j, dim=1) # (bs, dim) ---> (bs, dim)representations = torch.cat([z_i, z_j], dim=0) # repre: (2*bs, dim)similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2) # simi_mat: (2*bs, 2*bs)sim_ij = torch.diag(similarity_matrix, self.batch_size) # bssim_ji = torch.diag(similarity_matrix, -self.batch_size) # bspositives = torch.cat([sim_ij, sim_ji], dim=0) # 2*bsnominator = torch.exp(positives / self.temperature) # 2*bsdenominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature) # 2*bs, 2*bsloss_partial = -torch.log(nominator / torch.sum(denominator, dim=1)) # 2*bsloss = torch.sum(loss_partial) / (2 * self.batch_size)return loss
以下是SiT論文的對比損失代碼實現,筆者已經將debug過程中得到的張量形狀在注釋中標注了出來,供大家參考,其中dim是得到的特征向量的維度,bs是批尺寸batch size。
筆者簡單畫了一張similarity_matrix
的圖示來說明整個過程。本圖以bs==4為例,a,b,c,da,b,c,da,b,c,d分別代表同一個batch內的不同樣本,下表0和1表示兩種不同的圖像擴增方法。圖中每個方格則是對應行列的圖像特征(dim維的向量)表示計算相似度的結果值。
-
emb_i
,emb_j
是來自同一圖像的兩種不同的預處理方法得到的輸入圖像的特征表示。首先是通過F.normalize()
將emb_i
,emb_j
進行歸一化。 -
然后將二者拼接起來的到維度為2*bs的
representations
。再將representations
分別轉換為列向量和行向量計算相似度矩陣similarity_matrix
(見圖)。 -
在通過偏移的對角線(圖中藍線)的到
sim_ij
和sim_ji
,并拼接的到positives
。請注意藍線對應的行列坐標,分別是a0,a1a_0,a_1a0?,a1?、b0,b1b_0,b_1b0?,b1?等,即藍線對應的網格即是來自同一張原圖的不同處理的輸入圖像。這在損失的設計中即是我們的正樣本。 -
然后
nominator
(分子)即可根據公式計算的到。 -
而在計算
denominator
時需注意要乘上self.negatives_mask
。該變量在__init__
中定義,是對2*bs的方針對角陣取反,即主對角線全是0,其余位置全是1 。這是為了在負樣本中屏蔽自己與自己的相似度結果(圖中紅線),即使得similarity_matrix
的主對角錢全為0。因為自己與自己的相似度肯定是1,加入到計算中沒有意義。 -
再到后面
loss_partial
的計算(第22行)其實是計算出公式(1),torch.sum()
計算的是(1)中分母上的∑\sum∑符號。 -
第23行就是計算公式(2),其中與公式相比分母上多了除了個2,是因為本實現為了方便將
similarity_matrix
的維度擴展為2*bs。即相當于將公式(2)中的lcontrxj,xjˉl_{contr}^{x_j,x_{\bar{j}}}lcontrxj?,xjˉ??? 和 lcontrxjˉ,xjl_{contr}^{x_{\bar{j}},x_j}lcontrxjˉ??,xj?? 分別計算了一遍。所以要多除個2。
自行驗證
大家可以將上面的ContrastiveLoss
類復制到自己的測試的文件中,并構造幾個輸入進行測試,打印中間結果,驗證自己是否真正地理解了對比損失的代碼實現計算過程。
loss_func = losses.ContrastiveLoss(batch_size=4)
emb_i = torch.rand(4, 512).cuda()
emb_j = torch.rand(4, 512).cuda()loss_contra = loss_func(emb_i, emb_j)
print(loss_contra)