【Pytorch學習筆記】模型模塊06——hook函數

hook函數

什么是hook函數

hook函數相當于插件,可以實現一些額外的功能,而又不改變主體代碼。就像是把額外的功能掛在主體代碼上,所有叫hook(鉤子)。下面介紹Pytorch中的幾種主要hook函數。

torch.Tensor.register_hook

torch.Tensor.register_hook()是一個用于注冊梯度鉤子函數的方法。它主要用于獲取和修改張量在反向傳播過程中的梯度。

語法格式:

hook = tensor.register_hook(hook_fn)
# hook_fn的格式為:
def hook_fn(grad):# 處理梯度return new_grad  # 可選

主要特點:

  • hook函數在反向傳播計算梯度時被調用
  • hook函數接收梯度作為輸入參數
  • 可以返回修改后的梯度,或者不返回(此時使用原始梯度)
  • 可以注冊多個hook函數,按照注冊順序依次調用

使用示例:

import torch# 創建需要跟蹤梯度的張量
x = torch.tensor([1., 2., 3.], requires_grad=True)# 定義hook函數
def hook_fn(grad):print('梯度值:', grad)return grad * 2  # 將梯度翻倍# 注冊hook函數
hook = x.register_hook(hook_fn)# 進行一些運算
y = x.pow(2).sum()
y.backward()# 移除hook函數(可選)
hook.remove()

注意事項:

  • 只能在requires_grad=True的張量上注冊hook函數
  • hook函數在不需要時應該及時移除,以免影響后續計算
  • 不建議在hook函數中修改梯度的形狀,可能導致錯誤
  • 主要用于調試、可視化和梯度修改等場景

torch.nn.Module.register_forward_hook

torch.nn.Module.register_forward_hook()是一個用于注冊前向傳播鉤子函數的方法。它允許我們在模型的前向傳播過程中獲取和處理中間層的輸出

語法格式:

hook = module.register_forward_hook(hook_fn)
# hook_fn的格式為:
def hook_fn(module, input, output):# 處理輸入和輸出return modified_output  # 可選

主要特點:

  • hook函數在前向傳播過程中被調用
  • 可以訪問模塊的輸入和輸出數據
  • 可以用于監控和修改中間層的特征
  • 不影響反向傳播過程

使用示例:

import torch
import torch.nn as nn# 創建一個簡單的神經網絡
class Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)def forward(self, x):x = self.conv1(x)x = self.conv2(x)return x# 創建模型實例
model = Net()# 定義hook函數
def hook_fn(module, input, output):print('模塊:', module)print('輸入形狀:', input[0].shape)print('輸出形狀:', output.shape)# 注冊hook函數
hook = model.conv1.register_forward_hook(hook_fn)# 前向傳播
x = torch.randn(1, 1, 32, 32)
output = model(x)# 移除hook函數
hook.remove()

注意事項:

  • hook函數在每次前向傳播時都會被調用
  • 可以同時注冊多個hook函數,按注冊順序調用
  • 適用于特征可視化、調試網絡結構等場景
  • 建議在不需要時移除hook函數,以提高性能

torch.nn,Module.register_forward_pre_hook

torch.nn.Module.register_forward_pre_hook()是一個用于注冊前向傳播預處理鉤子函數的方法。它允許我們在模型的前向傳播開始之前對輸入數據進行處理或修改。

語法格式:

hook = module.register_forward_pre_hook(hook_fn)
# hook_fn的格式為:
def hook_fn(module, input):# 處理輸入return modified_input  # 可選

主要特點:

  • hook函數在前向傳播開始前被調用
  • 可以訪問和修改輸入數據
  • 常用于輸入預處理和數據轉換
  • 在實際計算前執行,可以改變輸入特征

使用示例:

import torch
import torch.nn as nn# 創建一個簡單的神經網絡
class Net(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(10, 5)def forward(self, x):return self.linear(x)# 創建模型實例
model = Net()# 定義pre-hook函數
def pre_hook_fn(module, input_data):print('模塊:', module)print('原始輸入形狀:', input_data[0].shape)# 對輸入數據進行處理,例如標準化modified_input = input_data[0] * 2.0return modified_input# 注冊pre-hook函數
hook = model.linear.register_forward_pre_hook(pre_hook_fn)# 前向傳播
x = torch.randn(32, 10)  # 批次大小為32,特征維度為10
output = model(x)# 移除hook函數
hook.remove()

注意事項:

  • pre-hook函數在每次前向傳播前都會被調用
  • 可以用于數據預處理、特征轉換等操作
  • 返回值會替換原始輸入,影響后續計算
  • 建議在不需要時及時移除,以免影響模型性能

與register_forward_hook的區別:

  • pre-hook在模塊計算之前執行,forward_hook在計算之后執行
  • pre-hook只能訪問輸入數據,forward_hook可以同時訪問輸入和輸出
  • pre-hook更適合做輸入預處理,forward_hook更適合做特征分析

torch.nn.Module.register_full_backward_hook

torch.nn.Module.register_full_backward_hook()是一個用于注冊完整反向傳播鉤子函數的方法。它允許我們在模型的反向傳播過程中訪問和修改梯度信息

語法格式:

hook = module.register_full_backward_hook(hook_fn)
# hook_fn的格式為:
def hook_fn(module, grad_input, grad_output):# 處理梯度return modified_grad_input  # 可選

主要特點:

  • hook函數在反向傳播過程中被調用
  • 可以同時訪問輸入梯度和輸出梯度
  • 可以修改反向傳播的梯度流
  • 比register_backward_hook更強大,提供更完整的梯度信息

使用示例:

import torch
import torch.nn as nn# 創建一個簡單的神經網絡
class Net(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(5, 3)def forward(self, x):return self.linear(x)# 創建模型實例
model = Net()# 定義backward hook函數
def backward_hook_fn(module, grad_input, grad_output):print('模塊:', module)print('輸入梯度形狀:', [g.shape if g is not None else None for g in grad_input])print('輸出梯度形狀:', [g.shape if g is not None else None for g in grad_output])# 可以返回修改后的輸入梯度return grad_input# 注冊backward hook函數
hook = model.linear.register_full_backward_hook(backward_hook_fn)# 前向和反向傳播
x = torch.randn(2, 5, requires_grad=True)
output = model(x)
loss = output.sum()
loss.backward()# 移除hook函數
hook.remove()

注意事項:

  • hook函數可能會影響模型的訓練過程,使用時需要謹慎
  • 建議僅在調試和分析梯度流時使用
  • 返回值會替換原始輸入梯度,可能影響模型收斂
  • 在不需要時應及時移除hook函數

與register_backward_hook的區別:

  • register_full_backward_hook提供更完整的梯度信息
  • 更適合處理復雜的梯度修改場景
  • 建議使用register_full_backward_hook替代已廢棄的register_backward_hook

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/83304.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/83304.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/83304.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SQL進階之旅 Day 11:復雜JOIN查詢優化

【SQL進階之旅 Day 11】復雜JOIN查詢優化 在數據處理日益復雜的今天,JOIN操作作為SQL中最強大的功能之一,常常成為系統性能瓶頸。今天我們進入"SQL進階之旅"系列的第11天,將深入探討復雜JOIN查詢的優化策略。通過本文學習&#xf…

Spring AI 之檢索增強生成(Retrieval Augmented Generation)

檢索增強生成(RAG)是一種技術,有助于克服大型語言模型在處理長篇內容、事實準確性和上下文感知方面的局限性。 Spring AI 通過提供模塊化架構來支持 RAG,該架構允許自行構建自定義的 RAG 流程,或者使用 Advisor API 提…

前端開源JavaScrip庫

以下內容仍在持續完善中,如有遺漏或需要補充之處,歡迎在評論區指出。感謝支持,如果覺得有幫助,歡迎點贊鼓勵。感謝支持 JavaScript 框架Vue.jsVue.js - 漸進式 JavaScript 框架 | Vue.jsReactReactAngularHome ? AngularjQueryj…

什么是 CPU 緩存模型?

導語: CPU 緩存模型是后端性能調優、并發編程乃至分布式系統設計中一個繞不開的核心概念。它不僅關系到指令執行效率,還影響鎖機制、內存可見性等多個面試高頻點。本文將以資深面試官視角,詳解緩存模型的原理、常見面試題及實戰落地&#xff…

海外tk抓包簡單暴力方式

將地址替換下面代碼就可以 function hook_dlopen(module_name, fun) {var android_dlopen_ext Module.findExportByName(null, "android_dlopen_ext");if (android_dlopen_ext) {Interceptor.attach(android_dlopen_ext, {onEnter: function (args) {var pathptr …

多模態大語言模型arxiv論文略讀(103)

Are Bigger Encoders Always Better in Vision Large Models? ?? 論文標題:Are Bigger Encoders Always Better in Vision Large Models? ?? 論文作者:Bozhou Li, Hao Liang, Zimo Meng, Wentao Zhang ?? 研究機構: 北京大學 ?? 問題背景&…

代碼隨想錄算法訓練營 Day61 圖論ⅩⅠ Floyd A※ 最短路徑算法

圖論 題目 97. 小明逛公園 本題是經典的多源最短路問題。 在這之前我們講解過,dijkstra樸素版、dijkstra堆優化、Bellman算法、Bellman隊列優化(SPFA) 都是單源最短路,即只能有一個起點。 而本題是多源最短路,即求多…

【機器學習】集成學習與梯度提升決策樹

目錄 一、引言 二、自舉聚合與隨機森林 三、集成學習器 四、提升算法 五、Python代碼實現集成學習與梯度提升決策樹的實驗 六、總結 一、引言 在機器學習的廣闊領域中,集成學習(Ensemble Learning)猶如一座閃耀的明星,它通過組合多個基本學習器的力量,創造出…

yarn、pnpm、npm

非常好,這樣從“問題驅動 → 工具誕生 → 優化演進”的角度來講,更清晰易懂。下面我按時間線和動機,把 npm → yarn → pnpm 的演變脈絡講清楚。 🧩 一、npm 為什么一開始不夠好? 早期(npm v4 及之前&…

如何用AI寫作?

過去半年,我如何用AI高效寫作,節省數倍時間 過去六個月,我幾乎所有文章都用AI輔助完成。我的朋友——大多是文字工作者,對語言極為敏感——都說看不出我的文章是AI寫的還是親手創作的。 我的AI寫作靈感部分來自丘吉爾。這位英國…

什么是trace,分布式鏈路追蹤(Distributed Tracing)

在你提到的 “個人免費版” 套餐中,“Trace 上報量:5 萬條 / 月,存儲 3 天” 里的 Trace 仍然是指 分布式鏈路追蹤記錄,但需要結合具體產品的場景來理解其含義和限制。以下是更貼近個人用戶使用場景的解釋: 一、這里的…

[免費]微信小程序網上花店系統(SpringBoot后端+Vue管理端)【論文+源碼+SQL腳本】

大家好,我是java1234_小鋒老師,看到一個不錯的微信小程序網上花店系統(SpringBoot后端Vue管理端)【論文源碼SQL腳本】,分享下哈。 項目視頻演示 【免費】微信小程序網上花店系統(SpringBoot后端Vue管理端) Java畢業設計_嗶哩嗶哩_bilibili 項…

PyTorch——DataLoader的使用

batch_size, drop_last 的用法 shuffle shuffleTrue 各批次訓練的圖像不一樣 shuffleFalse 在第156step順序一致

【Linux】基礎文件IO

🌟🌟作者主頁:ephemerals__ 🌟🌟所屬專欄:Linux 前言 無論是日常使用還是系統管理,文件是Linux系統中最核心的概念之一。對于初學者來說,理解文件是如何被創建、讀取、寫入以及存儲…

【JAVA后端入門基礎001】Tomcat 是什么?通俗易懂講清楚!

📚博客主頁:代碼探秘者 ?專欄:《JavaSe》 其他更新ing… ??感謝大家點贊👍🏻收藏?評論?🏻,您的三連就是我持續更新的動力?? 🙏作者水平有限,歡迎各位大佬指點&…

TDengine 的 AI 應用實戰——電力需求預測

作者: derekchen Demo數據集準備 我們使用公開的UTSD數據集里面的電力需求數據,作為預測算法的數據來源,基于歷史數據預測未來若干小時的電力需求。數據集的采集頻次為30分鐘,單位與時間戳未提供。為了方便演示,按…

D2000平臺上Centos使用mmap函數遇到的陷阱

----------原創不易,歡迎點贊收藏。廣交嵌入式開發的朋友,討論技術和產品------------- 在飛騰D2000平臺上,安裝了麒麟linux系統,我寫了個GPIO點燈的程序,在應用層利用mmap函數將內核空間映射到用戶態,然后…

深入了解linux系統—— 進程間通信之管道

前言 本篇博客所涉及到的代碼一同步到本人gitee:testfifo 遲來的grown/linux - 碼云 - 開源中國 一、進程間通信 什么是進程間通信 在之前的學習中,我們了解到了進程具有獨立性,就算是父子進程,在修改數據時也會進行寫時拷貝&…

設計模式——模版方法設計模式(行為型)

摘要 模版方法設計模式是一種行為型設計模式,定義了算法的步驟順序和整體結構,將某些步驟的具體實現延遲到子類中。它通過抽象類定義模板方法,子類實現抽象步驟,實現代碼復用和算法流程控制。該模式適用于有固定流程但部分步驟可…

Python使用

Python學習,從安裝,到簡單應用 前言 Python作為膠水語言在web開發,數據分析,網絡爬蟲等方向有著廣泛的應用 一、Python入門 相關基礎語法直接使用相關測試代碼 Python編譯器版本使用3以后,安裝參考其他教程&#xf…