【大模型LLM】大模型訓練加速 - 梯度累積(Gradient Accumulation)原理詳解

在這里插入圖片描述

梯度累積(Gradient Accumulation)原理詳解

梯度累積是一種在深度學習訓練中常用的技術,特別適用于顯存有限但希望使用較大批量大小(batch size)的情況。通過梯度累積,可以在不增加單個批次大小的情況下模擬較大的批量大小,從而提高模型的穩定性和收斂速度。

基本概念

在標準的隨機梯度下降(SGD)及其變體(如Adam、RMSprop等)中,每次更新模型參數時都需要計算整個批次數據的損失函數梯度,并立即用這個梯度來更新模型參數。然而,在處理大規模數據集或使用非常大的模型時,單個批次的數據量可能會超出GPU顯存的容量。此時,梯度累積技術就可以發揮作用。

工作原理

梯度累積的核心思想是:將多個小批次(mini-batch)的梯度累加起來,然后一次性執行一次參數更新。具體步驟如下:

  1. 初始化梯度累積器:在每個訓練步驟開始時,初始化一個梯度累積器(通常為零)。
  2. 前向傳播與梯度計算
    • 對于每一個小批次 i(從 1 到 k),執行前向傳播計算損失。
    • 執行反向傳播計算該小批次的梯度。
  3. 累積梯度:將當前小批次的梯度累加到梯度累積器中。
  4. 參數更新:當累積了 k 個小批次的梯度后,使用累積的梯度來更新模型參數,并重置梯度累積器。
詳細步驟

假設我們希望使用的批量大小是 N,但由于顯存限制只能使用較小的批量大小 n(其中 N = k * n),那么我們可以進行 k 次前向和后向傳播,每次都計算一個小批次的梯度并將其累加,直到累積了 k 個小批次的梯度之后,再進行一次參數更新。

示例代碼

以下是一個簡單的PyTorch示例,展示了如何實現梯度累積:

import torch
import torch.nn as nn
import torch.optim as optim# 假設有一個簡單的模型
model = nn.Linear(10, 2)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 設置梯度累積步數
accumulation_steps = 4
optimizer.zero_grad()  # 清空梯度for i, (inputs, labels) in enumerate(data_loader):outputs = model(inputs)loss = criterion(outputs, labels)# 將損失除以累積步數,使得總的損失不變loss = loss / accumulation_steps# 反向傳播計算梯度loss.backward()if (i + 1) % accumulation_steps == 0:# 累積足夠步數后,執行優化步驟optimizer.step()optimizer.zero_grad()  # 清空梯度
關鍵點解釋
  1. 損失縮放:由于我們將一個大批次分成多個小批次,并且每次只計算一個小批次的損失,因此需要將每個小批次的損失除以累積步數 accumulation_steps,以確保總的損失值保持不變。

  2. 梯度累積:每次反向傳播后,梯度會被累加而不是立即用于更新參數。只有當累積了足夠的步數后,才會使用累積的梯度進行一次參數更新。

  3. 參數更新:在累積了足夠的梯度后,調用 optimizer.step() 來更新模型參數,并清空梯度累積器(即調用 optimizer.zero_grad())。

優點
  • 突破顯存限制:通過使用較小的批量大小,可以有效地減少每一步所需的顯存量,從而允許在有限的硬件資源上訓練更大的模型或使用更大的批量大小。
  • 模擬大批次訓練效果:梯度累積實際上模擬了使用較大批量大小的效果,有助于提高模型訓練的穩定性和收斂速度。
  • 靈活性:可以根據實際硬件條件靈活調整累積步數,適應不同的訓練需求。
注意事項
  • 學習率調整:由于梯度累積實際上是將多個小批次的梯度累加起來進行一次更新,因此需要相應地調整學習率。例如,如果原始設置的學習率為 lr,并且使用了 k 步梯度累積,則新的有效學習率應為 lr * k
  • 隨機性影響:梯度累積可能會引入一定的隨機性,因為不同小批次之間的順序可能會影響最終的梯度累積結果。不過,在實踐中這種影響通常是可以接受的。
總結

梯度累積是一種非常實用的技術,特別是在顯存受限但希望利用更大批量大小的情況下。它不僅幫助克服了硬件限制,還能夠保持甚至提升模型訓練的質量。通過合理配置梯度累積步數和學習率,可以顯著改善訓練效率和效果。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/93889.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/93889.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/93889.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【數據分享】各省文旅融合耦合協調度及原始數據(2012-2022)

數據介紹引言 文旅融合是推動區域經濟高質量發展、促進共同富裕的重要路徑。黨的二十大報告明確提出“推進文化和旅游深度融合發展”的戰略目標,文旅產業通過資源整合與業態創新,可顯著縮小城鄉、區域差距,提升物質與精神雙重福祉&#xff08…

Linux編程: 10、線程池與初識網絡編程

今天我計劃通過一個小型項目,系統講解線程池與網絡編程的核心原理及實踐。項目將圍繞 “利用線程池實現高并發網絡通信” 這一核心需求展開,具體設計如下: 為保證線程安全,線程池采用單例模式設計,確保全局唯一實例避…

藏云閣 Logo 庫(開源項目SVG/PNG高清Logo)

在日常技術方案設計、架構圖繪制或PPT制作中,常常會遇到一些問題,比如: 找不到統一風格的開源項目組件圖標,PPT中的logo五花八門下載的圖標分辨率不足,放大后模糊失真不同來源的圖標顏色風格沖突,破壞整體…

從0開始學習R語言--Day64--決策樹回歸

對于沒有特征或者說需要尋找另類關系的數據集,我們通常會用聚合或KNN近鄰的方法來分類,但這樣的分類或許在結果上是好的,但是解釋性并不好,有時候我們甚至能看到好的結果反直覺;而決策樹回歸做出的結果,由于…

