淺談 EMP-SSL + 代碼解讀:自監督對比學習的一種極簡主義風

論文鏈接:https://arxiv.org/pdf/2304.03977.pdf

代碼:https://github.com/tsb0601/EMP-SSL

其他學習鏈接:突破自監督學習效率極限!馬毅、LeCun聯合發布EMP-SSL:無需花哨trick,30個epoch即可實現SOTA


主要思想

如圖,一張圖片裁剪成不同的 patch,對不同的 patch 做數據增強,分別輸入 encoder,得到多個 embedding,對它們求均值,得到?\bar z?作為這張圖片的 embedding。最后,拉近每個 patch 的 embedding 和圖片的 embedding(\bar z)之間的余弦距離;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 對所有輸入都輸出相同的 embedding)

圖片

圖片

Total Coding Rate(TCR)

公式如下:

圖片

其中,det 表示求矩陣的行列式,d 是 feature vector 的 dimension,b 是 batch size

查了查該公式的含義:expand all features of Z as large as possible,即盡可能拉遠矩陣中特征之間的距離。

源自 PPT 第 24 頁:

https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf

至于為什么最大化該公式的值就可以拉遠矩陣中特征之間的距離,這背后的數學原理真難啃啊 /(ㄒoㄒ)/~~


核心代碼解讀

數據處理

https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27

class ContrastiveLearningViewGenerator(object):def __init__(self, num_patch = 4):self.num_patch = num_patchdef __call__(self, x):normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])aug_transform = transforms.Compose([transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),transforms.RandomGrayscale(p=0.2),GBlur(p=0.1),transforms.RandomApply([Solarization()], p=0.1),transforms.ToTensor(),  normalize])augmented_x = [aug_transform(x) for i in range(self.num_patch)]return augmented_x

由此看出返回的 數據 為:長度為 num_patches 個 tensor 的列表。其中,每個?tensor 的 shape 為 (B, C, H, W)。

主函數

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63

for step, (data, label) in tqdm(enumerate(dataloader)):net.zero_grad()opt.zero_grad()data = torch.cat(data, dim=0) data = data.cuda()z_proj = net(data)z_list = z_proj.chunk(num_patches, dim=0)z_avg = chunk_avg(z_proj, num_patches)# Contractive Lossloss_contract, _ = contractive_loss(z_list, z_avg)loss_TCR = cal_TCR(z_proj, criterion, num_patches)

這里要稍微注意一下幾個變量的 shape:

  • data 被 cat 完后:(num_patches * B,C,H,W)
  • z_proj:(num_patches * B,C)
  • z_list:(num_patches,B,C)
  • z_avg:(B,C)

其中,chunk_avg 就是對來自同一張圖片的不同?patch 的 embedding 求均值(\bar z):

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67

def chunk_avg(x,n_chunks=2,normalize=False):x_list = x.chunk(n_chunks,dim=0)x = torch.stack(x_list,dim=0)if not normalize:return x.mean(0)else:return F.normalize(x.mean(0),dim=1)

loss

contractive_loss 就是計算每個 patch 的 embedding 和均值(\bar z)的余弦距離:

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76

class Similarity_Loss(nn.Module):def __init__(self, ):super().__init__()passdef forward(self, z_list, z_avg):z_sim = 0num_patch = len(z_list)z_list = torch.stack(list(z_list), dim=0)z_avg = z_list.mean(dim=0)z_sim = 0for i in range(num_patch):z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()z_sim = z_sim/num_patchz_sim_out = z_sim.clone().detach()return -z_sim, z_sim_out

TCR loss:最大化矩陣之間特征的距離,即拉遠負樣本(不是來自同一個樣本的 patches)之間的距離

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96

def cal_TCR(z, criterion, num_patches):z_list = z.chunk(num_patches,dim=0)loss = 0for i in range(num_patches):loss += criterion(z_list[i])loss = loss/num_patchesreturn loss

需要注意:函數輸入的 z 是?z_proj,形狀為(num_patches * B,C)。

