# 深度學習中的優化算法詳解

深度學習中的優化算法詳解

優化算法是深度學習的核心組成部分,用于最小化損失函數以更新神經網絡的參數。本文將詳細介紹深度學習中常用的優化算法,包括其概念、數學公式、代碼示例、實際案例以及圖解,幫助讀者全面理解優化算法的原理與應用。


一、優化算法的基本概念

在深度學習中,優化算法的目標是通過迭代更新模型參數 θ \theta θ,最小化損失函數 L ( θ ) L(\theta) L(θ)。損失函數通常表示為:

L ( θ ) = 1 N ∑ i = 1 N l ( f ( x i ; θ ) , y i ) L(\theta) = \frac{1}{N} \sum_{i=1}^N l(f(x_i; \theta), y_i) L(θ)=N1?i=1N?l(f(xi?;θ),yi?)

其中:

  • f ( x i ; θ ) f(x_i; \theta) f(xi?;θ):模型對輸入 x i x_i xi? 的預測;
  • y i y_i yi?:真實標簽;
  • l l l:單個樣本的損失(如均方誤差或交叉熵);
  • N N N:樣本數量。

優化算法通過計算梯度 ? θ L ( θ ) \nabla_\theta L(\theta) ?θ?L(θ),按照一定規則更新參數 θ \theta θ,以逼近損失函數的最優解。


二、常見優化算法詳解

以下是深度學習中常用的優化算法,逐一分析其原理、公式、優缺點及代碼實現。

1. 梯度下降(Gradient Descent, GD)

概念

梯度下降通過計算整個訓練集的梯度來更新參數,公式為:

θ t + 1 = θ t ? η ? θ L ( θ t ) \theta_{t+1} = \theta_t - \eta \nabla_\theta L(\theta_t) θt+1?=θt??η?θ?L(θt?)

其中:

  • η \eta η:學習率,控制步長;
  • ? θ L ( θ t ) \nabla_\theta L(\theta_t) ?θ?L(θt?):損失函數對參數的梯度。
優缺點
  • 優點:全局梯度信息準確,適合簡單凸優化問題。
  • 缺點:計算全量梯度開銷大,速度慢,易陷入局部極小值。
代碼示例
import numpy as np# 模擬損失函數 L = (theta - 2)^2
def loss_function(theta):return (theta - 2) ** 2def gradient(theta):return 2 * (theta - 2)# 梯度下降
theta = 0.0  # 初始參數
eta = 0.1    # 學習率
for _ in range(100):grad = gradient(theta)theta -= eta * grad
print(f"優化后的參數: {theta}")  # 接近 2

在這里插入圖片描述

參數沿梯度方向逐步逼近損失函數的最優解。*


2. 隨機梯度下降(Stochastic Gradient Descent, SGD)

概念

SGD 每次僅基于單個樣本計算梯度,更新公式為:

θ t + 1 = θ t ? η ? θ l ( f ( x i ; θ t ) , y i ) \theta_{t+1} = \theta_t - \eta \nabla_\theta l(f(x_i; \theta_t), y_i) θt+1?=θt??η?θ?l(f(xi?;θt?),yi?)

優缺點
  • 優點:計算效率高,適合大規模數據集,隨機性有助于逃離局部極小值。
  • 缺點:梯度噪聲大,收斂路徑不穩定。
代碼示例
# 模擬 SGD
np.random.seed(42)
data = np.random.randn(100, 2)  # 模擬數據
labels = data[:, 0] * 2 + 1     # 模擬標簽theta = np.zeros(2)  # 初始參數
eta = 0.01
for _ in range(100):i = np.random.randint(0, len(data))x, y = data[i], labels[i]grad = -2 * (y - np.dot(theta, x)) * x  # 均方誤差梯度theta -= eta * grad
print(f"優化后的參數: {theta}")

SGD 的更新路徑波動較大,但整體趨向最優解。*


3. 小批量梯度下降(Mini-Batch Gradient Descent)

概念

Mini-Batch GD 結合 GD 和 SGD 的優點,使用小批量樣本計算梯度:

θ t + 1 = θ t ? η 1 B ∑ i ∈ batch ? θ l ( f ( x i ; θ t ) , y i ) \theta_{t+1} = \theta_t - \eta \frac{1}{B} \sum_{i \in \text{batch}} \nabla_\theta l(f(x_i; \theta_t), y_i) θt+1?=θt??ηB1?ibatch??θ?l(f(xi?;θt?),yi?)

其中 B B B 為批量大小。

優缺點
  • 優點:平衡了計算效率和梯度穩定性,廣泛應用于深度學習框架。
  • 缺點:批量大小需調優,學習率敏感。
