神經網絡優化器-從SGD到AdamW

優化器準則

凸優化基本概念

  • 先定義凸集,集合中的兩個點連接的線還在集合里面,就是凸集,用數學語言來表示就是:對于集合中的任意兩個元素x,y以及任意實數 λ ∈ ( 0 , 1 ) \lambda \in (0,1) λ(0,1),有 λ x + ( 1 ? λ ) y ∈ C \lambda x + (1 - \lambda) y \in C λx+(1?λ)yC,則稱為凸集。
  • 再定義凸函數: f ( λ x + ( 1 ? λ ) y ) ≤ λ f ( x ) + ( 1 ? λ ) f ( y ) f(\lambda x + (1 - \lambda) y) \leq \lambda f(x) + (1 - \lambda) f(y) f(λx+(1?λ)y)λf(x)+(1?λ)f(y)其中, λ \lambda λ是一個滿足 0 ≤ λ ≤ 1 0 \leq \lambda \leq 1 0λ1的實數參數。
  • 可以看出,凸函數的定義域必須是凸集。直觀上,凸函數的圖像不會在任何地方凹陷,這使得凸函數的局部最小值也是全局最小值,這使得優化問題更容易解決。

現在再定義凸優化問題:
凸優化是數學優化理論中的一個重要分支,它研究的是凸函數的優化問題。用數學語言來表示就是:
minimize f ( x ) subject?to g i ( x ) ≤ 0 , i = 1 , … , m and h j ( x ) = 0 , j = 1 , … , p \begin{align*} \text{minimize} \quad & f(x) \\ \text{subject to} \quad & g_i(x) \leq 0, \quad i = 1, \ldots, m \\ \text{and} \quad & h_j(x) = 0, \quad j = 1, \ldots, p \\ \end{align*} minimizesubject?toand?f(x)gi?(x)0,i=1,,mhj?(x)=0,j=1,,p?
其中, f ( x ) f(x) f(x) 是目標函數, g i ( x ) g_i(x) gi?(x)是不等式約束函數, h j ( x ) h_j(x) hj?(x)是等式約束函數, x x x是決策變量。如果目標函數 f ( x ) f(x) f(x) 和所有約束函數 g i ( x ) g_i(x) gi?(x) h j ( x ) h_j(x) hj?(x)都是凸函數,并且可行域(滿足所有約束的 x x x的集合)也是凸集,那么這個問題就是一個凸優化問題。

凸優化問題有以下特點:

  • 局部最優即全局最優:如果一個點是局部最小點,那么它也是全局最小點。這使得尋找最優解變得更加容易。
  • 對偶性:凸優化問題具有良好的對偶性質,即原問題的對偶問題也是一個凸優化問題。
  • 存在性:如果目標函數和約束函數都是下閉的,并且可行域非空,那么凸優化問題總是有解的。
  • 穩定性:凸優化問題的解對問題的微小變化是穩定的。

研究方法有:

  • 梯度下降法:通過迭代地沿著目標函數的負梯度方向移動來找到最小點。
  • 牛頓法:利用目標函數的二階導數(Hessian)來加速梯度下降法。
  • 內點法:一種專門用于解決有約束凸優化問題的算法。
  • 次梯度法:對于非光滑的凸函數,使用次梯度而不是梯度來優化。
  • 對偶方法:通過解決對偶問題來找到原問題的解。

神經網絡優化問題定義

現在我們可以開始討論優化器了:
深度學習模型的訓練就是一個優化問題,模型權重就是我們上面提到的決策變量 x x x,目標函數就是我們所設計的損失函數(所以我們將損失函數設計成凸函數,以滿足凸優化的條件),模型本身就是一個等式或者不等式約束,輸出結果必須滿足事先知道的label。優化器,就是解決這個凸優化問題的實現方案。
我們的優化問題用數學來表示就是:
f ( W ) = min ? w 1 N ∑ i = 1 N L ( y i , F ( x i ) ) + ∑ j = 1 n λ ∥ w j ∥ f(W) = \min_{w} {\frac{1}{N}\sum_{i=1}^{N} L(y_i,F(x_i)) + \sum_{j=1 }^{n}\lambda \left \| w_{j} \right \| } f(W)=wmin?N1?i=1N?L(yi?,F(xi?))+j=1n?λwj?
W W W是模型的所有參數
前一項是損失,其中 w w w是參數, N N N是樣本總數, y i y_i yi?是樣本標簽, F ( x i ) F(x_i) F(xi?)是模型結果,L是損失函數,
后一項是正則損失,用于避免過擬合現象的, λ \lambda λ是正則化系數, ∥ w j ∥ \left \| w_{j} \right \| wj?是參數的范數,常見有L1,L2等

