文章目錄
- 前向鉤子
- 反向鉤子的輸入
- 反向鉤子的輸出
前向鉤子
下面是一個測試用的計算圖的網絡,這里因為模型是自定義的緣故,可以直接把前向鉤子注冊在模型類里面,這樣會更加方便一些。其實像以前BERT之類的last_hidden_state
以及pool_output
之類的輸出應該也是用鉤子鉤出來的。
import torch
from torch import nn
from torch.nn import functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear_1 = nn.Linear(4, 3)self.linear_2 = nn.Linear(3, 3)self.linear_3 = nn.Linear(3, 3)self.linear_4 = nn.Linear(3, 1)self._register_hooks(["linear_3"])def _register_hooks(self, module_names):self.hook_outputs = {}def make_hook(name):def hook(module, input_, output):self.hook_outputs[name]["input"].append(input_)self.hook_outputs[name]["output"].append(output)return hookfor module_name in module_names:self.hook_outputs[module_name] = {"input": [], "output": []}eval(f"self.{module_name}").register_forward_hook(make_hook(module_name))def forward(self, x):y_1 = self.linear_1(x)y_1_a = F.sigmoid(y_1)y_2 = self.linear_2(y_1_a)y_2_a = F.sigmoid(y_2)print(y_1_a)print(y_2_a)y_3_1 = self.linear_3(y_1_a)print(y_3_1)y_3_2 = self.linear_3(y_2_a)print(y_3_2)x_4 = F.sigmoid(y_3_1) + F.sigmoid(y_3_2)y_4 = self.linear_4(x_4)y_4_a = F.sigmoid(y_4)return y_4_a
x = torch.FloatTensor([[1,2,3,4]])
net = Net()
y = net(x)
可視化是:
輸出結果:
y_1_a: tensor([[0.2428, 0.5258, 0.2866]], grad_fn=<SigmoidBackward0>)
y_2_a: tensor([[0.4860, 0.4801, 0.6515]], grad_fn=<SigmoidBackward0>)
y_3_1: tensor([[ 0.3423, 0.2477, -0.7132]], grad_fn=<AddmmBackward0>)
y_3_2: tensor([[ 0.4148, 0.2024, -0.9481]], grad_fn=<AddmmBackward0>)
而鉤子抓到的結果net.hook_outputs
中的內容形如:
{'linear_3': {'input': [[tensor([[0.2428, 0.5258, 0.2866]])],[tensor([[0.4860, 0.4801, 0.6515]])]],'output': [tensor([[ 0.3423, 0.2477, -0.7132]]),tensor([[ 0.4148, 0.2024, -0.9481]])]}}
tensor([[0.5139, 0.5634, 0.6205]], grad_fn=<SigmoidBackward0>)
tensor([[0.3508, 0.2681, 0.4771]], grad_fn=<SigmoidBackward0>)
tensor([[ 0.1624, -0.3406, 0.4669]], grad_fn=<AddmmBackward0>)
tensor([[ 0.2090, -0.4021, 0.3506]], grad_fn=<AddmmBackward0>)
{'linear_3': {'input': [(tensor([[0.3508, 0.2681, 0.4771]], grad_fn=<SigmoidBackward0>),),(tensor([[0.5139, 0.5634, 0.6205]], grad_fn=<SigmoidBackward0>),)],'output': [tensor([[ 0.1624, -0.3406, 0.4669]], grad_fn=<AddmmBackward0>),tensor([[ 0.2090, -0.4021, 0.3506]], grad_fn=<AddmmBackward0>)]}}
是完全對的上的,盡管L3被多次調用,但實際上每次調用都是1個輸入1個輸出,但是input
鉤到的是tuple
,但output
鉤到的卻是tensor
但是假如我稍作修改,比如把linear_3
的單獨化成一個模塊self.m = M1()
,它有兩個輸入,也有兩個輸出:
class M1(nn.Module):def __init__(self):super(M1, self).__init__()self.linear_3 = nn.Linear(3, 3)def forward(self, y_1_a, y_2_a):y_3_1 = self.linear_3(y_1_a)y_3_2 = self.linear_3(y_2_a)return y_3_1, y_3_2class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear_1 = nn.Linear(4, 3)self.linear_2 = nn.Linear(3, 3)self.m = M1()self.linear_4 = nn.Linear(3, 1)self._register_hooks(["m"])def _register_hooks(self, module_names):self.hook_outputs = {}def make_hook(name):def hook(module, input_, output):self.hook_outputs[name]["input"].append(input_)self.hook_outputs[name]["output"].append(output)return hookfor module_name in module_names:self.hook_outputs[module_name] = {"input": [], "output": []}eval(f"self.{module_name}").register_forward_hook(make_hook(module_name))def forward(self, x):y_1 = self.linear_1(x)y_1_a = F.sigmoid(y_1)y_2 = self.linear_2(y_1_a)y_2_a = F.sigmoid(y_2)print(y_1_a)print(y_2_a)y_3_1, y_3_2 = self.m(y_1_a, y_2_a)x_4 = F.sigmoid(y_3_1) + F.sigmoid(y_3_2)y_4 = self.linear_4(x_4)y_4_a = F.sigmoid(y_4)return y_4_ax = torch.FloatTensor([[1,2,3,4]])
net = Net()
y = net(x)
from pprint import pprint
pprint(net.hook_outputs)
此時輸出結果就是:
tensor([[0.6084, 0.6544, 0.6909]], grad_fn=<SigmoidBackward0>)
tensor([[0.2917, 0.4068, 0.2910]], grad_fn=<SigmoidBackward0>)
tensor([[-0.0419, 0.2307, -0.3825]], grad_fn=<AddmmBackward0>)
tensor([[-0.0515, 0.4510, 0.0154]], grad_fn=<AddmmBackward0>)
{'m': {'input': [(tensor([[0.6084, 0.6544, 0.6909]], grad_fn=<SigmoidBackward0>),tensor([[0.2917, 0.4068, 0.2910]], grad_fn=<SigmoidBackward0>))],'output': [(tensor([[-0.0419, 0.2307, -0.3825]], grad_fn=<AddmmBackward0>),tensor([[-0.0515, 0.4510, 0.0154]], grad_fn=<AddmmBackward0>))]}}
發現兩次結果的區別了沒有?觀察兩次鉤子的輸出:
第一次:
{'linear_3': {'input': [(tensor([[0.3508, 0.2681, 0.4771]], grad_fn=<SigmoidBackward0>),),(tensor([[0.5139, 0.5634, 0.6205]], grad_fn=<SigmoidBackward0>),)],'output': [tensor([[ 0.1624, -0.3406, 0.4669]], grad_fn=<AddmmBackward0>),tensor([[ 0.2090, -0.4021, 0.3506]], grad_fn=<AddmmBackward0>)]}}
第二次:
{'m': {'input': [(tensor([[0.6084, 0.6544, 0.6909]], grad_fn=<SigmoidBackward0>),tensor([[0.2917, 0.4068, 0.2910]], grad_fn=<SigmoidBackward0>))],'output': [(tensor([[-0.0419, 0.2307, -0.3825]], grad_fn=<AddmmBackward0>),tensor([[-0.0515, 0.4510, 0.0154]], grad_fn=<AddmmBackward0>))]}}
是的,input
不管怎么樣,總是默認是一個tuple
,哪怕里面只有一個輸入張量,但是輸出output
第一次其實就是tensor
,第二次則變成了tuple
也就說,如果一個module
有多個輸出的時候,依然是會變成tuple
的。
反向鉤子的輸入
把上面的代碼稍作修改,我們添加反向鉤子,然后隨便寫一個損失輸出出來進行反向傳播,看看情況:
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear_1 = nn.Linear(4, 3)self.linear_2 = nn.Linear(3, 3)self.linear_3 = nn.Linear(3, 3)self.linear_4 = nn.Linear(3, 1)self._register_hooks(["linear_3"])def _register_hooks(self, module_names):self.forward_hook_outputs = {}self.backward_hook_outputs = {}def make_forward_hook(name):def hook(module, input_, output):self.forward_hook_outputs[name]["input"].append(input_)self.forward_hook_outputs[name]["output"].append(output)return hookdef make_backward_hook(name):def hook(module, input_, output):self.backward_hook_outputs[name]["input"].append(input_)self.backward_hook_outputs[name]["output"].append(output)return hookfor module_name in module_names:self.forward_hook_outputs[module_name] = {"input": [], "output": []}self.backward_hook_outputs[module_name] = {"input": [], "output": []}eval(f"self.{module_name}").register_forward_hook(make_forward_hook(module_name))eval(f"self.{module_name}").register_backward_hook(make_backward_hook(module_name))def forward(self, x):y_1 = self.linear_1(x)y_1_a = F.sigmoid(y_1)y_2 = self.linear_2(y_1_a)y_2_a = F.sigmoid(y_2)print(y_1_a)print(y_2_a)y_3_1 = self.linear_3(y_1_a)y_3_2 = self.linear_3(y_2_a)print(y_3_1)print(y_3_2)x_4 = F.sigmoid(y_3_1) + F.sigmoid(y_3_2)y_4 = self.linear_4(x_4)y_4_a = F.sigmoid(y_4)return y_4_a, (y_4_a - torch.FloatTensor([[1]])) ** 2x = torch.FloatTensor([[1,2,3,4]])
net = Net()
y, loss = net(x)
loss.backward()from pprint import pprintpprint(net.forward_hook_outputs)
pprint(net.backward_hook_outputs)
輸出結果:
tensor([[0.7906, 0.2277, 0.3887]], grad_fn=<SigmoidBackward0>)
tensor([[0.5084, 0.4351, 0.3494]], grad_fn=<SigmoidBackward0>)
tensor([[-0.0058, -0.2894, -0.4183]], grad_fn=<AddmmBackward0>)
tensor([[-0.2083, -0.2993, -0.3650]], grad_fn=<AddmmBackward0>)
{'linear_3': {'input': [(tensor([[0.7906, 0.2277, 0.3887]], grad_fn=<SigmoidBackward0>),),(tensor([[0.5084, 0.4351, 0.3494]], grad_fn=<SigmoidBackward0>),)],'output': [tensor([[-0.0058, -0.2894, -0.4183]], grad_fn=<AddmmBackward0>),tensor([[-0.2083, -0.2993, -0.3650]], grad_fn=<AddmmBackward0>)]}}
{'linear_3': {'input': [(tensor([ 0.0061, -0.0089, -0.0080]),tensor([[ 0.0085, 0.0027, -0.0065]]),tensor([[ 0.0031, -0.0045, -0.0041],[ 0.0026, -0.0039, -0.0035],[ 0.0021, -0.0031, -0.0028]])),(tensor([ 0.0061, -0.0089, -0.0079]),tensor([[ 0.0085, 0.0027, -0.0064]]),tensor([[ 0.0048, -0.0071, -0.0063],[ 0.0014, -0.0020, -0.0018],[ 0.0024, -0.0035, -0.0031]]))],'output': [(tensor([[ 0.0061, -0.0089, -0.0080]]),),(tensor([[ 0.0061, -0.0089, -0.0079]]),)]}}
發現問題了沒有:
反向鉤子,哪怕output
只有一個元素,也是返回的是tuple
,而非tensor
以這個例子里linear_3
為例,它有兩個輸入(y_1_a
和y_2_a
),同時也有兩個輸出(y_3_1
和y_3_2
),所以它在反向傳播的時候,會產生兩次輸入輸出對,體現在上面捕獲的時候net.backward_hook_outputs["linear_3"]["input"]
與net.backward_hook_outputs["linear_3"]["input"]
兩個列表的長度都是2:
{'linear_3': {'input': [(tensor([ 0.0061, -0.0089, -0.0080]),tensor([[ 0.0085, 0.0027, -0.0065]]),tensor([[ 0.0031, -0.0045, -0.0041],[ 0.0026, -0.0039, -0.0035],[ 0.0021, -0.0031, -0.0028]])),(tensor([ 0.0061, -0.0089, -0.0079]),tensor([[ 0.0085, 0.0027, -0.0064]]),tensor([[ 0.0048, -0.0071, -0.0063],[ 0.0014, -0.0020, -0.0018],[ 0.0024, -0.0035, -0.0031]]))],'output': [(tensor([[ 0.0061, -0.0089, -0.0080]]),),(tensor([[ 0.0061, -0.0089, -0.0079]]),)]}}
反向傳播的梯度里,linear_3
這一層的input
每個都有3個張量,形狀分別是(3, ), (1, 3), (3, 3)
這個稍許有點費解,我不太能理解為什么會有3個,但是如果要用的話,因為linear_3
的形狀是(3, 3)
,所以它的W對應的梯度就是(3, 3)
,另外兩個里面還有一個是bias項的梯度(應該就是(3, )
的那個,即第1個),那剩下還有一個會是什么呢?
問了一下deepseek,結論是其實剩下的那個(1, 3)
是輸入 x x x的梯度,這倒是也比較合理
對于 Linear
層的反向鉤子(register_backward_hook
),其 grad_input
參數的組成確實需要仔細分析。下面我們徹底解析這個問題:
1. Linear 層的梯度構成
對于一個 Linear(in_features, out_features)
層,其前向計算為:
y = x W T + b y = xW^T + b y=xWT+b
反向傳播時,PyTorch 會計算以下梯度:
- 輸入梯度( ? L ? x \frac{\partial L}{\partial x} ?x?L?)
- 權重梯度( ? L ? W \frac{\partial L}{\partial W} ?W?L?)
- 偏置梯度( ? L ? b \frac{\partial L}{\partial b} ?b?L?)
在反向鉤子的 grad_input
元組中,三個張量的順序固定為:
- 輸入 x x x 的梯度(形狀與 x x x 相同)
- 權重 W W W 的梯度(形狀與 W W W 相同,即
(out_features, in_features)
) - 偏置 b b b 的梯度(形狀與 b b b 相同,即
(out_features,)
)
2. 你的具體案例解析
你觀察到 Linear(3, 3)
層的 grad_input
包含三個張量:
(3, )
→ 偏置 b b b 的梯度
(因為bias
的形狀是(3,)
)(1, 3)
→ 輸入 x x x 的梯度
(原始輸入形狀為(1, 3)
,所以梯度形狀相同)(3, 3)
→ 權重 W W W 的梯度
(權重矩陣的形狀是(3, 3)
)
為什么輸入梯度是 (1, 3)
?
- 你的前向輸入 x x x 的形狀可能是
(batch_size=1, in_features=3)
,因此梯度形狀與之匹配。
3. 驗證實驗
通過以下代碼可以明確驗證梯度順序:
import torch
import torch.nn as nn# 定義線性層
linear = nn.Linear(3, 3, bias=True)# 注冊反向鉤子
def backward_hook(module, grad_input, grad_output):print("\nGradients for Linear layer:")print(f"1. Input grad: shape={grad_input[0].shape} (x的梯度)")print(f"2. Weight grad: shape={grad_input[1].shape} (W的梯度)")print(f"3. Bias grad: shape={grad_input[2].shape} (b的梯度)")return None # 不修改梯度handle = linear.register_backward_hook(backward_hook)# 模擬前向和反向
x = torch.randn(1, 3, requires_grad=True) # 輸入形狀 (1, 3)
y = linear(x)
loss = y.sum()
loss.backward()handle.remove() # 移除鉤子
輸出結果
# Gradients for Linear layer:
# 1. Input grad: shape=torch.Size([1, 3]) (x的梯度)
# 2. Weight grad: shape=torch.Size([3, 3]) (W的梯度)
# 3. Bias grad: shape=torch.Size([3]) (b的梯度)
4. 特殊情況分析
(1) 當 bias=False
時
如果線性層無偏置(bias=False
),grad_input
將只包含兩個張量:
- 輸入梯度
- 權重梯度
(2) 批量輸入的情況
若輸入 x x x 的形狀為 (batch_size, 3)
,輸入梯度的形狀會變為 (batch_size, 3)
:
x = torch.randn(5, 3, requires_grad=True) # 批量大小=5
y = linear(x)
loss = y.sum()
loss.backward()# 輸出:
# 1. Input grad: shape=torch.Size([5, 3])
# 2. Weight grad: shape=torch.Size([3, 3])
# 3. Bias grad: shape=torch.Size([3])
5. 梯度計算原理
PyTorch 的 Linear
層反向傳播公式:
- 輸入梯度:
? L ? x = ? L ? y ? W \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot W ?x?L?=?y?L??W - 權重梯度:
? L ? W = ? L ? y T ? x \frac{\partial L}{\partial W} = \frac{\partial L}{\partial y}^T \cdot x ?W?L?=?y?L?T?x - 偏置梯度:
? L ? b = ∑ b a t c h ? L ? y \frac{\partial L}{\partial b} = \sum_{batch} \frac{\partial L}{\partial y} ?b?L?=batch∑??y?L?
6. 總結
grad_input 順序 | 張量含義 | 形狀示例(Linear(3,3) ) |
---|---|---|
第 0 項 | 輸入 x x x 的梯度 | (batch_size, 3) |
第 1 項 | 權重 W W W 的梯度 | (3, 3) |
第 2 項 | 偏置 b b b 的梯度 | (3,) |
注:這個存疑,我不知道這個順序是不是按照它這個順序來的,我感覺就觀察結果而言,應該第0項是bias,第1項是x,第2項是W。持保留意見。
關鍵結論
- 你觀察到的
(3,)
是偏置梯度,(1,3)
是輸入梯度,(3,3)
是權重梯度。 - 順序是固定的,與
forward
的輸入/參數順序無關。 - 如果層無偏置,
grad_input
長度會減 1。
反向鉤子的輸出
上面已經發現了,反向鉤子的輸出不管輸出結果是一項還是多項,輸出總是一個tuple(與前向鉤子是不同的),這個輸出本身是loss關于 y y y的梯度
y = x W ? + b y=xW^\top+b y=xW?+b
在反向傳播過程中,反向鉤子(register_backward_hook
)捕獲的 grad_output
本質上是 損失函數對模塊原始輸出的梯度,數學上表示為:
grad_output = ? L ? y \text{grad\_output} = \frac{\partial \mathcal{L}}{\partial y} grad_output=?y?L?
其中:
- L \mathcal{L} L 是損失函數(標量)
- y y y 是模塊的前向輸出(可能是張量或元組)
(1) 單輸出模塊(如 Linear
層)
- 前向計算:
y = x W T + b (假設輸入? x ∈ R B × d in , W ∈ R d out × d in , b ∈ R d out ) y = xW^T + b \quad \text{(假設輸入 } x \in \mathbb{R}^{B \times d_{\text{in}}}, W \in \mathbb{R}^{d_{\text{out}} \times d_{\text{in}}}, b \in \mathbb{R}^{d_{\text{out}}}) y=xWT+b(假設輸入?x∈RB×din?,W∈Rdout?×din?,b∈Rdout?) - 反向梯度:
grad_output = ( ? L ? y ) ∈ R B × d out \text{grad\_output} = \left( \frac{\partial \mathcal{L}}{\partial y} \right) \in \mathbb{R}^{B \times d_{\text{out}}} grad_output=(?y?L?)∈RB×dout?
PyTorch 會將其包裝為單元素元組:(grad_output,)
(2) 多輸出模塊(如 LSTM
)
- 前向輸出:
y = ( h all , ( h n , c n ) ) (輸出序列、最后隱狀態和細胞狀態) y = (h_{\text{all}}, (h_n, c_n)) \quad \text{(輸出序列、最后隱狀態和細胞狀態)} y=(hall?,(hn?,cn?))(輸出序列、最后隱狀態和細胞狀態) - 反向梯度:
grad_output = ( ? L ? h all , ? L ? h n , ? L ? c n ) \text{grad\_output} = \left( \frac{\partial \mathcal{L}}{\partial h_{\text{all}}}, \frac{\partial \mathcal{L}}{\partial h_n}, \frac{\partial \mathcal{L}}{\partial c_n} \right) grad_output=(?hall??L?,?hn??L?,?cn??L?)
每個分量的形狀與前向輸出的對應張量形狀一致。
PyTorch 的反向鉤子(register_backward_hook
)和前向鉤子(register_forward_hook
)在輸出處理上確實存在這種關鍵差異。下面我們徹底解析這種設計差異的原因和具體行為:
1. 反向鉤子的輸出行為
(1) 輸出始終是 tuple
無論模塊的原始輸出是單個張量還是元組,反向鉤子的 grad_output
參數 總是以 tuple
形式傳遞,即使只有一個梯度張量。例如:
def backward_hook(module, grad_input, grad_output):print(type(grad_output)) # 永遠是 <class 'tuple'>return None
(2) 結構對應關系
模塊輸出類型 | 前向鉤子的 output | 反向鉤子的 grad_output |
---|---|---|
單個張量 | Tensor | (Tensor,) (單元素元組) |
元組/多個輸出 | Tuple[Tensor,...] | Tuple[Tensor,...] (同長度) |
2. 設計原因
(1) 一致性處理
PyTorch 選擇統一用 tuple
傳遞反向梯度,是為了:
- 避免條件判斷:無論單輸出還是多輸出,鉤子代碼無需檢查類型。
- 兼容自動微分系統:PyTorch 的 autograd 始終以
tuple
形式傳遞梯度。
(2) 與前向鉤子的對比
- 前向鉤子:保留原始輸出類型(張量或元組),因為用戶可能需要直接使用該值。
- 反向鉤子:梯度計算是系統行為,統一格式更安全。
3. 驗證實驗
(1) 單輸出模塊(如 Linear
)
import torch
import torch.nn as nnlinear = nn.Linear(3, 3)def hook(module, grad_in, grad_out):print(f"Linear層 grad_out類型: {type(grad_out)}, 長度: {len(grad_out)}")return Nonelinear.register_backward_hook(hook)x = torch.randn(2, 3, requires_grad=True)
y = linear(x) # 單輸出
loss = y.sum()
loss.backward()
輸出:
Linear層 grad_out類型: <class 'tuple'>, 長度: 1
(2) 多輸出模塊(如 LSTM
)
lstm = nn.LSTM(3, 3)def hook(module, grad_in, grad_out):print(f"LSTM層 grad_out類型: {type(grad_out)}, 長度: {len(grad_out)}")return Nonelstm.register_backward_hook(hook)x = torch.randn(5, 2, 3) # (seq_len, batch, input_size)
output, (h_n, c_n) = lstm(x) # 多輸出
loss = output.sum() + h_n.sum()
loss.backward()
輸出:
LSTM層 grad_out類型: <class 'tuple'>, 長度: 3
4. 實際應用建議
(1) 安全訪問梯度
無論模塊輸出類型如何,始終按元組處理:
def backward_hook(module, grad_in, grad_out):# 安全獲取第一個梯度(即使單輸出)grad = grad_out[0] if len(grad_out) > 0 else Nonereturn None
(2) 多輸出模塊的梯度順序
對于多輸出模塊(如 LSTM
),grad_out
的順序與前向輸出的順序一致:
# 前向輸出順序: (output, (h_n, c_n))
# 反向梯度順序: (grad_output, grad_h_n, grad_c_n)
5. 深入原理
PyTorch 的 grad_output
設計源于其自動微分系統的實現:
- 計算圖構建:前向傳播時記錄輸出節點。
- 反向傳播:系統統一以
tuple
形式傳遞梯度,即使只有一個節點。 - 鉤子注入:反向鉤子接收到的是系統處理后的梯度結構。
總結
特性 | 前向鉤子 output | 反向鉤子 grad_output |
---|---|---|
類型 | 保持原始類型 | 強制轉為 tuple |
單輸出處理 | 直接返回 Tensor | 返回 (Tensor,) |
多輸出處理 | 返回 Tuple[Tensor,...] | 返回 Tuple[Tensor,...] |
這種設計確保了反向傳播梯度處理的統一性,而前向鉤子則更注重輸出值的原始性。理解這一差異能幫助你更安全地編寫調試工具或自定義梯度邏輯。