知識點回顧
1.回調函數
2.lambda函數
3.hook函數的模塊鉤子和張量鉤子
4.Grad-CAM的示例
一。回調函數示例
Hook本質是回調函數,所以我們先介紹一下回調函數。回調函數是作為參數傳遞給其他函數的函數,其目的是在某個特定事件發生時被調用執行。這種機制允許代碼在運行時動態指定需要執行的邏輯,實現了代碼的靈活性和可擴展性。
回調函數的核心價值在于:解耦邏輯:將通用邏輯與特定處理邏輯分離,使代碼更模塊化。
? ? ? ? ? 事件驅動編程:在異步操作、事件監聽(如點擊按鈕、網絡請求完成)等場景中廣泛應用。
? ? ? ? ? ? 延遲執行:允許在未來某個時間點執行特定代碼,而不必立即執行。
其中回調函數作為參數傳入,所以在定義的時候一般用callback來命名,在 PyTorch 的 Hook API 中,回調參數通常命名為 hook
# 訓練過程中的回調函數
class Callback:def on_train_begin(self):print("訓練開始")def on_epoch_end(self, epoch, logs=None):print(f"Epoch {epoch} 完成")# 使用示例
callback = Callback()
callback.on_train_begin()
for epoch in range(10):# ...訓練代碼...callback.on_epoch_end(epoch)
二、lambda函數示例
在hook中常常用到lambda函數,它是一種匿名函數(沒有正式名稱的函數),最大特點是用完即棄,無需提前命名和定義。它的語法形式非常簡約,僅需一行即可完成定義,格式如下:
lambda 參數列表: 表達式
參數列表:可以是單個參數、多個參數或無參數。
表達式:函數的返回值(無需 return 語句,表達式結果直接返回)
# 簡單lambda
add = lambda x, y: x + y# 在PyTorch中的使用
data = torch.randn(10)
processed = list(map(lambda x: x*2, data)) # 每個元素乘以2
三、hook函數示例
# 模塊鉤子
model = nn.Sequential(nn.Linear(10,5), nn.ReLU())
def module_hook(module, input, output):print(f"{module.__class__.__name__} 輸出形狀: {output.shape}")
model[0].register_forward_hook(module_hook)# 張量鉤子
x = torch.randn(3, requires_grad=True)
x.register_hook(lambda grad: grad * 0.5) # 梯度修改
四、Grad-CAM示例
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.gradients = Noneself.activations = Nonetarget_layer.register_forward_hook(self.save_activations)target_layer.register_backward_hook(self.save_gradients)def save_activations(self, module, input, output):self.activations = output.detach()def save_gradients(self, module, grad_input, grad_output):self.gradients = grad_output[0].detach()def __call__(self, x, class_idx=None):# ...前向/反向傳播邏輯...cam = torch.relu(torch.sum(self.activations * weights, dim=1))return cam