從代碼學習深度學習 - 學習率調度器 PyTorch 版

文章目錄

  • 前言
  • 一、理論背景
  • 二、代碼解析
    • 2.1. 基本問題和環境設置
    • 2.2. 訓練函數
    • 2.3. 無學習率調度器實驗
    • 2.4. SquareRootScheduler 實驗
    • 2.5. FactorScheduler 實驗
    • 2.6. MultiFactorScheduler 實驗
    • 2.7. CosineScheduler 實驗
    • 2.8. 帶預熱的 CosineScheduler 實驗
  • 三、結果對比與分析
  • 總結


前言

學習率是深度學習優化中的關鍵超參數,決定了模型參數更新的步長。固定學習率可能導致訓練初期收斂過慢或后期在次優解附近震蕩。學習率調度器(Learning Rate Scheduler)通過動態調整學習率,幫助模型在不同訓練階段高效優化,平衡快速收斂與精細調整的需求。本文基于 PyTorch,在 Fashion-MNIST 數據集上使用 LeNet 模型,展示五種學習率調度策略:無調度器、SquareRootScheduler、FactorScheduler、MultiFactorScheduler 和 CosineScheduler(包括帶預熱的版本)。通過代碼實現、實驗結果和可視化,我們將深入探討每種調度器的理論基礎和實際效果,幫助讀者從代碼角度理解學習率調度器的核心作用。
值得注意的是,本文展示的代碼不完整,僅展示了與學習率調度器相關的部分,完整代碼包含了可視化、數據加載和訓練輔助函數,完整代碼可以通過下方鏈接下載。
完整代碼:下載鏈接


一、理論背景

學習率調度器的設計需要考慮以下幾個關鍵因素:

  1. 學習率大小:過大的學習率可能導致優化發散,過小則使訓練緩慢或陷入次優解。問題條件數(最不敏感與最敏感方向變化的比率)影響學習率的選擇。
  2. 衰減速率:學習率需要逐步降低以避免在最小值附近震蕩,但衰減不能過快(如 ( O(t^{-1/2}) ) 是凸問題優化的一個合理選擇)。
  3. 預熱(Warmup):在訓練初期,隨機初始化的參數可能導致不穩定的更新方向。通過逐漸增加學習率(預熱),可以穩定初期優化。
  4. 周期性調整:某些調度器(如余弦調度器)通過周期性調整學習率,探索更優的解空間。

本文將通過實驗驗證這些因素如何影響模型性能。

二、代碼解析

以下是完整的 PyTorch 實現,包含模型定義、訓練函數和五種調度器實驗。

2.1. 基本問題和環境設置

我們使用 LeNet 模型在 Fashion-MNIST 數據集上進行分類,設置損失函數、設備和數據加載器。

%matplotlib inline
import math
import torch
from torch import nn
from torch.optim import lr_scheduler
import utils_for_train
import utils_for_data
import utils_for_huitudef net_fn():"""定義LeNet神經網絡模型"""model = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),  # 輸出: [batch_size, 6, 28, 28]nn.MaxPool2d(kernel_size=2, stride=2),  # 輸出: [batch_size, 6, 14, 14]nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),  # 輸出: [batch_size, 16, 10, 10]nn.MaxPool2d(kernel_size=2, stride=2),  # 輸出: [batch_size, 16, 5, 5]nn.Flatten(),  # 輸出: [batch_size, 16*5*5]nn.Linear(16 * 5 * 5, 120), nn.ReLU(),  # 輸出: [batch_size, 120]nn.Linear(120, 84), nn.ReLU(),  # 輸出: [batch_size, 84]nn.Linear(84, 10)  # 輸出: [batch_size, 10])return model# 定義損失函數
loss = nn.CrossEntropyLoss()# 選擇計算設備
device = utils_for_train.try_gpu()# 設置批量大小和訓練輪數
batch_size = 256
num_epochs = 30# 加載Fashion-MNIST數據集
train_iter, test_iter = utils_for_data.load_data_fashion_mnist(batch_size=batch_size)

解析

  • LeNet 模型:適用于 Fashion-MNIST 的 28x28 灰度圖像分類,包含兩層卷積+池化和三層全連接層。
  • 損失函數:交叉熵損失,適合多分類任務。
  • 數據加載:批量大小為 256,輸入維度為 [batch_size, 1, 28, 28],標簽維度為 [batch_size]

2.2. 訓練函數

訓練函數支持多種學習率調度器,負責模型訓練、評估和可視化。

