對比損失的PyTorch實現詳解

對比損失的PyTorch實現詳解

本文以SiT代碼中對比損失的實現為例作介紹。

論文:https://arxiv.org/abs/2104.03602
代碼:https://github.com/Sara-Ahmed/SiT

對比損失簡介

作為一種經典的自監督損失,對比損失就是對一張原圖像做不同的圖像擴增方法,得到來自同一原圖的兩張輸入圖像,由于圖像擴增不會改變圖像本身的語義,因此,認為這兩張來自同一原圖的輸入圖像的特征表示應該越相似越好(通常用余弦相似度來進行距離測度),而來自不同原圖像的輸入圖像應該越遠離越好。來自同一原圖的輸入圖像可做正樣本,同一個batch內的不同輸入圖像可用作負樣本。如下圖所示(粗箭頭向上表示相似度越高越好,向下表示越低越好)。
在這里插入圖片描述

論文中的公式

lcontrxi,xj(W)=esim(SiTcontr(xi),SiTcontr(xj))/τ∑k=1,k≠i2Nesim(SiTcontr(xi),SiTcontr(xk))/τ(1)l^{x_i,x_j}_{contr}(W)=\frac{e^{sim(SiT_{contr}(x_i),SiT_{contr}(x_j))/\tau}}{\sum_{k=1,k\ne i}^{2N}e^{sim(SiT_{contr}(x_i),SiT_{contr}(x_k))/\tau}} \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (1) lcontrxi?,xj??(W)=k=1,k?=i2N?esim(SiTcontr?(xi?),SiTcontr?(xk?))/τesim(SiTcontr?(xi?),SiTcontr?(xj?))/τ???????????????????(1)

L=?1N∑j=1Nloglxj,xjˉ(W)(2)\mathcal{L}=-\frac{1}{N}\sum_{j=1}^Nlogl^{x_j,x_{\bar{j}}}(W) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2) L=?N1?j=1N?loglxj?,xjˉ??(W)??????????????????(2)

SiT論文中的對比損失公式如上所示。其中xix_ixi?xjx_jxj?分別表示兩個不同的輸入圖像,sim(?,?)sim(\cdot,\cdot)sim(?,?)表示余弦相似度,即歸一化之后的點積,τ\tauτ是超參數溫度,xjx_jxj?xjˉx_{\bar{j}}xjˉ??是來自同一原圖的兩種不同數據增強的輸入圖像, SiTcontr(?)SiT_{contr}(\cdot)SiTcontr?(?) 表示從對比頭中得到的圖像表示,沒看過原文的話,就直接理解為輸入圖像經過一系列神經網絡,得到一個dimdimdim 維度的特征向量作為圖像的特征表示,網絡不是本文的重點,重點是怎樣根據得到的特征向量計算對比損失

與最近很火的infoNCE對比損失基本一樣,只是寫法不同。

代碼實現

class ContrastiveLoss(nn.Module):def __init__(self, batch_size, device='cuda', temperature=0.5):super().__init__()self.batch_size = batch_sizeself.register_buffer("temperature", torch.tensor(temperature).to(device))			# 超參數 溫度self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool).to(device)).float())		# 主對角線為0,其余位置全為1的mask矩陣def forward(self, emb_i, emb_j):		# emb_i, emb_j 是來自同一圖像的兩種不同的預處理方法得到z_i = F.normalize(emb_i, dim=1)     # (bs, dim)  --->  (bs, dim)z_j = F.normalize(emb_j, dim=1)     # (bs, dim)  --->  (bs, dim)representations = torch.cat([z_i, z_j], dim=0)          # repre: (2*bs, dim)similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)      # simi_mat: (2*bs, 2*bs)sim_ij = torch.diag(similarity_matrix, self.batch_size)         # bssim_ji = torch.diag(similarity_matrix, -self.batch_size)        # bspositives = torch.cat([sim_ij, sim_ji], dim=0)                  # 2*bsnominator = torch.exp(positives / self.temperature)             # 2*bsdenominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature)             # 2*bs, 2*bsloss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))        # 2*bsloss = torch.sum(loss_partial) / (2 * self.batch_size)return loss

以下是SiT論文的對比損失代碼實現,筆者已經將debug過程中得到的張量形狀在注釋中標注了出來,供大家參考,其中dim是得到的特征向量的維度,bs是批尺寸batch size。

筆者簡單畫了一張similarity_matrix的圖示來說明整個過程。本圖以bs==4為例,a,b,c,da,b,c,da,b,c,d分別代表同一個batch內的不同樣本,下表0和1表示兩種不同的圖像擴增方法。圖中每個方格則是對應行列的圖像特征(dim維的向量)表示計算相似度的結果值。

