模型訓練中梯度累積步數(gradient_accumulation_steps)的作用

模型訓練中梯度累積步數(gradient_accumulation_steps)的作用

flyfish

在使用訓練大模型時,TrainingArguments有一個參數梯度累積步數(gradient_accumulation_steps)

from transformers import TrainingArguments

梯度累積是一種在訓練深度學習模型時用于處理內存限制問題的技術。在每次迭代中,模型的梯度是通過反向傳播計算得到的,而梯度累積步數(gradient_accumulation_steps)指定了在執行實際的參數更新之前,要累積多少個小批次(mini - batch)的梯度。

以代碼來說gradient_accumulation_steps的作用

import torch
from torch import nn, optim# 生成更合理的數據集,假設目標關系是y = 3 * x + 2 加上一些噪聲
def generate_dataset(num_samples):inputs = torch.randn(num_samples, 10)# 根據線性關系生成標簽,添加一些隨機噪聲模擬真實情況labels = 3 * inputs.sum(dim=1, keepdim=True) + 2 + torch.randn(num_samples, 1) * 0.5return list(zip(inputs, labels))# 生成數據集,這里生成2000個樣本(可根據實際情況調整數據量)
your_dataset = generate_dataset(2000)# 模型、損失和優化器
model = nn.Linear(10, 1)
# 使用Xavier初始化方法來初始化模型參數,有助于緩解梯度消失和爆炸問題,提升訓練效果
nn.init.xavier_uniform_(model.weight)
nn.init.zeros_(model.bias)
criterion = nn.MSELoss()
# 適當調整學習率,這里改為0.1,可根據實際情況進一步微調
optimizer = optim.Adam(model.parameters(), lr=0.1)# 配置梯度累積步數
gradient_accumulation_steps = 4
global_step = 0# 模擬訓練循環
for epoch in range(20):  # 訓練20個周期for step, (inputs, labels) in enumerate(torch.utils.data.DataLoader(your_dataset, batch_size=8)):# 前向傳播outputs = model(inputs)loss = criterion(outputs, labels)# 反向傳播(累積梯度)loss.backward()# 執行梯度更新if (step + 1) % gradient_accumulation_steps == 0:optimizer.step()optimizer.zero_grad()global_step += 1print(f"更新了模型參數,當前全局步數: {global_step}, 當前損失: {loss.item()}")

解釋:

  • batch_size=8:每個梯度計算時,模型會處理 8 張圖像。
  • gradient_accumulation_steps=4:表示每次參數更新前累積 4 次梯度。

因此:

  • 每個 step: 處理 8 張圖像。
  • 累積 4 個 step: 共處理 8 × 4 = 32 8 \times 4 = 32 8×4=32 張圖像。

關鍵點:

  • 一個 step: 是指一次前向和后向傳播(不包含參數更新)。
  • 一次參數更新: 在累積 4 個 step 后,進行一次模型參數更新。

等效有效批次:

有效批次大小 = batch_size × gradient_accumulation_steps
即: 8 × 4 = 32 8 \times 4 = 32 8×4=32

這意味著,即使顯存有限,模型仍然能以有效批次大小 32 的方式進行訓練

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/63181.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/63181.shtml
英文地址,請注明出處:http://en.pswp.cn/web/63181.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

技術速遞|.NET 9 簡介

作者:.NET 團隊 排版:Alan Wang 今天,我們非常激動地宣布 .NET 9的發布,這是迄今為止最高效、最現代、最安全、最智能、性能最高的 .NET 版本。這是來自世界各地數千名開發人員又一年努力的成果。這個新版本包括數千項性能、安全和…

Vue項目打包部署到服務器

1. Vue項目打包部署到服務器 1.1. 配置 (1)修改package.json文件同級目錄下的vue.config.js文件。 // vue.config.js module.exports {publicPath: ./, }(2)檢查router下的index.js文件下配置的mode模式。 ??檢查如果模式改…

【jpa】springboot使用jpa示例

目錄 1. 請求示例2. pom依賴3. application.yaml4.controller5. service6. repository7. 實體8. 啟動類 1. 請求示例 curl --location --request POST http://127.0.0.1:8080/user \ --header User-Agent: Apifox/1.0.0 (https://apifox.com) \ --header Content-Type: applic…

uniapp 常用的指令語句

uniapp 是一個使用 Vue.js 開發的跨平臺應用框架,因此,它繼承了 Vue.js 的大部分指令。以下是一些在 uniapp 中常用的 Vue 指令語句及其用途: v-if / v-else-if / v-else 條件渲染。v-if 有條件地渲染元素,v-else-if 和 v-else 用…

中企出海-德國會計準則和IFRS間的差異

根據提供的網頁內容,德國的公認會計準則(HGB)與國際會計準則(IFRS)之間的主要差異可以從以下幾個方面進行比較: 財務報告的目的: IFRS:財務報告主要是供投資者做決策使用&#xff0c…

NPU是什么?電腦NPU和CPU、GPU區別介紹

隨著人工智能技術的飛速發展,計算機硬件架構也在不斷演進以適應日益復雜的AI應用場景。其中,NPU(Neural Processing Unit,神經網絡處理器)作為一種專為深度學習和神經網絡運算設計的新型處理器,正逐漸嶄露頭…

使用skywalking,grafana實現從請求跟蹤、 指標收集和日志記錄的完整信息記錄

