SHARPNESS-AWARE MINIMIZATION FOR EFFICIENTLY IMPROVING GENERALIZATION--論文筆記

論文筆記

資料

1.代碼地址

https://github.com/google-research/sam
https://github.com/davda54/sam

2.論文地址

https://arxiv.org/abs/2010.01412

3.數據集地址

論文摘要的翻譯

在當今嚴重過度參數化的模型中,訓練損失的值很難保證模型的泛化能力。事實上,像通常所做的那樣,只優化訓練損失值很容易導致次優的模型質量。受損失的幾何圖像與泛化相結合的前人工作的啟發,我們引入了一種新的、有效的同時最小化損失值和損失銳度的方法。特別是,我們的方法,銳度感知最小化(SAM),尋找位于具有一致低損失的鄰域的參數;這個公式致使一個最小值-最大值的優化問題,在該問題上可以有效地執行梯度下降。我們提出的經驗結果表明,SAM提高了各種基準數據集(例如,CIFAR-{10,100}、ImageNet、微調任務)和模型的模型泛化能力,為幾個數據集帶來了新的最先進的性能。此外,我們發現SAM原生地提供了對標簽噪聲的魯棒性,這與專門針對帶有噪聲標簽的學習的過程所提供的魯棒性不相上下。

1背景

現代機器學習成功地在廣泛的任務中實現了越來越好的性能,這在很大程度上取決于越來越重的過度參數化,以及開發越來越有效的訓練算法,這些算法能夠找到很好地泛化的參數。事實上,許多現代神經網絡可以很容易地記住訓練數據,并具有容易過擬合的能力。目前需要這種嚴重的過度參數化才能在各種領域實現最先進的結果。反過來,至關重要的是,使用程序來訓練這些模型,以確保實際選擇的參數事實上超越了訓練集。
不幸的是,簡單地最小化訓練集上常用的損失函數(例如,交叉熵)通常不足以實現令人滿意的泛化。今天的模型的訓練損失景觀通常是復雜的和非凸的,具有多個局部和全局極小值,并且具有不同的全局極小值產生具有不同泛化能力的模型。因此,從許多可用的(例如,隨機梯度下降、Adam)、RMSProp和其他中選擇優化器(和相關的優化器設置)已成為一個重要的設計選擇,盡管對其與模型泛化的關系的理解仍處于初級階段。與此相關的是,已經提出了一系列修改訓練過程的方法,包括dropout,批量歸一化、隨機深度、數據增強和混合樣本增強.
損失圖像的幾何形狀——特別是最小值的平坦性和泛化之間的聯系已經從理論和實證的角度進行了廣泛的研究)。雖然這種聯系有望實現新的模型訓練方法,從而產生更好的泛化能力,但迄今為止,專門尋找更平坦的最小值并進一步有效提高一系列最先進模型泛化能力的實用高效算法一直難以實現;我們在第5節中對先前的工作進行了更詳細的討論)。

2論文的創新點

  • 我們引入了銳度感知最小化(SAM),這是一種新的方法,通過同時最小化損失值和銳度來提高模型的泛化能力。SAM通過尋找位于具有一致低損耗值的鄰域中的參數(而不是僅具有低損耗值的參數,如圖1的中間和右側圖像所示)來工作,并且可以高效且容易地實現。

在這里插入圖片描述

  • 我們通過一項嚴格的實證研究表明,使用SAM提高了一系列廣泛研究的計算機視覺任務(例如,CIFAR-{10,100},ImageNet,微調任務)和模型的模型泛化能力,如圖1的左側曲線圖中總結的。例如,應用SAM為許多已經深入研究的任務,如ImageNet,CIFAR-{10,100},SVHN,Fashion-MNIST,和標準的圖像分類微調任務集(例如,Flowers,Stanford Cars,Oxford Pets,等)產生了新的最先進的性能。
  • 我們還表明,SAM進一步提供了標簽噪聲的穩健性,與專門針對帶有噪聲標簽的學習的最先進的程序所提供的一樣。
  • 通過SAM提供的視角,我們提出了一個很有使用價值的銳度的新概念,我們稱之為m-銳度,從而進一步闡明了損失銳度和泛化之間的聯系。

3 論文方法的概述

