論文筆記
資料
1.代碼地址
https://github.com/google-research/sam
https://github.com/davda54/sam
2.論文地址
https://arxiv.org/abs/2010.01412
3.數據集地址
論文摘要的翻譯
在當今嚴重過度參數化的模型中,訓練損失的值很難保證模型的泛化能力。事實上,像通常所做的那樣,只優化訓練損失值很容易導致次優的模型質量。受損失的幾何圖像與泛化相結合的前人工作的啟發,我們引入了一種新的、有效的同時最小化損失值和損失銳度的方法。特別是,我們的方法,銳度感知最小化(SAM),尋找位于具有一致低損失的鄰域的參數;這個公式致使一個最小值-最大值的優化問題,在該問題上可以有效地執行梯度下降。我們提出的經驗結果表明,SAM提高了各種基準數據集(例如,CIFAR-{10,100}、ImageNet、微調任務)和模型的模型泛化能力,為幾個數據集帶來了新的最先進的性能。此外,我們發現SAM原生地提供了對標簽噪聲的魯棒性,這與專門針對帶有噪聲標簽的學習的過程所提供的魯棒性不相上下。
1背景
現代機器學習成功地在廣泛的任務中實現了越來越好的性能,這在很大程度上取決于越來越重的過度參數化,以及開發越來越有效的訓練算法,這些算法能夠找到很好地泛化的參數。事實上,許多現代神經網絡可以很容易地記住訓練數據,并具有容易過擬合的能力。目前需要這種嚴重的過度參數化才能在各種領域實現最先進的結果。反過來,至關重要的是,使用程序來訓練這些模型,以確保實際選擇的參數事實上超越了訓練集。
不幸的是,簡單地最小化訓練集上常用的損失函數(例如,交叉熵)通常不足以實現令人滿意的泛化。今天的模型的訓練損失景觀通常是復雜的和非凸的,具有多個局部和全局極小值,并且具有不同的全局極小值產生具有不同泛化能力的模型。因此,從許多可用的(例如,隨機梯度下降、Adam)、RMSProp和其他中選擇優化器(和相關的優化器設置)已成為一個重要的設計選擇,盡管對其與模型泛化的關系的理解仍處于初級階段。與此相關的是,已經提出了一系列修改訓練過程的方法,包括dropout,批量歸一化、隨機深度、數據增強和混合樣本增強.
損失圖像的幾何形狀——特別是最小值的平坦性和泛化之間的聯系已經從理論和實證的角度進行了廣泛的研究)。雖然這種聯系有望實現新的模型訓練方法,從而產生更好的泛化能力,但迄今為止,專門尋找更平坦的最小值并進一步有效提高一系列最先進模型泛化能力的實用高效算法一直難以實現;我們在第5節中對先前的工作進行了更詳細的討論)。
2論文的創新點
- 我們引入了銳度感知最小化(SAM),這是一種新的方法,通過同時最小化損失值和銳度來提高模型的泛化能力。SAM通過尋找位于具有一致低損耗值的鄰域中的參數(而不是僅具有低損耗值的參數,如圖1的中間和右側圖像所示)來工作,并且可以高效且容易地實現。
- 我們通過一項嚴格的實證研究表明,使用SAM提高了一系列廣泛研究的計算機視覺任務(例如,CIFAR-{10,100},ImageNet,微調任務)和模型的模型泛化能力,如圖1的左側曲線圖中總結的。例如,應用SAM為許多已經深入研究的任務,如ImageNet,CIFAR-{10,100},SVHN,Fashion-MNIST,和標準的圖像分類微調任務集(例如,Flowers,Stanford Cars,Oxford Pets,等)產生了新的最先進的性能。
- 我們還表明,SAM進一步提供了標簽噪聲的穩健性,與專門針對帶有噪聲標簽的學習的最先進的程序所提供的一樣。
- 通過SAM提供的視角,我們提出了一個很有使用價值的銳度的新概念,我們稱之為m-銳度,從而進一步闡明了損失銳度和泛化之間的聯系。
3 論文方法的概述
在本文中,我們將標量表示為 a a a,將向量表示為 a \mathbf{a} a、將矩陣表示為 A \Alpha A、將集合表示為 A \mathcal A A,將等式定義為,。給定訓練數據集 S ? ∪ i = 1 n { ( x i , y i ) } \mathcal{S}\triangleq\cup_{i=1}^n\{(x_i,y_i)\} S?∪i=1n?{(xi?,yi?)}從分布 D , \mathscr{D}, D,,中i.i.d.繪制,我們尋求學習一個很好地泛化的模型。特別地,考慮一組由 w ∈ W ? R d w\in\mathcal{W}\subseteq\mathbb{R}^d w∈W?Rd參數化的模型;給定一個逐數據點損失函數 l : W × X × Y → R + l:\mathcal{W}\times\mathcal{X}\times\mathcal{Y}\to\mathbb{R}_+ l:W×X×Y→R+?,我們定義了訓練集損失 L S ( w ) ? 1 n ∑ i = 1 n l ( w , x i , y i ) L_S(\boldsymbol{w})\triangleq\frac1n\sum_{i=1}^nl(\boldsymbol{w},\boldsymbol{x}_i,\boldsymbol{y}_i) LS?(w)?n1?∑i=1n?l(w,xi?,yi?)和總體損失 L D ( w ) ? E ( x , y ) ~ D [ l ( w , x , y ) ] L_{\mathscr{D}}(\boldsymbol{w})\triangleq\mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\thicksim D}[l(\boldsymbol{w},\boldsymbol{x},\boldsymbol{y})] LD?(w)?E(x,y)~D?[l(w,x,y)]。在僅觀察到 S \mathcal S S的情況下,模型訓練的目標是選擇具有低總體損失 L D ( w ) L_{\mathscr{D}}(\boldsymbol{w}) LD?(w)的模型參數 w w w。
利用 L S ( w ) 作為 L D ( w ) L_S(\boldsymbol{w})作為L_{\mathscr{D}}(\boldsymbol{w}) LS?(w)作為LD?(w)的估計,通過使用諸如SGD或Adam之類的優化過程來求解 m i n w L S ( w ) min_w L_S(\boldsymbol{w}) minw?LS?(w)(可能與w上的正則化子結合)來激勵選擇參數w的標準方法。然而,不幸的是,對于現代的過度參數化模型,如深度神經網絡,典型的優化方法很容易在測試時導致次優性能。特別地,對于現代模型, L S ( w ) L_S(\boldsymbol{w}) LS?(w)通常在w中是非凸的,具有多個局部甚至全局最小值,這些局部甚至全局極小值可以產生相似的 L S ( w ) L_S(\boldsymbol{w}) LS?(w)值,同時具有顯著不同的泛化性能(即,顯著不同的 L D ( w ) L_{\mathscr{D}}(\boldsymbol{w}) LD?(w)值)。
受損失圖形的銳度和泛化之間的聯系的啟發,我們提出了一種不同的方法:我們不是尋找簡單地具有低訓練損失值 L S ( w ) L_S(\boldsymbol{w}) LS?(w)的參數值 w w w,而是尋找整個鄰域具有一致低訓練損失的參數值(相當于,具有低損失和低曲率的鄰域)。以下定理通過在鄰域訓練損失方面限制泛化能力來說明這種方法的動機(附錄A中的完整定理陳述和證明):
- 定理1
對于任意 ρ > 0 ρ>0 ρ>0、在由分布 D \mathscr D D生成的訓練集 S 上 \mathcal S上 S上具有高概率的情況, L D ( w ) ≤ max ? ∥ ? ∥ 2 ≤ ρ L S ( w + ? ) + h ( ∥ w ∥ 2 2 / ρ 2 ) , L_\mathscr{D}(\boldsymbol{w})\leq\max_{\|\boldsymbol{\epsilon}\|_2\leq\rho}L_\mathcal{S}(\boldsymbol{w}+\boldsymbol{\epsilon})+h(\|\boldsymbol{w}\|_2^2/\rho^2), LD?(w)≤∥?∥2?≤ρmax?LS?(w+?)+h(∥w∥22?/ρ2),其中h: R + → R + \mathbb{R}_+\to\mathbb{R}_+ R+?→R+?是嚴格遞增函數(在 L D ( w ) L_{\mathscr{D}}(\boldsymbol{w}) LD?(w)上的一些技術條件下)。為了明確我們的銳度項,我們可以將上面不等式的右側重寫為為了明確我們的銳度項,我們可以將上面不等式的右側重寫為 [ max ? ∥ ? ∥ 2 ≤ ρ L S ( w + ? ) ? L S ( w ) ] + L S ( w ) + h ( ∥ w ∥ 2 2 / ρ 2 ) . [\max_{\|\boldsymbol{\epsilon}\|_2\leq\rho}L_\mathcal{S}(\boldsymbol{w}+\boldsymbol{\epsilon})-L_\mathcal{S}(\boldsymbol{w})]+L_\mathcal{S}(\boldsymbol{w})+h(\|\boldsymbol{w}\|_2^2/\rho^2). [max∥?∥2?≤ρ?LS?(w+?)?LS?(w)]+LS?(w)+h(∥w∥22?/ρ2).
方括號中的項通過測量通過從 w w w移動到附近的參數值可以以多快的速度增加訓練損失來捕捉 L S L_S LS?在 w w w處的銳度;然后將該銳度項與訓練損失值本身和 w w w大小上的正則化子求和。給定特定函數h深受證明細節的影響,我們用 λ ∣ ∣ w ∣ ∣ 2 2 \lambda||w||_2^2 λ∣∣w∣∣22?代替超參數λ的第二項,得到標準L2正則化項。因此,受約束項的啟發,我們建議通過解決以下SharpnessAware最小化(SAM)問題來選擇參數值:
min ? w L S S A M ( w ) + λ ∥ w ∥ 2 2 w h e r e L S S A M ( w ) ? max ? ∣ ∣ ? ∣ ∣ p ≤ ρ L S ( w + ? ) , \min_{\boldsymbol{w}}L_{\mathcal{S}}^{SAM}(\boldsymbol{w})+\lambda\|\boldsymbol{w}\|_2^2\quad\mathrm{~where~}\quad L_{\mathcal{S}}^{SAM}(\boldsymbol{w})\triangleq\max_{||\boldsymbol{\epsilon}||_p\leq\rho}L_S(\boldsymbol{w}+\boldsymbol{\epsilon}), wmin?LSSAM?(w)+λ∥w∥22??where?LSSAM?(w)?∣∣?∣∣p?≤ρmax?LS?(w+?),其中 ρ ≥ 0 ρ≥0 ρ≥0是一個超參數, p ∈ [ 1 , ∞ ] p∈[1,∞] p∈[1,∞](我們在最大化過程中從L2范數略微推廣到p范數,本文已經證明 p = 2 p=2 p=2通常是最優的)。圖1顯示了1通過最小化 L S ( w ) L_S(\boldsymbol{w}) LS?(w)或 L S S A M ( w ) L^{SAM}_S(\boldsymbol{w}) LSSAM?(w)而收斂到最小值的模型的損失情況,說明了銳度感知損失阻止了模型收斂到尖銳的最小值。
為了最小化 L S S A M ( w ) L^{SAM}_S(\boldsymbol{w}) LSSAM?(w),我們通過內部最大化進行微分,推導出了一個有效的近似值,從而使我們能夠將隨機梯度下降直接應用于SAM目標。沿著這條路前進,我們首先通過 L S ( w + ? ) w . r . t . ? L_{\mathcal{S}}(w+\epsilon)\mathrm{~w.r.t.~}\epsilon LS?(w+?)?w.r.t.?? 在0附近的一階泰勒展開來近似內部最大化問題,得到 ? w ) ? arg ? max ? ∥ ? ∥ p ≤ ρ L S ( w + ? ) ≈ arg ? max ? ∥ ? ∥ p ≤ ρ L S ( w ) + ? T ? w L S ( w ) = arg ? max ? ∥ ? ∥ p ≤ ρ ? T ? w L S ( w ) . \boldsymbol{\epsilon}\\\boldsymbol{w})\triangleq\arg\max_{\|\boldsymbol{\epsilon}\|_p\leq\rho}L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})\approx\arg\max_{\|\boldsymbol{\epsilon}\|_p\leq\rho}L_{\mathcal{S}}(\boldsymbol{w})+\boldsymbol{\epsilon}^T\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(\boldsymbol{w})=\arg\max_{\|\boldsymbol{\epsilon}\|_p\leq\rho}\boldsymbol{\epsilon}^T\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(\boldsymbol{w}). ?w)?arg∥?∥p?≤ρmax?LS?(w+?)≈arg∥?∥p?≤ρmax?LS?(w)+?T?w?LS?(w)=arg∥?∥p?≤ρmax??T?w?LS?(w).反過來,求解該近似的值 ξ ( w ) ξ(w) ξ(w)由經典對偶范數問題的解給出( ∣ ? ∣ q ? 1 |·|^{q?1} ∣?∣q?1表示元素絕對值和冪): ? ^ ( w ) = ρ s i g n ( ? w L S ( w ) ) ∣ ? w L S ( w ) ∣ q ? 1 / ( ∥ ? w L S ( w ) ∥ q q ) 1 / p (2) \hat{\boldsymbol{\epsilon}}(w)=\rho\mathrm{~sign}\left(\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(w)\right)|\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(w)|^{q-1}/\left(\|\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(w)\|_{q}^{q}\right)^{1/p}\text{(2)} ?^(w)=ρ?sign(?w?LS?(w))∣?w?LS?(w)∣q?1/(∥?w?LS?(w)∥qq?)1/p(2)其中1/p+1/q=1。代入方程(1)并進行微分,我們得到 ? w L S S A M ( w ) ≈ ? w L S ( w + ? ^ ( w ) ) = d ( w + ? ^ ( w ) ) d w ? w L S ( w ) ∣ w + ? ^ ( w ) \nabla_{\boldsymbol{w}}L_{\mathcal{S}}^{SAM}(\boldsymbol{w})\approx\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w}))=\frac{d(\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w}))}{d\boldsymbol{w}}\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(\boldsymbol{w})|_{\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(w)} ?w?LSSAM?(w)≈?w?LS?(w+?^(w))=dwd(w+?^(w))??w?LS?(w)∣w+?^(w)?
可以通過自動微分直接計算這種對 ? w L S S A M ( w ) \nabla_{\boldsymbol{w}}L_{\mathcal{S}}^{SAM}(w) ?w?LSSAM?(w)的近似,如在JAX、TensorFlow和PyTorch等公共庫中實現的那樣。盡管該計算隱含地依賴于LS(w)的Hessian,因為ξ。獲得我們的最終梯度近似: ? w L S S A M ( w ) ≈ ? w L S ( w ) ∣ w + ? ^ ( w ) . \nabla_{\boldsymbol{w}}L_{\mathcal{S}}^{SAM}(\boldsymbol{w})\approx\nabla_{\boldsymbol{w}}L_{\mathcal{S}}(w)|_{\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w})}. ?w?LSSAM?(w)≈?w?LS?(w)∣w+?^(w)?.
我們通過將標準數值優化器(如隨機梯度下降(SGD))應用于SAM目標 L S S A M ( w ) L_\mathcal{S}^{SAM}(\boldsymbol{w}) LSSAM?(w),使用方程3來計算必要的目標函數梯度,從而獲得最終的SAM算法。算法1給出了完整SAM算法的偽代碼,使用SGD作為基本優化器,圖2示意性地說明了單個SAM參數更新。
4 論文實驗
為了評估SAM的功效,我們將其應用于一系列不同的任務,包括從頭開始的圖像分類(包括在CIFAR-10、CIFAR-100和ImageNet上)、微調預訓練的模型以及使用噪聲標簽進行學習。在所有情況下,我們都通過簡單地用SAM代替用于訓練現有模型的優化程序來衡量使用SAM的好處,并計算由此對模型泛化的影響。如下所示,SAM在絕大多數情況下都能顯著提高泛化性能。
4.1 圖像分類
我們首先評估了SAM對當今最先進的CIFAR-10和CIFAR-100模型(無需預訓練)泛化的影響:具有ShakeShake正則化的WideResNets和具有ShakeDrop正則化的PyramidNet。請注意,這些模型中的一些已經在先前的工作中進行了大量調整,并包括精心選擇的正則化方案,以防止過擬合;因此,顯著提高它們的泛化能力是非常重要的。我們已經確保在沒有SAM的情況下,我們的實現的泛化性能與先前工作中報告的性能相匹配或超過。
所有結果都使用了基本數據增強(水平翻轉、四像素填充和隨機裁剪)。我們還評估了更先進的數據增強方法,如剪切正則化和AutoAugment(,這些方法被先前的工作用來實現最先進的結果。
SAM具有單個超參數 ρ (鄰域大小) ρ(鄰域大小) ρ(鄰域大小),我們使用10%的訓練集作為驗證集,通過在 { 0.01 , 0.02 , 0.05 , 0.1 , 0.2 , 0.5 } \{0.01,0.02,0.05,0.1,0.2,0.5\} {0.01,0.02,0.05,0.1,0.2,0.5}上的網格搜索對其進行調整。有關所有超參數的值和其他訓練細節,請參見附錄C.1。由于每個SAM權重更新需要兩個反向傳播操作(一個用于計算 ξ ( w ) ξ(w) ξ(w),另一個用于估計最終梯度),我們允許每個非SAM訓練運行執行兩倍于每個SAM訓練運行的輪次,并且我們報告每個非SAM訓練運行在標準輪次計數或加倍輪次計數上獲得的最佳得分。我們運行每個實驗條件的五個獨立副本,報告結果(每個條件都有獨立的權重初始化和數據混洗),報告測試集的平均誤差(或準確性)以及相關的95%置信區間。我們的實現使用JAX,并且我們在具有8個NVidia V100 GPU的單個主機上訓練所有模型。為了在跨多個加速器并行時計算SAM更新,我們在加速器之間均勻地劃分每個數據批次,獨立地計算每個加速器上的SAM梯度,并對得到的子批次SAM梯度進行平均,以獲得最終的SAM更新。
如表1所示,SAM提高了針對CIFAR-10和CIFAR-100評估的所有設置的泛化能力。例如,SAM使簡單的WideResNet能夠獲得1.6%的測試誤差,而沒有SAM時的誤差為2.2%。這種增益以前只能通過使用更復雜的模型架構(例如,PyramidNet)和正則化方案(例如,Shake-Shake、ShakeDrop)來實現;SAM提供了一個易于實施的、獨立于模型的替代方案。此外,即使在已經使用復雜正則化的復雜架構上應用SAM,SAM也能提供改進:例如,將SAM應用于具有ShakeDrop正則化的PyramidNet,在CIFAR-100上產生10.3%的誤差,據我們所知,這是該數據集上的一個新的最先進技術,無需使用額外數據。
除了CIFAR-{101100}之外,我們還在SVHN和Fashion MNIST數據集上評估了SAM。再次,SAM使一個簡單的WideResNet能夠實現這些數據集達到或高于最先進水平的精度:SVHN的誤差為0.99%,Fashion MNIST為3.59%。詳細信息見附錄B.1。
為了更大規模地評估SAM的性能,我們將其應用于在ImageNet上訓練的不同深度(50,101,152)的ResNets。在這種情況下,根據先前的工作,我們將圖像調整大小并裁剪到224像素分辨率,對其進行歸一化,并使用批量大小4096、初始學習率1.0、余弦學習率計劃、動量為0.9的SGD優化器、標簽平滑度為0.1和權重衰減0.0001。當應用SAM時,我們使用ρ=0.05(通過對訓練了100個時期的ResNet-50進行網格搜索確定)。我們使用Google Cloud TPU3在ImageNet上訓練所有模型長達400個時期,并報告每個實驗條件的前1和前5測試錯誤率(5次獨立運行的平均值和95%置信區間)。
如表2所示,SAM再次持續提高性能,例如將ResNet-152的ImageNet top 1錯誤率從20.3%提高到18.4%。此外,請注意,SAM能夠增加訓練時期的數量,同時在不過度擬合的情況下繼續提高準確性。相反,當訓練從200個時期擴展到400個時期時,標準訓練過程(沒有SAM)通常顯著地過擬合。
4.2 微調
通過在大型相關數據集上預訓練模型,然后在感興趣的較小目標數據集上進行微調,遷移學習已成為一種強大且廣泛使用的技術,用于為各種不同的任務生成高質量的模型。我們在這里展示了SAM在這種情況下再次提供了相當大的好處,即使在微調非常大、最先進、已經高性能的模型時也是如此。
特別地,我們將SAM應用于微調EfficientNet-b7(在ImageNet上預訓練)和EfficientNet-L2(在ImageNet上預訓練加上未標記的JFT;輸入分辨率475)。我們將這些模型初始化為公開可用的檢查點6,分別用RandAugment(在ImageNet上的準確率為84.7%)和NoisyStudent(在Image Net上的正確率為88.2%)訓練。我們通過從上述檢查點開始訓練每個模型,在幾個目標數據集中的每個數據集上微調這些模型;有關使用的超參數的詳細信息,請參閱附錄。我們報告了每個數據集在5次獨立運行中前1個測試誤差的平均值和95%置信區間。
如表3所示,相對于沒有SAM的微調,SAM均勻地提高了性能。此外,在許多情況下,SAM產生了新的最先進的性能,包括CIFAR-10上0.30%的誤差、CIFAR-100上3.92%的誤差和ImageNet上11.39%的誤差。
4.3對標簽噪聲的魯棒性
SAM尋找對擾動具有魯棒性的模型參數這一事實表明,SAM有可能在訓練集中提供對噪聲的魯棒性(這將擾亂訓練損失景觀)。因此,我們在這里評估SAM為標記噪聲提供的魯棒性程度。
特別地,我們測量了在CIFAR-10的經典噪聲標簽設置中應用SAM的效果,其中訓練集的一小部分標簽被隨機翻轉;測試集保持未修改(即干凈)。為了確保與之前的工作進行有效的比較,之前的工作通常使用專門用于噪聲標簽設置的架構,我們在Jiang等人之后,為200個時期訓練了一個類似大小的簡單模型(ResNet-32)。我們評估了模型訓練的五種變體:標準SGD、帶Mixup的SGD、SAM和帶MixupSAM的“自舉”SGD變體(其中,首先像往常一樣訓練模型,然后在最初訓練的模型預測的標簽上從頭開始重新訓練)。當應用SAM時,我們對除80%之外的所有噪聲級使用ρ=0.1,對于80%,我們使用ρ=0.05來獲得更穩定的收斂。對于混合基線,我們嘗試了α∈{1,8,16,32}的所有值,并保守地報告每個噪聲水平的最佳得分。
如表4所示,SAM提供了對標簽噪聲的高度魯棒性,與專門針對具有噪聲標簽的學習的現有技術程序所提供的魯棒性不相上下。事實上,除了MentorMix(之外,簡單地用SAM訓練模型勝過所有專門針對標簽噪聲魯棒性的現有方法。然而,簡單地自舉SAM產生的性能與MentorMix相當(后者要復雜得多)。
5 SAM視角下的銳度與泛化
5.1 m-銳度
盡管我們對SAM的推導定義了整個訓練集的SAM目標,但當在實踐中使用SAM時,我們計算每個批次的SAM更新(如算法1所述),甚至通過平均每個加速器獨立計算的SAM更新來計算(其中每個加速器接收一個批次的大小為m的子集,如第3節所述)。后一種設置等效于修改SAM目標(等式1)以在一組獨立的最大化上求和,每個最大化對m個數據點的不相交子集上的每個數據點損失的總和執行,而不是在訓練集上的全局總和上執行最大化(這將等效于將m設置為總訓練集大小)。我們將損失圖像的銳度的相關度量稱為m-銳度。
為了更好地理解m對SAM的影響,我們在CIFAR-10上使用m值范圍的SAM訓練一個小的ResNet。如圖3(中間)所示,較小的m值往往產生具有更好泛化能力的模型。這種關系恰好符合跨多個加速器并行化的需要,以便為當今的許多模型擴展訓練。有趣的是,如圖3(右)所示,隨著m的減少,上述m銳度測量進一步與模型的實際泛化差距表現出更好的相關性。特別地,這意味著,與上述定理1所建議的全訓練集測度相比,m<n的m-清晰度產生了更好的泛化預測因子,這為理解泛化提供了一條有趣的新途徑。
5.2 HESSIAN SPECTRA
受損失圖像的幾何形狀和泛化之間的聯系的啟發,我們構建了SAM,以尋找具有低損失值和低曲率(即,低銳度)的訓練損失圖像的最小值。為了進一步證實SAM確實發現了具有低曲率的極小值,我們計算了在CIFAR-10上訓練300步的WideResNet40-10在訓練期間的不同時期的Hessian譜,包括有SAM和沒有SAM(沒有批處理規范,這往往會模糊對Hessian的解釋)。由于參數空間的維數,我們使用Lanczos算法來近似Hessian譜。
圖3(左)報告了由此產生的Hessian光譜。正如預期的那樣,用SAM訓練的模型收斂到具有較低曲率的最小值,如在特征值的總體分布中所見,收斂時的最大特征值( λ m a x λ_{max} λmax?)(無SAM時約為24,有SAM時為1.0),以及大部分頻譜(比率 λ m a x / λ 5 λ_{max}/λ_{5} λmax?/λ5?,通常用作銳度的代理;在沒有SAM的情況下高達11.4,在有SAM的情況中高達2.6)。
6 總結
在這項工作中,我們引入了SAM,這是一種新的算法,通過同時最小化損失值和損失清晰度來提高泛化能力;我們已經通過嚴格的大規模實證評估證明了SAM的有效性。我們已經為未來的工作提供了許多有趣的途徑。在理論方面,m-銳度產生的每個數據點-銳度的概念(與過去通常研究的在整個訓練集上計算的全局清晰度形成對比)提出了一個有趣的新視角,通過它來研究泛化。從方法上講,我們的結果表明,在目前依賴Mixup的穩健或半監督方法中,SAM有可能取代Mixup(例如,提供MentorSAM)。