其實如果是FP32的訓練,基本的調試方法還是差不多,這里就講一下混合精度訓練過程中的nan。
混合精度訓練使用較低的數值精度(通常是半精度浮點數,例如FP16)來加速模型訓練,但在一些情況下,可能會引發數值不穩定性的問題,導致 NaN 的出現。處理混合精度訓練中的 NaN 問題時,可以考慮以下步驟:
數值檢查: 在訓練過程中,定期檢查模型參數、梯度等是否包含 NaN 或 Inf(無窮大)值。你可以在訓練循環中添加斷言語句,及時發現異常值
assert not torch.isnan(model.parameters()).any(), "Model parameters contain NaN!"
梯度縮放(Gradient Scaling): 在混合精度訓練中,通常會使用梯度縮放來抵消使用較低精度帶來的梯度范圍減小的問題。你可以嘗試調整梯度縮放的比例。
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
注意,相比與前向出nan,混合精度訓練會多一個梯度縮放的過程,這個是前向沒有出nan的前提下實現的,影響的梯度更新:
前向計算過程中沒有nan,loss算完后,乘以scale后導致inf,這時候再往后反向傳播出nan了,那在梯度更新的時候就會在梯度更新前進行數值檢查,check finite and unscale過程會去檢查權重的梯度發現有nan或者inf就會跳過更新,此時就可以調整scale的值,把scale降低,然后跑下一個step的前向。如果scale調整后,乘以loss,沒有inf,就調成功了,繼續正常更新參數,如果還是inf就得繼續調小scale