這就是深度學習優化的數學定義,這樣我們就可以去使用數學方法來解決這個問題了。

  • 使用梯度下降法就是: W t = W t ? 1 ? α ? ▽ f ( W t ? 1 ) W_t = W_{t-1}-\alpha *\bigtriangledown f(W_{t-1}) Wt?=Wt?1??α?f(Wt?1?)其中 ▽ f ( W t ? 1 ) \bigtriangledown f(W_{t-1}) f(Wt?1?)是函數的梯度向量。
  • 使用牛頓法就是: W t = W t ? 1 ? α ? H t ? 1 ? 1 ? ▽ f ( W t ? 1 ) W_t = W_{t-1}-\alpha*H_{t-1}^{-1} *\bigtriangledown f(W_{t-1}) Wt?=Wt?1??α?Ht?1?1??f(Wt?1?)其中 H t ? 1 ? 1 H_{t-1}^{-1} Ht?1?1?為Hessian矩陣的逆矩陣即二階偏導矩陣的逆矩陣。

這也是深度學習中隨機梯度下降的由來,從最優化的梯度下降借鑒過來的。
優化器可以做的事情,就是對解決問題方法中的:梯度gt,學習率,參數正則項,參數初始化這幾個因素進行調整。
不同的優化器,他們的區別就是這四項的不同。

優化器分類與發展

隨機梯度

  • SGD:梯度計算的變種,主要區別在于gt的計算方式,原始梯度下降算法叫做GD,計算所有梯度然后更新,SGD叫做隨機梯度下降,因為它每次只采用一小批訓練樣本作為梯度更新參數,然后根據這個梯度更新模型參數。這種方法的優點是計算效率高,因為不需要計算整個訓練集上的梯度,這在數據量很大時尤其有用。
  • 動量SGD:mSGD,gt不光包括計算出的梯度,還包括了部分過去的梯度信息,好處是會加速收斂,并且跳過一些局部最優
    RMS等。

自適應梯度

算法的核心思想是根據參數的歷史更新信息來調整每個參數的學習率,從而提高收斂速度并減少訓練時間。

  • Adaptive Gradient:自適應梯度算法,它通過為每個參數維護一個累積的梯度平方和來調整學習率。AdaGrad 的更新規則如下:
    θ i = θ i ? η G i i + ? ? ? θ L ( θ ) i \theta_i = \theta_i - \frac{\eta}{\sqrt{G_{ii} + \epsilon}} \cdot \nabla_\theta L(\theta)_i θi?=θi??Gii?+? ?η???θ?L(θ)i?
    其中, G G G 是一個對角矩陣, G i i G_{ii} Gii?是參數 θ i \theta_i θi?的累積梯度平方和, ? \epsilon ?是一個很小的常數,用來保證數值穩定性。這種算法的缺點是因為下面的累計梯度平方和越來越大,越往后訓練的效果越弱,如果有出現異常梯度值,那直接后面的訓練就約等于無效了。
  • RMSProp(均方根傳播):
    RMSProp 是一種指數加權的移動平均算法,用于計算梯度的平方的指數衰減平均。它與 AdaGrad 類似,但是使用了梯度平方的指數衰減平均而不是累積和,避免了學習率變得過小的問題。更新規則如下:
    G i i = γ G i i + ( 1 ? γ ) ? ( ? θ L ( θ ) i ) 2 G_{ii} = \gamma G_{ii} + (1 - \gamma) \cdot (\nabla_\theta L(\theta)_i)^2 Gii?=γGii?+(1?γ)?(?θ?L(θ)i?)2
    θ i = θ i ? η G i i + ? ? ? θ L ( θ ) i \theta_i = \theta_i - \frac{\eta}{\sqrt{G_{ii} + \epsilon}} \cdot \nabla_\theta L(\theta)_i θi?=θi??Gii?+? ?η???θ?L(θ)i?。其中, γ \gamma γ是衰減率。

    歷史的梯度只占一部分,避免了因為歷史梯度導致G不斷增大,進而出現無法更新的情況。

