[預備知識]6. 優化理論(二)

優化理論

本章節介紹深度學習中的高級優化技術,包括學習率衰減、梯度裁剪和批量歸一化。這些技術能夠顯著提升模型的訓練效果和穩定性。

學習率衰減(Learning Rate Decay)

數學原理與可視化

學習率衰減策略的數學表達:

  1. 步進式衰減
    α t = α 0 × γ ? t / s ? \alpha_t = \alpha_0 \times \gamma^{\lfloor t/s \rfloor} αt?=α0?×γ?t/s?
    其中 s s s為衰減周期, γ \gamma γ為衰減因子

  2. 指數衰減
    α t = α 0 × e ? γ t \alpha_t = \alpha_0 \times e^{-\gamma t} αt?=α0?×e?γt

  3. 余弦衰減
    α t = α min + 1 2 ( α 0 ? α min ) ( 1 + cos ? ( t π T ) ) \alpha_t = \alpha_{\text{min}} + \frac{1}{2}(\alpha_0 - \alpha_{\text{min}})(1 + \cos(\frac{t\pi}{T})) αt?=αmin?+21?(α0??αmin?)(1+cos(Ttπ?))

import matplotlib.pyplot as plt# 衰減策略可視化
epochs = 100
initial_lr = 0.1# 計算各策略學習率
step_lrs = [initial_lr * (0.1 ** (i//30)) for i in range(epochs)]
expo_lrs = [initial_lr * (0.95 ** i) for i in range(epochs)]
cosine_lrs = [0.01 + 0.5*(0.1-0.01)*(1 + np.cos(np.pi*i/epochs)) for i in range(epochs)]# 繪制對比圖
plt.figure(figsize=(12,6))
plt.plot(step_lrs, label='Step Decay (每30步×0.1)')
plt.plot(expo_lrs, label='Exponential Decay (γ=0.95)')
plt.plot(cosine_lrs, label='Cosine Decay (T=100)')
plt.title("不同學習率衰減策略對比", fontsize=14)
plt.xlabel("訓練周期", fontsize=12)
plt.ylabel("學習率", fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

在這里插入圖片描述

最佳實踐

# 組合使用多種調度器
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)# 前50步使用余弦衰減
scheduler1 = CosineAnnealingLR(optimizer, T_max=50)
# 之后使用步進衰減
scheduler2 = StepLR(optimizer, step_size=10, gamma=0.5)for epoch in range(100):train(...)if epoch < 50:scheduler1.step()else:scheduler2.step()

梯度裁剪(Gradient Clipping)

數學原理

梯度裁剪通過限制梯度范數防止參數更新過大:
if? ∥ g ∥ > c : g ← c ∥ g ∥ g \text{if } \|g\| > c: \quad g \gets \frac{c}{\|g\|}g if?g>c:ggc?g
其中 c c c為裁剪閾值, ∥ g ∥ \|g\| g為梯度范數

梯度動態可視化

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 定義一個簡單的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)# 初始化模型、優化器和損失函數
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()grad_norms = []
clipped_grad_norms = []for _ in range(1000):# 生成隨機輸入和目標inputs = torch.randn(32, 10)targets = torch.randn(32, 1)# 前向傳播outputs = model(inputs)loss = criterion(outputs, targets)# 反向傳播optimizer.zero_grad()loss.backward()# 記錄裁剪前梯度grad_norms.append(torch.norm(torch.cat([p.grad.view(-1) for p in model.parameters()])))# 執行裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)# 記錄裁剪后梯度clipped_grad_norms.append(torch.norm(torch.cat([p.grad.view(-1) for p in model.parameters()])))# 更新參數optimizer.step()# 繪制梯度變化
plt.figure(figsize=(12, 6))
plt.plot(grad_norms, alpha=0.6, label='Original Gradient Norm')
plt.plot(clipped_grad_norms, alpha=0.6, label='Clipped Gradient Norm')
plt.axhline(1.0, color='r', linestyle='--', label='Clipping Threshold')
plt.yscale('log')
plt.title("Gradient Clipping Effect Monitoring", fontsize=14)
plt.xlabel("Training Steps", fontsize=12)
plt.ylabel("Gradient L2 Norm (log scale)", fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

在這里插入圖片描述

實踐技巧

  1. RNN中推薦值:LSTM/GRU 中 max_norm 取 1.0 或 2.0
  2. 結合學習率:較高學習率需配合較小裁剪閾值
  3. 監控策略:定期輸出梯度統計量
print(f"梯度均值: {grad.mean().item():.3e} ± {grad.std().item():.3e}")

批量歸一化(Batch Normalization)

數學推導

對于輸入批次 B = { x 1 , . . . , x m } B = \{x_1,...,x_m\} B={x1?,...,xm?}

  1. 計算統計量:
    μ B = 1 m ∑ i = 1 m x i σ B 2 = 1 m ∑ i = 1 m ( x i ? μ B ) 2 \mu_B = \frac{1}{m}\sum_{i=1}^m x_i \\ \sigma_B^2 = \frac{1}{m}\sum_{i=1}^m (x_i - \mu_B)^2 μB?=m1?i=1m?xi?σB2?=m1?i=1m?(xi??μB?)2
  2. 標準化:
    x ^ i = x i ? μ B σ B 2 + ? \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} x^i?=σB2?+? ?xi??μB??
  3. 仿射變換:
    y i = γ x ^ i + β y_i = \gamma \hat{x}_i + \beta yi?=γx^i?+β

訓練/評估模式對比

import torch.nn as nn
import torch
# 創建BN層
bn = nn.BatchNorm1d(64)# 訓練模式
bn.train()
for _ in range(100):x = torch.randn(32, 64)  # 批大小32y = bn(x)
print("訓練模式統計:", bn.running_mean[:5].detach().numpy())  # 顯示部分通道# 評估模式
bn.eval()
with torch.no_grad():x = torch.randn(32, 64)y = bn(x)
print("評估模式統計:", bn.running_mean[:5].detach().numpy())

可視化BN效果

# 生成模擬數據
data = torch.cat([torch.normal(2.0, 1.0, (100, 1)),torch.normal(-1.0, 0.5, (100, 1))
], dim=1)# 應用BN
bn = nn.BatchNorm1d(2)
output = bn(data)# 繪制分布對比
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
sns.histplot(data[:,0], kde=True, label='原始特征1')
sns.histplot(data[:,1], kde=True, label='原始特征2')
plt.title("Distribution of features before BN")plt.subplot(1,2,2)
sns.histplot(output[:,0], kde=True, label='BN后特征1')
sns.histplot(output[:,1], kde=True, label='BN后特征2') 
plt.title("Distribution of features after BN")plt.tight_layout()

在這里插入圖片描述


技術組合應用案例

圖像分類任務

# 自定義CNN模型
class CustomCNN(nn.Module):def __init__(self):super().__init__()# 卷積層 使用BNself.conv_layers = nn.Sequential(nn.Conv2d(3, 64, 3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(64, 128, 3),nn.BatchNorm2d(128),nn.ReLU(),nn.MaxPool2d(2),)# 全連接層self.fc = nn.Linear(128*5*5, 10)def forward(self, x):x = self.conv_layers(x)return self.fc(x.view(x.size(0), -1))# 初始化模型、優化器和調度器
model = CustomCNN()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=200)# 帶梯度裁剪的訓練循環
max_grad_norm = 5.0  # 裁剪閾值
for epoch in range(200):model.train()  # 模型進入訓練模式for inputs, targets in train_loader:  # 訓練數據加載器outputs = model(inputs)  # 前向傳播loss = F.cross_entropy(outputs, targets)  # 計算損失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向傳播# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)optimizer.step()  # 參數更新scheduler.step()  # 學習率更新

關鍵技術總結

技術主要作用典型應用場景注意事項
學習率衰減精細收斂深層網絡訓練配合warmup效果更佳
梯度裁剪穩定訓練RNN、Transformer閾值需隨batch size調整
批量歸一化加速收斂CNN、全連接網絡小batch效果差

組合策略建議

  1. CNN架構:BN + 動量SGD + 余弦衰減
  2. RNN架構:梯度裁剪 + Adam + 步進衰減
  3. Transformer:預熱 + 梯度裁剪 + AdamW
# Transformer優化示例
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.98))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda step: min(step**-0.5, step*(4000**-1.5))  # 預熱
)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/78137.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/78137.shtml
英文地址,請注明出處:http://en.pswp.cn/web/78137.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【計算機視覺】語義分割:Mask2Former:統一分割框架的技術突破與實戰指南