代碼示例
import torch# 模擬數據
X = torch.randn(100, 2)
y = X[:, 0] * 2 + 1
theta = torch.zeros(2, requires_grad=True)
optimizer = torch.optim.SGD([theta], lr=0.01)# Mini-Batch GD
batch_size = 16
for _ in range(100):indices = torch.randperm(100)[:batch_size]batch_X, batch_y = X[indices], y[indices]pred = batch_X @ thetaloss = ((pred - batch_y) ** 2).mean()optimizer.zero_grad()loss.backward()optimizer.step()
print(f"優化后的參數: {theta}")

4. 動量法(Momentum)

概念

動量法通過引入速度項 v t v_t vt?,加速梯度下降,公式為:

v t + 1 = μ v t + ? θ L ( θ t ) v_{t+1} = \mu v_t + \nabla_\theta L(\theta_t) vt+1?=μvt?+?θ?L(θt?)
θ t + 1 = θ t ? η v t + 1 \theta_{t+1} = \theta_t - \eta v_{t+1} θt+1?=θt??ηvt+1?

其中 μ \mu μ 為動量系數(通常為 0.9)。

優缺點
  • 優點:加速收斂,減少震蕩。
  • 缺點:超參數需調優,可能超調。
代碼示例
# 動量法
theta = 0.0
v = 0.0
eta, mu = 0.1, 0.9
for _ in range(100):grad = gradient(theta)v = mu * v + gradtheta -= eta * v
print(f"優化后的參數: {theta}")

動量法通過累積速度平滑更新路徑。*


5. Adam(Adaptive Moment Estimation)

概念

Adam 結合動量法和自適應學習率,通過一階動量(均值)和二階動量(方差)更新參數:

m t + 1 = β 1 m t + ( 1 ? β 1 ) ? θ L ( θ t ) m_{t+1} = \beta_1 m_t + (1 - \beta_1) \nabla_\theta L(\theta_t) mt+1?=β1?mt?+(1?β1?)?θ?L(θt?)
v t + 1 = β 2 v t + ( 1 ? β 2 ) ( ? θ L ( θ t ) ) 2 v_{t+1} = \beta_2 v_t + (1 - \beta_2) (\nabla_\theta L(\theta_t))^2 vt+1?=β2?vt?+(1?β2?)(?θ?L(θt?))2
m ^ t + 1 = m t + 1 1 ? β 1 t + 1 , v ^ t + 1 = v t + 1 1 ? β 2 t + 1 \hat{m}_{t+1} = \frac{m_{t+1}}{1 - \beta_1^{t+1}}, \quad \hat{v}_{t+1} = \frac{v_{t+1}}{1 - \beta_2^{t+1}} m^t+1?=1?β1t+1?mt+1??,v^t+1?=1?β2t+1?vt+1??
θ t + 1 = θ t ? η m ^ t + 1 v ^ t + 1 + ? \theta_{t+1} = \theta_t - \eta \frac{\hat{m}_{t+1}}{\sqrt{\hat{v}_{t+1}} + \epsilon} θt+1?=θt??ηv^t+1? ?+?m^t+1??

其中:

  • β 1 = 0.9 \beta_1 = 0.9 β1?=0.9 β 2 = 0.999 \beta_2 = 0.999 β2?=0.999
  • ? = 1 0 ? 8 \epsilon = 10^{-8} ?=10?8,防止除零。
優缺點
  • 優點:自適應學習率,收斂快,適合復雜模型。
  • 缺點:可能過早收斂到次優解。
代碼示例
import torch.optim as optim# 使用 PyTorch 的 Adam
model = torch.nn.Linear(2, 1)
optimizer = optim.Adam(model.parameters(), lr=0.001)
for _ in range(100):pred = model(X)loss = ((pred - y) ** 2).mean()optimizer.zero_grad()loss.backward()optimizer.step()
print(f"優化后的參數: {model.weight}")

Adam 通過自適應步長快速逼近最優解。*


三、實際案例:優化神經網絡

任務

使用 PyTorch 訓練一個簡單的二分類神經網絡,比較 SGD 和 Adam 的性能。

