用一個實際例子和簡單代碼來清晰解釋 batch、epoch 和 iteration 的關系:
------------------------------------------------------------------------------------
假設場景
-
你有一個數據集:1000 張貓狗圖片
-
你設置 batch_size = 100(每次處理 100 張圖片)
-
你計劃訓練 5 個 epoch
概念關系圖解
1個Epoch = 完整遍歷整個數據集1次│├── Iteration 1: 處理第1-100張圖片 (batch 1)├── Iteration 2: 處理第101-200張圖片 (batch 2)├── ...└── Iteration 10: 處理第901-1000張圖片 (batch 10)
具體計算
-
總樣本數:1000 張圖片
-
Batch size:100(每次處理的圖片數)
-
Iterations per epoch = 總樣本數 / batch_size = 1000 / 100 = 10 次
-
Total iterations = Iterations per epoch × Epochs = 10 × 5 = 50 次
代碼示例說明
import torch
from torch.utils.data import DataLoader, TensorDataset# 創建模擬數據集:1000張圖片(用1000個數字代替)
data = torch.arange(0, 1000)? # [0, 1, 2, ..., 999]
dataset = TensorDataset(data)# 創建數據加載器:設置batch_size=100
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)# 訓練5個epoch
for epoch in range(5):
??? print(f"\n=== 開始第 {epoch+1} 個epoch ===")
?? ?
??? # 每個epoch內遍歷所有batch
??? for batch_idx, batch_data in enumerate(dataloader):
??????? # 獲取當前batch的數據
??????? images = batch_data[0]? # 當前batch的100張"圖片"
?????? ?
??????? # 這里應該是你的訓練代碼:
??????? # 1. 正向傳播
??????? # 2. 計算損失
??????? # 3. 反向傳播
??????? # 4. 更新權重
?????? ?
??????? # 打印當前iteration信息
??????? print(f"Epoch {epoch+1} | Iteration {batch_idx+1} | 處理圖片: {images[0].item()}到{images[-1].item()}")print("\n訓練完成!")
print(f"總Iteration次數: {5 * len(dataloader)} 次")
import torch
from torch.utils.data import DataLoader, TensorDataset# 創建模擬數據集:1000張圖片(用1000個數字代替)
data = torch.arange(0, 1000) # [0, 1, 2, ..., 999]
dataset = TensorDataset(data)# 創建數據加載器:設置batch_size=100
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)# 訓練5個epoch
for epoch in range(5):print(f"\n=== 開始第 {epoch+1} 個epoch ===")# 每個epoch內遍歷所有batchfor batch_idx, batch_data in enumerate(dataloader):# 獲取當前batch的數據images = batch_data[0] # 當前batch的100張"圖片"# 這里應該是你的訓練代碼:# 1. 正向傳播# 2. 計算損失# 3. 反向傳播# 4. 更新權重# 打印當前iteration信息print(f"Epoch {epoch+1} | Iteration {batch_idx+1} | 處理圖片: {images[0].item()}到{images[-1].item()}")print("\n訓練完成!")
print(f"總Iteration次數: {5 * len(dataloader)} 次")
輸出示例
=== 開始第 1 個epoch ===
Epoch 1 | Iteration 1 | 處理圖片: 12到980
Epoch 1 | Iteration 2 | 處理圖片: 88到799
...
Epoch 1 | Iteration 10 | 處理圖片: 36到995=== 開始第 2 個epoch ===
Epoch 2 | Iteration 1 | 處理圖片: 44到932
...
(共5個epoch,每個epoch 10個iteration)訓練完成!
總Iteration次數: 50 次
關鍵概念解析
-
Batch(批):
-
每次實際輸入模型的數據子集
-
代碼中的
images = batch_data[0]
獲取的就是一個batch -
大小由
batch_size=100
決定
-
-
Iteration(迭代):
-
完成一個batch訓練所需的步驟
-
每個iteration包含:
-
從dataloader取出一個batch數據
-
執行正向傳播 → 計算損失 → 反向傳播 → 更新權重
-
-
代碼中
for batch_idx, ...
循環的每次執行就是一個iteration
-
-
Epoch(輪次):
-
完整遍歷整個數據集一次
-
每個epoch包含多個iteration:
-
1000樣本 ÷ 100 batch_size = 10 iterations/epoch
-
-
外層循環
for epoch in range(5)
控制epoch數量
-
為什么需要batch?
-
內存限制:無法一次性加載所有數據(如100萬張圖片)
-
訓練效率:小批量數據更適合GPU并行計算
-
梯度穩定性:批量梯度下降比單個樣本更穩定
-
正則化效果:小批量帶來輕微噪聲,有助于防止過擬合
實際訓練中的選擇
Batch Size | 迭代次數 | 內存占用 | 訓練穩定性 |
---|---|---|---|
小 (8-32) | 多 | 低 | 較低 |
中 (64-256) | 中等 | 中等 | 較好 |
大 (512+) | 少 | 高 | 高 |
初學者建議從 batch_size=64 開始嘗試,這是常用基準值。你之前提到的8也是合理的,尤其當使用大模型或顯存有限時。