PyTorch優化器

????????PyTorch 提供了多種優化算法用于神經網絡的參數優化。以下是對 PyTorch 中主要優化器的全面介紹,包括它們的原理、使用方法和適用場景。

一、基本優化器

1. SGD (隨機梯度下降)

torch.optim.SGD(params, lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False)
  • 特點:

    • 最基本的優化器

    • 可以添加動量(momentum)加速收斂

    • 支持Nesterov動量

  • 參數:

    • lr: 學習率(必需)

    • momentum: 動量因子(0-1)

    • weight_decay: L2正則化系數

  • 適用場景: 大多數基礎任務

    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

2. Adam (自適應矩估計)

torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
  • 特點:

    • 自適應學習率

    • 結合了動量法和RMSProp的優點

    • 通常需要較少調參

  • 參數:

    • betas: 用于計算梯度及其平方的移動平均系數

    • eps: 數值穩定項

    • amsgrad: 是否使用AMSGrad變體

  • 適用場景: 深度學習默認選擇

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

二、自適應優化器

1. Adagrad

torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
  • 特點:

    • 自適應學習率

    • 為每個參數保留學習率

    • 適合稀疏數據

  • 缺點: 學習率會單調遞減

2. RMSprop

torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
  • 特點:

    • 解決Adagrad學習率急劇下降問題

    • 適合非平穩目標

    • 常用于RNN

?3. Adadelta

torch.optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
  • 特點:

    • 不需要設置初始學習率

    • 是Adagrad的擴展

三、其他優化器?

1. AdamW

torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
  • 特點:

    • Adam的改進版

    • 更正確的權重衰減實現

    • 通常優于Adam

2. SparseAdam

torch.optim.SparseAdam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08)
  • 特點: 專為稀疏張量優化

3. LBFGS?

torch.optim.LBFGS(params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100)
  • 特點:

    • 準牛頓方法

    • 內存消耗大

    • 適合小批量數據

四、優化器選擇指南

優化器適用場景優點缺點
SGD基礎任務簡單可控需要手動調整學習率
SGD+momentum大多數任務加速收斂需要調參
Adam深度學習默認自適應學習率可能不如SGD泛化好
AdamW帶權重衰減的任務更正確的實現-
Adagrad稀疏數據自動調整學習率學習率單調減
RMSpropRNN/非平穩目標解決Adagrad問題-

五、學習率調度器

PyTorch還提供了學習率調度器,可與優化器配合使用:

# 創建優化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 創建調度器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)# 訓練循環中
for epoch in range(100):train(...)validate(...)scheduler.step()  # 更新學習率

常用調度器:

  • LambdaLR: 自定義函數調整

  • MultiplicativeLR: 乘法更新

  • StepLR: 固定步長衰減

  • MultiStepLR: 多步長衰減

  • ExponentialLR: 指數衰減

  • CosineAnnealingLR: 余弦退火

  • ReduceLROnPlateau: 根據指標動態調整

六、優化器使用技巧

  1. 參數分組: 不同層使用不同學習率

    optimizer = torch.optim.SGD([{'params': model.base.parameters(), 'lr': 0.001},{'params': model.classifier.parameters(), 'lr': 0.01}
    ], momentum=0.9)
  2. 梯度裁剪: 防止梯度爆炸

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  3. 零梯度: 每次迭代前清空梯度

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/75445.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/75445.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/75445.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

C++的UDP連接解析域名地址錯誤

背景 使用c開發一個udp連接功能的腳本,可以接收發送數據,而且地址是經過內網穿透到外網的 經過 通常發送數據給目標地址,需要把目的地址結構化,要么使用inet_addr解析ip地址,要么使用inet_pton sockaddr_in target…

Spark,上傳文件

上傳文件 1.上傳 先使用命令打開HDFS的NameNode [roothadoop100 hadoop-3.1.3]$ sbin/start-dfs.sh [roothadoop100 hadoop-3.1.3]$ sbin/stop-dfs.sh 和YARN的Job [roothadoop101 hadoop-3.1.3]$ sbin/start-yarn.sh [roothadoop101 hadoop-3.1.3]$ sbin/stop-yarn.sh 在Nam…

如何為Linux/Android Kernel 5.4和5.15添加 fuse passthrough透傳功能 ?

背景 參考:Google文檔 FUSE 透傳 參考此文檔,目前kernel.org提供的fuse passthrough補丁在6.9版本之后,但想要在5.4和5.15版本內核做移植應該如何簡單點呢?文檔中提到 Android的內核為5.4 和 5.15版本內核做了fuse passthrough功…

Ubuntu 防火墻配置

Ubuntu 的防火墻配置可以參考文章:Firewall - Ubuntu Server documentation 22 端口 需要注意的是,在啟動防火墻之前,需要先開放 22 端口。 否則 SSH 將會拒絕你連接防火墻。 開放 22 端口的命令為:sudo ufw allow 22 添加端…

Jetson 設備卸載 OpenCV 4.5.4 并編譯安裝 OpenCV 4.2.0

?一、卸載 OpenCV 4.5.4? 清除已安裝的 OpenCV 庫? sudo apt-get purge libopencv* python3-opencv # 卸載所有APT安裝的OpenCV包?:ml-citation{ref"1,3" data"citationList"}sudo apt autoremove # 清理殘留依賴?:ml-citation{ref"1,4"…

《AI大模型應知應會100篇》第57篇:LlamaIndex使用指南:構建高效知識庫

