CVPR2022人臉識別Partial FC論文及代碼學習筆記

論文鏈接:https://openaccess.thecvf.com/content/CVPR2022/papers/An_Killing_Two_Birds_With_One_Stone_Efficient_and_Robust_Training_CVPR_2022_paper.pdf

代碼鏈接:insightface/recognition/arcface_torch at master · deepinsight/insightface · GitHub

背景

使用基于百萬規模的數據集和基于margin的softmax損失函數來學習區分性的embeddings是當前人臉識別的SOTA方法。然而,全連接層的內存和計算成本隨著訓練集中ID數量的增加而線性增加。此外,大規模訓練數據存在類間沖突(同一個人被分成不同ID)和長尾分布的問題。

傳統FC

將傳統的FC層應用在大規模的數據集上時,存在以下缺陷:

1、gradient confusion under interclass conflict

WebFace42M里有很多不同類別對之間的余弦相似度大于0.4,這表明類間沖突仍然存在于這些清洗過的數據集中。直接優化的話會導致gradient confusion(同一個人的特征非常相似卻要掰成兩個ID)

2、centers of tail classes undergo too many passive updates

每個iteration都優化圖片數量很少的id,可能會導致負優化

3、the storage and calculation of the FC layer can easily exceed current GPU capabilities

PartialFC

在訓練期間仍然維護所有類別中心,但只隨機采樣一小部分負類別中心來計算基于margin的softmax損失,而不是在每次迭代中使用所有負類別中心。更具體地說,首先從每個GPU收集embeddings和標簽,然后將組合的特征和標簽分布到所有GPU。為了平衡每個GPU的內存使用和計算成本,為每個GPU設置了一個內存緩沖區(下面代碼中的perm)。內存緩沖區的大小由類別總數和負類別中心的采樣率決定。在每個GPU上,首先通過標簽選擇正類中心并放入緩沖區,然后隨機選擇一小部分負類中心(負類中心的數量為self.sample_rate * self.num_local)填充緩沖區的其余部分,

def sample(self, labels, index_positive):"""This functions will change the value of labelsParameters:-----------labels: torch.Tensorpassindex_positive: torch.Tensorpassoptimizer: torch.optim.Optimizerpass"""with torch.no_grad():positive = torch.unique(labels[index_positive], sorted=True).cuda()if self.num_sample - positive.size(0) >= 0:perm = torch.rand(size=[self.num_local]).cuda()perm[positive] = 2.0index = torch.topk(perm, k=self.num_sample)[1].cuda()index = index.sort()[0].cuda()else:index = positiveself.weight_index = indexlabels[index_positive] = torch.searchsorted(index, labels[index_positive])return self.weight[self.weight_index]

隨后,使用選出的樣本中心去與特征相乘并計算基于margin的softmax損失。

PFC在DDP框架下的流程圖如下圖所示,

整體代碼如下,

