鉤子函數僅在backward()
時才會觸發。其中,鉤子函數接受梯度作為輸入,返回操作后的梯度,操作后的梯度必須要輸入的梯度同類型、同形狀,否則報錯。
主要功能包括:
- 監控當前的梯度(不返回值);
- 對當前的梯度進行操作,返回新的梯度以覆蓋原梯度;
- 在模型中對梯度進行監控或者修改。
案例 1:監控梯度值
import torch# 創建一個張量,并啟用梯度追蹤
x = torch.tensor([1.0], requires_grad=True)
y = x * 2# 定義鉤子函數
def hook_fn(grad):'''作用:打印梯度'''print("Hook triggered, gradient:", grad)# 注冊鉤子:將鉤子函數注冊到x上,反向傳播計算x梯度時自動觸發鉤子函數
x.register_hook(hook_fn)# 觸發反向傳播和鉤子函數
y.backward()
結果:
Hook triggered, gradient: tensor([2.])
案例 2:修改梯度值
import torch# 創建一個張量,并啟用梯度追蹤
x = torch.tensor([1.0], requires_grad=True)
y = x * 2# 定義鉤子函數
def hook_fn(grad):'''作用:修改輸入的梯度'''print('原梯度:',grad)return grad * 3# 注冊鉤子:將鉤子函數注冊到x上,反向傳播計算x梯度時自動觸發鉤子函數
x.register_hook(hook_fn)# 觸發反向傳播和鉤子函數
y.backward() print("修改后的梯度:", x.grad)
結果:
原梯度: tensor([2.])
修改后的梯度: tensor([6.])
案例 3:在模型中使用 register_hook
import torch
import torch.nn as nnmodel = nn.Linear(1, 1)
weight = model.weight # 模型權重# 定義鉤子函數
def hook_fn(grad):'''作用:打印梯度'''print("Gradient of weight:", grad)# 注冊鉤子:將鉤子函數注冊到weight上,反向傳播計算weight梯度時自動觸發鉤子函數
weight.register_hook(hook_fn)# 輸入數據
x = torch.tensor([[1.0]])
target = torch.tensor([[3.0]])# 前向傳播
output = model(x)
print(output)# 損失函數
loss = (output - target).pow(2)# 觸發反向傳播和鉤子函數
loss.backward()
結果:
Gradient of weight: tensor([[-6.1532]])
注意:
在實際使用中,必須使用clone()
來確保梯度操作的安全性和計算圖完整性,例如:
def hook_fn(grad):return grad.clone() * 3
- 通過
grad.clone()
創建梯度副本后進行操作,所有修改僅作用于副本,不會觸碰原始梯度存儲。不采用克隆,直接對原始梯度進行操作,PyTorch 會檢測到對計算圖中張量的潛在原地修改(in-place operation),并拋出異常。 - 不采用克隆,會破壞計算圖路徑,導致梯度回傳中斷或錯誤。