python打卡day42

Grad-CAM與Hook函數

知識點回顧

  1. 回調函數
  2. lambda函數
  3. hook函數的模塊鉤子和張量鉤子
  4. Grad-CAM的示例

在深度學習中,我們經常需要查看或修改模型中間層的輸出或梯度,但標準的前向傳播和反向傳播過程通常是一個黑盒,很難直接訪問中間層的信息。PyTorch 提供了一種強大的工具——hook 函數,它允許我們在不修改模型結構的情況下,獲取或修改中間層的信息。常用場景如下:

  1. 調試與可視化中間層輸出
  2. 特征提取:如在圖像分類模型中提取高層語義特征用于下游任務
  3. 梯度分析與修改: 在訓練過程中,對某些層進行梯度裁剪或縮放,以改變模型訓練的動態
  4. 模型壓縮:在推理階段對特定層的輸出應用掩碼(如剪枝后的模型權重掩碼),實現輕量化推理

1、回調函數

Hook本質是回調函數,所以我們先介紹一下回調函數。回調函數是作為參數傳遞給其他函數的函數,其目的是在某個特定事件發生時被調用執行。這種機制允許代碼在運行時動態指定需要執行的邏輯,其中回調函數作為參數傳入,所以在定義的時候一般用callback來命名

在 PyTorch 的 Hook API 中,回調參數通常命名為 hook,PyTorch 的 Hook 機制基于其動態計算圖系統:

  1. 當你注冊一個 Hook 時,PyTorch 會在計算圖的特定節點(如模塊或張量)上添加一個回調函數
  2. 當計算圖執行到該節點時(前向或反向傳播),自動觸發對應的 Hook 函數
  3. Hook 函數可以訪問或修改流經該節點的數據(如輸入、輸出或梯度)

2、lambda函數

在hook中常常用到lambda函數,它是一種匿名函數(沒有正式名稱的函數),最大特點是用完即棄,無需提前命名和定義。它的語法形式非常簡約,僅需一行即可完成定義,格式:lambda 參數列表: 表達式

  • 參數列表:可以是單個參數、多個參數或無參數
  • 表達式:函數的返回值(無需 return 語句,表達式結果直接返回)

舉個例子

# 定義匿名函數:計算平方
square = lambda x: x ** 2# 調用
print(square(5))  # 輸出: 25

3、hook函數

PyTorch 提供了兩種主要的 hook:

  • Module Hooks(模塊鉤子):用于監聽整個模塊的輸入和輸出
  • Tensor Hooks:用于監聽張量的梯度

(1)模塊鉤子

允許我們在模塊的輸入或輸出經過時進行監聽。PyTorch 提供了兩種模塊鉤子:

  1. register_forward_hook:在模塊的前向傳播完成后立即被調用,這個函數可以訪問模塊的輸入和輸出,但不能修改
  2. register_backward_hook:在反向傳播過程中被調用的,可以用來獲取或修改梯度信息

前向鉤子舉個例子

# 創建模型實例
model = SimpleModel()# 創建一個列表用于存儲中間層的輸出
conv_outputs = []# 定義前向鉤子函數 - 用于在模型前向傳播過程中獲取中間層信息
def forward_hook(module, input, output):print(f"鉤子被調用!模塊類型: {type(module)}")print(f"輸入形狀: {input[0].shape}") #  input是一個元組,對應 (image, label)print(f"輸出形狀: {output.shape}")# 保存卷積層的輸出用于后續分析# 使用detach()避免追蹤梯度,防止內存泄漏conv_outputs.append(output.detach())# 在卷積層注冊前向鉤子
# register_forward_hook返回一個句柄,用于后續移除鉤子
hook_handle = model.conv.register_forward_hook(forward_hook)# 創建一個隨機輸入張量 (批次大小=1, 通道=1, 高度=4, 寬度=4)
x = torch.randn(1, 1, 4, 4)# 執行前向傳播 - 此時會自動觸發鉤子函數
output = model(x)# 釋放鉤子 - 重要!防止在后續模型使用中持續調用鉤子造成意外行為或內存泄漏
hook_handle.remove()

反向鉤子