代碼實現
import torch
import torch.nn as nn
import matplotlib.pyplot as plt# 生成模擬數據
X = torch.randn(1000, 2)
y = (X[:, 0] + X[:, 1] > 0).float().reshape(-1, 1)# 定義模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc = nn.Linear(2, 1)def forward(self, x):return torch.sigmoid(self.fc(x))# 訓練函數
def train(model, optimizer, epochs=100):criterion = nn.BCELoss()losses = []for _ in range(epochs):pred = model(X)loss = criterion(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()losses.append(loss.item())return losses# 比較 SGD 和 Adam
model_sgd = Net()
model_adam = Net()
optimizer_sgd = optim.SGD(model_sgd.parameters(), lr=0.01)
optimizer_adam = optim.Adam(model_adam.parameters(), lr=0.001)losses_sgd = train(model_sgd, optimizer_sgd)
losses_adam = train(model_adam, optimizer_adam)# 繪制損失曲線
plt.plot(losses_sgd, label="SGD")
plt.plot(losses_adam, label="Adam")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()
結果分析

Adam 通常比 SGD 收斂更快,損失下降更平穩,但在某些任務中 SGD 配合動量可能獲得更好的泛化性能。


四、優化算法選擇建議

  1. 小型數據集:SGD + 動量,簡單且泛化能力強。
  2. 復雜模型(如深度神經網絡):Adam 或其變體(如 AdamW),收斂速度快。
  3. 超參數調優
    • 學習率:嘗試 1 0 ? 3 10^{-3} 10?3 1 0 ? 5 10^{-5} 10?5
    • 批量大小:16、32 或 64;
    • 動量系數:0.9 或 0.99。

五、總結

優化算法是深度學習訓練的基石,從簡單的梯度下降到自適應的 Adam,每種算法都有其適用場景。通過理解其數學原理、代碼實現和實際表現,開發者可以根據任務需求選擇合適的優化策略。


本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/900802.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/900802.shtml
英文地址,請注明出處:http://en.pswp.cn/news/900802.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

汽車的四大工藝

文章目錄 沖壓工藝核心流程關鍵技術 焊接工藝核心流程 涂裝工藝核心流程 總裝工藝核心流程終檢與測試靜態檢查動態檢查四輪定位制動轉鼓測試淋雨測試總結 簡單總結下汽車的四大工藝(從網上找了一張圖,感覺挺全面的)。 沖壓工藝 將金屬板材通過…

Perl 發送郵件

Perl 發送郵件 概述 Perl 是一種強大的編程語言,廣泛應用于系統管理、網絡編程和數據分析等領域。其中,使用 Perl 發送郵件是一項非常實用的技能。本文將詳細介紹使用 Perl 發送郵件的方法,包括必要的配置、代碼示例以及注意事項。 準備工…

關于柔性數組

以前確實沒關注過這個問題,一直都是直接定義固定長度的數組,盡量減少指針的操作。 柔性數組主要是再結構體里面定義一個長度為0的數組,這里和定義一個指針式存在明顯去別的。定義一個指針會占用內存,但是定義一個長度為0的數組不會…

NOIP2011提高組.瑪雅游戲

目錄 題目算法標簽: 模擬, 搜索, d f s dfs dfs, 剪枝優化思路*詳細注釋版代碼精簡注釋版代碼 題目 185. 瑪雅游戲 算法標簽: 模擬, 搜索, d f s dfs dfs, 剪枝優化 思路 可行性剪枝 如果某個顏色的格子數量少于 3 3 3一定無解因為要求字典序最小, 因此當一個格子左邊有…

go游戲后端開發29:實現游戲內聊天

接下來,我們再來開發一個功能,這個功能相對簡單,就是聊天。在游戲里,我們會收到一個聊天請求,我們只需要做一個聊天推送即可。具體來說,就是誰發的消息,就推送給所有人,包括消息內容…

基于大數據的美團外賣數據可視化分析系統

【大數據】基于大數據的美團外賣數據可視化分析系統 (完整系統源碼開發筆記詳細部署教程)? 目錄 一、項目簡介二、項目界面展示三、項目視頻展示 一、項目簡介 該系統通過對海量外賣數據的深度挖掘與分析,能夠為美團外賣平臺提供運營決策支…

[ctfshow web入門] web32

前置知識 協議相關博客:https://blog.csdn.net/m0_73353130/article/details/136212770 include:include "filename"這是最常用的方法,除此之外還可以 include url,被包含的文件會被當做代碼執行。 data://&#xff1a…

kotlin中const 和val的區別

在 Kotlin 中,const 和 val 都是用來聲明常量的,但它們的使用場景和功能有所不同: 1. val: val 用于聲明只讀變量,也就是不可修改的變量(類似于 Java 中的 final 變量)。它可以是任何類型,包括…

【STM32】綜合練習——智能風扇系統

目錄 0 前言 1 硬件準備 2 功能介紹 3 前置配置 3.1 時鐘配置 3.2 文件配置 4 功能實現 4.1 按鍵功能 4.2 屏幕功能 4.3 調速功能 4.4 倒計時功能 4.5 搖頭功能 4.6 測距待機功能 0 前言 由于時間關系,暫停詳細更新,本文章中,…

任務擴展-輸入商品原價,折扣并計算促銷后的價格

1.在HbuilderX軟件中創建項目,把項目的路徑放在xampp中的htdocs 2.創建php文件:price.php,price_from.php 3.在瀏覽器中,運行項目效果,通過xampp中admin進行運行瀏覽,在后添加文件名稱即可,注意&#xff…

3D Gaussian Splatting as MCMC 與gsplat中的應用實現

3D高斯潑濺(3D Gaussian splatting)自2023年提出以后,相關研究paper井噴式增長,盡管出現了許多改進版本,但依舊面臨著諸多挑戰,例如實現照片級真實感、應對高存儲需求,而 “懸浮的高斯核” 問題就是其中之一。浮動高斯核通常由輸入圖像中的曝光或顏色不一致引發,也可能…

【軟件測試】Postman中如何搭建Mock服務

在 Postman 中,Mock 服務是一項非常有用的功能,允許你在沒有實際后端服務器的情況下模擬 API 響應。通過創建 Mock 服務,你可以在開發階段或測試中模擬 API 的行為,幫助團隊成員進行前端開發、API 測試和集成測試等工作。 Mock 服…

Spring-MVC

Spring-MVC 1.SpringMVC簡介 - SpringMVC概述 SpringMVC是一個基于Spring開發的MVC輕量級框架,Spring3.0后發布的組件,SpringMVC和Spring可以無縫整合,使用DispatcherServlet作為前端控制器,且內部提供了處理器映射器、處理器適…

關于Spring MVC中@RequestParam注解的詳細說明,用于在前后端參數名稱不一致時實現參數映射。包含代碼示例和總結表格

以下是關于Spring MVC中RequestParam注解的詳細說明,用于在前后端參數名稱不一致時實現參數映射。包含代碼示例和總結表格: 1. 核心作用 RequestParam用于顯式綁定HTTP請求參數到方法參數,支持以下場景: 參數名不一致&#xff1…

MySQL主從復制技術詳解:原理、實現與最佳實踐

目錄 引言:MySQL主從復制的技術基礎 MySQL主從復制的實現機制 復制架構與線程模型 復制連接建立過程 數據變更與傳輸流程 MySQL不同復制方式的特點與適用場景 異步復制(Asynchronous Replication) 全同步復制(Fully Synch…

ROS Master多設備連接

Bash Shell Shell是位于用戶與操作系統內核之間的橋梁,當用戶在終端敲入命令后,這些輸入首先會進入內核中的tty子系統,TTY子系統負責捕獲并處理終端的輸入輸出流,確保數據正確無誤的在終端和系統內核之中。Shell在此過程不僅僅是…

Trae + LangGPT 生成結構化 Prompt

Trae LangGPT 生成結構化 Prompt 0. 引言1. 安裝 Trae2. 克隆 LangGPT3. Trae 和 LangGPT 聯動4. 集成到 Dify 中 0. 引言 Github 上 LangGPT 這個項目,主要向我們介紹了寫結構化Prompt的一些方法和示例,我們怎么直接使用這個項目,輔助我們…

《安富萊嵌入式周報》第352期:手持開源終端,基于參數陣列的定向揚聲器,炫酷ASCII播放器,PCB電阻箱,支持1Ω到500KΩ,Pebble智能手表代碼重構

周報匯總地址:嵌入式周報 - uCOS & uCGUI & emWin & embOS & TouchGFX & ThreadX - 硬漢嵌入式論壇 - Powered by Discuz! 視頻版 https://www.bilibili.com/video/BV1DEf3YiEqE/ 《安富萊嵌入式周報》第352期:手持開源終端&#x…

python 淺拷貝copy與深拷貝deepcopy 理解

一 淺拷貝與深拷貝 1. 淺拷貝 淺拷貝只復制了對象本身(即c中的引用)。 2. 深拷貝 深拷貝創建一個新的對象,同時也會創建所有子對象的副本,因此新對象與原對象之間完全獨立。 二 代碼理解 1. 案例一 a 10 b a b 20 print…

day22 學習筆記

文章目錄 前言一、遍歷1.行遍歷2.列遍歷3.直接遍歷 二、排序三、去重四、分組 前言 通過今天的學習,我掌握了對Pandas的數據類型進行基本操作,包括遍歷,去重,排序,分組 一、遍歷 1.行遍歷 intertuples方法用于遍歷D…