文章目錄
- 什么是gradient checkpoint
- 原理
- 使用場景
- 注意事項
什么是gradient checkpoint
- gradient checkpoint是一種優化深度學習模型內存使用的技術,尤其在訓練大型模型時非常有用。它通過犧牲計算時間為代價來減少顯存占用。
- 大多數情況下,transformers庫中的gradient checkpoint粒度是“一個Transformer Block(也叫layer)為單位。
原理
-在標準的反向傳播中,為了計算梯度,需要保存所有中間激活值(activations),這會占用大量顯存。
- Gradient Checkpointing 的核心思想是只保留部分層的激活值,其余層在反向傳播時重新計算,從而節省顯存。【一般只保存transformer block的輸入輸出,這樣節省了大量的存儲】
使用場景
- 顯存受限時(如訓練大模型)
- batch size 需要增大但受顯存限制
- 模型層數較多(如Transformer)
注意事項
- 會增加訓練時間(因為需要重復計算激活值)【如果計算是瓶頸,那么這個方法會增加訓練時長。】
- 不適用于所有模型結構,建議先測試是否有效
- 可能與某些優化器或混合精度訓練有兼容性問題