關于PyTorch中的register_forward_hook()函數未能執行其中hook函數的問題

關于PyTorch中的register_forward_hook()函數未能執行其中hook函數的問題

Hook 是 PyTorch 中一個十分有用的特性。利用它,我們可以不必改變網絡輸入輸出的結構,方便地獲取、改變網絡中間層變量的值和梯度。這個功能被廣泛用于可視化神經網絡中間層的 feature、gradient,從而診斷神經網絡中可能出現的問題,分析網絡有效性。

Hook函數機制:不改變主體,實現額外的功能,像一個掛件一樣;

Hook函數本身不是本文介紹的重點,網上介紹的文章頗多,本文主要是記錄一下筆者在使用hook函數時遇到的一些問題及解決過程。

register_forward_hook

首先看一下一個最簡單的使用register_forward_hook的例子:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):out = F.relu(self.conv1(x))     #1 out = F.max_pool2d(out, 2)      #2out = F.relu(self.conv2(out))   #3out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return outfeatures = []
def hook(module, input, output): # module: model.conv2 # input :in forward function  [#2]# output:is  [#3 self.conv2(out)]print('*'*100)features.append(output.clone().detach())# output is saved  in a list net = LeNet() ## 模型實例化 
x = torch.randn(2, 3, 32, 32) ## input 
handle = net.conv2.register_forward_hook(hook) ## 獲取整個Lenet模型 conv2的中間結果
y = net(x)  ## 獲取的是 關于 input x 的 conv2 結果 print(features[0].size()) # 即 [#3 self.conv2(out)]
handle.remove() ## hook刪除 ,防止多次保存hook內容占用空間

輸出

****************************************************************************************************
torch.Size([2, 16, 10, 10])

形狀是我們想要的結果,打印一串*是為了直觀地驗證hook函數被調用了。

其中conv2的名稱,我們可以打印模型的state_dict()來查看自己要的是哪個module

for k in model.state_dict():print(k)

輸出:

conv1.weight
conv1.bias
conv2.weight
conv2.bias
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias

我們上面直接拿conv2做例子了。

出現的問題

在實際使用中,我想打印最近的transformer模型alt_gvt_large的位置編碼來看一下,但是遇到了問題。

我查看了一下模型中的module,找到自己想要的

import torch
import timm
import numpy as np
import cv2
import seaborn as sns
import gvt
from PIL import Image
from torchvision import transformsfmap_block = []
def forward_hook(module, data_input, data_output):print('*'*100)fmap_block.append(data_output.clone().detach())model = timm.create_model('alt_gvt_large',pretrained=False,num_classes=1000,drop_rate=0.1,drop_path_rate=0.1,drop_block_rate=None,)
pipeline = transforms.Compose([transforms.RandomCrop(224),transforms.ToTensor(),])for k in model.state_dict():print(k)

輸出:

# ...
patch_embeds.3.norm.weight
patch_embeds.3.norm.bias
norm.weight
norm.bias
head.weight
head.bias
pos_block.0.proj.0.weight
pos_block.0.proj.0.bias
pos_block.1.proj.0.weight
pos_block.1.proj.0.bias
pos_block.2.proj.0.weight
pos_block.2.proj.0.bias
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
blocks.0.0.norm1.weight
blocks.0.0.norm1.bias
# ...

那肯定就是pos_block嘍。

開始hook:


image = Image.open('125.jpg')
image = pipeline(image).unsqueeze(dim=0)handle = model.pos_block.register_forward_hook(forward_hook)pred = model(image)
print(fmap_block[0].shape)
handle.remove()

出大問題,根本沒有輸出,連我們設置來驗證hook函數運行的*也沒有出現,hook函數肯定沒有被執行,這是怎么回事呢?

解決過程

經過仔細比對以上兩次成功和失敗hook經歷:

conv2.bias
conv2.weight
--------
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias

簡單分析不難有如此猜測:只有下面直接能點( . )到weight和bias的module才能被直接hook。

但是直接將輸出結果粘貼過去會出現:

handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)

直接報語法錯誤,數字肯定是不能直接點的。

handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)^
SyntaxError: invalid syntax

于是筆者一層一層查看進去:

for k in model.pos_block:print(k)for _k in k.proj.state_dict():print(_k)breakbreak 
print(type(model.pos_block))

發現上面出現數字的地方的類型其實是:<class ‘torch.nn.modules.container.ModuleList’>,也就是一個list,那是不是直接可以用[ ]進行索引。

于是我們可以改為:

handle = model.pos_block[3].proj[0].register_forward_hook(forward_hook)

