深度學習 - 梯度下降優化方法

梯度下降的基本概念

梯度下降(Gradient Descent)是一種用于優化機器學習模型參數的算法,其目的是最小化損失函數,從而提高模型的預測精度。梯度下降的核心思想是通過迭代地調整參數,沿著損失函數下降的方向前進,最終找到最優解。

生活中的背景例子:尋找山谷的最低點

想象你站在一個山谷中,眼睛被蒙住,只能用腳感受地面的坡度來找到山谷的最低點(即損失函數的最小值)。你每一步都想朝著坡度下降最快的方向走,直到你感覺不到坡度,也就是你到了最低點。這就好比在優化一個模型時,通過不斷調整參數,使得模型的預測誤差(損失函數)越來越小,最終找到最佳參數組合。

梯度下降的具體方法及其優化

1. 批量梯度下降(Batch Gradient Descent)

生活中的例子
你決定每次移動之前,都要先測量整個山谷的坡度,然后再決定移動的方向和步幅。雖然每一步的方向和步幅都很準確,但每次都要花很多時間來測量整個山谷的坡度。

公式
θ : = θ ? η ? ? θ J ( θ ) \theta := \theta - \eta \cdot \nabla_{\theta} J(\theta) θ:=θ?η??θ?J(θ)
其中:

  • θ \theta θ是模型參數
  • η \eta η是學習率
  • ? θ J ( θ ) \nabla_{\theta} J(\theta) ?θ?J(θ)是損失函數 J ( θ ) J(\theta) J(θ)關于 θ \theta θ的梯度

API
TensorFlow

optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

PyTorch

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
2. 隨機梯度下降(Stochastic Gradient Descent, SGD)

生活中的例子
你決定每一步都只根據當前所在位置的坡度來移動。雖然這樣可以快速決定下一步怎么走,但由于只考慮當前點,可能會導致路徑不穩定,有時候會走過頭。

公式
θ : = θ ? η ? ? θ J ( θ ; x ( i ) , y ( i ) ) \theta := \theta - \eta \cdot \nabla_{\theta} J(\theta; x^{(i)}, y^{(i)}) θ:=θ?η??θ?J(θ;x(i),y(i))
其中 ( x ( i ) , y ( i ) ) (x^{(i)}, y^{(i)}) (x(i),y(i))是當前樣本的數據

API
TensorFlowPyTorch 中的API與批量梯度下降相同,具體行為取決于數據的加載方式。例如在訓練時可以一批數據包含一個樣本。

3. 小批量梯度下降(Mini-Batch Gradient Descent)

生活中的例子
你決定每次移動之前,只測量周圍一小部分區域的坡度,然后根據這小部分區域的平均坡度來決定方向和步幅。這樣既不需要花太多時間測量整個山谷,也不會因為只看一個點而導致路徑不穩定。

公式
θ : = θ ? η ? ? θ J ( θ ; B ) \theta := \theta - \eta \cdot \nabla_{\theta} J(\theta; \mathcal{B}) θ:=θ?η??θ?J(θ;B)
其中 B \mathcal{B} B是當前小批量的數據

API
TensorFlowPyTorch 中的API與批量梯度下降相同,但在數據加載時使用小批量。

4. 動量法(Momentum)

生活中的例子
你在移動時,不僅考慮當前的坡度,還考慮之前幾步的移動方向,就像帶著慣性一樣。如果前幾步一直往一個方向走,那么你會傾向于繼續往這個方向走,減少來回震蕩。

公式
v : = β v + ( 1 ? β ) ? θ J ( θ ) v := \beta v + (1 - \beta) \nabla_{\theta} J(\theta) v:=βv+(1?β)?θ?J(θ)
θ : = θ ? η v \theta := \theta - \eta v θ:=θ?ηv
其中:

  • v v v是動量項
  • β \beta β是動量系數(通常接近1,如0.9)

API
TensorFlow

optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)

PyTorch

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
5. RMSProp

生活中的例子
你在移動時,會根據最近一段時間內每一步的坡度情況,動態調整步幅。比如,當坡度變化劇烈時,你會邁小步,當坡度變化平緩時,你會邁大步。

公式
s : = β s + ( 1 ? β ) ( ? θ J ( θ ) ) 2 s := \beta s + (1 - \beta) (\nabla_{\theta} J(\theta))^2 s:=βs+(1?β)(?θ?J(θ))2
θ : = θ ? η s + ? ? θ J ( θ ) \theta := \theta - \frac{\eta}{\sqrt{s + \epsilon}} \nabla_{\theta} J(\theta) θ:=θ?s+? ?η??θ?J(θ)
其中:

  • s s s是梯度平方的加權平均值
  • ? \epsilon ?是一個小常數,防止除零錯誤

API
TensorFlow

optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)

PyTorch

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
6. Adam(Adaptive Moment Estimation)

生活中的例子
你在移動時,結合動量法和RMSProp的優點,不僅考慮之前的移動方向(動量),還根據最近一段時間內的坡度變化情況(調整步幅),從而使移動更加平穩和高效。

公式
m : = β 1 m + ( 1 ? β 1 ) ? θ J ( θ ) m := \beta_1 m + (1 - \beta_1) \nabla_{\theta} J(\theta) m:=β1?m+(1?β1?)?θ?J(θ)
v : = β 2 v + ( 1 ? β 2 ) ( ? θ J ( θ ) ) 2 v := \beta_2 v + (1 - \beta_2) (\nabla_{\theta} J(\theta))^2 v:=β2?v+(1?β2?)(?θ?J(θ))2
m ^ : = m 1 ? β 1 t \hat{m} := \frac{m}{1 - \beta_1^t} m^:=1?β1t?m?
v ^ : = v 1 ? β 2 t \hat{v} := \frac{v}{1 - \beta_2^t} v^:=1?β2t?v?
θ : = θ ? η m ^ v ^ + ? \theta := \theta - \eta \frac{\hat{m}}{\sqrt{\hat{v}} + \epsilon} θ:=θ?ηv^ ?+?m^?
其中:

  • m m m v v v分別是梯度的一階和二階動量
  • β 1 \beta_1 β1? β 2 \beta_2 β2?是動量系數(通常分別取0.9和0.999)
  • m ^ \hat{m} m^ v ^ \hat{v} v^是偏差校正后的動量項
  • t t t是時間步

API
TensorFlow

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

PyTorch

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

綜合應用示例

假設我們在使用TensorFlow和PyTorch訓練一個簡單的神經網絡,以下是如何應用這些優化方法的示例代碼。

TensorFlow 示例

import tensorflow as tf# 定義模型
model = tf.keras.Sequential([tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),tf.keras.layers.Dense(10, activation='softmax')
])# 編譯模型并選擇優化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 準備數據
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# 訓練模型
model.fit(x_train, y_train, epochs=10, batch_size=32)

PyTorch 示例

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 定義模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return xmodel = SimpleNN()# 選擇優化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()# 準備數據
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 訓練模型
for epoch in range(10):for batch in train_loader:x_train, y_train = batchx_train = x_train.view(x_train.size(0), -1)  # Flatten the imagesoptimizer.zero_grad()outputs = model(x_train)loss = criterion(outputs, y_train)loss.backward()optimizer.step()

更多問題咨詢

CosAI

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:
http://www.pswp.cn/web/24360.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/24360.shtml
英文地址,請注明出處:http://en.pswp.cn/web/24360.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

人體感應提醒 大聲公+微波模塊

文章目錄 模塊簡介接線程序示例 模塊簡介 微波感應開關模塊 RCWL-0516是一款采用多普勒雷達技術,專門檢測物體移動的微波感應模塊。采用 2.7G 微波信號檢測,該模塊具有靈敏度高,感應距離遠,可靠性強,感應角度大&#…