Adam & AdamW

  • Adam(自適應矩估計):
    14年提出,Adam 結合了 AdaGrad 和 RMSProp 的優點,同時計算了梯度的一階矩(均值)和二階矩(方差)的指數加權移動平均。Adam 的更新規則較為復雜,涉及兩個時刻的估計量:
    m t = β 1 m t ? 1 + ( 1 ? β 1 ) ? ? θ L ( θ ) m_t = \beta_1 m_{t-1} + (1 - \beta_1) \cdot \nabla_\theta L(\theta) mt?=β1?mt?1?+(1?β1?)??θ?L(θ)

    就是上面提到的一階動量部分,借鑒Momentum部分

    v t = β 2 v t ? 1 + ( 1 ? β 2 ) ? ( ? θ L ( θ ) ) 2 v_t = \beta_2 v_{t-1} + (1 - \beta_2) \cdot (\nabla_\theta L(\theta))^2 vt?=β2?vt?1?+(1?β2?)?(?θ?L(θ))2

    二階動量部分,也就是借鑒RMSProp部分

    這樣雖然避免了后期無法更新的問題,但是引入了一個新的問題,那就是因為有衰弱因數,導致在剛開始訓練的時候梯度信息積累太慢,因此在更新的時候設一個無偏估計,使用該無偏估計來進行更新
    m ^ t = m t 1 ? β 1 t \hat{m}_t = \frac{m_t}{1-\beta_1^t} m^t?=1?β1t?mt??
    v ^ t = v t 1 ? β 1 t \hat{v}_t = \frac{v_t}{1-\beta_1^t} v^t?=1?β1t?vt??
    θ t + 1 = θ t ? η v ^ t + ? ? m ^ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \cdot \hat{m}_t θt+1?=θt??v^t? ?+?η??m^t?
    其中, m t m_t mt? v t v_t vt? 分別是梯度的一階矩和二階矩的估計, m ^ t \hat{m}_t m^t? v ^ t \hat{v}_t v^t? 是它們的無偏估計量,( \beta_1 ) 和 ( \beta_2 ) 是超參數。

    無偏估計的意思是:在大量數據的時候,估計量(estimator)的期望值(或平均值)等于被估計的參數的真實值。

在transformer模型中常用,因為transformer的Lipschitz常量很大,每一層的Lipschitz常量差異又很大,學習率很難估計,而且學習完表現也比較差。所以mSGD基本不用,都是用Adam。

mSGD在卷積網絡的時候效果還是不錯的,能夠和Adam打個平手
Lipschitz常量是指在Lipschitz連續中的一個量,能夠體現凸函數的變化率。Lipschitz常量差異大就表示不同函數間相同的自變量變化導致的因變量變化差異大,簡而言之也就是學習率需要設置的不同。

AdamW

AdamW和Adam基本一致,只有對正則項的處理不一致。Adam和前面是其他的一樣,都是在損失函數里面加一個正則項,但是當訓練時,前期梯度太大,會把正則項淹沒掉,后期梯度太小,正則項又會把梯度信息淹沒掉,AdamW的目的是為了平衡這兩項。
AdamW中W 代表權重衰減(Weight Decay),將原本的正則項改為weight decay,將原本在損失函數中的項,放到了權重更新公式中:
θ t + 1 = θ t ? η v ^ t + ? ? m ^ t ? λ θ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \cdot \hat{m}_t - \lambda \theta_{t} θt+1?=θt??v^t? ?+?η??m^t??λθt?
在 AdamW 中,權重衰減不是直接從參數更新中減去,而是作為參數更新的一部分。這樣做的好處是:

  • 保持自適應學習率的一致性:權重衰減與自適應學習率相結合,確保了不同參數的學習率保持一致。
  • 提高收斂性和穩定性:調整后的權重衰減有助于算法更快地收斂,并提高了訓練過程的穩定性。

優化器內存占用

在進行小模型訓練時,對于優化器的內存占用不是很關注,但是在進行大模型訓練時,優化器的內存占用非常大,就需要專門考慮了,大模型常用的優化器為AdamW。
AdamW算法的內存占用相對較高,因為它需要同時保存一階和二階矩。具體來說,AdamW算法在優化過程中需要存儲以下內容:

  1. 參數的當前值 θ \theta θ
  2. 梯度的一階矩估計(即一階動量) m \mathbf{m} m
  3. 梯度的二階矩估計(即二階動量) v \mathbf{v} v