def train(net, train_iter, test_iter, num_epochs, loss, trainer, device, scheduler=None):"""訓練模型函數參數:net: 神經網絡模型train_iter: 訓練數據迭代器, 維度: [batch_size, 1, 28, 28], [batch_size]test_iter: 測試數據迭代器, 維度: [batch_size, 1, 28, 28], [batch_size]num_epochs: 訓練輪數, 標量loss: 損失函數trainer: 優化器device: 計算設備(GPU/CPU)scheduler: 學習率調度器, 默認為None"""net.to(device)animator = utils_for_huitu.Animator(xlabel='epoch', xlim=[0, num_epochs],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):metric = utils_for_train.Accumulator(3)  # [總損失, 準確預測數, 樣本總數]for i, (X, y) in enumerate(train_iter):net.train()trainer.zero_grad()X, y = X.to(device), y.

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/77654.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/77654.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/77654.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

k8s 基礎入門篇之開啟 firewalld

前面在部署k8s時,都是直接關閉的防火墻。由于生產環境需要開啟防火墻,只能放行一些特定的端口, 簡單記錄一下過程。 1. firewall 與 iptables 的關系 1.1 防火墻(Firewall) 定義: 防火墻是網絡安全系統&…

RSS 2025|蘇黎世提出「LLM-MPC混合架構」增強自動駕駛,推理速度提升10.5倍!

論文題目:Enhancing Autonomous Driving Systems with On-Board Deployed Large Language Models 論文作者:Nicolas Baumann,Cheng Hu,Paviththiren Sivasothilingam,Haotong Qin,Lei Xie,Miche…

list的學習

list的介紹 list文檔的介紹 list是可以在常數范圍內在任意位置進行插入和刪除的序列式容器,并且該容器可以前后雙向迭代。list的底層是雙向鏈表結構,雙向鏈表中每個元素存儲在互不相關的獨立節點中,在節點中通過指針指向其前一個元素和后一…

生物信息學技能樹(Bioinformatics)與學習路徑

李升偉 整理 生物信息學是一門跨學科領域,涉及生物學、計算機科學以及統計學等多個方面。以下是關于生物信息學的學習路徑及相關技能的詳細介紹。 一、基礎理論知識 1. 生物學基礎知識 需要掌握分子生物學、遺傳學、細胞生物學等相關概念。 對基因組結構、蛋白質…

AOSP Android14 Launcher3——遠程窗口動畫關鍵類SurfaceControl詳解

在 Launcher3 執行涉及其他應用窗口(即“遠程窗口”)的動畫時,例如“點擊桌面圖標啟動應用”或“從應用上滑回到桌面”的過渡動畫,SurfaceControl 扮演著至關重要的角色。它是實現這些跨進程、高性能、精確定制動畫的核心技術。 …

超詳細實現單鏈表的基礎增刪改查——基于C語言實現

文章目錄 1、鏈表的概念與分類1.1 鏈表的概念1.2 鏈表的分類 2、單鏈表的結構和定義2.1 單鏈表的結構2.2 單鏈表的定義 3、單鏈表的實現3.1 創建新節點3.2 頭插和尾插的實現3.3 頭刪和尾刪的實現3.4 鏈表的查找3.5 指定位置之前和之后插入數據3.6 刪除指定位置的數據和刪除指定…

17.整體代碼講解

從入門AI到手寫Transformer-17.整體代碼講解 17.整體代碼講解代碼 整理自視頻 老袁不說話 。 17.整體代碼講解 代碼 import collectionsimport math import torch from torch import nn import os import time import numpy as np from matplotlib import pyplot as plt fro…

前端性能優化:所有權轉移

前端性能優化:所有權轉移 在學習rust過程中,學到了所有權概念,于是便聯想到了前端,前端是否有相關內容,于是進行了一些實驗,并整理了這些內容。 所有權轉移(Transfer of Ownership)…

Missashe考研日記-day23

Missashe考研日記-day23 0 寫在前面 博主前幾天有事回家去了,斷更幾天了不好意思,就當回家休息一下調整一下狀態了,今天接著開始更新。雖然每天的博客寫的內容不算多,但其實還是挺費時間的,比如這篇就花了我40多分鐘…

Docker 中將文件映射到 Linux 宿主機

在 Docker 中,有多種方式可以將文件映射到 Linux 宿主機,以下是常見的幾種方法: 使用-v參數? 基本語法:docker run -v [宿主機文件路徑]:[容器內文件路徑] 容器名稱? 示例:docker run -it -v /home/user/myfile.txt:…

