【漫話機器學習系列】277.梯度裁剪(Gradient Clipping)

【深度學習】什么是梯度裁剪(Gradient Clipping)?一張圖徹底搞懂!

在訓練深度神經網絡,尤其是 RNN、LSTM、Transformer 這類深層結構時,你是否遇到過以下情況:

  • 模型 loss 突然變成 NaN;

  • 梯度爆炸導致訓練中斷;

  • 訓練剛開始幾步模型就“失控”了。

這些問題,很多時候都是因為——梯度過大(梯度爆炸)。而應對這個問題的常見方案之一,就是本文要講的主角:梯度裁剪(Gradient Clipping)


一、梯度裁剪是什么?

我們先看一張圖,一圖勝千言:

圖中文字解讀如下:

  • 標題:梯度裁剪(Gradient Clipping)

  • 說明文字

    損失函數中的梯度懸崖會導致模型在學習過程中超出期望最小值。發生這種情況,是因為梯度陡峭。解決方法:阻止梯度選擇極端值。

  • 圖示公式

    if ‖g‖ > v:g ← (g / ‖g‖) * v
    

    意思是:

    • 如果梯度的范數(即長度)大于某個閾值 v,就將梯度縮放為長度為 v 的向量。

    • 這樣可以防止某些參數更新過大。


二、為什么需要梯度裁剪?

1. 梯度爆炸的根源

在反向傳播中,每一層的梯度是前面所有梯度的乘積。在深層網絡中,如果這些乘積的值都 > 1,最終梯度將呈指數級增長,導致所謂的梯度爆炸(Gradient Explosion)

表現形式

  • loss 一直上升,甚至變成 NaN

  • 參數更新過大,模型發散

  • 模型無法收斂

2. 梯度裁剪的作用

梯度裁剪并不會改變梯度的方向,它只是在梯度的模(大小)超過某個閾值時,進行縮放。這就像是給模型裝了一個“剎車”系統,一旦速度過快就減速。


三、梯度裁剪的數學原理

設:

  • 當前梯度為 g

  • 范數為 ∥g∥

  • 閾值為 v

裁剪操作如下:

\text{if } \|g\| > v, \quad g \leftarrow \frac{g}{\|g\|} \cdot v

也就是說:將梯度的模限制在最大值 vv 內,方向保持不變。


四、實戰中如何實現梯度裁剪?

在 PyTorch 中非常簡單:

import torch# 假設已經定義 optimizer 和 model
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

在 TensorFlow(Keras)中也可以:

optimizer = tf.keras.optimizers.Adam(clipnorm=1.0)

五、梯度裁剪 vs 梯度正則化

名稱作用是否改變方向
梯度裁剪控制梯度最大值,避免爆炸
L2 正則化(權重衰減)防止模型過擬合,限制權重大小

注意:梯度裁剪是為了“救訓練”,不是為了“提高精度”!


六、何時需要使用梯度裁剪?

  • 訓練深度模型如 RNN、LSTM、Transformer

  • loss 出現爆炸性增長,模型訓練不穩定;

  • 使用高學習率訓練時容易出問題;

  • 模型結構復雜,層數深,非線性強。


七、調參建議

參數建議取值說明
clip norm0.1 ~ 5通常從 1.0 開始嘗試,逐步調整
適用優化器Adam、SGD梯度裁剪不依賴特定優化器
使用頻率每次 step 前每次梯度更新前裁剪

八、總結

梯度裁剪是深度學習中極其實用的一種 訓練穩定性保障機制。它的作用不是提升模型能力,而是防止模型“發瘋”。在某些模型結構中(如 LSTM、GAN),它幾乎是標配操作。

一句話總結:梯度裁剪不是為了讓模型跑得快,而是為了別讓它翻車。


推薦閱讀

  • 《Deep Learning》by Ian Goodfellow(第 6 章)

  • PyTorch 官方文檔:clip_grad_norm_


如果你覺得本文對你有幫助,歡迎點贊、收藏、評論~
也歡迎你分享你在訓練中使用梯度裁剪的經驗或踩過的坑!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/82402.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/82402.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/82402.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

零基礎弄懂 ngx_http_slice_module分片緩存加速

