鉤子函數的作用(register_hook)

鉤子函數僅在backward()時才會觸發。其中,鉤子函數接受梯度作為輸入,返回操作后的梯度,操作后的梯度必須要輸入的梯度同類型、同形狀,否則報錯。

主要功能包括:

  • 監控當前的梯度(不返回值);
  • 對當前的梯度進行操作,返回新的梯度以覆蓋原梯度;
  • 在模型中對梯度進行監控或者修改。

案例 1:監控梯度值

import torch# 創建一個張量,并啟用梯度追蹤
x = torch.tensor([1.0], requires_grad=True)
y = x * 2# 定義鉤子函數
def hook_fn(grad):'''作用:打印梯度'''print("Hook triggered, gradient:", grad)# 注冊鉤子:將鉤子函數注冊到x上,反向傳播計算x梯度時自動觸發鉤子函數
x.register_hook(hook_fn)# 觸發反向傳播和鉤子函數
y.backward()             

結果:

Hook triggered, gradient: tensor([2.])

案例 2:修改梯度值

import torch# 創建一個張量,并啟用梯度追蹤
x = torch.tensor([1.0], requires_grad=True)
y = x * 2# 定義鉤子函數
def hook_fn(grad):'''作用:修改輸入的梯度'''print('原梯度:',grad)return grad * 3# 注冊鉤子:將鉤子函數注冊到x上,反向傳播計算x梯度時自動觸發鉤子函數
x.register_hook(hook_fn)# 觸發反向傳播和鉤子函數
y.backward()          print("修改后的梯度:", x.grad)            

結果:

原梯度: tensor([2.])
修改后的梯度: tensor([6.])

案例 3:在模型中使用 register_hook

import torch
import torch.nn as nnmodel = nn.Linear(1, 1)
weight = model.weight # 模型權重# 定義鉤子函數
def hook_fn(grad):'''作用:打印梯度'''print("Gradient of weight:", grad)# 注冊鉤子:將鉤子函數注冊到weight上,反向傳播計算weight梯度時自動觸發鉤子函數
weight.register_hook(hook_fn)# 輸入數據
x = torch.tensor([[1.0]])
target = torch.tensor([[3.0]])# 前向傳播
output = model(x)
print(output)# 損失函數
loss = (output - target).pow(2)# 觸發反向傳播和鉤子函數
loss.backward()           

結果:

Gradient of weight: tensor([[-6.1532]])

注意:
在實際使用中,必須使用clone()來確保梯度操作的安全性和計算圖完整性,例如:

def hook_fn(grad):return grad.clone() * 3
  • 通過 grad.clone() 創建梯度副本后進行操作,所有修改僅作用于副本,不會觸碰原始梯度存儲。不采用克隆,直接對原始梯度進行操作,PyTorch 會檢測到對計算圖中張量的潛在原地修改(in-place operation),并拋出異常。
  • 不采用克隆,會破壞計算圖路徑,導致梯度回傳中斷或錯誤。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/82278.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/82278.shtml
英文地址,請注明出處:http://en.pswp.cn/web/82278.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【頭歌實驗】Keras機器翻譯實戰

【頭歌實驗】Keras機器翻譯實戰 第1關:加載原始數據 編程要求 根據提示,在右側編輯器補充代碼,實現load_data函數,該函數需要加載path所代表的文件中的數據,并將文件中所有的內容按\n分割,轉換成一個列表…

python中使用高并發分布式隊列庫celery的那些坑

python中使用高并發分布式隊列庫celery的那些坑 🌟 簡單理解🛠? 核心功能🚀 工作機制📦 示例代碼(使用 Redis 作為 broker)🔗 常見搭配📦 我的環境📦第一個問題&#x1…

截圖工具 Snipaste V2.10.7(2025.06.2更新)

—————【下 載 地 址】——————— 【?本章下載一】:https://pan.xunlei.com/s/VORklK9hcuoI6n_qgx25jSq2A1?pwde7bi# 【?本章下載二】:https://pan.quark.cn/s/7c62f8f86735 【百款黑科技】:https://ucnygalh6wle.feishu.cn/wiki/…

batch_size 參數最優設置

在深度學習訓練中,batch_size(批量大小)的選擇是一個需要權衡的問題,既不是越大越好,也不是越小越好,而是需要根據硬件資源、數據規模、模型復雜度和優化目標等因素綜合決定。以下是詳細分析:

【agent開發】部署LLM(一)

本周基本就是在踩坑,沒什么實質性的進展 下載模型文件 推薦一個網站,可以簡單計算下模型推理需要多大顯存:https://apxml.com/tools/vram-calculator 我的顯卡是RTX 4070,有12GB的顯存,部署一個1.7B的Qwen3應該問題…

大數據-274 Spark MLib - 基礎介紹 機器學習算法 剪枝 后剪枝 ID3 C4.5 CART

點一下關注吧!!!非常感謝!!持續更新!!! 大模型篇章已經開始! 目前已經更新到了第 22 篇:大語言模型 22 - MCP 自動操作 FigmaCursor 自動設計原型 Java篇開…