每個參數 θ \theta θ 都需要額外存儲兩個與其尺寸相同的向量 m \mathbf{m} m v \mathbf{v} v,這導致內存占用大約是原始參數內存的兩倍。此外,還需要存儲超參數。在大規模訓練或參數量非常大的模型中,這種內存占用可能會成為一個問題。例如,在訓練具有數百萬參數的模型時,使用AdamW可能會導致顯著的內存需求增加,這可能限制了模型的大小或訓練并行度。
對于參數量為 Φ \Phi Φ的模型,使用混合精度進行訓練,模型參數本身使用fp16存儲,占用 2 Φ 2\Phi 個字節,同樣模型梯度占用 2 Φ 2\Phi 個字節,Adam狀態(fp32的模型參數備份,fp32的momentum和fp32的variance)一共要占用 12 Φ 12\Phi 12Φ個字節,這兩個統稱模型狀態,共占用 16 Φ 16\Phi 16Φ個字節

混合精度訓練時,前向傳播和反向傳播都是fp16,但是參數更新時使用fp32。

針對顯存這個問題,微軟提出了ZeRO技術,將模型狀態進行分片,對于N個GPU,每個GPU中保存 1 N \frac{1}{N} N1?的模型狀態量

實現

Pytorch中優化器:官方教程

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/13800.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/13800.shtml
英文地址,請注明出處:http://en.pswp.cn/web/13800.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【NLP】詞性標注

詞 詞是自然語言處理的基本單位,自動詞法分析就是利用計算機對詞的形態進行分析,判斷詞的結構和類別。 詞性(Part of Speech)是詞匯最重要的特性,鏈接詞匯和句法 詞的分類 屈折語:形態分析 分析語&#…

k8s 1.24.x之后如果rest 訪問apiserver

1.由于 在 1.24 (還是 1.20 不清楚了)之后,下面這兩個apiserver的配置已經被棄用 了,簡單的說就是想不安全的訪問k8s是不可能了,所以只能走安全的訪問方式也就是 https://xx:6443了,所以需要證書。 - --ins…

Git系列:git rm 的高級使用技巧

💝💝💝歡迎蒞臨我的博客,很高興能夠在這里和您見面!希望您在這里可以感受到一份輕松愉快的氛圍,不僅可以獲得有趣的內容和知識,也可以暢所欲言、分享您的想法和見解。 推薦:「stormsha的主頁」…

【go項目01_學習記錄15】

重構MVC 1 Article 模型1.1 首先創建 Article 模型文件1.2 接下來創建獲取文章的方法1.3 新增 types.StringToUint64()函數1.4 修改控制器的調用1.5 重構 route 包1.6 通過 SetRoute 來傳參對象變量1.7 新增方法:1.8 控制器將 Int64ToString 改為 Uint64ToString1.9…

【數據結構】棧和隊列的相互實現

歡迎瀏覽高耳機的博客 希望我們彼此都有更好的收獲 感謝三連支持! 1.用棧實現隊列 當隊列中進入這些元素時,相應的棧1中元素出棧順序與出隊列相反,因此我們可以使用兩個棧來使元素的出棧順序相同; 通過將棧1元素出棧,再…

Databend 倒排索引的設計與實現

倒排索引是一種用于全文搜索的數據結構。它的主要功能是將文檔中的單詞作為索引項,映射到包含該單詞的文檔列表。通過倒排索引,可以快速準確地定位到與查詢詞相匹配的文檔列表,從而大幅提高查詢性能。倒排索引在搜索引擎、數據庫和信息檢索系…

matlab實現繪制煙花代碼

下面是一個簡化的示例,它使用MATLAB的繪圖功能來模擬煙花爆炸的視覺效果。請注意,這個示例是概念性的,并且可能需要根據您的具體需求進行調整。 % 初始化參數 num_fireworks 5; % 煙花數量 num_particles_per_firework 200; % 每個煙花…

前端 CSS 經典:3D 漸變輪播圖

前言&#xff1a;無論什么樣式的輪播圖&#xff0c;核心 JS 實現原理都差不多。所以小伙伴們&#xff0c;還是需要了解一下核心 JS 實驗原理的。 效果圖&#xff1a; 實現代碼&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta chars…

MySQL —— 復合查詢

一、基本的查詢回顧練習 前面兩章節整理了許多關于查詢用到的語句和關鍵字&#xff0c;以及MySQL的內置函數&#xff0c;我們先用一些簡單的查詢練習去回顧之前的知識 1. 前提準備 同樣是前面用到的用于測試的表格和數據&#xff0c;一張學生表和三張關于雇員信息表 雇員信息…

優化數據查詢性能:StarRocks 與 Apache Iceberg 的強強聯合

