【大模型LLM】梯度累積(Gradient Accumulation)原理詳解

在這里插入圖片描述

梯度累積(Gradient Accumulation)原理詳解

梯度累積是一種在深度學習訓練中常用的技術,特別適用于顯存有限但希望使用較大批量大小(batch size)的情況。通過梯度累積,可以在不增加單個批次大小的情況下模擬較大的批量大小,從而提高模型的穩定性和收斂速度。

基本概念

在標準的隨機梯度下降(SGD)及其變體(如Adam、RMSprop等)中,每次更新模型參數時都需要計算整個批次數據的損失函數梯度,并立即用這個梯度來更新模型參數。然而,在處理大規模數據集或使用非常大的模型時,單個批次的數據量可能會超出GPU顯存的容量。此時,梯度累積技術就可以發揮作用。

工作原理

梯度累積的核心思想是:將多個小批次(mini-batch)的梯度累加起來,然后一次性執行一次參數更新。具體步驟如下:

  1. 初始化梯度累積器:在每個訓練步驟開始時,初始化一個梯度累積器(通常為零)。
  2. 前向傳播與梯度計算
    • 對于每一個小批次 i(從 1 到 k),執行前向傳播計算損失。
    • 執行反向傳播計算該小批次的梯度。
  3. 累積梯度:將當前小批次的梯度累加到梯度累積器中。
  4. 參數更新:當累積了 k 個小批次的梯度后,使用累積的梯度來更新模型參數,并重置梯度累積器。
詳細步驟

假設我們希望使用的批量大小是 N,但由于顯存限制只能使用較小的批量大小 n(其中 N = k * n),那么我們可以進行 k 次前向和后向傳播,每次都計算一個小批次的梯度并將其累加,直到累積了 k 個小批次的梯度之后,再進行一次參數更新。

示例代碼

以下是一個簡單的PyTorch示例,展示了如何實現梯度累積:

import torch
import torch.nn as nn
import torch.optim as optim# 假設有一個簡單的模型
model = nn.Linear(10, 2)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 設置梯度累積步數
accumulation_steps = 4
optimizer.zero_grad()  # 清空梯度for i, (inputs, labels) in enumerate(data_loader):outputs = model(inputs)loss = criterion(outputs, labels)# 將損失除以累積步數,使得總的損失不變loss = loss / accumulation_steps# 反向傳播計算梯度loss.backward()if (i + 1) % accumulation_steps == 0:# 累積足夠步數后,執行優化步驟optimizer.step()optimizer.zero_grad()  # 清空梯度
關鍵點解釋
  1. 損失縮放:由于我們將一個大批次分成多個小批次,并且每次只計算一個小批次的損失,因此需要將每個小批次的損失除以累積步數 accumulation_steps,以確保總的損失值保持不變。

  2. 梯度累積:每次反向傳播后,梯度會被累加而不是立即用于更新參數。只有當累積了足夠的步數后,才會使用累積的梯度進行一次參數更新。

  3. 參數更新:在累積了足夠的梯度后,調用 optimizer.step() 來更新模型參數,并清空梯度累積器(即調用 optimizer.zero_grad())。

優點
  • 突破顯存限制:通過使用較小的批量大小,可以有效地減少每一步所需的顯存量,從而允許在有限的硬件資源上訓練更大的模型或使用更大的批量大小。
  • 模擬大批次訓練效果:梯度累積實際上模擬了使用較大批量大小的效果,有助于提高模型訓練的穩定性和收斂速度。
  • 靈活性:可以根據實際硬件條件靈活調整累積步數,適應不同的訓練需求。
注意事項
  • 學習率調整:由于梯度累積實際上是將多個小批次的梯度累加起來進行一次更新,因此需要相應地調整學習率。例如,如果原始設置的學習率為 lr,并且使用了 k 步梯度累積,則新的有效學習率應為 lr * k
  • 隨機性影響:梯度累積可能會引入一定的隨機性,因為不同小批次之間的順序可能會影響最終的梯度累積結果。不過,在實踐中這種影響通常是可以接受的。
總結