所以,函數內部?z_list?的形狀為(num_patches,B,C),即將數據分為了?num_patches 個組,每個組包含了來自不同圖片里 patch 的 embedding。再分別對每個組求 TCR loss,最大化組內(不同圖片的 patch)特征的距離。

所以,公式中的?Z?指的是一組來自不同圖片里 patch 的 embedding,形狀為(B,C)。

每個組內求 TCR loss 的代碼按照公式計算,如下:?

圖片

https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76

class TotalCodingRate(nn.Module):def __init__(self, eps=0.01):super(TotalCodingRate, self).__init__()self.eps = epsdef compute_discrimn_loss(self, W):"""Discriminative Loss."""p, m = W.shape  #[d, B]I = torch.eye(p,device=W.device)scalar = p / (m * self.eps)logdet = torch.logdet(I + scalar * W.matmul(W.T))return logdet / 2.def forward(self,X):return - self.compute_discrimn_loss(X.T)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/38733.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/38733.shtml
英文地址,請注明出處:http://en.pswp.cn/news/38733.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

08 qt進程和網絡編程(cs模型)

一 、qt進程 qt中進程最主要的任務就是啟動額外應用程序 并且跟他們之間通信。進程類為QProcess 定義用途Header:#include qmake:QT += coreInherits:QIODevice//繼承于IO設備類1.1 QProcess基本使用 第一步:創建一個QProcess對象 // process = new QProcess(this); //說明…

資訊速遞 | ArkUI-X 預覽版已正式開源!

OpenHarmony項目群技術指導委員會(以下簡稱“TSC”)-跨平臺應用開發框架TSG所孵化項目 —— ArkUI-X,近期已正式開源 ,開發者基于一套主代碼,就可以將在OpenHarmony上開發的精美、高性能應用同時運行在Android、iOS等其…

LNMP環境搭建wordpress以及跳轉后臺報404解決

基于上文配置好的LNMP環境繼續搭建wordpress 目錄 一.到官網下載tar.gz包,并上傳到Linux上,也可以通過復制鏈接地址進行下載 二. 將wordpress中的所有文件移動到你nginx.conf中指定目錄中 三.為wordpress配置數據庫 四.到瀏覽器進行注冊 1.剛開始…

maven編譯始終提示無效的目標發行版的解決方法

摘自個人印象筆記2021-05-07:https://app.yinxiang.com/fx/55e1d5f4-aeea-446a-a768-0f1a48195f5b(圖顯示不完整可查看原筆記內容)1:確保IDE中的編譯版本正確 在idea中,主要看項目屬性中和setting的java compiler中對應的jdk版本是否正確&…

好用的安卓手機投屏到mac分享

工具推薦:scrcpy github地址:https://github.com/Genymobile/scrcpy/tree/master mac使用方式 安裝環境,打開terminal,執行以下命令,沒有brew的先安裝brew brew install scrcpy brew install android-platform-too…

學習 Iterator 迭代器

今天看到一個面試題, 讓下面解構賦值成立。 let [a,b] {a:1,b:2} 如果我們直接在瀏覽器輸出這行代碼,會直接報錯,說是 {a:1,b:2} 不能迭代。 看了es6文檔后,具有迭代器的就一下幾種類型,沒有Object類型,…

404. 左葉子之和

給定二叉樹的根節點 root ,返回所有左葉子之和。 示例 1: 輸入: root [3,9,20,null,null,15,7] 輸出: 24 解釋: 在這個二叉樹中,有兩個左葉子,分別是 9 和 15,所以返回 24示例 2: 輸入: root [1] 輸出: 0提示: 節點…

【NetCore】09-中間件

文章目錄 中間件:掌控請求處理過程的關鍵1. 中間件1.1 中間件工作原理1.2 中間件核心對象 2.異常處理中間件:區分真異常和邏輯異常2.1 處理異常的方式2.1.1 日常錯誤處理--定義錯誤頁的方法2.1.2 使用代理方法處理異常2.1.3 異常過濾器 IExceptionFilter2.1.4 特性過…

go web框架 gin-gonic源碼解讀02————router

go web框架 gin-gonic源碼解讀02————router 本來想先寫context,但是發現context能簡單講講的東西不多,就準備直接和router合在一起講好了 router是web服務的路由,是指講來自客戶端的http請求與服務器端的處理邏輯或者資源相映射的機制。&…