在這里插入圖片描述

  1. emb_i,emb_j 是來自同一圖像的兩種不同的預處理方法得到的輸入圖像的特征表示。首先是通過F.normalize()emb_iemb_j進行歸一化。

  2. 然后將二者拼接起來的到維度為2*bs的representations。再將representations分別轉換為列向量和行向量計算相似度矩陣similarity_matrix(見圖)。

  3. 在通過偏移的對角線(圖中藍線)的到sim_ijsim_ji,并拼接的到positives。請注意藍線對應的行列坐標,分別是a0,a1a_0,a_1a0?,a1?b0,b1b_0,b_1b0?,b1?等,即藍線對應的網格即是來自同一張原圖的不同處理的輸入圖像。這在損失的設計中即是我們的正樣本。

  4. 然后nominator(分子)即可根據公式計算的到。

  5. 而在計算denominator時需注意要乘上self.negatives_mask。該變量在__init__中定義,是對2*bs的方針對角陣取反,即主對角線全是0,其余位置全是1 。這是為了在負樣本中屏蔽自己與自己的相似度結果(圖中紅線),即使得similarity_matrix的主對角錢全為0。因為自己與自己的相似度肯定是1,加入到計算中沒有意義。

  6. 再到后面loss_partial的計算(第22行)其實是計算出公式(1),torch.sum()計算的是(1)中分母上的∑\sum符號。

  7. 第23行就是計算公式(2),其中與公式相比分母上多了除了個2,是因為本實現為了方便將similarity_matrix的維度擴展為2*bs。即相當于將公式(2)中的lcontrxj,xjˉl_{contr}^{x_j,x_{\bar{j}}}lcontrxj?,xjˉ???lcontrxjˉ,xjl_{contr}^{x_{\bar{j}},x_j}lcontrxjˉ??,xj?? 分別計算了一遍。所以要多除個2。

自行驗證

大家可以將上面的ContrastiveLoss類復制到自己的測試的文件中,并構造幾個輸入進行測試,打印中間結果,驗證自己是否真正地理解了對比損失的代碼實現計算過程。

loss_func = losses.ContrastiveLoss(batch_size=4)
emb_i = torch.rand(4, 512).cuda()
emb_j = torch.rand(4, 512).cuda()loss_contra = loss_func(emb_i, emb_j)
print(loss_contra)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532773.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532773.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532773.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

android 融云瀏覽大圖,融云 Android sdk kit 頭像昵稱更新機制

先申明筆者的實現方式不是唯一 也不一定是最優化的方案 如果您看到此篇博文 有不同看法 或者 更好的優化 更高的效率 歡迎在評論發表意見 融云官網點我融云頭像機制相關視頻詳解首先跟大家說一下 kit 跟 lib 的頭像機制 kit 是已經包含融云已經給開發者定制好的界面 諸如 會話界…

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one.

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. 報錯信息 報錯信息: RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates tha…

android訪問重定向地址,如何從android中重定向url加載圖像(示例代碼)

嗨,我正面臨這個問題我從RESTCall獲取了一個URL網址是http://hck.re/kWWxUI但是當我在瀏覽器中檢查時,它會重定向到https://s3-ap-southeast-1.amazonaws.com/he-public-data/afreen2ac5a33.jpg如何將此圖像加載到我的imageView中我已經知道如何將畢加索…

Linux中的awk、sed、grep及正則表達式詳解

Linux中的awk、sed、grep及正則表達式詳解 簡介 awk、sed和grep是Linux中文本操作的三大利器。 其中awk適用于取列,sed適用于取行,grep適用于過濾。 正則表達式 首先我們來介紹一下正則表達式,正則表達式(regular expression)描述了一種…

android聚焦時如何給控件加邊框,edittext設置獲得焦點時的邊框顏色

第一步:為了更好的比較,準備兩個一模一樣的EditText(當Activity啟動時,焦點會在第一個EditText上,如果你不希望這樣只需要寫一個高度和寬帶為0的EditText即可避免,這里就不這么做了),代碼如下:a…

gcc參數 -i, -L, -l, -include

gcc參數 -i, -L, -l, -include -i,-L,-l,-include -l和-L -l參數就是用來指定程序要鏈接的庫,-l參數緊接著就是庫名,那么庫名跟真正的庫文件名有什么關系呢?就拿數學庫來說,他的庫名是m&…

xargs 命令教程

xargs 命令教程 轉自:http://www.ruanyifeng.com/blog/2019/08/xargs-tutorial.html 作者: 阮一峰 日期: 2019年8月 8日 xargs是 Unix 系統的一個很有用的命令,但是常常被忽視,很多人不了解它的用法。 本文介紹如…

android strictmode有什么作用,Android 性能優化 之 StrictMode