在本文中,我們將標量表示為 a a a,將向量表示為 a \mathbf{a} a、將矩陣表示為 A \Alpha A、將集合表示為 A \mathcal A A,將等式定義為,。給定訓練數據集 S ? ∪ i = 1 n { ( x i , y i ) } \mathcal{S}\triangleq\cup_{i=1}^n\{(x_i,y_i)\} S?i=1n?{(xi?,yi?)}從分布 D , \mathscr{D}, D,,中i.i.d.繪制,我們尋求學習一個很好地泛化的模型。特別地,考慮一組由 w ∈ W ? R d w\in\mathcal{W}\subseteq\mathbb{R}^d wW?Rd參數化的模型;給定一個逐數據點損失函數 l : W × X × Y → R + l:\mathcal{W}\times\mathcal{X}\times\mathcal{Y}\to\mathbb{R}_+ l:W×X×YR+?,我們定義了訓練集損失 L S ( w ) ? 1 n ∑ i = 1 n l ( w , x i , y i ) L_S(\boldsymbol{w})\triangleq\frac1n\sum_{i=1}^nl(\boldsymbol{w},\boldsymbol{x}_i,\boldsymbol{y}_i) LS?(w)?n1?i=1n?l(w,xi?,yi?)和總體損失 L D ( w ) ? E ( x , y ) ~ D [ l ( w , x , y ) ] L_{\mathscr{D}}(\boldsymbol{w})\triangleq\mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\thicksim D}[l(\boldsymbol{w},\boldsymbol{x},\boldsymbol{y})] LD?(w)?E(x,y)D?[l(w,x,y)]。在僅觀察到 S \mathcal S S的情況下,模型訓練的目標是選擇具有低總體損失 L D ( w ) L_{\mathscr{D}}(\boldsymbol{w}) LD?(w)的模型參數 w w w
利用 L S ( w ) 作為 L D ( w ) L_S(\boldsymbol{w})作為L_{\mathscr{D}}(\boldsymbol{w}) LS?(w)作為LD?(w)的估計,通過使用諸如SGD或Adam之類的優化過程來求解 m i n w L S ( w ) min_w L_S(\boldsymbol{w}) minw?LS?(w)(可能與w上的正則化子結合)來激勵選擇參數w的標準方法。然而,不幸的是,對于現代的過度參數化模型,如深度神經網絡,典型的優化方法很容易在測試時導致次優性能。特別地,對于現代模型, L S ( w ) L_S(\boldsymbol{w}) LS?(w)通常在w中是非凸的,具有多個局部甚至全局最小值,這些局部甚至全局極小值可以產生相似的 L S ( w ) L_S(\boldsymbol{w}) LS?(w)值,同時具有顯著不同的泛化性能(即,顯著不同的 L D ( w ) L_{\mathscr{D}}(\boldsymbol{w}) LD?(w)值)。
受損失圖形的銳度和泛化之間的聯系的啟發,我們提出了一種不同的方法:我們不是尋找簡單地具有低訓練損失值 L S ( w ) L_S(\boldsymbol{w}) LS?(w)的參數值 w w w,而是尋找整個鄰域具有一致低訓練損失的參數值(相當于,具有低損失和低曲率的鄰域)。以下定理通過在鄰域訓練損失方面限制泛化能力來說明這種方法的動機(附錄A中的完整定理陳述和證明):

  • 定理1
    對于任意 ρ > 0 ρ>0 ρ>0、在由分布 D \mathscr D D生成的訓練集 S 上 \mathcal S上 S具有高概率的情況, L D ( w ) ≤ max ? ∥ ? ∥ 2 ≤ ρ L S ( w + ? ) + h ( ∥ w ∥ 2 2 / ρ 2 ) , L_\mathscr{D}(\boldsymbol{w})\leq\max_{\|\boldsymbol{\epsilon}\|_2\leq\rho}L_\mathcal{S}(\boldsymbol{w}+\boldsymbol{\epsilon})+h(\|\boldsymbol{w}\|_2^2/\rho^2), LD?(w)?2?ρmax?LS?(w+?)+h(w22?/ρ2),其中h: R + → R + \mathbb{R}_+\to\mathbb{R}_+ R+?R+?是嚴格遞增函數(在 L D ( w ) L_{\mathscr{D}}(\boldsymbol{w}) LD?(w)上的一些技術條件下)。為了明確我們的銳度項,我們可以將上面不等式的右側重寫為為了明確我們的銳度項,我們可以將上面不等式的右側重寫為 [ max ? ∥ ? ∥ 2 ≤ ρ L S ( w + ? ) ? L S ( w ) ] + L S ( w ) + h ( ∥ w ∥ 2 2 / ρ 2 ) . [\max_{\|\boldsymbol{\epsilon}\|_2\leq\rho}L_\mathcal{S}(\boldsymbol{w}+\boldsymbol{\epsilon})-L_\mathcal{S}(\boldsymbol{w})]+L_\mathcal{S}(\boldsymbol{w})+h(\|\boldsymbol{w}\|_2^2/\rho^2). [max?2?ρ?LS?(w+?)?LS?(w)]+LS?(w)+h(w22?/ρ2).
    方括號中的項通過測量通過從 w w w移動到附近的參數值可以以多快的速度增加訓練損失來捕捉 L S L_S LS? w w w處的銳度;然后將該銳度項與訓練損失值本身和 w w w大小上的正則化子求和。給定特定函數h深受證明細節的影響,我們用 λ ∣ ∣ w ∣ ∣ 2 2 \lambda||w||_2^2 λ∣∣w22?代替超參數λ的第二項,得到標準L2正則化項。因此,受約束項的啟發,我們建議通過解決以下SharpnessAware最小化(SAM)問題來選擇參數值:
    min ? w L S S A M ( w ) + λ ∥ w ∥ 2 2 w h e r e L S S A M ( w ) ? max ? ∣ ∣ ? ∣ ∣ p ≤ ρ L S ( w + ? ) , \min_{\boldsymbol{w}}L_{\mathcal{S}}^{SAM}(\boldsymbol{w})+\lambda\|\boldsymbol{w}\|_2^2\quad\mathrm{~where~}\quad L_{\mathcal{S}}^{SAM}(\boldsymbol{w})\triangleq\max_{||\boldsymbol{\epsilon}||_p\leq\rho}L_S(\boldsymbol{w}+\boldsymbol{\epsilon}), wmin?LSSAM?(w)+λw22??where?LSSAM?(w)?∣∣?p?ρmax?LS?(w+?),其中 ρ ≥ 0 ρ≥0 ρ0是一個超參數, p ∈ [ 1 , ∞ ] p∈[1,∞] p[1](我們在最大化過程中從L2范數略微推廣到p范數,本文已經證明 p = 2 p=2 p=2通常是最優的)。圖1顯示了1通過最小化 L S ( w ) L_S(\boldsymbol{w}) LS?(w) L S S A M ( w ) L^{SAM}_S(\boldsymbol{w}) LSSAM?(w)而收斂到最小值的模型的損失情況,說明了銳度感知損失阻止了模型收斂到尖銳的最小值。在這里插入圖片描述
    為了最小化 L S S A M ( w ) L^{SAM}_S(\boldsymbol{w}) LSSAM?(w),我們通過內部最大化進行微分,推導出了一個有效的近似值,從而使我們能夠將隨機梯度下降直接應用于SAM目標。沿著這條路前進,我們首先通過 L S ( w + ? ) w . r . t . ? L_{\mathcal{S}}(w+\epsilon)\mathrm{~w.r.t.~}\epsilon LS?(w+?)?w.r.t.?? 在0附近的一階泰勒展開來近似內部最大化問題,得到 ? w ) ? arg ? max ? ∥ ? ∥ p ≤ ρ L S ( w + ? ) ≈ arg ? max ? ∥ ? ∥ p ≤ ρ L S ( w ) + ? T ? w L S ( w ) = arg ? max ? ∥ ? ∥ p ≤ ρ ? T ? w L S ( w ) . \boldsymbol{\epsilon}\\\boldsymbol{w})\triangleq\arg\max_{\|\boldsymbol{\epsilon}\|_p\leq\rho}L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})\approx\arg\max_{\|\boldsymbol{\epsilon}\|_p\leq\rho}L_{\mathcal{S}}(\boldsymbol{w})+\boldsymbol{\epsilon}^T\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(\boldsymbol{w})=\arg\max_{\|\boldsymbol{\epsilon}\|_p\leq\rho}\boldsymbol{\epsilon}^T\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(\boldsymbol{w}). ?w)?arg?p?ρmax?LS?(w+?)arg?p?ρmax?LS?(w)+?T?w?LS?(w)=arg?p?ρmax??T?w?LS?(w).反過來,求解該近似的值 ξ ( w ) ξ(w) ξ(w)由經典對偶范數問題的解給出( ∣ ? ∣ q ? 1 |·|^{q?1} ?q?1表示元素絕對值和冪): ? ^ ( w ) = ρ s i g n ( ? w L S ( w ) ) ∣ ? w L S ( w ) ∣ q ? 1 / ( ∥ ? w L S ( w ) ∥ q q ) 1 / p (2) \hat{\boldsymbol{\epsilon}}(w)=\rho\mathrm{~sign}\left(\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(w)\right)|\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(w)|^{q-1}/\left(\|\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(w)\|_{q}^{q}\right)^{1/p}\text{(2)} ?^(w)=ρ?sign(?w?LS?(w))?w?LS?(w)q?1/(?w?LS?(w)qq?)1/p(2)其中1/p+1/q=1。代入方程(1)并進行微分,我們得到 ? w L S S A M ( w ) ≈ ? w L S ( w + ? ^ ( w ) ) = d ( w + ? ^ ( w ) ) d w ? w L S ( w ) ∣ w + ? ^ ( w ) \nabla_{\boldsymbol{w}}L_{\mathcal{S}}^{SAM}(\boldsymbol{w})\approx\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w}))=\frac{d(\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w}))}{d\boldsymbol{w}}\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(\boldsymbol{w})|_{\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(w)} ?w?LSSAM?(w)?w?LS?(w+?^(w))=dwd(w+?^(w))??w?LS?(w)w+?^(w)?
    可以通過自動微分直接計算這種對 ? w L S S A M ( w ) \nabla_{\boldsymbol{w}}L_{\mathcal{S}}^{SAM}(w) ?w?LSSAM?(w)的近似,如在JAX、TensorFlow和PyTorch等公共庫中實現的那樣。盡管該計算隱含地依賴于LS(w)的Hessian,因為ξ。獲得我們的最終梯度近似: ? w L S S A M ( w ) ≈ ? w L S ( w ) ∣ w + ? ^ ( w ) . \nabla_{\boldsymbol{w}}L_{\mathcal{S}}^{SAM}(\boldsymbol{w})\approx\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(w)|_{\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w})}. ?w?LSSAM?(w)?w?LS?(w)w+?^(w)?.
    我們通過將標準數值優化器(如隨機梯度下降(SGD))應用于SAM目標 L S S A M ( w ) L_\mathcal{S}^{SAM}(\boldsymbol{w}) LSSAM?(w),使用方程3來計算必要的目標函數梯度,從而獲得最終的SAM算法。算法1給出了完整SAM算法的偽代碼,使用SGD作為基本優化器,圖2示意性地說明了單個SAM參數更新。
    在這里插入圖片描述在這里插入圖片描述

