知識點回顧
- 圖像數據的格式:灰度和彩色數據
- 模型的定義
- 顯存占用的4種地方
- 模型參數+梯度參數
- 優化器參數
- 數據批量所占顯存
- 神經元輸出中間狀態
- batchisize和訓練的關系
作業:今日代碼較少,理解內容即可
在 PyTorch 中,圖像數據的形狀通常遵循 (通道數, 高度, 寬度) 的格式(即 Channel First 格式),這與常見的 (高度, 寬度, 通道數)(Channel Last,如 NumPy 數組)不同。---注意順序關系,
注意點:
- 如果用matplotlib庫來畫圖,需要轉換下順序 image = np.transpose(image.numpy(), (1, 2, 0)
- 模型輸入通常需要批次維度(Batch Size),形狀變為 (批次大小, 通道數, 高度, 寬度)。例如,批量輸入 10 張 MNIST 圖像時,形狀為 (10, 1, 28, 28)
對于圖像數據集比如MNIST構建神經網絡來訓練的話,比起之前的結構化數據多了一個展平操作:
# 定義兩層MLP神經網絡
class MLP(nn.Module):def __init__(self, input_size=784, hidden_size=128, num_classes=10):super().__init__()self.flatten = nn.Flatten() # 將28x28的圖像展平為784維向量self.layer1 = nn.Linear(input_size, hidden_size) # 第一層:784個輸入,128個神經元self.relu = nn.ReLU() # 激活函數self.layer2 = nn.Linear(hidden_size, num_classes) # 第二層:128個輸入,10個輸出(對應10個數字類別)def forward(self, x):x = self.flatten(x) # 展平圖像x = self.layer1(x) # 第一層線性變換x = self.relu(x) # 應用ReLU激活函數x = self.layer2(x) # 第二層線性變換,輸出logitsreturn x# 初始化模型
model = MLP()
MLP的輸入層要求輸入是一維向量,但 MNIST 圖像是二維結構(28×28 像素),形狀為 [1, 28, 28](通道 × 高 × 寬)。nn.Flatten() 展平操作將二維圖像 “拉成” 一維向量(784=28×28 個元素),使其符合全連接層的輸入格式
在面對數據集過大的情況下,由于無法一次性將數據全部加入到顯存中,所以采取了分批次加載這種方式。所以實際應用中,輸入圖像還存在batch_size這一維度,但在PyTorch中,模型定義和輸入尺寸的指定不依賴于batch_size,無論設置多大的batch_size,模型結構和輸入尺寸的寫法都是不變的,batch_size是在數據加載階段定義的(之前提過這是DataLoader的參數)
那么顯存設置多少合適呢?如果設置的太小,那么每個batch_size的訓練不足以發揮顯卡的能力,浪費計算資源;如果設置的太大,會出現OOM(out of memory)顯存一般被以下內容占用:
- 模型參數與梯度:模型的權重和對應的梯度會占用顯存,尤其是深度神經網絡(如 Transformer、ResNet 等),一個 1 億參數的模型(如 BERT-base),單精度(float32)參數占用約 400MB(1e8×4Byte),加上梯度則翻倍至 800MB(每個權重參數都有其對應的梯度)
- 部分優化器(如 Adam)會為每個參數存儲動量(Momentum)和平方梯度(Square Gradient),進一步增加顯存占用(通常為參數大小的 2-3 倍)
- 其他開銷
- 單張圖像尺寸:1×28×28(通道×高×寬),歸一化轉換為張量后為float32類型,顯存占用:1×28×28×4 Byte = 3,136 Byte ≈ 3 KB
- 批量數據占用:batch_size × 單張圖像占用,例如batch_size=64時,數據占用為64×3 KB ≈ 192 KB
對于batch_size的設置,大規模數據時,通常從16開始測試,然后逐漸增加,確保代碼運行正常且不報錯,直到出現內存不足(OOM)報錯或訓練效果下降,此時選擇略小于該值的 batch_size。訓練時候搭配 nvidia-smi 監控顯存占用,合適的 batch_size = 硬件顯存允許的最大值 × 0.8(預留安全空間),并通過訓練效果驗證調整
@浙大疏錦行