梯度下降優化算法-Adam

Adam(Adaptive Moment Estimation)是一種結合了動量法(Momentum)和 RMSProp 的自適應學習率優化算法。它通過計算梯度的一階矩(均值)和二階矩(未中心化的方差)來調整每個參數的學習率,從而在深度學習中表現出色。


1. Adam 的數學原理

1.1 動量法和 RMSProp 的回顧

  • 動量法:通過引入動量變量,加速梯度下降并減少震蕩。
  • RMSProp:通過指數加權移動平均計算歷史梯度平方和,自適應調整學習率。

Adam 結合了這兩種方法的優點,同時計算梯度的一階矩和二階矩。


1.2 Adam 的更新規則

Adam 的更新規則分為以下幾個步驟:

1.2.1 梯度計算

首先,計算當前時刻的梯度:

g t = ? θ J ( θ t ) g_t = \nabla_\theta J(\theta_t) gt?=?θ?J(θt?)

其中:

  • g t g_t gt? 是當前時刻的梯度向量,形狀與參數 θ t \theta_t θt? 相同。

1.2.2 一階矩估計(動量)

Adam 使用指數加權移動平均來計算梯度的一階矩(均值):

m t = β 1 ? m t ? 1 + ( 1 ? β 1 ) ? g t m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t mt?=β1??mt?1?+(1?β1?)?gt?

其中:

  • m t m_t mt? 是梯度的一階矩估計。
  • β 1 \beta_1 β1? 是一階矩的衰減率,通常取值在 [ 0.9 , 0.99 ) [0.9, 0.99) [0.9,0.99) 之間。
  • 初始時, m 0 m_0 m0? 通常設置為 0。

1.2.3 二階矩估計(RMSProp)

Adam 使用指數加權移動平均來計算梯度的二階矩(未中心化的方差):

v t = β 2 ? v t ? 1 + ( 1 ? β 2 ) ? g t 2 v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 vt?=β2??vt?1?+(1?β2?)?gt2?

其中:

  • v t v_t vt? 是梯度的二階矩估計。
  • β 2 \beta_2 β2? 是二階矩的衰減率,通常取值在 [ 0.99 , 0.999 ) [0.99, 0.999) [0.99,0.999) 之間。
  • g t 2 g_t^2 gt2? 表示對梯度向量 g t g_t gt? 逐元素平方。
  • 初始時, v 0 v_0 v0? 通常設置為 0。

1.2.4 偏差校正

由于 m t m_t mt? v t v_t vt? 初始值為 0,在訓練初期會偏向 0,因此需要進行偏差校正:

m ^ t = m t 1 ? β 1 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t} m^t?=1?β1t?mt??

v ^ t = v t 1 ? β 2 t \hat{v}_t = \frac{v_t}{1 - \beta_2^t} v^t?=1?β2t?vt??

其中:

  • m ^ t \hat{m}_t m^t? 是校正后的一階矩估計。
  • v ^ t \hat{v}_t v^t? 是校正后的二階矩估計。
  • t t t 是當前時間步。

1.2.5 參數更新

最后,Adam 的參數更新公式為:

θ t + 1 = θ t ? η v ^ t + ? ? m ^ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \cdot \hat{m}_t θt+1?=θt??v^t? ?+?η??m^t?

其中:

  • η \eta η 是全局學習率。
  • ? \epsilon ? 是一個很小的常數(通常為 1 0 ? 8 10^{-8} 10?8),用于避免分母為零。
  • v ^ t + ? \sqrt{\hat{v}_t} + \epsilon v^t? ?+? 是對校正后的二階矩估計逐元素開平方。

2. Adam 的詳細推導

2.1 一階矩和二階矩的意義

  • 一階矩 m t m_t mt?:類似于動量法,表示梯度的指數加權移動平均,用于加速收斂。
  • 二階矩 v t v_t vt?:類似于 RMSProp,表示梯度平方的指數加權移動平均,用于自適應調整學習率。

2.2 偏差校正的作用