梯度累積是一種非常實用的技術,特別是在顯存受限但希望利用更大批量大小的情況下。它不僅幫助克服了硬件限制,還能夠保持甚至提升模型訓練的質量。通過合理配置梯度累積步數和學習率,可以顯著改善訓練效率和效果。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/90998.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/90998.shtml
英文地址,請注明出處:http://en.pswp.cn/web/90998.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

阿里云Ubuntu 22.04 ssh隔一段時間自動斷開的解決方法

在使用ssh連接阿里云ubuntu22.04隔一段時間之后就自動斷開,很影響體驗,按照如下配置就可以解決vim /etc/ssh/sshd_config

R中匹配函數

在 R 中,字符串匹配是一個常見的任務,可以使用正則表達式或非正則表達式的方法來完成。以下是對這些方法的總結,包括在向量和數據框中的應用。 正則表達式匹配 常用函數grepl: 功能:檢查向量中的每個元素是否匹配某個正…

Ubuntu服務器上JSP運行緩慢怎么辦?全面排查與優化方案

隨著企業系統越來越多地部署在Linux平臺上,Ubuntu成為JSP Web系統常見的部署環境。但不少開發者會遇到一個共同的問題:在Ubuntu服務器上運行的JSP項目訪問緩慢、頁面加載時間長,甚至出現卡頓現象。這類問題如果不及時解決,容易導致…

web刷題

[極客大挑戰 2019]RCE ME 打開環境,代碼邏輯還是很簡單的 思路是傳參code參數,一般傳參shell然后用蟻劍連接看flag,但是這題做了之后就會發現思路是沒錯但是這題多了一些驗證,這題就是無字符rce,可以考慮用取反&…

FFmpeg+javacpp中FFmpegFrameGrabber

FFmpegjavacpp中FFmpegFrameGrabber1、FFmpegFrameGrabber1.1 Demo使用1.2 音頻相關1.3 視頻相關2、Frame屬性2.1 視頻幀屬性2.2 音頻幀屬性2.3 音頻視頻區分JavaCV 1.5.12 API JavaCPP Presets for FFmpeg 7.1.1-1.5.12 API1、FFmpegFrameGrabber org\bytedeco\javacv\FFmpeg…

1-FPGA的LUT理解

FPGA的LUT理解 FPGA的4輸入LUT中,SRAM存儲的16位二進制數(如 0110100110010110)直接對應真值表的輸出值。下面通過具體例子詳細解釋其含義: 1. 4輸入LUT 4輸入LUT的本質是一個161的SRAM,它通過存儲真值表的方式實現任意…

Vue2文件上傳相關

導入彈窗<template><el-dialog:title"title":visible.sync"fileUploadVisible"append-to-bodyclose-on-click-modalclose-on-press-escapewidth"420px"><div v-if"showDatePicker">選擇時間&#xff1a;<el-date…

vue使用xlsx庫導出excel

引入xlsx庫 import XLSX from "xlsx";將后端接口返回的數據和列名&#xff0c;拼接到XLSX.utils.aoa_to_sheet中exportExcel() {debugger;if (!this.feedingTableData || this.feedingTableData.length "0") {this.$message.error("投料信息為空&…

卷積神經網絡(CNN)處理流程(簡化版)

前言 是看了這個大佬的視頻后想進行一下自己的整理&#xff08;流程只到了扁平化&#xff09;&#xff0c;如果有問題希望各位大佬能夠給予指正。卷積神經網絡&#xff08;CNN&#xff09;到底卷了啥&#xff1f;8分鐘帶你快速了解&#xff01;_嗶哩嗶哩_bilibilihttps://www.…

DBSyncer:開源免費的全能數據同步工具,多數據源無縫支持!

DBSyncer&#xff08;英[dbs??k??]&#xff0c;美[dbs??k?? 簡稱dbs&#xff09;是一款開源的數據同步中間件&#xff0c;提供MySQL、Oracle、SqlServer、PostgreSQL、Elasticsearch(ES)、Kafka、File、SQL等同步場景。支持上傳插件自定義同步轉換業務&#xff0c;提供…

kafka開啟Kerberos使用方式