flutter常用動畫

Flutter 動畫基礎概念 術語解釋Animation表示動畫的值,通常是一個 double (0.0 ~ 1.0) 或其他數值。AnimationController管理動畫的時間進度和狀態。需要 Ticker (vsync) 來驅動。Tween定義動畫的取值范圍,如從 0.0 到 1.0,從紅色到藍色。Cu…

Python打卡DAY43

復習日 作業: kaggle找到一個圖像數據集,用cnn網絡進行訓練并且用grad-cam做可視化 進階:并拆分成多個文件 我選擇ouIntel Image Classification | Kagglezz,該數據集分為六類,包含建筑、森林、冰川、山脈、海洋和街道…

從多巴胺的誘惑到內啡肽的力量 | 個體成長代際教育的成癮困局與破局之道

注:本文為“多巴胺,內啡肽”相關文章合輯。 圖片清晰度受引文原圖所限。 略作重排,未整理去重。 如有內容異常,請看原文。 年少偏愛多巴胺,中年才懂內啡肽 摘要 :本文通過生活實例與科學研究相結合的方式…

【音視頻】H265 NALU分析

1 H265 概述 H264 與 H265 的區別 傳輸碼率:H264 由于算法優化,可以低于 2Mbps 的速度實現標清數字圖像傳送;H.265 High Profile 可實現低于 1.5Mbps 的傳輸帶寬下,實現 1080p 全高清視頻傳輸。 編碼架構:H.265/HEVC…

Python訓練營打卡 Day26

知識點回顧: 函數的定義變量作用域:局部變量和全局變量函數的參數類型:位置參數、默認參數、不定參數傳遞參數的手段:關鍵詞參數傳遞參數的順序:同時出現三種參數類型時 ——————————————————————…

PH熱榜 | 2025-05-29

1. Tapflow 2.0 標語:將你的文檔轉化為可銷售的指導手冊、操作手冊和工作流程。 介紹:Tapflow 2.0將各類知識(包括人工智能、設計、開發、營銷等)轉化為有條理且可銷售的產品。現在你可以導入文件,讓人工智能快速為你…

GitHub 趨勢日報 (2025年05月30日)

📊 由 TrendForge 系統生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日報中的項目描述已自動翻譯為中文 📈 今日獲星趨勢圖 今日獲星趨勢圖 833 agenticSeek 789 prompt-eng-interactive-tutorial 466 ai-agents-for-beginn…

Cesium 8 ,在 Cesium 上實現雷達動畫和車輛動畫效果,并控制顯示和隱藏

目錄 ?前言 一、功能背景 1.1 核心功能概覽 1.2 技術棧與工具 二、車輛動畫 2.1 模型坐標 2.2 組合渲染 2.3 顯隱狀態 2.4 模型文件 三、雷達動畫 3.1 創建元素 3.2 動畫解析 3.3 坐標聯動 3.4 交互事件 四、完整代碼 4.1 屬性參數 4.2 邏輯代碼 加載車輛動畫…

相機--相機標定

教程 相機標定分類 相機標定分為內參標定和外參標定。 內參標定 目的 作用 原理 外參標定

JS手寫代碼篇---手寫類型判斷函數

9、手寫類型判斷函數 手寫完成這個函數:輸入一個對象(value),返回它的類型 js中的數據類型: 值類型:String、Number、Boolean、Null、Undefied、Symbol引用類型:Object、Array、Function、RegExp、Date 使用typeOf…

量子物理:初步認識量子物理

核心特點——微觀世界與宏觀世界的差異 量子物理(又稱量子力學)是物理學中描述微觀世界(原子、電子、光子等尺度)基本規律的理論框架。它與我們熟悉的經典物理(牛頓力學、電磁學等)有根本性的不同,因為微觀粒子的行為展現出許多奇特且反直覺的現象。 簡單來說,量子物…

springboot配置cors攔截器與cors解釋

文章目錄 cors?代碼 cors? CORS(跨域資源共享)的核心機制是 由后端服務器(bbb.com)決定是否允許前端(aaa.com)的跨域請求 當瀏覽器訪問 aaa.com 的頁面,并向 bbb.com/list 發起請求時&#…

國芯思辰| 同步降壓轉換器CN2020應用于智能電視,替換LMR33620

在智能電視不斷向高畫質、多功能、智能化發展的當下,其內部電源管理系統的性能至關重要。同步降壓轉換器可以為智能電視提供穩定、高效的運行。 國芯思辰CN2020是一款脈寬調制式同步降壓轉換器。內部集成兩個功率MOS管,在4.5~18V寬輸入電壓范圍內可以持…

API 版本控制:使用 ABP vNext 實現版本化 API 系統

🚀API 版本控制:使用 ABP vNext 實現版本化 API 系統 📚 目錄 🚀API 版本控制:使用 ABP vNext 實現版本化 API 系統一、背景切入 🧭二、核心配置規則 📋2.1 前置準備:NuGet 包與 usi…