深度解析Mask2Former&#xff1a;統一分割框架的技術突破與實戰指南 技術架構與創新設計核心設計理念關鍵技術組件 環境配置與安裝指南硬件要求安裝步驟預訓練模型下載 實戰全流程解析1. 數據準備2. 配置文件定制3. 訓練流程4. 推理與可視化 核心技術深度解析1. 掩膜注意力機制…

數字智慧方案5857丨智慧機場解決方案與應用(53頁PPT)(文末有下載方式)

資料解讀&#xff1a;智慧機場解決方案與應用 詳細資料請看本解讀文章的最后內容。 隨著科技的飛速發展&#xff0c;智慧機場的建設已成為現代機場發展的重要方向。智慧機場不僅提升了旅客的出行體驗&#xff0c;還極大地提高了機場的運營效率。本文將詳細解讀沃土數字平臺在…

【C到Java的深度躍遷:從指針到對象,從過程到生態】第五模塊·生態征服篇 —— 第二十章 項目實戰:從C系統到Java架構的蛻變

一、跨語言重構&#xff1a;用Java重寫Redis核心模塊 1.1 Redis的C語言基因解析 Redis 6.0源碼核心結構&#xff1a; // redis.h typedef struct redisObject { unsigned type:4; // 數據類型&#xff08;String/List等&#xff09; unsigned encoding:4; // …