輸出:

****************************************************************************************************
torch.Size([1, 256, 28, 28])

終于成功。

總結

還是對PyTorch中的Model,Module,childeren_module等理解的不到位啊,只會最基本的使用方法,稍微進階一點的操作就會遇到阻力,以后有時間梳理一下。PyTorch是當今公認比較好用的開源框架了,但是想要隨心所欲地實現自己的想法,還是需要花點時間把其中的各個組件及相互之間的關系都理解到位。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532839.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532839.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532839.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

geoda權重矩陣導入matlab,空間計量經濟學-分析解析.ppt

廈門大學 鄧明 空間截面回歸模型 地理加權回歸模型 地理加權回歸模型擴展了普通線性回歸模型。在GWR模型中&#xff0c;特定區位的回歸系數不再是利用全部信息獲得的假定常數&#xff0c;而是利用鄰近觀測值的子樣本數據信息進行局域(Local)回歸估計而得&#xff0c;并隨著空間…

樹莓派攝像頭基礎配置及測試

樹莓派攝像頭基礎配置 step 1 硬件連接 硬件連接&#xff0c;注意不要接反了&#xff0c;排線藍色一段朝向網口的方向。&#xff08;筆者的設備是樹莓派4B&#xff09; step 2 安裝raspi-config 安裝 raspi-config raspi-config在raspbian中是預裝的&#xff0c;而在kali、…

matlab sobel銳化,sobel銳化 - yirui wu.ppt

sobel銳化 - yirui wu第六章 圖像銳化 圖像銳化的概念 圖像銳化的目的是加強圖像中景物的細節邊緣和輪廓。 銳化的作用是使灰度反差增強。 因為邊緣和輪廓都位于灰度突變的地方。所以銳化算法的實現是基于微分作用。 圖像銳化方法 圖像的景物細節特征&#xff1b; 一階微分銳化…

使用百度云智能SDK和樹莓派搭建簡易的人臉識別系統 Python語言版

硬件 樹莓派4B一個CSI攝像頭一個 筆者使用的是樹莓派4B和CSI攝像頭&#xff0c;但是樹莓派3和USB攝像頭等相似設備均可。 百度云智能設置 Step 1 登錄 百度云智能 網址https://cloud.baidu.com/ 首先登錄百度賬號&#xff0c;與百度云、百度貼吧等互通&#xff0c;可直接…

php 5.6 引用傳遞,升級到5.6.x后如何在php中修復引用傳遞

我最近將fom php 5.2升級到5.6,并且有一些代碼我無法修復&#xff1a;//Finds users with the same ip- or email-addressfunction find_related_users($user_id) {global $pdo;//print_R($pdo);//Let SQL do the magic!$sth $pdo->prepare(CALL find_related_users(?));$…

RuntimeError: [enforce fail at inline_container.cc:145] . PytorchStreamReader failed reading zip arc

RuntimeError: [enforce fail at inline_container.cc:145] . PytorchStreamReader failed reading zip archive: failed finding central directory 原因分析 這個報錯是出現在PyTorch在讀入模型參數時&#xff1a; checkpoint torch.load(epoch_15.pth, map_locationcpu)…

xp搭建 php環境,windows xp 下 LAMP環境搭建

1. apache安裝步驟如下圖在瀏覽器中輸入&#xff1a;localhost&#xff0c;出現下面頁面說明已成功安裝apache。2. mysql安裝如下圖顯示在運行里面輸入cmd &#xff0c;然后連接測試mysql &#xff0c;如圖所示&#xff1a;3. php安裝(1)將php壓縮包解壓到安裝路徑中的php目錄…

C++中的虛函數(表)實現機制以及用C語言對其進行的模擬實現

C中的虛函數(表)實現機制以及用C語言對其進行的模擬實現 聲明&#xff1a;本文非博主原創&#xff0c;轉自https://blog.twofei.com/496/&#xff0c;博主讀后受益良多&#xff0c;特地轉載&#xff0c;一是希望好文能有更多人看到&#xff0c;二是為了日后自己查閱。 前言 …

php 前端模板 yii,php – Yii2高級模板:添加獨立網頁

