Grad-CAM與Hook函數
知識點回顧
- 回調函數
- lambda函數
- hook函數的模塊鉤子和張量鉤子
- Grad-CAM的示例
在深度學習中,我們經常需要查看或修改模型中間層的輸出或梯度,但標準的前向傳播和反向傳播過程通常是一個黑盒,很難直接訪問中間層的信息。PyTorch 提供了一種強大的工具——hook 函數,它允許我們在不修改模型結構的情況下,獲取或修改中間層的信息。常用場景如下:
- 調試與可視化中間層輸出
- 特征提取:如在圖像分類模型中提取高層語義特征用于下游任務
- 梯度分析與修改: 在訓練過程中,對某些層進行梯度裁剪或縮放,以改變模型訓練的動態
- 模型壓縮:在推理階段對特定層的輸出應用掩碼(如剪枝后的模型權重掩碼),實現輕量化推理
1、回調函數
Hook本質是回調函數,所以我們先介紹一下回調函數。回調函數是作為參數傳遞給其他函數的函數,其目的是在某個特定事件發生時被調用執行。這種機制允許代碼在運行時動態指定需要執行的邏輯,其中回調函數作為參數傳入,所以在定義的時候一般用callback來命名
在 PyTorch 的 Hook API 中,回調參數通常命名為 hook,PyTorch 的 Hook 機制基于其動態計算圖系統:
- 當你注冊一個 Hook 時,PyTorch 會在計算圖的特定節點(如模塊或張量)上添加一個回調函數
- 當計算圖執行到該節點時(前向或反向傳播),自動觸發對應的 Hook 函數
- Hook 函數可以訪問或修改流經該節點的數據(如輸入、輸出或梯度)
2、lambda函數
在hook中常常用到lambda函數,它是一種匿名函數(沒有正式名稱的函數),最大特點是用完即棄,無需提前命名和定義。它的語法形式非常簡約,僅需一行即可完成定義,格式:lambda 參數列表: 表達式
- 參數列表:可以是單個參數、多個參數或無參數
- 表達式:函數的返回值(無需 return 語句,表達式結果直接返回)
舉個例子
# 定義匿名函數:計算平方
square = lambda x: x ** 2# 調用
print(square(5)) # 輸出: 25
3、hook函數
PyTorch 提供了兩種主要的 hook:
- Module Hooks(模塊鉤子):用于監聽整個模塊的輸入和輸出
- Tensor Hooks:用于監聽張量的梯度
(1)模塊鉤子
允許我們在模塊的輸入或輸出經過時進行監聽。PyTorch 提供了兩種模塊鉤子:
- register_forward_hook:在模塊的前向傳播完成后立即被調用,這個函數可以訪問模塊的輸入和輸出,但不能修改
- register_backward_hook:在反向傳播過程中被調用的,可以用來獲取或修改梯度信息
前向鉤子舉個例子
# 創建模型實例
model = SimpleModel()# 創建一個列表用于存儲中間層的輸出
conv_outputs = []# 定義前向鉤子函數 - 用于在模型前向傳播過程中獲取中間層信息
def forward_hook(module, input, output):print(f"鉤子被調用!模塊類型: {type(module)}")print(f"輸入形狀: {input[0].shape}") # input是一個元組,對應 (image, label)print(f"輸出形狀: {output.shape}")# 保存卷積層的輸出用于后續分析# 使用detach()避免追蹤梯度,防止內存泄漏conv_outputs.append(output.detach())# 在卷積層注冊前向鉤子
# register_forward_hook返回一個句柄,用于后續移除鉤子
hook_handle = model.conv.register_forward_hook(forward_hook)# 創建一個隨機輸入張量 (批次大小=1, 通道=1, 高度=4, 寬度=4)
x = torch.randn(1, 1, 4, 4)# 執行前向傳播 - 此時會自動觸發鉤子函數
output = model(x)# 釋放鉤子 - 重要!防止在后續模型使用中持續調用鉤子造成意外行為或內存泄漏
hook_handle.remove()
反向鉤子
# 定義一個存儲梯度的列表
conv_gradients = []# 定義反向鉤子函數
def backward_hook(module, grad_input, grad_output):print(f"反向鉤子被調用!模塊類型: {type(module)}")print(f"輸入梯度數量: {len(grad_input)}")print(f"輸出梯度數量: {len(grad_output)}")# 保存梯度供后續分析conv_gradients.append((grad_input, grad_output))# 在卷積層注冊反向鉤子
hook_handle = model.conv.register_backward_hook(backward_hook)# 創建一個隨機輸入并進行前向傳播
x = torch.randn(1, 1, 4, 4, requires_grad=True)
output = model(x)# 定義一個簡單的損失函數并進行反向傳播
loss = output.sum()
loss.backward()# 釋放鉤子
hook_handle.remove()
(2)張量鉤子
PyTorch 還提供了張量鉤子,允許我們直接監聽和修改張量的梯度。張量鉤子有兩種:
- register_hook:用于監聽張量的梯度
- register_full_backward_hook:用于在完整的反向傳播過程中監聽張量的梯度
# 創建一個需要計算梯度的張量
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3# 定義一個鉤子函數,用于修改梯度
def tensor_hook(grad):print(f"原始梯度: {grad}")# 修改梯度,例如將梯度減半return grad / 2# 在y上注冊鉤子
hook_handle = y.register_hook(tensor_hook)# 計算梯度,梯度會從z反向傳播經過y到x,此時調用鉤子函數
z.backward()print(f"x的梯度: {x.grad}")# 釋放鉤子
hook_handle.remove()
4、Grad-CAM
一個可視化算法,通過梯度信息用熱力圖顯示圖片中哪些區域讓CNN做出了某個分類決定(比如為什么認為這是“貓”),原理:
- 梯度計算:看最后幾層特征圖的梯度,哪個特征圖對預測“貓”的貢獻大
- 加權融合:把重要的特征圖合并成一張熱力圖(重要區域更亮)
- 疊加顯示:把熱力圖蓋在原圖上,一眼看出貓的臉/耳朵等關鍵部位被高亮了
# Grad-CAM實現
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 注冊鉤子,用于獲取目標層的前向傳播輸出和反向傳播梯度self.register_hooks()def register_hooks(self):# 前向鉤子函數,在目標層前向傳播后被調用,保存目標層的輸出(激活值)def forward_hook(module, input, output):self.activations = output.detach()# 反向鉤子函數,在目標層反向傳播后被調用,保存目標層的梯度def backward_hook(module, grad_input, grad_output):self.gradients = grad_output[0].detach()# 在目標層注冊前向鉤子和反向鉤子self.target_layer.register_forward_hook(forward_hook)self.target_layer.register_backward_hook(backward_hook)def generate_cam(self, input_image, target_class=None):# 前向傳播,得到模型輸出model_output = self.model(input_image)if target_class is None:# 如果未指定目標類別,則取模型預測概率最大的類別作為目標類別target_class = torch.argmax(model_output, dim=1).item()# 清除模型梯度,避免之前的梯度影響self.model.zero_grad()# 反向傳播,構造one-hot向量,使得目標類別對應的梯度為1,其余為0,然后進行反向傳播計算梯度one_hot = torch.zeros_like(model_output)one_hot[0, target_class] = 1model_output.backward(gradient=one_hot)# 獲取之前保存的目標層的梯度和激活值gradients = self.gradientsactivations = self.activations# 對梯度進行全局平均池化,得到每個通道的權重,用于衡量每個通道的重要性weights = torch.mean(gradients, dim=(2, 3), keepdim=True)# 加權激活映射,將權重與激活值相乘并求和,得到類激活映射的初步結果cam = torch.sum(weights * activations, dim=1, keepdim=True)# ReLU激活,只保留對目標類別有正貢獻的區域,去除負貢獻的影響cam = F.relu(cam)# 調整大小并歸一化,將類激活映射調整為與輸入圖像相同的尺寸(32x32),并歸一化到[0, 1]范圍cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)cam = cam - cam.min()cam = cam / cam.max() if cam.max() > 0 else camreturn cam.cpu().squeeze().numpy(), target_classidx = 102 # 選擇測試集中的第101張圖片 (索引從0開始)
image, label = testset[idx]
print(f"選擇的圖像類別: {classes[label]}")# 轉換圖像以便可視化
def tensor_to_np(tensor):img = tensor.cpu().numpy().transpose(1, 2, 0)mean = np.array([0.5, 0.5, 0.5])std = np.array([0.5, 0.5, 0.5])img = std * img + meanimg = np.clip(img, 0, 1)return img# 添加批次維度并移動到設備
input_tensor = image.unsqueeze(0).to(device)# 初始化Grad-CAM(選擇最后一個卷積層)
grad_cam = GradCAM(model, model.conv3)# 生成熱力圖
heatmap, pred_class = grad_cam.generate_cam(input_tensor)# 可視化
plt.figure(figsize=(12, 4))# 原始圖像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始圖像: {classes[label]}")
plt.axis('off')# 熱力圖
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM熱力圖: {classes[pred_class]}")
plt.axis('off')# 疊加的圖像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("疊加熱力圖")
plt.axis('off')plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()
@浙大疏錦行