B+樹高效實現與優化技巧

B樹的定義 一顆M階B樹T,滿足以下條件 每個結點至多擁有M課子樹 根結點至少擁有兩顆子樹 除了根結點以外,其余每個分支結點至少擁有M/2課子樹 所有的葉結點都在同一層上 有k棵子樹的分支結點則存在k-1個關鍵字,關鍵字按照遞增順序進行排序 關鍵字數量滿足 ceil( M/2 ) - 1 &…

Android 基礎入門學習目錄(持續更新)

四大組件 Activity: Service: BroadcastReceiver: ContentProvider: UI 與交互開發 常見的UI布局和UI控件 樣式與主題 Fragment Intent 數據存儲 自定義View和自定義Group 自定義View 自定義ViewGroup 事件分發 Key…

Linux移動大量文件命令

背景 使用 mv 命令報“/bin/mv: 參數列表過長”,也是第一遇到,查了一下,最后用rsync命令解決了。還好每臺服務器,都必裝rsync了,記錄如下。 命令 nohup rsync -av --remove-source-files --progress /public/tmp/video…

SQL中的HAVING用法

HAVING 是 SQL 中專門對 “分組之后的聚合結果” 再做篩選的子句。 它一般跟在 GROUP BY 后面,不能單獨使用,作用類似于分組版的 WHERE。? 1. 語法位置 SELECT 列1, 聚合函數(列2) AS 別名 FROM 表 GROUP BY 列1 HAVING 聚合條件; -- 這里寫對聚合…

【Halcon 】Halcon 實戰:如何為 XLD 模板添加極性信息以提升匹配精度?

Halcon 實戰:如何為 XLD 模板添加極性信息以提升匹配精度? 在使用 Halcon 進行模板匹配時,我們通常有兩種方式創建模板: 基于圖像灰度(CreateScaledShapeModel)基于輪廓 XLD(CreateScaledShapeM…

grafana/lock-stack 日志 Pipeline 配置

前言 本文使用的是 grafana/loki-stack chart 抓取的 k8s 日志。其他 chart 配置都差不多。 日志問題 docker 容器運行時 pod 內原始日志 [cpu-4] Hello, 第 9788 次報時,時間:2025-08-01T06:35:420000 {"HOSTNAME":"cpu-4",&qu…

appium2.0+之PointerActions詳解

以下內容在 夜神模擬器 上進行。 一、應用場景 一些針對手勢的操作,比如滑動、長按、拖動等。可以將這些基本手勢組合成一個相對復雜的手勢。 二、使用步驟創建觸摸輸入設備(模擬手指操作) touch_input PointerInput(interaction.POINTER_TO…

Java HTTPS 請求失敗排查與證書導入全過程

文章目錄Java HTTPS 請求失敗排查與證書導入全過程問題背景問題初步分析排查過程查看目標地址證書導入證書驗證證書是否導入成功重啟應用進一步驗證:是否真的是證書問題?1. 瀏覽器訪問2. 抓包工具驗證(如 Charles、Wireshark)補充…

android APT技術

1,背景 對于注解的使用,想必大家都不陌生,它出現在我們的源碼中,以及大部分框架中,比如ButterKnife、Arouter、Retrofit,但它們是有區別的,其中前2個是編譯時注解,最后一個是運行時注…

MySQL 和 PostgreSQL綜合比對分析匯總

面對大數據項目或其它類型項目中,面對關系型數據庫選擇一直是很總要的一點,本文針對MySQL 和 PostgreSQL進行綜合比對分析匯總,內容僅供參考。MySQL 和 PostgreSQL 是兩款主流的開源關系型數據庫(RDBMS),但…

Linux---make和makefile

一、基本概念1.是什么make是一條命令,makefile是一個文件2.對應在vs中按一下f5就能運行代碼,在Linux中make就相當于f5,使用makefile來封裝從而實現我, 想要的功能3.使用①創建makefile文件②編輯makefile解釋:test.exe…

【DAB收音機】DAB收音機協議及其他資料匯總

目錄[ETSI DAB標準協議文檔](https://www.etsi.org/standards)Other DAB資料DAB收音機相關的專利DAB收音機相關的期刊及學位論文DAB開源項目代碼倉庫qt-dab工具welle.io工具dablin工具【eti廣播工具】?? 項目對比與選型建議Other 收音機資料Other資料ETSI DAB標準協議文檔 官…

RabbitMQ的特點和消息可靠性保障

掌握RabbitMQ的核心知識,需從其特點和消息可靠性保障(尤其是消息丟失解決方案)兩方面入手,以下是詳細說明: 一、RabbitMQ的核心特點 RabbitMQ是基于AMQP(Advanced Message Queuing Protocol)協議…

項目升級啦

公司要新做一個醫療行業的業務,經過業務端和產品端的評估該業務與公司已有的產品線關聯不大,用戶后續也不想在老系統那臺老爺車上繼續使用,話說老系統到現在差不多10年了,中間經歷過的前后端開發者形形色色,維護者換了…

Android中頁面生命周期變化

一、Activity切換的生命周期變化(A啟動B)1. 標準流程(B完全覆蓋A)完整生命周期路徑:Activity A:onPause():失去焦點,仍部分可見onStop():完全不可見(當B完全覆…

自動駕駛控制算法——PID算法

自動駕駛控制算法——PID算法 文章目錄自動駕駛控制算法——PID算法一、PID 是什么?二、PID 原理2.1 **比例環節(P)**2.2 **積分環節(I)**2.3 **微分環節(D)**2.4 特點總結2.5 案例分析 —— 小…