偏差校正的目的是解決初始階段 m t m_t mt? v t v_t vt? 偏向 0 的問題。通過除以 1 ? β 1 t 1 - \beta_1^t 1?β1t? 1 ? β 2 t 1 - \beta_2^t 1?β2t?,可以校正估計值,使其更接近真實值。


2.3 小常數 ? \epsilon ? 的作用

小常數 ? \epsilon ? 的作用是避免分母為零。具體來說:

  • v ^ t \hat{v}_t v^t? 很小時, v ^ t + ? \sqrt{\hat{v}_t} + \epsilon v^t? ?+? 接近于 ? \epsilon ?,避免學習率過大。
  • v ^ t \hat{v}_t v^t? 很大時, ? \epsilon ? 的影響可以忽略不計。

3. PyTorch 中的 Adam 實現

在 PyTorch 中,Adam 通過 torch.optim.Adam 實現。以下是 torch.optim.Adam 的主要參數:

參數名含義
params需要優化的參數(通常是模型的參數)。
lr全局學習率(learning rate),即 η \eta η,默認值為 1 0 ? 3 10^{-3} 10?3
betas一階矩和二階矩的衰減率,即 ( β 1 , β 2 ) (\beta_1, \beta_2) (β1?,β2?),默認值為 (0.9, 0.999)。
eps分母中的小常數 ? \epsilon ?,用于避免除零,默認值為 1 0 ? 8 10^{-8} 10?8
weight_decay權重衰減(L2 正則化)系數,默認值為 0。
amsgrad是否使用 AMSGrad 變體,默認值為 False

3.1 使用 Adam 的代碼示例

以下是一個使用 Adam 的完整代碼示例:

import torch
import torch.nn as nn
import torch.optim as optim# 定義一個簡單的線性模型
model = nn.Linear(10, 1)# 定義損失函數
criterion = nn.MSELoss()# 定義優化器,使用 Adam
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)# 模擬輸入數據和目標數據
inputs = torch.randn(32, 10)  # 32 個樣本,每個樣本 10 維
targets = torch.randn(32, 1)  # 32 個目標值# 訓練過程
for epoch in range(100):# 前向傳播outputs = model(inputs)loss = criterion(outputs, targets)# 反向傳播optimizer.zero_grad()  # 清空梯度loss.backward()        # 計算梯度# 更新參數optimizer.step()       # 更新參數# 打印損失if (epoch + 1) % 10 == 0:print(f"Epoch [{epoch+1}/100], Loss: {loss.item():.4f}")

3.2 參數設置說明

  1. 學習率 (lr)

    • 學習率 η \eta η 控制每次參數更新的步長。
    • 在 Adam 中,學習率會自適應調整,因此初始學習率可以設置得稍小一些。
  2. 衰減率 (betas)

    • 一階矩衰減率 β 1 \beta_1 β1? 和二階矩衰減率 β 2 \beta_2 β2? 分別控制一階矩和二階矩的衰減速度。
    • 默認值為 (0.9, 0.999),適用于大多數情況。
  3. 小常數 (eps)

    • 小常數 ? \epsilon ? 用于避免分母為零,通常設置為 1 0 ? 8 10^{-8} 10?8
  4. 權重衰減 (weight_decay)

    • 權重衰減系數用于 L2 正則化,防止過擬合。
  5. AMSGrad (amsgrad)

    • 如果設置為 True,則使用 AMSGrad 變體,解決 Adam 在某些情況下的收斂問題。

