Pytorch Lightning 進階 1 - 梯度檢查點(Gradient Checkpointing)

梯度檢查點(Gradient Checkpointing)是一種在深度學習訓練中優化顯存使用的技術,尤其適用于處理大型模型(如Transformer架構)時顯存不足的情況。下面用簡單的例子解釋其工作原理和優缺點:

核心原理

深度學習訓練中的顯存占用主要來自三個方面:

  1. 模型參數(如權重、偏置)
  2. 優化器狀態(如Adam的動量項)
  3. 中間激活值(forward過程中產生的張量,如注意力圖、隱藏層輸出等)

其中,中間激活值通常占用最大的顯存空間,尤其是在深層網絡中。梯度檢查點的核心思想是:

  • 在正向傳播時:只保存少量關鍵的中間結果(稱為“檢查點”),其余中間值在計算后立即丟棄。
  • 在反向傳播時:利用保存的檢查點重新計算被丟棄的中間值,從而獲得計算梯度所需的全部信息。

如下圖所示:
在這里插入圖片描述

這種方法通過犧牲計算時間(重新計算)來節省顯存空間(無需保存所有中間值)。

為什么需要梯度檢查點?

假設你有一個包含100層的Transformer模型,每層在forward過程中產生1GB的中間激活值:

  • 傳統訓練:需要保存所有100層的中間值,總顯存需求為100GB。
  • 梯度檢查點:只保存10個檢查點(每層1GB),反向傳播時通過檢查點重新計算其余90層,總顯存需求降至10GB。

代碼中的應用

在你的代碼中,gradient_checkpointing=True的配置會使模型在訓練時啟用梯度檢查點:

trainable_model = Model(# ...其他參數gradient_checkpointing=training_config.get("gradient_checkpointing", False),# ...
)

這意味著:

  1. 正向傳播時,模型不會保存所有注意力圖隱藏層輸出
  2. 反向傳播時,PyTorch會利用檢查點重新計算這些值,從而減少顯存占用

優缺點

  • 優點:顯著減少顯存使用(通常能節省30%-50%的顯存),允許訓練更大的模型或使用更大的批次大小。
  • 缺點:增加訓練時間(通常慢20%-30%),因為需要重新計算中間值。

何時使用?

  • 顯存不足:當模型因顯存限制無法訓練時,梯度檢查點是一種有效的解決方案。
  • 計算資源充足:如果你的GPU算力充足但顯存有限,可以通過延長訓練時間換取更小的顯存占用。

技術細節

在PyTorch中,梯度檢查點通過torch.utils.checkpoint模塊實現。例如:

from torch.utils.checkpoint import checkpointdef forward(self, x):# 普通forward:保存所有中間值x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return x# 使用梯度檢查點:只保存關鍵檢查點
def forward(self, x):x = checkpoint(self.layer1, x)  # 只保存layer1的輸出x = checkpoint(self.layer2, x)  # 只保存layer2的輸出x = self.layer3(x)return x

PyTorch Lightning的gradient_checkpointing參數會自動為模型的所有層應用這種優化。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/86111.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/86111.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/86111.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SpreadJS 迷你圖:數據趨勢可視化的利器

引言 在數據處理和分析領域,直觀地展示數據趨勢對于理解數據和做出決策至關重要。迷你圖作為一種簡潔而有效的數據可視化方式,在顯示數據趨勢方面發揮著重要作用,尤其在與他人共享數據時,能夠快速傳達關鍵信息。SpreadJS 作為一款…

GESP2024年12月認證C++一級( 第三部分編程題(1)溫度轉換)

參考程序1&#xff1a; #include <cstdio> using namespace std;int main() {double K;scanf("%lf", &K);double C K - 273.15; //轉換為攝氏溫度 double F 32 C * 1.8; //轉換為華氏溫度 if (F > 212) //條件判斷 print…

從零開始手寫redis(18)緩存淘汰算法 FIFO 優化

項目簡介 大家好&#xff0c;我是老馬。 Cache 用于實現一個可拓展的高性能本地緩存。 有人的地方&#xff0c;就有江湖。有高性能的地方&#xff0c;就有 cache。 v1.0.0 版本 以前的 FIFO 實現比較簡單&#xff0c;但是 queue 循環一遍刪除的話&#xff0c;性能實在是太…

用Zynq實現脈沖多普勒雷達信號處理:架構、算法與實現詳解