8種機械鍵盤軸體對比本人程序員,要買一個寫代碼的鍵盤,請問紅軸和茶軸怎么選?StrictMode概述StrictMode 是用來檢測程序中違例情況的開發者工具。使用StrictMode,系統檢測出主線程違例的情況會做出相應的反應,如日志打…

curl 的用法指南

curl 的用法指南 轉自:http://www.ruanyifeng.com/blog/2019/09/curl-reference.html 作者: 阮一峰 日期: 2019年9月 5日 簡介 curl 是常用的命令行工具,用來請求 Web 服務器。它的名字就是客戶端(client&#xf…

怎么在html顯示已登錄狀態,jQuery Ajax 實現在html頁面實時顯示用戶登錄狀態

當網站是全靜態的html頁面時,而又希望網站會員在登錄之后并在所有頁面頭部顯示登錄狀態,如用戶名等,如果未登錄就是未登錄狀態,下面給大家來分享實現的方法。一、在html靜態頁面中加入div,并指定ID如:二、新…

互斥鎖、條件變量、信號量淺析

互斥鎖、條件變量、信號量淺析 互斥鎖與條件變量 條件變量是為了保證同步 條件變量用在多線程多任務同步的,一個線程完成了某一個動作就通過條件變量告訴別的線程,別的線程再進行某些動作(大家都在semtake的時候,就阻塞在哪里&a…

xpwifi熱點設置android,教你在XP電腦中開啟設置WiFi熱點使用的步驟

對于系統中網絡的連接問題是最重要的,那在處理不同的錯誤的情況中,對于無線網絡的設置也就是我們說的WiFi的使用也是會遇到問題的,那在操作的時候對于電腦中是怎么實現設置WiFi熱點的的,對于這個問題今天小編就來跟大家分享一下教…

C/C++ 指針詳解

指針詳解 參考視頻:https://www.bilibili.com/video/BV1bo4y1Z7xf/,感謝Bilibilifengmuzi2003的搬運翻譯及后續勘誤,也感謝已故原作者Harsha Suryanarayana的講解,RIP。 學習完之后,回看找特定的知識點,善…

android雙聯動列表,Android Fragment實現列表和內容聯動

在平板上經常能看到這種的情況:左邊是一個列表,右邊是列表項對應的內容,當點擊某一個列表時,右邊內容區也會隨之改變。下面使用fragment簡單的demo:思路:在mainactivity定義一個回調接口,并在列…

linux /proc 詳解

linux /proc 詳解 本文整理了一下 linux /proc下的幾個常用的目錄和文件,可供查閱,之后在學習工作中有別的用到的話會再補充。 /proc 簡介 Linux系統上的/proc目錄是一種文件系統,即proc文件系統。與其它常見的文件系統不同的是&#xff0…

android模擬器太卡,安卓模擬器安裝之后太卡怎么解決

用安卓模擬器玩游戲原理就是在電腦上安裝了一部手機,如果你的電腦配置不是非常高,能不卡頓嗎?遇到卡頓怎么解決?1、安裝最新版本的顯卡驅動。逍遙模擬器對于顯卡的性能要求很高,因此升級至最新版本的顯卡驅動,是確保逍遙模擬器流…

編程環境中Runtime(運行時)的三個含義

編程環境中Runtime(運行時)的三個含義 轉自:https://www.zhihu.com/question/20607178 知乎答主doodlewind 三個含義 實際上編程語境中的 runtime 至少有三個含義,分別是: 指「程序運行的時候」,即程序…

非常不錯的一款html5【404頁面】,不含js腳本可以左右擺動,原生JavaScript實現日歷功能代碼實例(無引用Jq)...

這篇文章主要介紹了原生JavaScript實現日歷功能代碼實例(無引用Jq),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下成品顯示,可左右切換月份html 代碼移動端日歷日一二三四五六css代碼*{margin: 0;pa…

12 [虛擬化] 進程抽象;fork,execve,exit

12 [虛擬化] 進程抽象;fork,execve,exit 南京大學操作系統課蔣炎巖老師網絡課程筆記。 視頻:https://www.bilibili.com/video/BV1N741177F5?p12 講義:http://jyywiki.cn/OS/2021/slides/8.slides#/ 本講概述 回到“…

計算機應用與基礎實踐怎么考,自考計算機基礎應用科目筆試和實踐性考試怎么考...

自考計算機基礎應用科目筆試和實踐性考試怎么考? 報考自考的考生有些專業的考生會在自己的課程科目中發現計算機基礎應用不僅有理論知識考試還有實踐性考試,那么自考計算機基礎應用科目的筆試和實踐性考試怎么考?自考計算機基礎應用科目筆試怎…