方法區別
在 PyTorch 中,disable_torch_init
和 torch.no_grad()
是兩種完全不同的機制,它們的作用和目的不同,以下是它們的區別:
1. disable_torch_init
- 作用:
disable_torch_init
通常用于某些特定的框架或庫中,目的是禁用 PyTorch 的默認初始化邏輯。例如,在某些情況下,框架可能希望自定義模型參數的初始化方式,而不是使用 PyTorch 默認的初始化方法。 - 顯存優化原理:禁用默認初始化可以減少初始化過程中不必要的顯存分配。例如,某些框架可能會在初始化時創建額外的臨時張量或執行復雜的初始化邏輯,這些操作可能會占用顯存。通過禁用這些默認初始化,可以節省這部分顯存。
- 使用場景:通常用于框架內部的優化,或者在某些特定的模型加載或訓練準備階段。
2. torch.no_grad()
- 作用:
torch.no_grad()
上下文管理器或裝飾器,用于禁用梯度計算。在torch.no_grad()
的上下文內,所有張量操作都不會記錄梯度信息,也不會構建計算圖。 - 顯存優化原理:在默認情況下,PyTorch 會為每個需要梯度的張量(
requires_grad=True
)保存中間結果,以便在反向傳播時計算梯度。這些中間結果會占用顯存。通過禁用梯度計算,torch.no_grad()
可以避免這些中間結果的存儲,從而顯著減少顯存占用。 - 使用場景:主要用于模型的推理(inference)階段,或者在不需要計算梯度的場景中。例如,在模型評估、數據預處理、特征提取等場景中,
torch.no_grad()
是常用的優化手段。
3. 具體區別
特性 | disable_torch_init | torch.no_grad() |
---|---|---|
作用范圍 | 禁用模型參數的初始化邏輯 | 禁用梯度計算和計算圖構建 |
顯存優化原理 | 減少初始化過程中不必要的顯存分配 | 避免存儲中間梯度和計算圖,減少顯存占用 |
使用場景 | 模型加載或訓練準備階段 | 模型推理、評估、數據預處理等 |
是否影響模型結構 | 可能影響模型參數的初始化方式 | 不影響模型結構,僅影響梯度計算 |
是否需要手動啟用 | 需要框架或用戶顯式調用 | 可通過上下文管理器或裝飾器顯式啟用 |
4. 總結
disable_torch_init
是一種針對模型初始化過程的優化機制,主要用于減少初始化階段的顯存占用。torch.no_grad()
是一種禁用梯度計算的工具,主要用于推理階段,通過避免計算圖的構建和梯度存儲來減少顯存占用。
兩者雖然都可以減少顯存占用,但作用機制和使用場景完全不同。在實際應用中,torch.no_grad()
是更常用且更通用的顯存優化手段,而 disable_torch_init
更多是框架內部的優化策略。
(常見)在評估前@torch.no_grad()
源代碼:
class no_grad(_DecoratorContextManager):r"""Context-manager that disabled gradient calculation.Disabling gradient calculation is useful for inference, when you are surethat you will not call :meth:`Tensor.backward()`. It will reduce memoryconsumption for computations that would otherwise have `requires_grad=True`.In this mode, the result of every computation will have`requires_grad=False`, even when the inputs have `requires_grad=True`.This context manager is thread local; it will not affect computationin other threads.Also functions as a decorator. (Make sure to instantiate with parenthesis.).. note::No-grad is one of several mechanisms that can enable ordisable gradients locally see :ref:`locally-disable-grad-doc` formore information on how they compare... note::This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.If you want to disable forward AD for a computation, you can unpackyour dual tensors.Example::>>> # xdoctest: +SKIP>>> x = torch.tensor([1.], requires_grad=True)>>> with torch.no_grad():... y = x * 2>>> y.requires_gradFalse>>> @torch.no_grad()... def doubler(x):... return x * 2>>> z = doubler(x)>>> z.requires_gradFalse"""def __init__(self) -> None:if not torch._jit_internal.is_scripting():super().__init__()self.prev = Falsedef __enter__(self) -> None:self.prev = torch.is_grad_enabled()torch.set_grad_enabled(False)def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:torch.set_grad_enabled(self.prev)
(放在評估函數里面)disable_torch_init()
源代碼:
def disable_torch_init():"""Disable the redundant torch default initialization to accelerate model creation."""import torchsetattr(torch.nn.Linear, "reset_parameters", lambda self: None)setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)