Skywalking是由國內開源愛好者吳晟開源并提交到Apache孵化器的開源項目, 2017年12月SkyWalking成為Apache國內首個個人孵化項目, 2019年4月17日SkyWalking從Apache基金會的孵化器畢業成為頂級項目, 目前SkyWalking支持Java、 .Net、 Node.js、…

純CSS實現文本或表格特效(連續滾動與首尾相連)

純CSS實現文本連續向左滾動首尾相連 1.效果圖&#xff1a; 2.實現代碼&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, init…

【LeetCode刷題之路】622.設計循環隊列

LeetCode刷題記錄 &#x1f310; 我的博客主頁&#xff1a;iiiiiankor&#x1f3af; 如果你覺得我的內容對你有幫助&#xff0c;不妨點個贊&#x1f44d;、留個評論?&#xff0c;或者收藏?&#xff0c;讓我們一起進步&#xff01;&#x1f4dd; 專欄系列&#xff1a;LeetCode…

Node.js基礎入門

1.Node.js 簡介 Node 是一個讓 JavaScript (獨立)運行在服務端的開發平臺,它讓 JavaScript 成為與PHP、Python、Perl、Ruby 等服務端語言平起平坐的腳本語言。 發布于2009年5月,由Ryan Dahl開發,實質是對Chrome V8引擎進行了封裝。 簡單的說 Node.js 就是運行在服務端的…

#思科模擬器通過服務配置保障無線網絡安全Radius

演示拓撲圖&#xff1a; 搭建拓撲時要注意&#xff1a; 只能連接它的Ethernet接口&#xff0c;不然會不通 MAC地址綁定 要求 &#xff1a;通過配置MAC地址過濾禁止非內部員工連接WiFi 打開無線路由器GUI界面&#xff0c;點開下圖頁面&#xff0c;配置路由器無線網絡MAC地址過…

docker 部署kafka集群

docker run 部署 docker run -d --name zookeeper --restart always -p 2181:2181 wurstmeister/zookeeperdocker run -d --name kafka1 --restart always -p 9094:9092 \-e KAFKA_ADVERTISED_HOST_NAME182.54.14.45 \-e KAFKA_ZOOKEEPER_CONNECT182.54.14.45:2181 \-e KAFKA_…

Qt-chart 畫折線圖(以時間為x軸)

上圖 代碼 #include <iostream> #include <random> #include <qcategoryaxis.h>void MainWindow::testLine() {//1、創建圖表視圖QChartView* view new QChartView(this);//2.創建圖表QChart* chart new QChart();//3.將圖表設置給圖表視圖view->setCh…

C++多線程常用方法

在 C 中&#xff0c;線程相關功能主要通過頭文件提供的類和函數來實現&#xff0c;以下是一些常用的線程接口方法和使用技巧&#xff1a; std::thread類 構造函數&#xff1a; 可以通過傳入可調用對象&#xff08;如函數指針、函數對象、lambda 表達式等&#xff09;來創建一…

up主親測,ToDesk/青椒云/順網云這三款云電腦玩轉AIGC場景

文章目錄 1. 前言2. 云電腦性能分析3. 基礎硬件數據3.1 硬件配置3.2 AI 評測跑分 4. 云電腦 AIGC 上手實測4.1 ToDesk4.1.1 AIGC 技術集成情況4.1.2 界面及功能4.1.3 項目部署4.1.4 黑神話悟空 AI 換臉4.1.6 AIGC 文生圖體驗 4.2 青椒云4.2.1 AIGC 技術集成情況4.2.2 界面及功能…

C++(十八)

前言&#xff1a; 本文依據上一篇&#xff0c;繼續對C中的函數進行學習。 一&#xff0c;內聯函數。 再執行函數代碼時&#xff0c;比不使用函數花費了更多時間&#xff0c;因為總結步驟&#xff0c;傳遞參數和返回值都很花費時間。 因此&#xff0c;在調試小型函數時&…

功能篇:JAVA后端實現跨域配置

在Java后端實現跨域配置&#xff08;CORS&#xff0c;Cross-Origin Resource Sharing&#xff09;有多種方法&#xff0c;具體取決于你使用的框架。如果你使用的是Spring Boot或Spring MVC&#xff0c;可以通過以下幾種方式來配置CORS。 ### 方法一&#xff1a;全局配置 對于所…

數獨游戲app制作拆解(之一)——功能介紹

android studio版本&#xff1a;2023.3.1 例程名稱&#xff1a;shudu666 前陣子作了一個EXCEL版的數獨&#xff0c;再早之前就想作這個數獨app,但一直沒動手&#xff0c;一方面懶&#xff0c;另一方面我把自己繞到坑里了&#xff0c;之前做的是一解數獨的app,那個是有點難&am…

Spring注解篇:@Configuration詳解

前言 在Spring框架中&#xff0c;Configuration注解是實現Java配置的核心。它允許開發者以編程的方式定義Bean的創建過程&#xff0c;而不是使用XML文件。這種基于注解的配置方式&#xff0c;不僅簡化了配置的復雜性&#xff0c;還提高了代碼的可讀性和可維護性。 摘要 本文…

通過一個例子學習回溯算法:從方法論到實際應用

回溯算法&#xff1a;從方法論到實際應用 回溯算法&#xff08;Backtracking&#xff09;是一種通過窮舉法尋找問題所有解的算法&#xff0c;它的核心思想是逐步構建解空間樹&#xff0c;在每個步驟中判斷當前解是否合法。如果不合法&#xff0c;就“回溯”到上一步&#xff0…