Apache Iceberg 是一種開源的表格格式&#xff0c;專為在數據湖中存儲大規模分析數據而設計。它與多種大數據生態系統組件高度兼容&#xff0c;相較于傳統的 Hive 表格格式&#xff0c;Iceberg 在設計上提供了更高的性能和更好的可擴展性。它支持 ACID 事務、Schema 演化、數據…

leetcode-設計LRU緩存結構-112

題目要求 思路 雙鏈表哈希表 代碼實現 struct Node{int key, val;Node* next;Node* pre;Node(int _key, int _val): key(_key), val(_val), next(nullptr), pre(nullptr){} };class Solution { public: unordered_map<int, Node*> hash; Node* head; Node* tail; int …

普源DHO924示波器OFFSET設置

一、簡介 示波器是電子工程師常用的測量工具之一&#xff0c;能夠直觀地顯示電路信號的波形和參數。普源DHO924是一款優秀的數字示波器&#xff0c;具有優異的性能和易用性。其中OFFSET功能可以幫助用戶調整信號的垂直位置&#xff0c;使波形更清晰易讀。本文將詳細介紹DHO924…

專注于運動控制芯片、運動控制產品研發、生產與銷售為一體的技術型芯片代理商、方案商——青牛科技

深圳市青牛科技實業有限公司,是專注于運 動控制芯片、運動控制產品研發、生產與銷售為一體的技術型 芯片代理商、方案商。現今代理了國產品牌GLOBALCHIP&#xff0c;芯谷&#xff0c;矽普&#xff0c;TOPPOWER等品牌。其中代理品牌TOPPOWER為電源模塊&#xff0c;他們公司通過了…

cherry-pick的強大之處在于哪里

git cherry-pick 的強大之處在于它提供了一種靈活的方式來應用特定的提交到不同的分支上&#xff0c;而無需合并整個分支或拉取其他不需要的提交。以下是 git cherry-pick 的幾個主要優點和強大之處&#xff1a; 選擇性應用提交&#xff1a;你可以挑選一個或多個特定的提交&…

聲音轉文本(免費工具)

聲音轉文本&#xff1a;解鎖語音技術的無限可能 在當今這個數字化時代&#xff0c;信息的傳遞方式正以前所未有的速度進化。從手動輸入到觸控操作&#xff0c;再到如今的語音交互&#xff0c;技術的發展讓溝通變得更加自然與高效。聲音轉文本&#xff08;Speech-to-Text, STT&…

Plant Simulation驗證AGV算法

Plant Simulation驗證算法也是非常高效且直觀的&#xff0c;一直以來波哥在迭代算法的時候圖形顯示這塊都是使用Openframeworks去做&#xff0c;效果還是非常不錯的。 這里簡要介紹一下openFrameworks&#xff0c;openFrameworks是一個開源的、跨平臺的 C 工具包。旨在開發實時…

LeetCode hot100-49-N

236. 二叉樹的最近公共祖先 給定一個二叉樹, 找到該樹中兩個指定節點的最近公共祖先。百度百科中最近公共祖先的定義為&#xff1a;“對于有根樹 T 的兩個節點 p、q&#xff0c;最近公共祖先表示為一個節點 x&#xff0c;滿足 x 是 p、q 的祖先且 x 的深度盡可能大&#xff08;…

爬蟲學習--12.MySQL數據庫的基本操作(下)

MySQL查詢數據 MySQL 數據庫使用SQL SELECT語句來查詢數據。 語法&#xff1a;在MySQL數據庫中查詢數據通用的 SELECT 語法 SELECT 字段1&#xff0c;字段2&#xff0c;……&#xff0c;字段n FROM table_name [WHERE 條件] [LIMIT N] 查詢語句中你可以使用一個或者多個表&…

uni-app項目在微信開發者工具打開時報錯[ app.json 文件內容錯誤] app.json: 在項目根目錄未找到 app.json

uni-app項目在微信開發者工具打開時報錯[ app.json 文件內容錯誤] app.json: 在項目根目錄未找到 app.json 出現這個問題是因為打開的文件地址不對&#xff0c;解決這個問題首先我們要查看是否有unpackage文件夾&#xff0c;如果有&#xff0c;項目直接指向unpackage\dist\dev\…

vue3使用mitt.js進行各種組件間通信

我們在vue工程中&#xff0c;除開vue自帶的什么父子間&#xff0c;祖孫間通信&#xff0c;還有一個非常方便的通信方式&#xff0c;類似Vue2.x 使用 EventBus 進行組件通信&#xff0c;而 Vue3.x 推薦使用 mitt.js。可以實現各個組件間的通信 優點&#xff1a;首先它足夠小&…