1. 使用自定義函數的雙重反向傳播
有時候,在反向計算圖中運行兩次反向傳播是有用的,例如計算高階梯度。然而,支持雙重反向傳播需要對自動求導(autograd)有一定的理解,并且需要小心處理。支持單次反向傳播的函數不一定能夠支持雙重反向傳播。在本教程中,我們將展示如何編寫一個支持雙重反向傳播的自定義自動求導函數,并指出一些需要注意的事項。
在編寫一個支持兩次反向傳播的自定義自動求導函數時,了解自定義函數中的操作何時被自動求導記錄、何時不被記錄,以及最重要的是,save_for_backward 如何與這些機制配合工作,是非常關鍵的。
自定義函數以兩種方式隱式影響梯度模式:
-
在前向傳播期間,自動求導不會記錄在前向函數中執行的任何操作的計算圖。當前向傳播完成時,自定義函數的反向函數將成為每個前向輸出的 grad_fn。
-
在反向傳播期間,如果指定了 create_graph,自動求導會記錄用于計算反向傳播的計算圖。
接下來,為了理解 save_for_backward 如何與上述機制交互,我們可以通過幾個示例來探討。
1.1保存輸入
考慮這個簡單的平方函數。它保存了一個輸入張量以便用于反向傳播。雙重反向傳播會在 autograd 能夠記錄反向傳播中的操作時自動