4 論文實驗

為了評估SAM的功效,我們將其應用于一系列不同的任務,包括從頭開始的圖像分類(包括在CIFAR-10、CIFAR-100和ImageNet上)、微調預訓練的模型以及使用噪聲標簽進行學習。在所有情況下,我們都通過簡單地用SAM代替用于訓練現有模型的優化程序來衡量使用SAM的好處,并計算由此對模型泛化的影響。如下所示,SAM在絕大多數情況下都能顯著提高泛化性能。

4.1 圖像分類

我們首先評估了SAM對當今最先進的CIFAR-10和CIFAR-100模型(無需預訓練)泛化的影響:具有ShakeShake正則化的WideResNets和具有ShakeDrop正則化的PyramidNet。請注意,這些模型中的一些已經在先前的工作中進行了大量調整,并包括精心選擇的正則化方案,以防止過擬合;因此,顯著提高它們的泛化能力是非常重要的。我們已經確保在沒有SAM的情況下,我們的實現的泛化性能與先前工作中報告的性能相匹配或超過。
所有結果都使用了基本數據增強(水平翻轉、四像素填充和隨機裁剪)。我們還評估了更先進的數據增強方法,如剪切正則化和AutoAugment(,這些方法被先前的工作用來實現最先進的結果。
SAM具有單個超參數 ρ (鄰域大小) ρ(鄰域大小) ρ(鄰域大小),我們使用10%的訓練集作為驗證集,通過在 { 0.01 , 0.02 , 0.05 , 0.1 , 0.2 , 0.5 } \{0.01,0.02,0.05,0.1,0.2,0.5\} {0.010.020.050.10.20.5}上的網格搜索對其進行調整。有關所有超參數的值和其他訓練細節,請參見附錄C.1。由于每個SAM權重更新需要兩個反向傳播操作(一個用于計算 ξ ( w ) ξ(w) ξw,另一個用于估計最終梯度),我們允許每個非SAM訓練運行執行兩倍于每個SAM訓練運行的輪次,并且我們報告每個非SAM訓練運行在標準輪次計數或加倍輪次計數上獲得的最佳得分。我們運行每個實驗條件的五個獨立副本,報告結果(每個條件都有獨立的權重初始化和數據混洗),報告測試集的平均誤差(或準確性)以及相關的95%置信區間。我們的實現使用JAX,并且我們在具有8個NVidia V100 GPU的單個主機上訓練所有模型。為了在跨多個加速器并行時計算SAM更新,我們在加速器之間均勻地劃分每個數據批次,獨立地計算每個加速器上的SAM梯度,并對得到的子批次SAM梯度進行平均,以獲得最終的SAM更新。
如表1所示,SAM提高了針對CIFAR-10和CIFAR-100評估的所有設置的泛化能力。例如,SAM使簡單的WideResNet能夠獲得1.6%的測試誤差,而沒有SAM時的誤差為2.2%。這種增益以前只能通過使用更復雜的模型架構(例如,PyramidNet)和正則化方案(例如,Shake-Shake、ShakeDrop)來實現;SAM提供了一個易于實施的、獨立于模型的替代方案。此外,即使在已經使用復雜正則化的復雜架構上應用SAM,SAM也能提供改進:例如,將SAM應用于具有ShakeDrop正則化的PyramidNet,在CIFAR-100上產生10.3%的誤差,據我們所知,這是該數據集上的一個新的最先進技術,無需使用額外數據。
在這里插入圖片描述
除了CIFAR-{101100}之外,我們還在SVHN和Fashion MNIST數據集上評估了SAM。再次,SAM使一個簡單的WideResNet能夠實現這些數據集達到或高于最先進水平的精度:SVHN的誤差為0.99%,Fashion MNIST為3.59%。詳細信息見附錄B.1。
為了更大規模地評估SAM的性能,我們將其應用于在ImageNet上訓練的不同深度(50,101,152)的ResNets。在這種情況下,根據先前的工作,我們將圖像調整大小并裁剪到224像素分辨率,對其進行歸一化,并使用批量大小4096、初始學習率1.0、余弦學習率計劃、動量為0.9的SGD優化器、標簽平滑度為0.1和權重衰減0.0001。當應用SAM時,我們使用ρ=0.05(通過對訓練了100個時期的ResNet-50進行網格搜索確定)。我們使用Google Cloud TPU3在ImageNet上訓練所有模型長達400個時期,并報告每個實驗條件的前1和前5測試錯誤率(5次獨立運行的平均值和95%置信區間)。
如表2所示,SAM再次持續提高性能,例如將ResNet-152的ImageNet top 1錯誤率從20.3%提高到18.4%。此外,請注意,SAM能夠增加訓練時期的數量,同時在不過度擬合的情況下繼續提高準確性。相反,當訓練從200個時期擴展到400個時期時,標準訓練過程(沒有SAM)通常顯著地過擬合。
在這里插入圖片描述

4.2 微調

通過在大型相關數據集上預訓練模型,然后在感興趣的較小目標數據集上進行微調,遷移學習已成為一種強大且廣泛使用的技術,用于為各種不同的任務生成高質量的模型。我們在這里展示了SAM在這種情況下再次提供了相當大的好處,即使在微調非常大、最先進、已經高性能的模型時也是如此。
特別地,我們將SAM應用于微調EfficientNet-b7(在ImageNet上預訓練)和EfficientNet-L2(在ImageNet上預訓練加上未標記的JFT;輸入分辨率475)。我們將這些模型初始化為公開可用的檢查點6,分別用RandAugment(在ImageNet上的準確率為84.7%)和NoisyStudent(在Image Net上的正確率為88.2%)訓練。我們通過從上述檢查點開始訓練每個模型,在幾個目標數據集中的每個數據集上微調這些模型;有關使用的超參數的詳細信息,請參閱附錄。我們報告了每個數據集在5次獨立運行中前1個測試誤差的平均值和95%置信區間。
如表3所示,相對于沒有SAM的微調,SAM均勻地提高了性能。此外,在許多情況下,SAM產生了新的最先進的性能,包括CIFAR-10上0.30%的誤差、CIFAR-100上3.92%的誤差和ImageNet上11.39%的誤差。
在這里插入圖片描述

4.3對標簽噪聲的魯棒性

SAM尋找對擾動具有魯棒性的模型參數這一事實表明,SAM有可能在訓練集中提供對噪聲的魯棒性(這將擾亂訓練損失景觀)。因此,我們在這里評估SAM為標記噪聲提供的魯棒性程度。
特別地,我們測量了在CIFAR-10的經典噪聲標簽設置中應用SAM的效果,其中訓練集的一小部分標簽被隨機翻轉;測試集保持未修改(即干凈)。為了確保與之前的工作進行有效的比較,之前的工作通常使用專門用于噪聲標簽設置的架構,我們在Jiang等人之后,為200個時期訓練了一個類似大小的簡單模型(ResNet-32)。我們評估了模型訓練的五種變體:標準SGD、帶Mixup的SGD、SAM和帶MixupSAM的“自舉”SGD變體(其中,首先像往常一樣訓練模型,然后在最初訓練的模型預測的標簽上從頭開始重新訓練)。當應用SAM時,我們對除80%之外的所有噪聲級使用ρ=0.1,對于80%,我們使用ρ=0.05來獲得更穩定的收斂。對于混合基線,我們嘗試了α∈{1,8,16,32}的所有值,并保守地報告每個噪聲水平的最佳得分。
如表4所示,SAM提供了對標簽噪聲的高度魯棒性,與專門針對具有噪聲標簽的學習的現有技術程序所提供的魯棒性不相上下。事實上,除了MentorMix(之外,簡單地用SAM訓練模型勝過所有專門針對標簽噪聲魯棒性的現有方法。然而,簡單地自舉SAM產生的性能與MentorMix相當(后者要復雜得多)。
在這里插入圖片描述

5 SAM視角下的銳度與泛化

5.1 m-銳度

盡管我們對SAM的推導定義了整個訓練集的SAM目標,但當在實踐中使用SAM時,我們計算每個批次的SAM更新(如算法1所述),甚至通過平均每個加速器獨立計算的SAM更新來計算(其中每個加速器接收一個批次的大小為m的子集,如第3節所述)。后一種設置等效于修改SAM目標(等式1)以在一組獨立的最大化上求和,每個最大化對m個數據點的不相交子集上的每個數據點損失的總和執行,而不是在訓練集上的全局總和上執行最大化(這將等效于將m設置為總訓練集大小)。我們將損失圖像的銳度的相關度量稱為m-銳度。
在這里插入圖片描述

為了更好地理解m對SAM的影響,我們在CIFAR-10上使用m值范圍的SAM訓練一個小的ResNet。如圖3(中間)所示,較小的m值往往產生具有更好泛化能力的模型。這種關系恰好符合跨多個加速器并行化的需要,以便為當今的許多模型擴展訓練。有趣的是,如圖3(右)所示,隨著m的減少,上述m銳度測量進一步與模型的實際泛化差距表現出更好的相關性。特別地,這意味著,與上述定理1所建議的全訓練集測度相比,m<n的m-清晰度產生了更好的泛化預測因子,這為理解泛化提供了一條有趣的新途徑。

5.2 HESSIAN SPECTRA

受損失圖像的幾何形狀和泛化之間的聯系的啟發,我們構建了SAM,以尋找具有低損失值和低曲率(即,低銳度)的訓練損失圖像的最小值。為了進一步證實SAM確實發現了具有低曲率的極小值,我們計算了在CIFAR-10上訓練300步的WideResNet40-10在訓練期間的不同時期的Hessian譜,包括有SAM和沒有SAM(沒有批處理規范,這往往會模糊對Hessian的解釋)。由于參數空間的維數,我們使用Lanczos算法來近似Hessian譜。
圖3(左)報告了由此產生的Hessian光譜。正如預期的那樣,用SAM訓練的模型收斂到具有較低曲率的最小值,如在特征值的總體分布中所見,收斂時的最大特征值( λ m a x λ_{max} λmax?)(無SAM時約為24,有SAM時為1.0),以及大部分頻譜(比率 λ m a x / λ 5 λ_{max}/λ_{5} λmax?/λ5?,通常用作銳度的代理;在沒有SAM的情況下高達11.4,在有SAM的情況中高達2.6)。

6 總結

在這項工作中,我們引入了SAM,這是一種新的算法,通過同時最小化損失值和損失清晰度來提高泛化能力;我們已經通過嚴格的大規模實證評估證明了SAM的有效性。我們已經為未來的工作提供了許多有趣的途徑。在理論方面,m-銳度產生的每個數據點-銳度的概念(與過去通常研究的在整個訓練集上計算的全局清晰度形成對比)提出了一個有趣的新視角,通過它來研究泛化。從方法上講,我們的結果表明,在目前依賴Mixup的穩健或半監督方法中,SAM有可能取代Mixup(例如,提供MentorSAM)。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/43328.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/43328.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/43328.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

代碼隨想錄算法訓練營第三十天|62.不同路徑、63. 不同路徑 II

62.不同路徑 一個機器人位于一個 m x n 網格的左上角 &#xff08;起始點在下圖中標記為 “Start” &#xff09;。 機器人每次只能向下或者向右移動一步。機器人試圖達到網格的右下角&#xff08;在下圖中標記為 “Finish” &#xff09;。 問總共有多少條不同的路徑&#xff…

軟設之生成器模式

生成器模式的意圖是:將一個復雜的類表示與其構造分離&#xff0c;使得相同的構建過程能夠得出不同的表示 Builder:抽象建造者&#xff0c;為創建一個產品對象各個部件指定抽象接口&#xff0c;把產品的生產過程分解為不同的步驟&#xff0c;從而使具體建造者在具體的建造步驟上…

Java中的對象克隆詳解

Java中的對象克隆詳解 大家好&#xff0c;我是免費搭建查券返利機器人省錢賺傭金就用微賺淘客系統3.0的小編&#xff0c;也是冬天不穿秋褲&#xff0c;天冷也要風度的程序猿&#xff01; 對象克隆在Java編程中是一個重要的概念和技術。它允許我們創建一個對象的精確副本&…

MySQL第三次練習

作業三 一 先創建DB abc&#xff0c;創建table student 1、插入一條記錄 2、添加多條記錄 3、添加部分記錄 4、加0.5 5、刪除成績為空的記錄 二 1、創建一個用戶test1使他只能本地登錄擁有查詢student表的權限。 2、查詢用戶test1的權限。 3、刪除用戶test1. 全在一張圖上…

怎樣優化 PostgreSQL 中對日期時間范圍的模糊查詢?

文章目錄 一、問題分析&#xff08;一&#xff09;索引未有效利用&#xff08;二&#xff09;日期時間格式不統一&#xff08;三&#xff09;復雜的查詢條件 二、優化策略&#xff08;一&#xff09;使用合適的索引&#xff08;二&#xff09;規范日期時間格式&#xff08;三&a…

AI學習指南機器學習篇-層次聚類(Hierarchical Clustering)簡介

AI學習指南機器學習篇-層次聚類(Hierarchical Clustering)簡介 在機器學習領域中&#xff0c;層次聚類(Hierarchical Clustering)是一種常見的無監督學習算法&#xff0c;用于將數據集中的樣本分成具有相似特征的群組。層次聚類不需要預先指定要分成的群組數目&#xff0c;而是…

邏輯回歸模型(非回歸問題,而是分類問題)

目錄&#xff1a; 一、Sigmoid函數&#xff1a;二、邏輯回歸介紹&#xff1a;三、決策邊界四、邏輯回歸模型訓練過程&#xff1a;1.訓練目標&#xff1a;2.梯度下降調整參數&#xff1a; 一、Sigmoid函數&#xff1a; Sigmoid函數是構建邏輯回歸模型的重要函數&#xff0c;如下…

免費壓縮pdf文件大小軟件收費嗎?pdf如何壓縮文件大小?12款壓縮應用推薦!

在數字化時代&#xff0c;PDF文件因其跨平臺、格式統一的特點而廣受歡迎。然而&#xff0c;隨著文件內容的增加&#xff0c;PDF文件的大小也逐漸增大&#xff0c;給存儲和傳輸帶來了諸多不便。因此&#xff0c;尋找一款合適的PDF壓縮軟件成為了許多用戶的需求。本文將詳細介紹1…

單調隊列與單調棧(集訓day2)

一、目錄 1、單調隊列 2、單調棧 二、正文 1.單調棧題型&#xff1a; &#xff08;1&#xff09;給出一個數組找出其中每個數左邊第一個比它小&#xff08;大&#xff09;的數字 830. 單調棧 - AcWing題庫 &#xff08;2&#xff09;求直方圖中最大的矩形&…

電子設備常用的膠水有哪些?

目錄 1、502膠水 2、703膠水 3、704膠水 4、AB膠 5、紅膠 6、Underfill 7、導電膠 8、UV膠 9、熱熔膠 10、環氧樹脂膠 11、硅酮膠 12、聚氨酯膠 13、丙烯酸膠 14、丁基膠 1、502膠水 502膠水&#xff0c;也被稱為瞬間膠或快干膠&#xff0c;是一種非常常見的粘合…

電動卡丁車語音芯片方案選型:讓駕駛體驗更智能、更安全

在追求速度與激情的電動卡丁車領域&#xff0c;每一次升級都意味著更加極致的駕駛體驗。而今天&#xff0c;我們要介紹的&#xff0c;正是一款能夠顯著提升電動卡丁車智能化與安全性的語音芯片方案——為您的愛車增添一份獨特的魅力與安全保障。 智能化升級&#xff0c;從“聽…

[Python學習篇] Python面向對象——繼承

繼承是什么 繼承是面向對象編程&#xff08;OOP&#xff09;中的一個核心概念。繼承允許一個類&#xff08;稱為子類或派生類&#xff09;從另一個類&#xff08;稱為父類或基類&#xff09;繼承屬性和方法。這樣可以重用代碼&#xff0c;提高代碼的模塊化和可維護性。 父類&am…

js面試題2024

1.js的數據類型 boolean number string null undefined bigint symbol object 按存儲方式分&#xff0c;前面七種為基本數據類型&#xff0c;存儲在棧上&#xff0c;object是引用數據類型&#xff0c;存儲在堆上&#xff0c;在棧中存儲指針 按es標準分&#xff0c;bigint 和sym…

PHP框架講解 - symfony框架

Symfony 框架概述 Symfony 是一個用于構建 web 應用的 PHP 框架&#xff0c;它遵循 MVC&#xff08;模型-視圖-控制器&#xff09;模式&#xff0c;并且具有高度的可定制性。Symfony 是一個組件庫&#xff0c;它提供了許多用于構建現代 web 應用的工具和功能。以下是對 Symfon…

布隆過濾器 redis

一.為什么要用到布隆過濾器&#xff1f; 緩存穿透&#xff1a;查詢一條不存在的數據&#xff0c;緩存中沒有&#xff0c;則每次請求都打到數據庫中&#xff0c;導致數據庫瞬時請求壓力過大&#xff0c;多見于爬蟲惡性攻擊因為布隆過濾器是二進制的數組&#xff0c;如果使用了它…

FLD工作日志

在FLD的工作日志 一、技能掌握楊總經驗的傳輸 一、技能掌握 06.12 學會如何看小產品的代碼&#xff0c;看的消毒燈 07.08 1.學會嘉立創eda 楊總經驗的傳輸 07.07 什么能做就做什么&#xff0c;一刻也不要停不要看不起簡單的事情&#xff0c;量變引起質變

科普文:K8S中常見知識點梳理

簡單說一下k8s集群內外網絡如何互通的 要在 Kubernetes&#xff08;k8s&#xff09;集群內外建立網絡互通&#xff0c;可以采取以下措施&#xff1a; 使用service&#xff1a; 使用Service類型為NodePort或LoadBalancer的Kubernetes服務。這可以使服務具有一個公共IP地址或端口…

怎么發頂會論文

AI頂會論文成功發表路徑四&#xff1a;寫作關_嗶哩嗶哩_bilibili 全集都有&#xff0c;隨手記錄一下。 講的很好&#xff0c;我多努力。努力靠近一下。

Open3D 計算點云的平均密度

目錄 一、概述 1.1基于領域密度計算原理 1.2應用 二、代碼實現 三、實現效果 2.1點云顯示 2.2密度計算結果 一、概述 在點云處理中&#xff0c;點的密度通常表示為某個點周圍一定區域內的點的數量。高密度區域表示點云較密集&#xff0c;低密度區域表示點云較稀疏。計算…