Ruoyi-Vue-Plus 下載啟動后菜單無法點擊展開,

1.Ruoyi-Vue-Plus框架下載后運行 2.使用mock數據 3.進入頁面后無法點擊菜單 本以為是動態路由或者菜單邏輯出了問題,最后發現是websocket的問題 解決辦法 把這兩行代碼注釋 頁面菜單即可點擊。 以上。

【ROS使用記錄】—— ros使用過程中的rosbag錄制播放和ros話題信息相關的指令與操作記錄

提示:文章寫完后,目錄可以自動生成,如何生成可參考右邊的幫助文檔 文章目錄 前言一、rosbag的介紹二、rosbag的在線和離線錄制三、rosbag的播放相關的指令四、其他rosbag和ros話題相關的指令總結 前言 rosbag是ROS(機器人操作系統…

Suse Linux ssh配置免密后仍需要輸入密碼

【問題描述】 Suse Linux已經配置了ssh免密,但無法ssh到目標服務器。 對自身的ssh登陸也需要輸入密碼。 系統–Suse 15 SP5 【重現步驟】 1.使用ssh-keygen -t rsa生產key文件 2.使用ssh-copy-id拷貝public key到目標機器(或者自身) 3.配置成功后ssh 目標時仍需要輸…

電商API在維護數據安全與合規性中的重要性

摘要 在數字化時代,數據安全和合規性是電商企業不可忽視的重大議題。本文將探討電商API如何在保護敏感數據、遵守法律法規和防范網絡威脅方面發揮關鍵作用。 引言 隨著大量敏感數據的電子化處理和存儲,電商企業面臨的安全挑戰日益嚴峻。API接口技術成為…

手機模擬操作進階:1.某團獲取附近商店情況

0.以超市便利為例分析: 超市便利的xp (//android.widget.ImageView[@resource-id="com.sankuai.meituan:id/channel_icon"])[5] 附近的xp //android.widget.TextView[@text="全部200+店"] 商家信息列表區: //android.support.v7.widget.RecyclerView[@…

《青少年編程與數學》課程方案:2、課程內容 4_4

《青少年編程與數學》課程方案:2、課程內容 4_4 十四、數學(三)高中數學(四)微機分(五)線性代數(六)概率論與數理統計(七)離散數學(八…

娛閑放鬆篇1

最近在B站看了挺多的動漫,挺小說化的,我這個人比較哲學,故和大家分享一下 B站娛閑 1.蘇老大的動漫 1.<<人類清除計劃>> 本來看的過癮,但沒想到,連小說也停更了..... 2.黑山羊遊戲 挺劇本的 3.顧毅 一個小說的主人公,第一個能力是無限推演... 崇山醫…

[C#]使用OpenCvSharp圖像濾波中值濾波均值濾波高通濾波雙邊濾波銳化濾波自定義濾波

在使用OpenCvSharp進行圖像濾波處理時&#xff0c;各種濾波方法都有其特定的用途和效果。以下是對中值濾波、均值濾波、高通濾波、雙邊濾波、銳化濾波和自定義濾波的詳細解釋和歸納&#xff1a; 中值濾波&#xff08;MedianBlur&#xff09; 原理與作用&#xff1a;中值濾波是…

Stable diffusion采樣器詳解

在我們使用SD web UI的過程中&#xff0c;有很多采樣器可以選擇&#xff0c;那么什么是采樣器&#xff1f;它們是如何工作的&#xff1f;它們之間有什么區別&#xff1f;你應該使用哪一個&#xff1f;這篇文章將會給你想要的答案。 什么是采樣&#xff1f; Stable Diffusion模…

UI學習--導航控制器

導航控制器 導航控制器基礎基本概念具體使用 導航控制器切換演示具體使用注意 導航欄與工具欄基本概念具體使用&#xff1a; 總結 導航控制器基礎 基本概念 根視圖控制器&#xff08;Root View Controller&#xff09;&#xff1a;導航控制器的第一個視圖控制器&#xff0c;通…

壓縮大文件消耗電腦CPU資源達到33%以上

今天用7-Zip壓縮一個大文件&#xff0c;文件大小是9G多&#xff0c;這時能聽到電腦風扇聲音&#xff0c;查看了一下電腦資源使用情況&#xff0c;確實增加了不少。 下面是兩張圖片&#xff0c;圖片上有電腦資源使用數據。

Spring系統學習 -Spring IOC 的XML管理Bean之bean的獲取、依賴注入值的方式

在Spring框架中&#xff0c;XML配置是最傳統和最常見的方式之一&#xff0c;用于管理Bean的創建、依賴注入和生命周期等。這個在Spring中我們使用算是常用的&#xff0c;我們需要根據Spring的基于XML管理Bean了解相關Spring中常用的獲取bean的方式、依賴注入值的幾種方式等等。…

c++ namespace以及使用建議

命名空間就是用來區分你使用的這個變量和函數是屬于那一塊的。用來防止不同的人所寫函數和變量&#xff0c;名字相同產生沖突。 在寫c代碼的時候&#xff0c;經常會使用標準庫中的函數&#xff0c;使用之前我們必須在前面添加一個std::&#xff0c;因為c標準庫的函數是在命名空…

關閉Cloudflare Pages的訪問策略

curl API 獲取相應的 uid curl -X GET "https://api.cloudflare.com/client/v4/accounts/賬戶標識符/access/apps" \-H "X-Auth-Email: 郵箱" \-H "X-Auth-Key: Global API KEY" \-H "Content-Type: application/json"賬戶標識符是登…

Dubbo面試題甄選及參考答案

目錄 Dubbo是什么? Dubbo的主要使用場景有哪些? Dubbo的核心功能有哪些? Dubbo與Spring框架的集成方式是什么? Dubbo的RPC調用原理是什么? Dubbo的架構中包含哪些核心組件? Provider、Consumer、Registry、Monitor在Dubbo中分別承擔什么角色? Container在Dubbo中…

Maven項目打包成jar項目后運行報錯誤: 找不到或無法加載主類 Main.Main 和 jar中沒有主清單屬性解決方案

已經用maven工程的package功能進行了打包 找不到或無法加載主類 Main.Main 規定主類 主要在maven的配置文件當中 這邊一定要綁定自己的啟動類 jar中沒有主清單屬性 刪掉這一行就行哈 正確的插件代碼 <plugin><groupId>org.springframework.boot</groupId&…

毫米波SDK使用1

本文檔是AM273x等毫米波雷達處理器SDK的配置和使用&#xff0c;主要參考TI的官方文檔《mmwave mcuplus sdk user guide》。這里僅摘取其中重要的部分&#xff0c;其余枝節可參考原文。 2 系統概覽 mmWave SDK分為兩個主要組件:mmWave套件和mmWave演示。 2.1. mmWave套件 mmWa…

AXI Quad SPI IP核基于AXI-Lite接口的標準SPI設計指南

在標準SPI配置下&#xff0c;SPI設備除了包含基本的SPI特性外&#xff0c;還具備以下一些標準功能&#xff0c;這些功能如下所示&#xff1a; 支持FPGA內部的多主設備配置&#xff0c;其中使用單獨的_I&#xff08;輸入&#xff09;、_O&#xff08;輸出&#xff09;、_T&…

FM148A,FM146B運行備件

FM148A,FM146B運行備件。電源保險絲倉主控底座的保險絲倉示意圖底座上共有兩個保險絲&#xff08;800mA&#xff09;&#xff0c;FM148A,FM146B運行備件。&#xff08;10&#xff5e;73&#xff09;30/195主控單元2.K-CUT014槽底座地址接口主控站地址撥開關從上到下為二進制數的…