kafka SASL_PLAINTEXT serviceName 配置&#xff1a; /etc/security/keytabs/kafka.service.keytab 對應的用戶名 $ cat /home/sunxy/kafka/jaas25.conf KafkaClient { com.sun.security.auth.module.Krb5LoginModule required useKeyTabtrue renewTickettrue serviceName“ocd…

Unity教程(二十四)技能系統 投劍技能(中)技能變種實現

Unity開發2D類銀河惡魔城游戲學習筆記 Unity開發2D類銀河惡魔城游戲學習筆記目錄 技能系統 Unity教程&#xff08;二十一&#xff09;技能系統 基礎部分 Unity教程&#xff08;二十二&#xff09;技能系統 分身技能 Unity教程&#xff08;二十三&#xff09;技能系統 擲劍技能…

局域網TCP通過組播放地址rtp推流和拉流實現實時喊話

應用場景&#xff0c;安卓端局域網不用ip通過組播放地址實現實時對講功能發送端: ffmpeg -f alsa -i hw:1 -acodec aac -ab 64k -ac 2 -ar 16000 -frtp -sdp file stream.sdp rtp://224.0.0.1:14556接收端: ffmpeg -protocol whitelist file,udp,rtp -i stream.sdp -acodec pcm…

基于深度學習的醫學圖像分析:使用YOLOv5實現細胞檢測

最近研學過程中發現了一個巨牛的人工智能學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家。點擊鏈接跳轉到網站人工智能及編程語言學習教程。讀者們可以通過里面的文章詳細了解一下人工智能及其編程等教程和學習方法。下面開始對正文內容的…

32.768KHZ 3215晶振CM315D與NX3215SA應用全場景

在現代電子設備中&#xff0c;一粒米大小的晶振&#xff0c;卻是掌控時間精度的“心臟”。CITIZEN的CM315D系列與NDK的NX3215SA系列晶振便是其中的佼佼者&#xff0c;它們以 3.2 1.5 mm 的小尺寸”(厚度不足1mm)&#xff0c;成為智能設備中隱形的節奏大師。精準計時的奧秘這兩…

嵌軟面試——ARM Cortex-M寄存器組

Cortex-M內存架構包含16個通用寄存器&#xff0c;其中R0-R12是13個32位的通用寄存器&#xff0c;另外三個寄存器是特殊用途&#xff0c;分別是R13&#xff08;棧指針&#xff09;,R14&#xff08;鏈接寄存器&#xff09;,R15&#xff08;程序計數器&#xff09;。對于處理器來說…

7.DRF 過濾、排序、分頁

過濾Filtering 對于列表數據可能需要根據字段進行過濾&#xff0c;我們可以通過添加django-fitlter擴展來增強支持。 pip install django-filter在配置文件中增加過濾器類的全局設置&#xff1a; """drf配置信息必須全部寫在REST_FRAMEWORK配置項中""…

二、CUDA、Pytorch與依賴的工具包

CUDA Compute Unified Device Architecture&#xff08;統一計算架構&#xff09;。專門用于 GPU 通用計算 的平臺 編程接口。CUDA可以使你的程序&#xff08;比如矩陣、神經網絡&#xff09;由 GPU 執行&#xff0c;這比CPU能快幾十甚至上百倍。 PyTorch 是一個深度學習框架…

SpringCloude快速入門

近期簡單了解一下SpringCloude微服務,本文主要就是我學習中所記錄的筆記,然后具體原理可能等以后再來深究,本文可能有些地方用詞不專業還望包容一下,感興趣可以參考官方文檔來深入學習一下微服務,然后我的下一步學習就是docker和linux了。 nacos: Nacos 快速開始 | Nacos 官網…

GPT Agent與Comet AI Aent瀏覽器對比橫評

1. 架構設計差異GPT Agent的雙瀏覽器架構&#xff1a;文本瀏覽器&#xff1a;專門用于高效處理大量文本內容&#xff0c;適合深度信息檢索和文獻追蹤&#xff0c;相當于Deep Research的延續可視化瀏覽器&#xff1a;具備界面識別與交互能力&#xff0c;可以點擊網頁按鈕、識別圖…