一、為什么需要 Slice? 在 NGINX 反向代理或 CDN 場景中,大文件(視頻、軟件包、鏡像等)常因單體體積過大而令緩存命中率低、回源代價高。 ngx_http_slice_module 通過把一次完整響應拆分成 固定大小的字節塊(Slice&am…

機器人強化學習入門學習筆記(三)

強化學習(Reinforcement Learning, RL)與監督學習不同——你不需要預先準備訓練數據集,而是要設計環境、獎勵函數,讓智能體通過交互不斷探索和學習。 🎯 一、強化學習和訓練數據的關系 強化學習不依賴固定的數據集。它…

【python實戰】二手房房價數據分析與預測

個人主頁:大數據蟒行探索者 目錄 一、數據分析目標與任務 1.1背景介紹 1.2課程設計目標與任務 1.3研究方法與技術路線 二、數據預處理 2.1數據說明 2.2數據清洗 2.3數據處理 三、數據探索分析 四、數據分析模型 五、方案評估 摘要:隨著社會經…

Kotlin IR編譯器插件開發指南

在 Kotlin 中開發基于 IR(Intermediate Representation)的編譯器插件,可以深度定制語言功能或實現高級代碼轉換。以下是分步驟指南: 一、IR 編譯器插件基礎 IR 是什么? Kotlin 編譯器將源碼轉換為 IR 中間表示&#xf…

如何用 python 代碼復現 MATLAB simulink 的 PID

MATLAB在 Simulink 里做以下設置MATLAB 腳本調用示例 python 實現離散 PID 實現(并行形式) Simulink 中兩種 PID 結構(并聯形式, I-形式)下連續/離散時域里積分增益 I 的表示并聯(Parallel) vs 理想&#x…

黑馬點評--基于Redis實現共享session登錄

集群的session共享問題分析 session共享問題:多臺Tomcat無法共享session存儲空間,當請求切換到不同Tomcat服務時,原來存儲在一臺Tomcat服務中的數據,在其他Tomcat中是看不到的,這就導致了導致數據丟失的問題。 雖然系…

SkyWalking啟動失敗:OpenSearch分片數量達到上限的完美解決方案

?? 問題現象 SkyWalking OAP服務啟動時報錯: org.apache.skywalking.oap.server.library.module.ModuleStartException: java.lang.RuntimeException: {"error":{"root_cause":[{"type":"validation_exception", "reason&q…

向量數據庫選型實戰指南:Milvus架構深度解析與技術對比

導讀:隨著大語言模型和AI應用的快速普及,傳統數據庫在處理高維向量數據時面臨的性能瓶頸日益凸顯。當文檔經過嵌入模型處理生成768到1536維的向量后,傳統B-Tree索引的檢索效率會出現顯著下降,而現代應用對毫秒級響應的嚴苛要求使得…

MySQL#秘籍#一條SQL語句執行時間以及資源分析

背景 一條 SQL 語句的執行完,每個模塊耗時,不同資源(CPU/IO/IPC/SWAP)消耗情況我該如何知道呢?別慌俺有 - MySQL profiling 1. SQL語句執行前 - 開啟profiling -- profiling (0-關閉 1-開啟) -- 或者:show variables like prof…

【數據結構】實現方式、應用場景與優缺點的系統總結

以下是編程中常見的數據結構及其實現方式、應用場景與優缺點的系統總結: 一、線性數據結構 1. 數組 (Array) 定義:連續內存空間存儲相同類型元素。實現方式:int[] arr new int[10]; // Javaarr [0] * 10 # Python操作: 訪問&…

PyTorch中cdist和sum函數使用示例詳解

以下是PyTorch中cdist與sum函數的聯合使用詳解: 1. cdist函數解析 功能:計算兩個張量間的成對距離矩陣 輸入格式: X1:形狀為(B, P, M)的張量X2:形狀為(B, R, M)的張量p:距離類型(默認2表示歐式距離)輸出:形狀為(B, P, R)的距離矩陣,其中元素 d i j d_{ij} dij?表示…

Ansible配置文件常用選項詳解

Ansible 的配置文件采用 INI 格式,分為多個模塊,每個模塊包含特定功能的配置參數。 以下是ansible.cfg配置文件中對各部分的詳細解析: [defaults](全局默認配置) inventory 指定主機清單文件路徑,默認值為 …

了解FTP搜索引擎

根據資料, FTP搜索引擎是專門搜集匿名FTP服務器提供的目錄列表,并向用戶提供文件信息的網站; FTP搜索引擎專門針對FTP服務器上的文件進行搜索; 就是它的搜索結果是一些FTP資源; 知名的FTP搜索引擎如下, …

【大模型面試每日一題】Day 28:AdamW 相比 Adam 的核心改進是什么?

【大模型面試每日一題】Day 28:AdamW 相比 Adam 的核心改進是什么? 📌 題目重現 🌟🌟 面試官:AdamW 相比 Adam 的核心改進是什么? #mermaid-svg-BJoVHwvOm7TY1VkZ {font-family:"trebuch…

C++系統IO

C系統IO 頭文件的使用 1.使用系統IO必須包含相應的頭文件,通常使用#include預處理指令。 2.頭文件中包含了若干變量的聲明,用于實現系統IO。 3.頭文件的引用方式有雙引號和尖括號兩種,區別在于查找路徑的不同。 4.C標準庫提供的頭文件通常沒…

多模態理解大模型高性能優化丨前沿多模態模型開發與應用實戰第七期

一、引言 在前序課程中,我們系統剖析了多模態理解大模型(Qwen2.5-VL、DeepSeek-VL2)的架構設計。鑒于此類模型訓練需消耗千卡級算力與TB級數據,實際應用中絕大多數的用戶場景均圍繞推理部署展開,模型推理的效率影響著…

各個網絡協議的依賴關系

網絡協議的依賴關系 學習網絡協議之間的依賴關系具有多方面重要作用,具體如下: 幫助理解網絡工作原理 - 整體流程明晰:網絡協議分層且相互依賴,如TCP/IP協議族,應用層協議依賴傳輸層的TCP或UDP協議來傳輸數據&#…

11.8 LangGraph生產級AI Agent開發:從節點定義到高并發架構的終極指南

使用 LangGraph 構建生產級 AI Agent:LangGraph 節點與邊的實現 關鍵詞:LangGraph 節點定義, 條件邊實現, 狀態管理, 多會話控制, 生產級 Agent 架構 1. LangGraph 核心設計解析 LangGraph 通過圖結構抽象復雜 AI 工作流,其核心要素構成如下表所示: 組件作用描述代碼對應…

相機--基礎

在機器人開發領域,相機種類很多,作為一個機器人領域的開發人員,我們需要清楚幾個問題: 1,相機的種類有哪些? 2,各種相機的功能,使用場景? 3,需要使用的相機…

【備忘】 windows 11安裝 AdGuardHome,實現開機自啟,使用 DoH

windows 11安裝 AdGuardHome,實現開機自啟,使用 DoH 下載 AdGuardHome解壓 AdGuardHome啟動 AdGuard Home設置 AdGuardHome設置開機自啟安裝 NSSM設置開機自啟重啟電腦后我們可以訪問 **http://127.0.0.1/** 設置使用 AdGuardHome DNS 效果圖 下載 AdGua…