模型訓練中梯度累積步數(gradient_accumulation_steps)的作用
flyfish
在使用訓練大模型時,TrainingArguments有一個參數梯度累積步數(gradient_accumulation_steps)
from transformers import TrainingArguments
梯度累積是一種在訓練深度學習模型時用于處理內存限制問題的技術。在每次迭代中,模型的梯度是通過反向傳播計算得到的,而梯度累積步數(gradient_accumulation_steps)指定了在執行實際的參數更新之前,要累積多少個小批次(mini - batch)的梯度。
以代碼來說gradient_accumulation_steps的作用
import torch
from torch import nn, optim# 生成更合理的數據集,假設目標關系是y = 3 * x + 2 加上一些噪聲
def generate_dataset(num_samples):inputs = torch.randn(num_samples, 10)# 根據線性關系生成標簽,添加一些隨機噪聲模擬真實情況labels = 3 * inputs.sum(dim=1, keepdim=True) + 2 + torch.randn(num_samples, 1) * 0.5return list(zip(inputs, labels))# 生成數據集,這里生成2000個樣本(可根據實際情況調整數據量)
your_dataset = generate_dataset(2000)# 模型、損失和優化器
model = nn.Linear(10, 1)
# 使用Xavier初始化方法來初始化模型參數,有助于緩解梯度消失和爆炸問題,提升訓練效果
nn.init.xavier_uniform_(model.weight)
nn.init.zeros_(model.bias)
criterion = nn.MSELoss()
# 適當調整學習率,這里改為0.1,可根據實際情況進一步微調
optimizer = optim.Adam(model.parameters(), lr=0.1)# 配置梯度累積步數
gradient_accumulation_steps = 4
global_step = 0# 模擬訓練循環
for epoch in range(20): # 訓練20個周期for step, (inputs, labels) in enumerate(torch.utils.data.DataLoader(your_dataset, batch_size=8)):# 前向傳播outputs = model(inputs)loss = criterion(outputs, labels)# 反向傳播(累積梯度)loss.backward()# 執行梯度更新if (step + 1) % gradient_accumulation_steps == 0:optimizer.step()optimizer.zero_grad()global_step += 1print(f"更新了模型參數,當前全局步數: {global_step}, 當前損失: {loss.item()}")
解釋:
batch_size=8
:每個梯度計算時,模型會處理 8 張圖像。gradient_accumulation_steps=4
:表示每次參數更新前累積 4 次梯度。
因此:
- 每個 step: 處理 8 張圖像。
- 累積 4 個 step: 共處理 8 × 4 = 32 8 \times 4 = 32 8×4=32 張圖像。
關鍵點:
- 一個 step: 是指一次前向和后向傳播(不包含參數更新)。
- 一次參數更新: 在累積 4 個 step 后,進行一次模型參數更新。
等效有效批次:
有效批次大小 = batch_size
× gradient_accumulation_steps
即: 8 × 4 = 32 8 \times 4 = 32 8×4=32。
這意味著,即使顯存有限,模型仍然能以有效批次大小 32 的方式進行訓練