梯度檢查點(Gradient Checkpointing)是一種在深度學習訓練中優化顯存使用的技術,尤其適用于處理大型模型(如Transformer架構)時顯存不足的情況。下面用簡單的例子解釋其工作原理和優缺點:
核心原理
深度學習訓練中的顯存占用主要來自三個方面:
- 模型參數(如權重、偏置)
- 優化器狀態(如Adam的動量項)
- 中間激活值(forward過程中產生的張量,如注意力圖、隱藏層輸出等)
其中,中間激活值通常占用最大的顯存空間,尤其是在深層網絡中。梯度檢查點的核心思想是:
- 在正向傳播時:只保存少量關鍵的中間結果(稱為“檢查點”),其余中間值在計算后立即丟棄。
- 在反向傳播時:利用保存的檢查點重新計算被丟棄的中間值,從而獲得計算梯度所需的全部信息。
如下圖所示:
這種方法通過犧牲計算時間(重新計算)來節省顯存空間(無需保存所有中間值)。
為什么需要梯度檢查點?
假設你有一個包含100層的Transformer模型,每層在forward過程中產生1GB的中間激活值:
- 傳統訓練:需要保存所有100層的中間值,總顯存需求為100GB。
- 梯度檢查點:只保存10個檢查點(每層1GB),反向傳播時通過檢查點重新計算其余90層,總顯存需求降至10GB。
代碼中的應用
在你的代碼中,gradient_checkpointing=True
的配置會使模型在訓練時啟用梯度檢查點:
trainable_model = Model(# ...其他參數gradient_checkpointing=training_config.get("gradient_checkpointing", False),# ...
)
這意味著:
- 正向傳播時,模型不會保存所有注意力圖和隱藏層輸出。
- 反向傳播時,PyTorch會利用檢查點重新計算這些值,從而減少顯存占用。
優缺點
- 優點:顯著減少顯存使用(通常能節省30%-50%的顯存),允許訓練更大的模型或使用更大的批次大小。
- 缺點:增加訓練時間(通常慢20%-30%),因為需要重新計算中間值。
何時使用?
- 顯存不足:當模型因顯存限制無法訓練時,梯度檢查點是一種有效的解決方案。
- 計算資源充足:如果你的GPU算力充足但顯存有限,可以通過延長訓練時間換取更小的顯存占用。
技術細節
在PyTorch中,梯度檢查點通過torch.utils.checkpoint
模塊實現。例如:
from torch.utils.checkpoint import checkpointdef forward(self, x):# 普通forward:保存所有中間值x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return x# 使用梯度檢查點:只保存關鍵檢查點
def forward(self, x):x = checkpoint(self.layer1, x) # 只保存layer1的輸出x = checkpoint(self.layer2, x) # 只保存layer2的輸出x = self.layer3(x)return x
PyTorch Lightning的gradient_checkpointing
參數會自動為模型的所有層應用這種優化。