HarmonyOS-ArkUI-動畫分類簡介

本文的目的是,了解一下HarmonyOS動畫體系中的分類。有個大致的了解即可。 動效與動畫簡介 動畫,是客戶端提升界面交互用戶體驗的一個重要的方式。可以使應用程序更加生動靈越,提高用戶體驗。 HarmonyOS對于界面的交互方面,圍繞回歸本源的設計理念,打造自然,流暢品質一提…

C++如何處理多線程環境下的異常?如何確保資源在異常情況下也能正確釋放

多線程編程的基本概念與挑戰 多線程編程的核心思想是將程序的執行劃分為多個并行運行的線程,每個線程可以獨立處理任務,從而充分利用多核處理器的性能優勢。在C中,開發者可以通過std::thread創建線程,并使用同步原語如std::mutex、…

區間選點詳解

步驟 operator< 的作用在 C 中&#xff0c; operator< 是一個運算符重載函數&#xff0c;它定義了如何比較兩個對象的大小。在 std::sort 函數中&#xff0c;它會用到這個比較函數來決定排序的順序。 在 sort 中&#xff0c;默認會使用 < 運算符來比較兩個對象…

前端配置代理解決發送cookie問題

場景&#xff1a; 在開發任務管理系統時&#xff0c;我遇到了一個典型的身份認證問題&#xff1a;??用戶登錄成功后&#xff0c;調獲取當前用戶信息接口卻提示"用戶未登錄"??。系統核心流程如下&#xff1a; ??用戶登錄??&#xff1a;調用 /login 接口&…

8.1 線性變換的思想

一、線性變換的概念 當一個矩陣 A A A 乘一個向量 v \boldsymbol v v 時&#xff0c;它將 v \boldsymbol v v “變換” 成另一個向量 A v A\boldsymbol v Av. 輸入 v \boldsymbol v v&#xff0c;輸出 T ( v ) A v T(\boldsymbol v)A\boldsymbol v T(v)Av. 變換 T T T…

【java實現+4種變體完整例子】排序算法中【冒泡排序】的詳細解析,包含基礎實現、常見變體的完整代碼示例,以及各變體的對比表格

以下是冒泡排序的詳細解析&#xff0c;包含基礎實現、常見變體的完整代碼示例&#xff0c;以及各變體的對比表格&#xff1a; 一、冒泡排序基礎實現 原理 通過重復遍歷數組&#xff0c;比較相鄰元素并交換逆序對&#xff0c;逐步將最大值“冒泡”到數組末尾。 代碼示例 pu…

系統架構設計(二):基于架構的軟件設計方法ABSD

“基于架構的軟件設計方法”&#xff08;Architecture-Based Software Design, ABSD&#xff09;是一種通過從軟件架構層面出發指導詳細設計的系統化方法。它旨在橋接架構設計與詳細設計之間的鴻溝&#xff0c;確保系統的高層結構能夠有效指導后續開發。 ABSD 的核心思想 ABS…

Office文件內容提取 | 獲取Word文件內容 |Javascript提取PDF文字內容 |PPT文檔文字內容提取

關于Office系列文件文字內容的提取 本文主要通過接口的方式獲取Office文件和PDF、OFD文件的文字內容。適用于需要獲取Word、OFD、PDF、PPT等文件內容的提取實現。例如在線文字統計以及論文文字內容的提取。 一、提取Word及WPS文檔的文字內容。 支持以下文件格式&#xff1a; …

Cesium學習筆記——dem/tif地形的分塊與加載

前言 在Cesium的學習中&#xff0c;學會讀文檔十分重要&#xff01;&#xff01;&#xff01;在這里附上Cesium中英文文檔1.117。 在Cesium項目中&#xff0c;在平坦坦地球中加入三維地形不僅可以增強真實感與可視化效果&#xff0c;還可以??提升用戶體驗與交互性&#xff0c…

Spring Boot 斷點續傳實戰:大文件上傳不再怕網絡中斷

精心整理了最新的面試資料和簡歷模板&#xff0c;有需要的可以自行獲取 點擊前往百度網盤獲取 點擊前往夸克網盤獲取 一、痛點與挑戰 在網絡傳輸大文件&#xff08;如視頻、數據集、設計稿&#xff09;時&#xff0c;常面臨&#xff1a; 上傳中途網絡中斷需重新開始服務器內…