ES6異步編程中Promise與Proxy對象

Promise 對象 Promise對象用于解決Javascript中的地獄回調問題&#xff0c;有效的減少了程序回調的嵌套調用。 創建 如果要創建一個Promise對象&#xff0c;最簡單的方法就是直接new一個。但是&#xff0c;如果深入學習&#xff0c;會發現使用Promise下的靜態方法Promise.re…

UE自動索敵插件Target System Component

https://www.fab.com/zh-cn/listings/9088334d-3bde-4e10-a937-baeb780f880f ? 一個完全用 C 編寫的 UE插件&#xff0c;添加了對簡單相機鎖定/瞄準系統的支持。它最初??在藍圖中開發和測試&#xff0c;然后轉換并重寫為 C 模塊和插件。 特征&#xff1a; 可通過一組可在…

中小企業MES系統概要設計

版本&#xff1a;V1.0 日期&#xff1a;2025年5月2日 一、系統架構設計 1.1 整體架構模式 采用分層微服務架構&#xff0c;實現模塊解耦與靈活擴展&#xff0c;支持混合云部署&#xff1a; #mermaid-svg-drxS3XaKEg8H8rAJ {font-family:"trebuchet ms",verdana,ari…

STM32移植U8G2

STM32 移植 U8G2 u8g2 &#xff08;Universal 8bit Graphics Library version2 的縮寫&#xff09;是用于嵌入式設備的單色圖形庫&#xff0c;可以在單色屏幕中繪制 GUI。u8g2 內部附帶了例如 SSD13xx&#xff0c;ST7xx 等很多 OLED&#xff0c;LCD 驅動。內置多種不同大小和風…

Langchain,為何要名為langchian?

來聽聽 DeepSeek 怎么說 Human 2025-05-02T01:13:43.627Z langchain 是一個大語言模型開發框架。我的理解中&#xff0c;lang 是詞根"語言"&#xff0c;chain是單詞"鏈"&#xff0c;langchain 便是將語言模型和組件串聯成鏈的框架。而 langchain 的圖標是…

Windows下Python3腳本傳到Linux下./example.py執行失敗

1. 背景 大多數情況下通過pycharm編寫Python代碼&#xff0c;編寫調試完&#xff0c;到Linux下發布執行。 以example.py腳本為例 #! /usr/bin/env python3 #! -*- encoding: utf-8 -*- def test(x,y): xint x yint y cxy return c if _name_"__main__": print(test(2…

當MCP撞進云宇宙:多芯片封裝如何重構云計算的“芯“未來?