# 定義一個存儲梯度的列表
conv_gradients = []# 定義反向鉤子函數
def backward_hook(module, grad_input, grad_output):print(f"反向鉤子被調用!模塊類型: {type(module)}")print(f"輸入梯度數量: {len(grad_input)}")print(f"輸出梯度數量: {len(grad_output)}")# 保存梯度供后續分析conv_gradients.append((grad_input, grad_output))# 在卷積層注冊反向鉤子
hook_handle = model.conv.register_backward_hook(backward_hook)# 創建一個隨機輸入并進行前向傳播
x = torch.randn(1, 1, 4, 4, requires_grad=True)
output = model(x)# 定義一個簡單的損失函數并進行反向傳播
loss = output.sum()
loss.backward()# 釋放鉤子
hook_handle.remove()

(2)張量鉤子

PyTorch 還提供了張量鉤子,允許我們直接監聽和修改張量的梯度。張量鉤子有兩種:

  1. register_hook:用于監聽張量的梯度
  2. register_full_backward_hook:用于在完整的反向傳播過程中監聽張量的梯度
# 創建一個需要計算梯度的張量
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3# 定義一個鉤子函數,用于修改梯度
def tensor_hook(grad):print(f"原始梯度: {grad}")# 修改梯度,例如將梯度減半return grad / 2# 在y上注冊鉤子
hook_handle = y.register_hook(tensor_hook)# 計算梯度,梯度會從z反向傳播經過y到x,此時調用鉤子函數
z.backward()print(f"x的梯度: {x.grad}")# 釋放鉤子
hook_handle.remove()

4、Grad-CAM

一個可視化算法,通過梯度信息用熱力圖顯示圖片中哪些區域讓CNN做出了某個分類決定(比如為什么認為這是“貓”),原理:

  • 梯度計算:看最后幾層特征圖的梯度,哪個特征圖對預測“貓”的貢獻大
  • 加權融合:把重要的特征圖合并成一張熱力圖(重要區域更亮)
  • 疊加顯示:把熱力圖蓋在原圖上,一眼看出貓的臉/耳朵等關鍵部位被高亮了
# Grad-CAM實現
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 注冊鉤子,用于獲取目標層的前向傳播輸出和反向傳播梯度self.register_hooks()def register_hooks(self):# 前向鉤子函數,在目標層前向傳播后被調用,保存目標層的輸出(激活值)def forward_hook(module, input, output):self.activations = output.detach()# 反向鉤子函數,在目標層反向傳播后被調用,保存目標層的梯度def backward_hook(module, grad_input, grad_output):self.gradients = grad_output[0].detach()# 在目標層注冊前向鉤子和反向鉤子self.target_layer.register_forward_hook(forward_hook)self.target_layer.register_backward_hook(backward_hook)def generate_cam(self, input_image, target_class=None):# 前向傳播,得到模型輸出model_output = self.model(input_image)if target_class is None:# 如果未指定目標類別,則取模型預測概率最大的類別作為目標類別target_class = torch.argmax(model_output, dim=1).item()# 清除模型梯度,避免之前的梯度影響self.model.zero_grad()# 反向傳播,構造one-hot向量,使得目標類別對應的梯度為1,其余為0,然后進行反向傳播計算梯度one_hot = torch.zeros_like(model_output)one_hot[0, target_class] = 1model_output.backward(gradient=one_hot)# 獲取之前保存的目標層的梯度和激活值gradients = self.gradientsactivations = self.activations# 對梯度進行全局平均池化,得到每個通道的權重,用于衡量每個通道的重要性weights = torch.mean(gradients, dim=(2, 3), keepdim=True)# 加權激活映射,將權重與激活值相乘并求和,得到類激活映射的初步結果cam = torch.sum(weights * activations, dim=1, keepdim=True)# ReLU激活,只保留對目標類別有正貢獻的區域,去除負貢獻的影響cam = F.relu(cam)# 調整大小并歸一化,將類激活映射調整為與輸入圖像相同的尺寸(32x32),并歸一化到[0, 1]范圍cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)cam = cam - cam.min()cam = cam / cam.max() if cam.max() > 0 else camreturn cam.cpu().squeeze().numpy(), target_classidx = 102  # 選擇測試集中的第101張圖片 (索引從0開始)
image, label = testset[idx]
print(f"選擇的圖像類別: {classes[label]}")# 轉換圖像以便可視化
def tensor_to_np(tensor):img = tensor.cpu().numpy().transpose(1, 2, 0)mean = np.array([0.5, 0.5, 0.5])std = np.array([0.5, 0.5, 0.5])img = std * img + meanimg = np.clip(img, 0, 1)return img# 添加批次維度并移動到設備
input_tensor = image.unsqueeze(0).to(device)# 初始化Grad-CAM(選擇最后一個卷積層)
grad_cam = GradCAM(model, model.conv3)# 生成熱力圖
heatmap, pred_class = grad_cam.generate_cam(input_tensor)# 可視化
plt.figure(figsize=(12, 4))# 原始圖像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始圖像: {classes[label]}")
plt.axis('off')# 熱力圖
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM熱力圖: {classes[pred_class]}")
plt.axis('off')# 疊加的圖像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("疊加熱力圖")
plt.axis('off')plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()