用Zynq實現脈沖多普勒雷達信號處理:架構、算法與實現詳解 脈沖多普勒(PD)雷達是現代雷達系統的核心技術之一,廣泛應用于機載火控、氣象監測、交通監控等領域。其核心優勢在于能在強雜波背景下檢測運動目標,并精確測量其徑向速度。本文將深入探討如何利用Xilinx Zynq SoC(…

OpenCV CUDA模塊設備層-----線程塊級別的一個內存填充工具函數blockFill()

操作系統&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 編程語言&#xff1a;C11 算法描述 在同一個線程塊&#xff08;thread block&#xff09;內&#xff0c;將 [beg, end) 范圍內的數據并行地填充為指定值 value。 它使用了 CUDA 線程…

SAP-ABAP:如何查詢 SAP 事務碼(T-Code)被包含在哪些權限角色或權限對象中

要查詢 SAP 事務碼&#xff08;T-Code&#xff09;被包含在哪些權限角色或權限對象中&#xff0c;可使用以下專業方法&#xff1a; &#x1f50d; 1. 通過權限瀏覽器 (SUIM) - 最推薦 事務碼&#xff1a;SUIM (權限信息系統) 操作步驟&#xff1a; 執行 SUIM → 選擇 “角色…

MySQL 多列 IN 查詢詳解:語法、性能與實戰技巧

在 MySQL 中&#xff0c;多列 IN 查詢是一種強大的篩選工具&#xff0c;它允許通過多字段組合快速過濾數據。相較于傳統的 OR 連接多個條件&#xff0c;這種語法更簡潔高效&#xff0c;尤其適合批量匹配復合鍵或聯合字段的場景。本文將深入解析其用法&#xff0c;并探討性能優化…

自由學習記錄(63)

編碼全稱&#xff1a;AV1&#xff08;Alliance for Open Media Video 1&#xff09;。 算力消耗大&#xff1a;目前&#xff08;截至 2025 年中&#xff09;軟件解碼 AV1 的 CPU 開銷非常高&#xff0c;如果沒有專門的硬件解碼單元&#xff0c;播放高清視頻時會很吃 CPU&#…

日本生活:日語語言學校-日語作文-溝通無國界(4)-題目:喜歡讀書

日本生活&#xff1a;日語語言學校-日語作文-溝通無國界&#xff08;4&#xff09;-題目&#xff1a;喜歡讀書 1-前言2-作文原稿3-作文日語和譯本&#xff08;1&#xff09;日文原文&#xff08;2&#xff09;對應中文&#xff08;3&#xff09;對應英文 4-老師評語5-自我感想&…

C++優化程序的Tips

轉自個人博客 1. 避免創建過多中間變量 過多的中間變量不利于代碼的可讀性&#xff0c;還會增加內存的使用&#xff0c;而且可能導致額外的計算開銷。 將用于同一種情況的變量統一管理&#xff0c;可以使用一種通用的變量來代替多個變量。 2. 函數中習慣使用引用傳參而不是返…

C#Blazor應用-跨平臺WEB開發VB.NET

在 C# 中實現 Blazor 應用需要結合 Razor 語法和 C# 代碼&#xff0c;Blazor 允許使用 C# 同時開發前端和后端邏輯。以下是一個完整的 C# Blazor 實現示例&#xff0c;包含項目創建、基礎組件和數據交互等內容&#xff1a; 一、創建 Blazor 項目 使用 Visual Studio 新建項目 …

前端的安全隱患之API惡意調用

永遠不要相信前端傳來的數據&#xff0c;對于資深開發者而言&#xff0c;這幾乎是一種本能&#xff0c;無需過多解釋。然而&#xff0c;初入職場的開發新手可能會感到困惑&#xff1a;為何要對前端傳來的數據持有如此不信任的態度&#xff1f;難道人與人之間連基本的信任都不存…

基于 Spark 實現 COS 海量數據處理

上周在組內分享了一下這個主題&#xff0c; 我覺得還是摘出一部分當文章輸出出來 分享主要包括三個方面&#xff1a; 1. 項目背景 2.Spark 原理 3. Spark 實戰 項目背景 主要是將海量日志進行多維度處理&#xff1b; 項目難點 1、數據量大&#xff08;壓縮包數量 6TB,60 億條數…

Unity3D 屏幕點擊特效

實現點擊屏幕任意位置播放點擊特效。 屏幕點擊特效 需求 現有一個需求&#xff0c;點擊屏幕任意位置&#xff0c;播放一個點擊特效。 美術已經做好了特效&#xff0c;效果如圖&#xff1a; 特效容器 首先&#xff0c;畫布是 Camera 模式&#xff0c;畫布底下有一個 UIClic…

MCU編程

MCU 編程基礎&#xff1a;概念、架構與實踐 一、什么是 MCU 編程&#xff1f; MCU&#xff08;Microcontroller Unit&#xff0c;微控制器&#xff09; 是將 CPU、內存、外設&#xff08;如 GPIO、UART、ADC&#xff09;集成在單一芯片上的小型計算機系統。MCU 編程即針對這些…

Go語言--語法基礎6--基本數據類型--數組類型(1)

Go 語言提供了數組類型的數據結構。 數組是具有相同唯一類型的一組已編號且長度固定的數據項序列&#xff0c;這種類型可以是任意的 原始類型例如整型、字符串或者自定義類型。相對于去聲明number0,number1, ..., and number99 的變量&#xff0c;使用數組形式 numbers[0], …

左神算法之給定一個數組arr,返回其中的數值的差值等于k的子數組有多少個

目錄 1. 題目2. 解釋3. 思路4. 代碼5. 總結 1. 題目 給定一個數組arr&#xff0c;返回其中的數值的差值等于k的子數組有多少個 2. 解釋 略 3. 思路 直接用hashSet進行存儲&#xff0c;查這個值加上k后的值是否在數組中 4. 代碼 public class Problem01_SubvalueEqualk {…

自回歸(AR)與掩碼(MLM)的核心區別:續寫還是補全?

自回歸(AR)與掩碼(MLM)的核心區別:用例子秒懂 一、核心機制對比:像“續寫”還是“完形填空”? 維度自回歸(Autoregressive)掩碼語言模型(Masked LM)核心目標根據已生成的token,預測下一個token(順序生成)預測句子中被“掩碼”的token(補全缺失信息)輸入輸出輸入…

后端開發兩個月實習總結

前言 本人目前在一家小公司后端開發實習差不多兩個月了&#xff0c;現在準備離職了&#xff0c;就這兩個月的實習經歷寫下這篇文章&#xff0c;既是對自己實習的一個總結&#xff0c;也是給正在找實習的小伙伴以及未來即將進入到后端開發這個行業的同學的分享一下經驗。 一、個…

Python基礎(??FAISS?和??Chroma?)

??1. 索引與查詢性能? ??指標????FAISS????Chroma????分析????索引構建速度??72.4秒&#xff08;5551個文本塊&#xff09;91.59秒&#xff08;相同數據集&#xff09;FAISS的底層優化&#xff08;如PQ量化&#xff09;加速索引構建&#xff0c;適合批…