4. 總結

  • Adam 的核心思想:結合動量法和 RMSProp,通過計算梯度的一階矩和二階矩,自適應調整學習率。
  • Adam 的更新公式
    m t = β 1 ? m t ? 1 + ( 1 ? β 1 ) ? g t m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t mt?=β1??mt?1?+(1?β1?)?gt?
    v t = β 2 ? v t ? 1 + ( 1 ? β 2 ) ? g t 2 v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 vt?=β2??vt?1?+(1?β2?)?gt2?
    m ^ t = m t 1 ? β 1 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t} m^t?=1?β1t?mt??
    v ^ t = v t 1 ? β 2 t \hat{v}_t = \frac{v_t}{1 - \beta_2^t} v^t?=1?β2t?vt??
    θ t + 1 = θ t ? η v ^ t + ? ? m ^ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \cdot \hat{m}_t θt+1?=θt??v^t? ?+?η??m^t?
  • PyTorch 實現:使用 torch.optim.Adam,設置 lrbetaseps 等參數。
  • 優缺點
    • 優點:自適應學習率,適合非凸優化問題,收斂速度快。
    • 缺點:需要手動調整超參數(如 β 1 \beta_1 β1? β 2 \beta_2 β2?)。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/67212.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/67212.shtml
英文地址,請注明出處:http://en.pswp.cn/web/67212.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

論文筆記(六十三)Understanding Diffusion Models: A Unified Perspective(六)(完結)

Understanding Diffusion Models: A Unified Perspective(六)(完結) 文章概括指導(Guidance)分類器指導無分類器引導(Classifier-Free Guidance) 總結 文章概括 引用: …

【PySide6快速入門】信號與槽的使用

文章目錄 前言什么是信號與槽信號與槽的功能最簡單的信號與槽控件連接信號與信號的連接總結 前言 在 PySide6 中,信號與槽機制是核心概念之一,它是 Qt 庫中事件通信的基礎。通過信號與槽,開發者能夠實現不同組件之間的解耦,從而使…

GOGOGO 枚舉

含義:一種類似于類的一種結構 作用:是Java提供的一個數據類型,可以設置值是固定的 【當某一個數據類型受自身限制的時候,使用枚舉】 語法格式: public enum 枚舉名{…… }有哪些成員? A、對象 public …

AWTK 骨骼動畫控件發布

Spine 是一款廣泛使用的 2D 骨骼動畫工具,專為游戲開發和動態圖形設計設計。它通過基于骨骼的動畫系統,幫助開發者創建流暢、高效的角色動畫。本項目是基于 Spine 實現的 AWTK 骨骼動畫控件。 代碼:https://gitee.com/zlgopen/awtk-widget-s…

[免費]基于Python的Django博客系統【論文+源碼+SQL腳本】

大家好,我是java1234_小鋒老師,看到一個不錯的基于Python的Django博客系統,分享下哈。 項目視頻演示 【免費】基于Python的Django博客系統 Python畢業設計_嗶哩嗶哩_bilibili 項目介紹 隨著互聯網技術的飛速發展,信息的傳播與…

如何將電腦桌面默認的C盤設置到D盤?詳細操作步驟!