class PartialFC_V2(torch.nn.Module):"""https://arxiv.org/abs/2203.15565A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).When sample rate less than 1, in each iteration, positive class centers and a random subset ofnegative class centers are selected to compute the margin-based softmax loss, all classcenters are still maintained throughout the whole training process, but only a subset isselected and updated in each iteration... note::When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).Example:-------->>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)>>> for img, labels in data_loader:>>>     embeddings = net(img)>>>     loss = module_pfc(embeddings, labels)>>>     loss.backward()>>>     optimizer.step()"""_version = 2def __init__(self,margin_loss: Callable,embedding_size: int,num_classes: int,sample_rate: float = 1.0,fp16: bool = False,):"""Paramenters:-----------embedding_size: intThe dimension of embedding, requirednum_classes: intTotal number of classes, requiredsample_rate: floatThe rate of negative centers participating in the calculation, default is 1.0."""super(PartialFC_V2, self).__init__()assert (distributed.is_initialized()), "must initialize distributed before create this"self.rank = distributed.get_rank()self.world_size = distributed.get_world_size()self.dist_cross_entropy = DistCrossEntropy()self.embedding_size = embedding_sizeself.sample_rate: float = sample_rateself.fp16 = fp16self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size)self.class_start: int = num_classes // self.world_size * self.rank + min(self.rank, num_classes % self.world_size)self.num_sample: int = int(self.sample_rate * self.num_local)self.last_batch_size: int = 0self.is_updated: bool = Trueself.init_weight_update: bool = Trueself.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))# margin_lossif isinstance(margin_loss, Callable):self.margin_softmax = margin_losselse:raisedef sample(self, labels, index_positive):"""This functions will change the value of labelsParameters:-----------labels: torch.Tensorpassindex_positive: torch.Tensorpassoptimizer: torch.optim.Optimizerpass"""with torch.no_grad():positive = torch.unique(labels[index_positive], sorted=True).cuda()if self.num_sample - positive.size(0) >= 0:perm = torch.rand(size=[self.num_local]).cuda()perm[positive] = 2.0index = torch.topk(perm, k=self.num_sample)[1].cuda()index = index.sort()[0].cuda()else:index = positiveself.weight_index = indexlabels[index_positive] = torch.searchsorted(index, labels[index_positive])return self.weight[self.weight_index]def forward(self,local_embeddings: torch.Tensor,local_labels: torch.Tensor,):"""Parameters:----------local_embeddings: torch.Tensorfeature embeddings on each GPU(Rank).local_labels: torch.Tensorlabels on each GPU(Rank).Returns:-------loss: torch.Tensorpass"""local_labels.squeeze_()local_labels = local_labels.long()batch_size = local_embeddings.size(0)if self.last_batch_size == 0:self.last_batch_size = batch_sizeassert self.last_batch_size == batch_size, (f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")_gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda()for _ in range(self.world_size)]_gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)]_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)distributed.all_gather(_gather_labels, local_labels)embeddings = torch.cat(_list_embeddings)labels = torch.cat(_gather_labels)## 選出落在本進程對應的類別范圍內的數據labels = labels.view(-1, 1)index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local)## 標簽不在本類別段的, 將其類別標簽設為-1labels[~index_positive] = -1## 將類別ID平移到原點(因為不同進程都會初始化對應的self.weight, 若不平移回去, 則label與self.weight中的index會對應不上)labels[index_positive] -= self.class_startif self.sample_rate < 1:weight = self.sample(labels, index_positive)else:weight = self.weightwith torch.cuda.amp.autocast(self.fp16):norm_embeddings = normalize(embeddings)norm_weight_activated = normalize(weight)logits = linear(norm_embeddings, norm_weight_activated)if self.fp16:logits = logits.float()logits = logits.clamp(-1, 1)logits = self.margin_softmax(logits, labels)loss = self.dist_cross_entropy(logits, labels)return loss

實驗結果

將PFC替換掉傳統FC后,模型在WebFace(包括4m、12m、42m)上的性能會有所提升,

?消融實驗的結果如下,

與SOTA方法的性能對比如下,?

結論與討論

結論

作者提出了一種用于在大規模數據集上訓練人臉識別模型的方法——Partial FC (PFC)。在PFC的每次迭代中,僅選擇一小部分類別中心來計算基于邊際的softmax損失,這樣可以顯著減少類間沖突的概率、尾類中心的被動更新頻率以及計算需求。通過廣泛的實驗,作者驗證了所提出的PFC的有效性、魯棒性和高效性。

局限性

盡管在WebFace上訓練的PFC模型在高質量測試集上取得了不錯的結果,但在人臉分辨率較低或低光照條件下拍攝的人臉上,PFC模型的表現可能較差。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/12201.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/12201.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/12201.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

DeepLab V3+: 引入可分離卷積與Decoder網絡

文章目錄 摘要引入深度可分離卷積普通卷積深度卷積,Depthwise點卷積Introduction & Related WorkMethodsEncoder-Decoder with Atrous ConvolutionModified Aligned Xception實驗結果Decoder部分的消融實驗ResNet-101作為backbone

基于lidar的多目標跟蹤