當MCP撞進云宇宙:多芯片封裝如何重構云計算的"芯"未來? 2024年3月,AMD發布了震撼業界的MI300A/B芯片——這顆為AI計算而生的"超級芯片",首次在單封裝內集成了13個計算芯片(包括3D V-Cache緩存、CDNA3 GPU和Zen4 CPU),用多芯片封裝(Multi-Chip Pac…

用定時器做微妙延時注意事項

注意定時器來著APB1還是APB2&#xff0c;二者頻率不一樣&#xff0c;配置PSC要注意 &#xff08;1&#xff09;高級定時器timer1&#xff0c; timer8以及通用定時器timer9&#xff0c; timer10&#xff0c; timer11的時鐘來源是APB2總線 &#xff08;2&#xff09;通用定時器ti…

三類思維坐標空間與時空序位信息處理架構

三類思維坐標空間與時空序位信息處理架構 一、靜態信息元子與元組的數據結構設計 三維思維坐標空間定義 形象思維軸&#xff08;x&#xff09;&#xff1a;存儲多媒體數據元子&#xff08;圖像/音頻/視頻片段&#xff09; 元子結構&#xff1a;{ID, 數據塊, 特征向量, 語義…

spring boot中@Validated

在 Spring Boot 中&#xff0c;Validated 是用于觸發參數校驗的注解&#xff0c;通常與 ??JSR-303/JSR-380??&#xff08;Bean Validation&#xff09;提供的校驗注解一起使用。以下是常見的校驗注解及其用法&#xff1a; ?1. 基本校驗注解?? 這些注解可以直接用于字段…

Hadoop 單機模式(Standalone Mode)部署與 WordCount 測試

通過本次實驗&#xff0c;成功搭建了 Hadoop 單機環境并運行了基礎 MapReduce 程序&#xff0c;為后續分布式計算學習奠定了基礎。 掌握 Hadoop 單機模式的安裝與配置方法。 熟悉 Hadoop 環境變量的配置及 Java 依賴管理。 使用 Hadoop 自帶的 WordCount 示例程序進行簡單的 …

歷史數據分析——運輸服務

運輸服務板塊簡介: 運輸服務板塊主要是為貨物與人員流動提供核心服務的企業的集合,涵蓋鐵路、公路、航空、海運、物流等細分領域。該板塊具有強周期屬性,與經濟復蘇、政策調控、供需關系密切關聯,尤其是海運領域。有不少國內股市的鐵路、公路等相關的上市公司同時屬于紅利…

openEuler 22.03 安裝 Mysql 5.7,TAR離線安裝

目錄 一、檢查系統是否安裝其他版本Mariadb數據庫二、環境檢查2.1 必要環境檢查2.2 在線安裝&#xff08;有網絡&#xff09;2.3 離線安裝&#xff08;無網絡&#xff09; 三、下載Mysql2.1 在線下載2.2 離線下載 四、安裝Mysql五、配置Mysql六、開放防火墻端口七、數據備份八、…

噴泉碼技術在現代物聯網中的應用 設計

噴泉碼技術在現代物聯網中的應用 摘 要 噴泉碼作為一種無速率編碼技術,憑借其動態生成編碼包的特性,在物聯網通信中展現出獨特的優勢。其核心思想在于接收端只需接收到足夠數量的任意編碼包即可恢復原始數據,這種特性使其特別適用于動態信道和多用戶場景。噴泉碼的實現主要…

GZIPInputStream 類詳解

GZIPInputStream 類詳解 GZIPInputStream 是 Java 中用于解壓縮 GZIP 格式數據的流類,屬于 java.util.zip 包。它是 InflaterInputStream 的子類,專門處理 GZIP 壓縮格式(.gz 文件)。 1. 核心功能 解壓 GZIP 格式數據(RFC 1952 標準)自動處理 GZIP 頭尾信息(校驗和、時…

網絡編程——TCP和UDP詳細講解

文章目錄 TCP/UDP全面詳解什么是TCP和UDP&#xff1f;TCP如何保證可靠性&#xff1f;1. 序列號&#xff08;Sequence Number&#xff09;2. 確認應答&#xff08;ACK&#xff09;3. 超時重傳&#xff08;Timeout Retransmission&#xff09;4. 窗口控制&#xff08;Sliding Win…

性能測試工具篇

文章目錄 目錄1. JMeter介紹1.1 安裝JMeter1.2 打開JMeter1.3 JMeter基礎配置1.4 JMeter基本使用流程1.5 JMeter元件作用域和執行順序 2. 重點組件2.1 線程組2.2 HTTP取樣器2.3 查看結果樹2.4 HTTP請求默認值2.5 JSON提取器2.6 用戶定義的變量2.7 JSON斷言2.8 同步定時器&#…