我在backend / views / site下添加了help.php,并在SiteController.php下聲明了一個能夠識別鏈接的函數public function behaviors(){return [access > [class > AccessControl::className(),rules > [[actions > [login, error],allow > true,],[actions > […

C++中數組和指針的關系(區別)詳解

C中數組和指針的關系&#xff08;區別&#xff09;詳解 本文轉自&#xff1a;http://c.biancheng.net/view/1472.html 博主在閱讀后將文中幾個知識點提出來放在前面&#xff1a; 沒有方括號和下標的數組名稱實際上代表數組的起始地址&#xff0c;這意味著數組名稱實際上就是…

安裝php獨立環境,0507-php獨立環境的安裝與配置 Web程序 - 貪吃蛇學院-專業IT技術平臺...

1.在一個純英文目錄下新建三個文件夾2.安裝apache(選擇好版本)過程中該填的按格式填好&#xff0c;其余的只更改安裝目錄即可如果報錯1901是安裝版本的問題。檢查&#xff1a;安裝完成后localhost打開為It works!添加到電腦屬性環境變量&#xff1a;3.將php文件解壓文檔放到AMP…

linux中PATH變量-詳細介紹

轉自&#xff1a;https://blog.csdn.net/haozhepeng/article/details/100584451 轉載者勘誤 原文最后提到的 echo 命令對于環境變量的修改無影響。這是肯定的&#xff0c;echo 命令相當于只是一個打印的函數&#xff08;比如 Python 中的 print&#xff09;。這里要修改環境變…

php assert eval,代碼執行函數之一句話木馬

前言大家好&#xff0c;我是阿里斯&#xff0c;一名IT行業小白。非常抱歉&#xff0c;昨天的內容出現瑕疵比較多&#xff0c;今天重新整理后再次發出&#xff0c;修改并添加了細節&#xff0c;另增加了常見的命令執行函數如果哪里不足&#xff0c;還請各位表哥指出。eval和asse…

顯卡、顯卡驅動、CUDA、CUDA Toolkit、cuDNN 梳理

顯卡、顯卡驅動、CUDA、CUDA Toolkit、cuDNN 梳理 轉自&#xff1a;https://www.cnblogs.com/marsggbo/p/11838823.html#nvccnvidia-smi GPU型號含義 顯卡&#xff1a; 簡單理解這個就是我們前面說的GPU&#xff0c;尤其指NVIDIA公司生產的GPU系列&#xff0c;因為后面介紹的…

php中msubstr,PHP學習:thinkphp中字符截取函數msubstr()用法分析

《PHP學習&#xff1a;thinkphp中字符截取函數msubstr()用法分析》要點&#xff1a;本文介紹了PHP學習&#xff1a;thinkphp中字符截取函數msubstr()用法分析&#xff0c;希望對您有用。如果有疑問&#xff0c;可以聯系我們。本文實例講述了thinkphp中字符截取函數msubstr()用法…

VS Code的Error: Running the contributed command: ‘_workbench.downloadResource‘ failed解決

VS Code的Error: Running the contributed command: _workbench.downloadResource failed解決 轉自&#xff1a;https://blog.csdn.net/ibless/article/details/118610776 1 問題描述 此前&#xff0c;本人參考網上教程在VS Code中配置了“Remote SSH”插件&#xff08;比如這…

Oracle閃回報錯,oracle 閃回區滿了,ORA-19815

oracle 閃回區滿了&#xff0c;查看日志報錯&#xff1a;ORA-19815&#xff0c;命令行輸入&#xff1a;sqlplus / as sysdbastartup mount //如果你的數據庫出現了無法連接的情況時&#xff0c;可以加上這句select file_type, percent_space_used as used,percent_space_rec…

[2021-ICCV] MUSIQ Multi-scale Image Quality Transformer 論文簡析

[2021-ICCV] MUSIQ: Multi-scale Image Quality Transformer 論文簡析 論文&#xff1a;https://arxiv.org/abs/2108.05997 代碼&#xff1a;https://github.com/google-research/google-research/tree/master/musiq 概述 當前SOTA的IQA&#xff08;圖像質量評估&#xff0…

安裝oracle不動了,windows2008安裝ORACLE到2%不動的問題 | 信春哥,系統穩,閉眼上線不回滾!...

最近又有網友遇到在windows2008服務器上安裝ORACLE軟件時到2%就卡住不動的問題&#xff0c;下面是該網友的描述&#xff1a;oralce 11g r2 windows server 2008 R2安裝到最后一步復制數據文件時卡到2% 不走了內存一直飆升求解決這個問題前段時間也有人遇到過&#xff0c;但是他…

手把手教你入門Git --- Git使用指南(Linux)

手把手教你入門Git — Git使用指南&#xff08;Linux&#xff09; 系統&#xff1a;ubuntu 18.04 LTS 本文所有git命令操作實驗具有連續性&#xff0c;git小白完全可以從頭到尾跟著本文所有給出的命令走一遍&#xff0c;就會對git有一個初步的了解&#xff0c;應當能做到會用并…