react實現對數組做增刪改操作自定義hook

需求 實現對數組的增刪改操作。 實現 import { useState } from react;const useArray (currList) > {const [list, setList] useState(currList);// 增const addItem (item) > {setList([...list, item]);};// 刪const removeItem (idx) > {const _arr [...l…

實戰指南,SpringBoot + Mybatis 如何對接多數據源

系列文章目錄 MyBatis緩存原理 Mybatis plugin 的使用及原理 MyBatisSpringboot 啟動到SQL執行全流程 數據庫操作不再困難,MyBatis動態Sql標簽解析 從零開始,手把手教你搭建Spring Boot后臺工程并說明 Spring框架與SpringBoot的關聯與區別 Spring監聽器…

輕松解決docker容器啟動閃退

docker run -p 3306:3306 --name mysql8 \ -v /usr/local/mysql/log:/var/log/mysql \ -v /usr/local/mysql/data:/var/lib/mysql \ -v /usr/local/mysql/conf:/etc/mysql \ -e MYSQL_ROOT_PASSWORD666 -d mysql:8.0.32執行這個命令的時候閃退,其實這個是命令是對你…

[cv] stable diffusion——2、公式

背景: 在圖像生成領域中,最常見的生成模型是GAN和VAE。然而,在2020年,提出了一種新的模型,即DDPM(Denoising Diffusion Probabilistic Model),也被稱為擴散模型(Diffusi…

基于eBPF技術構建一種應用層網絡管控解決方案

引言 隨著網絡應用的不斷發展,在linux系統中對應用層網絡管控的需求也日益增加,而傳統的iptables、firewalld等工具難以針對應用層進行網絡管控。因此需要一種創新的解決方案來提升網絡應用的可管理性。 本文將探討如何使用eBPF技術構建一種應用層網絡…

【CSS】禁用元素鼠標事件(例如實現元素禁用效果)

文章目錄 基本用法 基本用法 pointer-events 屬性指定在什么情況下 (如果有) 某個特定的圖形元素可以成為鼠標事件。實際運用中可以通過對auto 和none動態控制,來動態實現元素的禁用效果。 屬性描述auto與pointer-events屬性未指定時的表現效果相同,對…

【筆試題心得】排序算法總結整理

排序算法匯總 常用十大排序算法_calm_G的博客-CSDN博客 以下動圖參考 十大經典排序算法 Python 版實現(附動圖演示) - 知乎 冒泡排序 排序過程如下圖所示: 比較相鄰的元素。如果第一個比第二個大,就交換他們兩個。對每一對相鄰…

【LeetCode-簡單】劍指 Offer 29. 順時針打印矩陣(詳解)

題目 輸入一個矩陣,按照從外向里以順時針的順序依次打印出每一個數字。 示例 1: 輸入:matrix [[1,2,3],[4,5,6],[7,8,9]] 輸出:[1,2,3,6,9,8,7,4,5]示例 2: 輸入:matrix [[1,2,3,4],[5,6,7,8],[9,10,1…

互聯網發展歷程:速度與效率,交換機的登場

互聯網的演進就像一場追求速度與效率的競賽,每一次的技術升級都為我們帶來更快、更高效的網絡體驗。然而,在網絡的初期階段,人們面臨著數據傳輸速度不夠快的問題。一項關鍵的技術應運而生,那就是“交換機”。 速度不足的困境&…

CloudEvents—云原生事件規范

我們的系統中或多或少都會用到如下兩類業務技術: 異步任務,用于降低接口時延或削峰,提升用戶體驗,降低系統并發壓力;通知類RPC,用于微服務間狀態變更,用戶行為的聯動等場景; 以上兩種…

Go和Java實現解釋器模式

Go和Java實現解釋器模式 下面通過一個四則運算來說明解釋器模式的使用。 1、解釋器模式 解釋器模式提供了評估語言的語法或表達式的方式,它屬于行為型模式。這種模式實現了一個表達式接口,該接口 解釋一個特定的上下文。這種模式被用在 SQL 解析、符…