將電腦桌面默認的C盤設置到D盤的詳細操作步驟! 本博文介紹如何將電腦桌面(默認為C盤)設置在D盤下。 首先,在D盤建立文件夾Desktop,完整的路徑為D:\Desktop。winR,輸入Regedit命令。(或者單擊【…

C++ 寫一個簡單的加減法計算器

************* C topic:結構 ************* Structure is a very intersting issue. I really dont like concepts as it is boring. I would like to cases instead. If I want to learn something, donot hesitate to make shits. Like building a house. Wh…

excel如何查找一個表的數據在另外一個表是否存在

比如“Sheet1”有“張三”、“李四”“王五”三個人的數據,“Sheet2”只有“張三”、“李四”的數據。我們通過修改“Sheet1”的“民族”或者其他空的列,修改為“Sheet2”的某一列。這樣修改后篩選這個修改的列為空的或者為出錯的,就能找到兩…

MySQL 基礎學習(2): INSERT 操作

在這篇文章中,我們將專注于 MySQL 中的 INSERT 操作,深入了解如何高效地向表中插入數據,并探索插入操作中的一些常見錯誤與解決方案。 一、基礎 INSERT 語法 在 MySQL 中,INSERT 操作用于向表中插入新記錄,基本語法如…

CVE-2023-38831 漏洞復現:win10 壓縮包掛馬攻擊剖析

目錄 前言 漏洞介紹 漏洞原理 產生條件 影響范圍 防御措施 復現步驟 環境準備 具體操作 前言 在網絡安全這片沒有硝煙的戰場上,新型漏洞如同隱匿的暗箭,時刻威脅著我們的數字生活。其中,CVE - 2023 - 38831 這個關聯 Win10 壓縮包掛…

論文閱讀(二):理解概率圖模型的兩個要點:關于推理和學習的知識

1.論文鏈接:Essentials to Understand Probabilistic Graphical Models: A Tutorial about Inference and Learning 摘要: 本章的目的是為沒有概率圖形模型背景或沒有深入背景的科學家提供一個高級教程。對于更熟悉這些模型的讀者,本章將作為…

記錄 | 基于Docker Desktop的MaxKB安裝

目錄 前言一、MaxKBStep 1Step2 二、運行MaxKB更新時間 前言 參考文章:如何利用智譜全模態免費模型,生成大家都喜歡的圖、文、視并茂的文章! MaxKB的Github下載地址 參考視頻:【2025最新MaxKB教程】10分鐘學會一鍵部署本地私人專屬…

Go反射指南

概念: 官方對此有個非常簡明的介紹,兩句話耐人尋味: 反射提供一種讓程序檢查自身結構的能力反射是困惑的源泉 第1條,再精確點的描述是“反射是一種檢查interface變量的底層類型和值的機制”。 第2條,很有喜感的自嘲…

第26篇 基于ARM A9處理器用C語言實現中斷<二>

Q:基于ARM A9處理器怎樣編寫C語言工程,使用按鍵中斷將數字顯示在七段數碼管上呢? A:基本原理:主程序需要首先調用子程序set_A9_IRQ_stack()初始化IRQ模式的ARM A9堆棧指針;然后主程序調用子程序config_GIC…

基于GS(Gaussian Splatting)的機器人Sim2Real2Sim仿真平臺

項目地址:RoboGSim 背景簡介 已有的數據采集方法中,遙操作(下左)是數據質量高,但采集成本高、效率低下;傳統仿真流程成本低(下右),但真實度(如紋理、物理&…

「 機器人 」利用沖程對稱性調節實現仿生飛行器姿態與方向控制

前言 在仿生撲翼飛行器中,通過改變沖程對稱性這一技術手段,可以在上沖與下沖兩個階段引入不對稱性,進而產生額外的力或力矩,用于實現俯仰或其他姿態方向的控制。以下從原理、在仿生飛行器中的應用和典型實驗示例等方面進行梳理與闡述。 1. 沖程對稱性原理 1.1 概念:上沖與…

MongoDB部署模式

目錄 單節點模式(Standalone) 副本集模式(Replica Set) 分片集群模式(Sharded Cluster) MongoDB有多種部署模式,可以根據業務需求選擇適合的架構和部署方式。 單節點模式(Standa…

微服務搭建----springboot接入Nacos2.x

springboot接入Nacos2.x nacos之前用的版本是1.0的,現在重新搭建一個2.0版本的,學如逆水行舟,不進則退,廢話不多說,開搞 1、 nacos2.x搭建 1,首先第一步查詢下項目之間的版本對照,不然后期會…

react-native網絡調試工具Reactotron保姆級教程

在React Native開發過程中,調試和性能優化是至關重要的環節。今天,就來給大家分享一個非常強大的工具——Reactotron,它就像是一個貼心的助手,能幫助我們更輕松地追蹤問題、優化性能。下面就是一份保姆級教程哦! 一、…

npm啟動前端項目時報錯(vue) error:0308010C:digital envelope routines::unsupported

vue 啟動項目時,npm run serve 報下面的錯: error:0308010C:digital envelope routines::unsupported at new Hash (node:internal/crypto/hash:67:19) at Object.createHash (node:crypto:133:10) at FSReqCallback.readFileAfterClose [as on…