第57篇:LlamaIndex使用指南:構建高效知識庫 摘要 在大語言模型(LLM)驅動的智能應用中,如何高效地管理和利用海量知識數據是開發者面臨的核心挑戰之一。LlamaIndex(原 GPT Index) 是一個專為構建…

Sentinel[超詳細講解]-4

🚓 主要講解流控模式的 三種方式中的兩種: 直接、鏈路🚀 1?? 直接模式 🚎 直接模式:對資源本身進行限流,例如對某個接口進行限流,當該接口的訪問頻率超過設定的閾值時,直接拒絕新的…

工作記錄 2017-03-24

工作記錄 2017-03-24 序號 工作 相關人員 1 修改了郵件上的問題。 更新RD服務器。 郝 更新的問題 1、修改了New User時 init的保存。 2、文件的查詢加了ID。 3、加了 patient insurance secondary 4、修改了payment detail的處理。 識別引擎監控 Ps (iCDA LOG :剔除…

裴蜀定理:整數解的奧秘

裴蜀定理:整數解的奧秘 在數學的世界里,裴蜀定理(Bzout’s Theorem)是數論中一個非常重要的定理,它揭示了二次方程和整數解之間的關系。它不僅僅是純粹的理論知識,還在計算機科學、密碼學、算法優化等多個…

python之 “__init__.py” 文件

提示:python之 “init.py” 文件 文章目錄 前言一、Python 中 __init__.py 文件的理解1. What(是什么)2. Why(為什么需要)3. Where(在哪里使用)4. How(如何使用) 二、問題…

Gemini 2.5 Pro與Claude 3.7 Sonnet編程性能對比

AI領域的語言模型競賽日趨白熱化,尤其在編程輔助方面表現突出。 Gemini 2.5 Pro和Claude 3.7 Sonnet作為該領域的佼佼者,本文通過一系列編程測試與基準評估對兩者的編碼功能進行對比分析。 核心結論: ? Gemini 2.5 Pro在SWE Bench硬核編程測試中以63.8%的通過率略勝Clau…

On Superresolution Effects in Maximum Likelihood Adaptive Antenna Arrays論文閱讀

On Superresolution Effects in Maximum Likelihood Adaptive Antenna Arrays 1. 論文的研究目標與實際問題意義1.1 研究目標1.2 解決的實際問題1.3 實際意義2. 論文提出的新方法、模型與公式2.1 核心創新:標量化近似表達式關鍵推導步驟:公式優勢:2.2 與經典方法的對比傳統方…

GIT 撤銷上次推送

注意:在執行下述操作之前先備份現有工作進度,如果不慎未保存,在代碼編輯器中正在修改的文件下,使用CtrlZ 撤銷試試 撤銷推送的方法 情況 1:您剛剛推送到遠程倉庫 如果您的推送操作剛剛完成,并且沒有其他…

透視飛鶴2024財報:如何打贏奶粉罐里的科技戰?

去年乳制品行業壓力還是不小的,尼爾森IQ指出2024年國內乳品市場仍處在收縮區間。但是,總有龍頭能抗住壓力,飛鶴最近交出的2024財報中就有很多亮點。 比如,2024年飛鶴營收207.5億元、同比增長6%,凈利潤36.5億元&#x…

解決STM32CubeMX中文注釋亂碼

本人采用【修改系統環境變量】的方法 1. 使用快捷鍵 win X,打開【系統R】,點擊【高級系統設置】 2. 點擊【環境變量】 3. 點擊【新建】 4.按圖中輸入【JAVA_TOOL_OPTIONS】和【-Dfile.encodingUTF-8】,新建環境變量后重啟CubeMX即可。 解釋…

使用typescript實現游戲中的JPS跳點尋路算法

JPS是一種優化A*算法的路徑規劃算法,主要用于網格地圖,通過跳過不必要的節點來提高搜索效率。它利用路徑的對稱性,只擴展特定的“跳點”,從而減少計算量。 deepseek生成的總是無法完整運行,因此決定手寫一下。 需要注…

Jetpack Compose 狀態管理指南:從基礎到高級實踐

在Jetpack Compose中,界面狀態管理是構建響應式UI的核心。以下是Compose狀態管理的主要概念和實現方式: 基本狀態管理 1. 使用 mutableStateOf Composable fun Counter() {var count by remember { mutableStateOf(0) }Button(onClick { count }) {T…

vant4+vue3上傳一個pdf文件并實現pdf的預覽。使用插件pdf.js

注意下載的插件的版本"pdfjs-dist": "^2.2.228", npm i pdfjs-dist2.2.228 然后封裝一個pdf的遮罩。因為pdf文件有多頁,所以我用了swiper輪播的形式展示。因為用到移動端,手動滑動頁面這樣比點下一頁下一頁的方便多了。 直接貼代碼…

Leetcode hot 100(day 4)

翻轉二叉樹 做法:遞歸即可,注意判斷為空 class Solution { public:TreeNode* invertTree(TreeNode* root) {if(rootnullptr)return nullptr;TreeNode* noderoot->left;root->leftinvertTree(root->right);root->rightinvertTree(node);retu…

C,C++語言緩沖區溢出的產生和預防

緩沖區溢出的定義 緩沖區是內存中用于存儲數據的一塊連續區域,在 C 和 C 里,常使用數組、指針等方式來操作緩沖區。而緩沖區溢出指的是當程序向緩沖區寫入的數據量超出了該緩沖區本身能夠容納的最大數據量時,額外的數據就會覆蓋相鄰的內存區…