@浙大疏錦行

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/83263.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/83263.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/83263.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

中國風展示工作總結商務通用PPT模版

中國風展示工作總結商務通用PPT模版:中國風商務通用PPT 模版https://pan.quark.cn/s/42ad18c010d4

TeleAI發布TeleChat2.5及T1正式版,雙雙開源上線魔樂社區!

5月12日,中國電信開源TeleChat系列四個模型,涵蓋復雜推理和通用問答的多個尺寸模型,包括TeleChat-T1-35B、TeleChat-T1-115B、TeleChat2.5-35B和TeleChat2.5-115B,實測模型性能均有顯著的性能效果。TeleChat系列模型基于昇思MindS…

機器視覺2D定位引導一般步驟

機器視覺的2D定位引導是工業自動化中的核心應用,主要用于精確確定目標物體的位置(X, Y坐標)和角度(旋轉角度θ),并引導機器人或運動機構進行抓取、裝配、對位、檢測等操作。其一般步驟可概括如下: 一、系統規劃與硬件選型 明確需求: 定位精度要求(多少毫米/像素,多少…

兒童節快樂,聊聊數字的規律和同余原理

某年的6月1日是星期日。那么,同一年的6月30日是星期幾? 星期是7天一個循環。所以說,這一天是星期幾,7天之后同樣也是星期幾。而6月30日是在6月1日的29天之后:29 7 4 ... 1用29除以7,可以得出余數為1。而…

最佳實踐|互聯網行業軟件供應鏈安全建設的SCA縱深實踐方案

在數字化轉型的浪潮中,開源組件已成為企業構建云服務與應用的基石,但其引入的安全風險也日益凸顯。某互聯網大廠的核心安全研究團隊,通過深度應用軟件成分分析(SCA)技術,構建了一套覆蓋開源組件全生命周期管…

Docker Compose(容器編排)

目錄 什么是 Docker Compose Docker Compose 的功能 Docker Compose 使用場景 Docker Compose 文件(docker-compose.yml) Docker Compose 命令清單 常見命令說明 操作案例 總結 什么是 Docker Compose docker-compose 是 Docker 官方的開源項…

【網絡安全】輕量敏感路徑掃描工具

訂閱專欄,獲取文末項目源碼。 文章目錄 工具簡介工具特點項目結構使用方法1.環境準備2.配置目標URL3.運行掃描4.結果查看5.自定義擴展項目源碼工具簡介 該工具是一款基于Python的異步敏感路徑掃描工具,用于檢測目標網站是否存在敏感文件或路徑泄露(如配置文件、密鑰、版本控…

SpringAI+DeepSeek大模型應用開發實戰

內容來自黑馬程序員 這里寫目錄標題 認識AI和大模型大模型應用開發模型部署方案對比模型部署-云服務模型部署-本地部署調用大模型什么是大模型應用傳統應用和大模型應用大模型應用 大模型應用開發技術架構 SpringAI對話機器人快速入門會話日志會話記憶 認識AI和大模型 AI的發…

高溫爐制造企業Odoo ERP實施規劃與深度分析報告

摘要 本報告旨在為高溫爐生產企業提供一個基于Odoo 18平臺的企業資源規劃(ERP)系統實施的全面分析與規劃。報告首先系統梳理了高溫爐制造業獨特的業務流程特點,隨后詳細映射了Odoo 18各核心模塊功能與這些業務需求的匹配程度。重點分析了生產…

簡述什么是全局鎖?它的應用場景有哪些?

全局鎖是數據庫管理系統中的一種特殊鎖機制,用于對整個數據庫實例進行加鎖,使數據庫處于只讀狀態,阻止所有數據更新(DML)、數據定義(DDL)及更新類事務提交等操作。 其核心應用場景包括&#xf…

window 顯示驅動開發-呈現開銷改進(二)

對共享表面的紋理格式支持 驅動程序應支持共享資源和可共享的后臺緩沖區,以使用 DXGI_FORMAT 枚舉中的這些附加紋理格式: DXGI_FORMAT_A8_UNORMDXGI_FORMAT_R8_UNORMDXGI_FORMAT_R8G8_UNORMDXGI_FORMAT_BC1_TYPELESS\*DXGI_FORMAT_BC1_UNORMDXGI_FORMAT…

jenkins集成gitlab實現自動構建

jenkins集成gitlab實現自動構建 前面我們已經部署了Jenkins和gitlab,本文介紹將二者結合使用 項目源碼上傳至gitee提供公網訪問:https://gitee.com/ye-xiao-tian/my-webapp 1、創建一個群組和項目 2、添加ssh密鑰 #生成密鑰 [rootgitlab ~]# ssh-keyge…

barker-OFDM模糊函數原理及仿真

文章目錄 前言一、巴克碼序列二、barker-OFDM 信號1、OFDM 信號表達式2、模糊函數表達式 三、MATLAB 仿真1、MATLAB 核心源碼2、仿真結果①、barker-OFDM 模糊函數②、barker-OFDM 距離分辨率③、barker-OFDM 速度分辨率④、barker-OFDM 等高線圖 四、資源自取 前言 本文進行 …

深入解析 Redis Cluster 架構與實現(一)

#作者:stackofumbrella 文章目錄 Redis Cluster特點Redis Cluster與其它集群模式的區別集群目標性能hash tagsMutli-key操作Cluster Bus安全寫入(write safety)集群節點的屬性集群拓撲節點間handshake重定向與reshardingMOVED重定向ASK重定向…

linux centos 服務器性能排查 vmstat、top等常用指令

背景:項目上經常出現系統運行緩慢,由于數據庫服務器是linux服務器,記錄下linux服務器性能排查常用指令 vmstat vmstat介紹 vmstat 命令報告關于內核線程、虛擬內存、磁盤、陷阱和 CPU 活動的統計信息。由 vmstat 命令生成的報告可以用于平衡系統負載活動。系統范圍內的這…

在IIS上無法使用PUT等請求

錯誤來源: chat:1 Access to XMLHttpRequest at http://101.126.139.3:11000/api/receiver/message from origin http://101.126.139.3 has been blocked by CORS policy: No Access-Control-Allow-Origin header is present on the requested resource. 其實我的后…

Python訓練第四十一天

DAY 41 簡單CNN 知識回顧 數據增強卷積神經網絡定義的寫法batch歸一化:調整一個批次的分布,常用與圖像數據特征圖:只有卷積操作輸出的才叫特征圖調度器:直接修改基礎學習率 卷積操作常見流程如下: 1. 輸入 → 卷積層 →…

Linux線程同步實戰:多線程程序的同步與調度

個人主頁:chian-ocean 文章專欄-Linux Linux線程同步實戰:多線程程序的同步與調度 個人主頁:chian-ocean文章專欄-Linux 前言:為什么要實現線程同步線程饑餓(Thread Starvation)示例:搶票問題 …

5.2 初識Spark Streaming

在本節實戰中,我們初步探索了Spark Streaming,它是Spark的流式數據處理子框架,具備高吞吐量、可伸縮性和強容錯能力。我們了解了Spark Streaming的基本概念和運行原理,并通過兩個案例演示了如何利用Spark Streaming實現詞頻統計。…

Go 即時通訊系統:日志模塊重構,并從main函數開始

重構logger 上次寫的logger.go過于繁瑣,有很多沒用到的功能;重構后只提供了簡潔的日志接口,支持日志輪轉、多級別日志記錄等功能,并采用單例模式確保全局只有一個日志實例 全局變量 var (once sync.Once // 用于實現…