關于PyTorch中的register_forward_hook()函數未能執行其中hook函數的問題
Hook 是 PyTorch 中一個十分有用的特性。利用它,我們可以不必改變網絡輸入輸出的結構,方便地獲取、改變網絡中間層變量的值和梯度。這個功能被廣泛用于可視化神經網絡中間層的 feature、gradient,從而診斷神經網絡中可能出現的問題,分析網絡有效性。
Hook函數機制:不改變主體,實現額外的功能,像一個掛件一樣;
Hook函數本身不是本文介紹的重點,網上介紹的文章頗多,本文主要是記錄一下筆者在使用hook函數時遇到的一些問題及解決過程。
register_forward_hook
首先看一下一個最簡單的使用register_forward_hook的例子:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):out = F.relu(self.conv1(x)) #1 out = F.max_pool2d(out, 2) #2out = F.relu(self.conv2(out)) #3out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return outfeatures = []
def hook(module, input, output): # module: model.conv2 # input :in forward function [#2]# output:is [#3 self.conv2(out)]print('*'*100)features.append(output.clone().detach())# output is saved in a list net = LeNet() ## 模型實例化
x = torch.randn(2, 3, 32, 32) ## input
handle = net.conv2.register_forward_hook(hook) ## 獲取整個Lenet模型 conv2的中間結果
y = net(x) ## 獲取的是 關于 input x 的 conv2 結果 print(features[0].size()) # 即 [#3 self.conv2(out)]
handle.remove() ## hook刪除 ,防止多次保存hook內容占用空間
輸出
****************************************************************************************************
torch.Size([2, 16, 10, 10])
形狀是我們想要的結果,打印一串*是為了直觀地驗證hook函數被調用了。
其中conv2的名稱,我們可以打印模型的state_dict()來查看自己要的是哪個module
for k in model.state_dict():print(k)
輸出:
conv1.weight
conv1.bias
conv2.weight
conv2.bias
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias
我們上面直接拿conv2做例子了。
出現的問題
在實際使用中,我想打印最近的transformer模型alt_gvt_large的位置編碼來看一下,但是遇到了問題。
我查看了一下模型中的module,找到自己想要的
import torch
import timm
import numpy as np
import cv2
import seaborn as sns
import gvt
from PIL import Image
from torchvision import transformsfmap_block = []
def forward_hook(module, data_input, data_output):print('*'*100)fmap_block.append(data_output.clone().detach())model = timm.create_model('alt_gvt_large',pretrained=False,num_classes=1000,drop_rate=0.1,drop_path_rate=0.1,drop_block_rate=None,)
pipeline = transforms.Compose([transforms.RandomCrop(224),transforms.ToTensor(),])for k in model.state_dict():print(k)
輸出:
# ...
patch_embeds.3.norm.weight
patch_embeds.3.norm.bias
norm.weight
norm.bias
head.weight
head.bias
pos_block.0.proj.0.weight
pos_block.0.proj.0.bias
pos_block.1.proj.0.weight
pos_block.1.proj.0.bias
pos_block.2.proj.0.weight
pos_block.2.proj.0.bias
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
blocks.0.0.norm1.weight
blocks.0.0.norm1.bias
# ...
那肯定就是pos_block嘍。
開始hook:
image = Image.open('125.jpg')
image = pipeline(image).unsqueeze(dim=0)handle = model.pos_block.register_forward_hook(forward_hook)pred = model(image)
print(fmap_block[0].shape)
handle.remove()
出大問題,根本沒有輸出,連我們設置來驗證hook函數運行的*也沒有出現,hook函數肯定沒有被執行,這是怎么回事呢?
解決過程
經過仔細比對以上兩次成功和失敗hook經歷:
conv2.bias
conv2.weight
--------
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
簡單分析不難有如此猜測:只有下面直接能點( . )到weight和bias的module才能被直接hook。
但是直接將輸出結果粘貼過去會出現:
handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)
直接報語法錯誤,數字肯定是不能直接點的。
handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)^
SyntaxError: invalid syntax
于是筆者一層一層查看進去:
for k in model.pos_block:print(k)for _k in k.proj.state_dict():print(_k)breakbreak
print(type(model.pos_block))
發現上面出現數字的地方的類型其實是:<class ‘torch.nn.modules.container.ModuleList’>,也就是一個list,那是不是直接可以用[ ]進行索引。
于是我們可以改為:
handle = model.pos_block[3].proj[0].register_forward_hook(forward_hook)
輸出:
****************************************************************************************************
torch.Size([1, 256, 28, 28])
終于成功。
總結
還是對PyTorch中的Model,Module,childeren_module等理解的不到位啊,只會最基本的使用方法,稍微進階一點的操作就會遇到阻力,以后有時間梳理一下。PyTorch是當今公認比較好用的開源框架了,但是想要隨心所欲地實現自己的想法,還是需要花點時間把其中的各個組件及相互之間的關系都理解到位。