文章目錄 基本流程編譯過程注意事項基本流程 基于雷達點云的目標追蹤主要包括以下幾個步驟: 點云預處理: 濾除噪點和無效點(如NaN值)進行平面分割,提取地面點云對剩余的點云進行聚類,得到可能的目標點云目標檢測 對聚類后的點云進行分析,判斷是否為有效目標可以利用目標的尺寸…

怎么轉換音頻?看這3款音頻轉換器

隨著數字媒體的發展&#xff0c;音頻文件在我們的日常生活中占據了越來越重要的地位。有時候在不同的應用場景里&#xff0c;無論是音樂、語音還是其他類型的音頻內容&#xff0c;我們都需要對其進行轉換以滿足不同的需求。 本文將為您介紹3款常用的音頻轉換器&#xff0c;幫助…

如何讓Linux崩潰?

如何使 Linux 系統崩潰 警告 下面的代碼行是 Bash shell 的一個簡短而甜蜜的 fork 炸彈。分叉炸彈之所以有效&#xff0c;是因為它能夠產生無限數量的進程。最終&#xff0c;Linux無法處理所有這些&#xff0c;并且會崩潰。 fork 炸彈的一大優點是你不需要 root 權限即可執行它…

Springboot+mybatis-plus+dynamic-datasource+繼承DynamicRoutingDataSource切換數據源

Springbootmybatis-plusdynamic-datasource繼承DynamicRoutingDataSource切換數據源 背景 最近公司要求支持saas&#xff0c;實現動態切換庫的操作&#xff0c;默認會加載主租戶的數據源&#xff0c;其他租戶數據源在使用過程中自動創建加入。 解決問題 1.通過請求中設置租…

數據可視化訓練第7天(json文件讀取國家人口數據,找出前10和后10)

數據 https://restcountries.com/v3.1/all&#xff1b;建議下載下來&#xff0c;并不是很大 import numpy as np import matplotlib.pyplot as plt import requests import json #由于訪問url過于慢&#xff1b;將數據下載到本地是json數據 #urlhttps://restcountries.com/v3…

MATLAB蟻群算法求解帶時間窗的旅行商TSPTW問題代碼實例

MATLAB蟻群算法求解帶時間窗的旅行商TSPTW問題代碼實例 蟻群算法編程求解TSPTW問題實例&#xff1a; 在經緯度范圍為(121, 43)到(123, 45)的矩形區域內&#xff0c;散布著1個商家&#xff08;編號1&#xff09;和25個顧客點&#xff08;編號為226&#xff09;&#xff0c;各個…

前端工程化實踐:Monorepo與Lerna管理

前端工程化實踐中&#xff0c;Monorepo&#xff08;單倉庫&#xff09;管理和Lerna是兩種流行的方式&#xff0c;用于大型項目或組件庫的組織和版本管理。 2500G計算機入門到高級架構師開發資料超級大禮包免費送&#xff01; Monorepo簡介 Monorepo&#xff08;單倉庫&#…

web入門練手案例(二)

下面是一下web入門案例和實現的代碼&#xff0c;帶有部分注釋&#xff0c;倘若代碼中有任何問題或疑問&#xff0c;歡迎留言交流~ 數字變色Logo 案例描述 “Logo”是“商標”的英文說法&#xff0c;是企業最基本的視覺識別形象&#xff0c;通過商標的推廣可以讓消費者了解企…

第一個Rust程序

在安裝好Rust以后&#xff0c;我們就可以編寫程序了。 首先&#xff0c;我們執行下面的命令&#xff0c;盡量讓你的rust版本和我的版本相同&#xff0c;或者比我的版本大。 zhangdapengzhangdapeng:~$ cargo --version cargo 1.78.0 (54d8815d0 2024-03-26) zhangdapengzhangd…

C語言(指針)2

Hi~&#xff01;這里是奮斗的小羊&#xff0c;很榮幸各位能閱讀我的文章&#xff0c;誠請評論指點&#xff0c;關注收藏&#xff0c;歡迎歡迎~~ &#x1f4a5;個人主頁&#xff1a;小羊在奮斗 &#x1f4a5;所屬專欄&#xff1a;C語言 本系列文章為個人學習筆記&#x…

聽說SOLIDWORKS科研版可以節約研發成本?

近幾年來&#xff0c;政府越來越重視科研帶動產業&#xff0c;績效優良的產業技術研究院對于國家和地區的學術成果轉化、技術創新、產業發展等具有不可忽視的促進和帶動作用。研究院會承擔眾多新產業的基礎研究工作&#xff0c;而常規的基礎研究需要長期的積累&#xff0c;每個…

JAVA畢業設計141—基于Java+Springboot+Vue的物業管理系統(源代碼+數據庫)

畢設所有選題&#xff1a; https://blog.csdn.net/2303_76227485/article/details/131104075 基于JavaSpringbootVue的物業管理系統(源代碼數據庫)141 一、系統介紹 本項目前后端分離&#xff0c;分為管理員、員工、用戶三種角色(角色權限可自行分配) 1、用戶&#xff1a; …

Nginx詳解:高性能HTTP和反向代理服務器

Nginx詳解&#xff1a;高性能HTTP和反向代理服務器 一、引言 Nginx&#xff08;發音為“engine x”&#xff09;是一個開源的高性能HTTP和反向代理服務器&#xff0c;也是一個IMAP/POP3/SMTP代理服務器。由于其出色的性能和穩定性&#xff0c;Nginx已經成為互聯網上最受歡迎的…

asp.net結課作業中遇到的問題解決4

目錄 1、vs2019每次運行一次項目之后&#xff0c;樣式表的格式就算在vs2019上改變了&#xff0c;在瀏覽器中顯示的還是以前的樣式&#xff0c;所以應該如何修改 2、如何實現選擇下拉框之后&#xff0c;顯示所選擇的這個類型的書籍的名稱 3、如何實現點擊首頁顯示的書籍&#…

高清模擬視頻采集卡CVBS四合一信號采集設備解析

介紹一款新產品——LCC261高清視頻采集與編解碼一體化采集卡。這款高品質的產品擁有卓越的性能表現和豐富多樣的功能特性&#xff0c;能夠滿足廣大用戶對于高清視頻采集、處理以及傳輸的需求。 首先&#xff0c;讓我們來了解一下LCC261的基本信息。它是一款基于靈卡技術研發的高…

Shell三劍客之sed

前言&#xff1a; Shell三劍客是grep、sed和awk三個工具的簡稱,因功能強大&#xff0c;使用方便且使用頻率高&#xff0c;因此被戲稱為三劍客&#xff0c;熟練使用這三個工具可以極大地提升運維效率。 sed是一個流編輯器&#xff0c;用于對文本進行編輯、替換、刪除等操作。sed…

LeetCode2095刪除鏈表的中間節點

題目描述 給你一個鏈表的頭節點 head 。刪除 鏈表的 中間節點 &#xff0c;并返回修改后的鏈表的頭節點 head 。長度為 n 鏈表的中間節點是從頭數起第 ?n / 2? 個節點&#xff08;下標從 0 開始&#xff09;&#xff0c;其中 ?x? 表示小于或等于 x 的最大整數。對于 n 1、…

深入探索Android簽名機制:從v1到v3的演進之旅

引言 在Android開發的世界中&#xff0c;APK的簽名機制是確保應用安全性的關鍵環節。隨著技術的不斷進步&#xff0c;Android簽名機制也經歷了從v1到v3的演進。本文將帶你深入了解Android簽名機制的演變過程&#xff0c;揭示每個版本背后的技術細節&#xff0c;并探討它們對開…

淺談下MYSQL表設計的幾條規則

作為后端開發人員&#xff0c;避免不了和數據庫打交道&#xff0c;可是我們怎么能夠設計出高效&#xff0c;可維護&#xff0c;可擴展的數據庫設計呢&#xff0c;在這里我總結了幾個點&#xff0c;供大家參考。 在寫之前&#xff0c;可能需要重復下數據庫設計的范式原則&#…