torch.autograd.Function自定義前向傳播和反向傳播

torch.autograd.Function 是 PyTorch 提供的一個接口,用于自定義前向傳播和反向傳播的操作。自定義操作需要繼承 torch.autograd.Function 并重載 forward 和 backward 方法。

下面是一個簡單的示例,展示如何自定義一個平方操作的前向傳播和反向傳播。

示例一:

import torch
from torch.autograd import Function
class SquareFunction(Function):@staticmethoddef forward(ctx, input):# ctx 是一個上下文對象,用于存儲反向傳播所需的信息ctx.save_for_backward(input)return input * input@staticmethoddef backward(ctx, grad_output):# 從上下文對象中取回前向傳播保存的信息input, = ctx.saved_tensorsgrad_input = grad_output * 2 * inputreturn grad_input
# 輸入張量
input = torch.tensor([2.0, 3.0], requires_grad=True)# 使用自定義的 SquareFunction
output = SquareFunction.apply(input)# 進行反向傳播
output.backward(torch.tensor([1.0, 1.0]))# 打印梯度
print(input.grad)  # 輸出:tensor([4., 6.])

示例二:

import torchclass SignWithSigmoidGrad(torch.autograd.Function):@staticmethoddef forward(ctx, x):result = (x > 0).float()sigmoid_result = torch.sigmoid(x)ctx.save_for_backward(sigmoid_result)return result@staticmethoddef backward(ctx, grad_result):(sigmoid_result,) = ctx.saved_tensorsif ctx.needs_input_grad[0]:grad_input = grad_result * sigmoid_result * (1 - sigmoid_result)else:grad_input = Nonereturn grad_input

這段代碼定義了一個自定義的 PyTorch autograd 函數 SignWithSigmoidGrad,這個函數在前向傳播中計算輸入張量 x 的符號函數(sign function),在反向傳播中計算與 sigmoid 函數有關的梯度。

示例三:

import torch
from torch.autograd import Functionclass SquareFunction(Function):@staticmethoddef forward(ctx, input):# ctx 是一個上下文對象,用于存儲反向傳播所需的信息ctx.save_for_backward(input)return torch.sum(input)@staticmethoddef backward(ctx, grad_output):# 從上下文對象中取回前向傳播保存的信息input, = ctx.saved_tensorsgrad_input = grad_output * 2 * inputreturn grad_input# 輸入張量
input = torch.tensor([2.0, 3.0], requires_grad=True)# 使用自定義的 SquareFunction
output = SquareFunction.apply(input)# 進行反向傳播
output.backward(torch.tensor(2.0))# 打印梯度
print(input.grad)  # 輸出:tensor([8., 12.])

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/43949.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/43949.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/43949.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

idea創建dynamic web project

由于網課老師用的是eclipse,所以又得自己找教程了…… 解決方案: https://blog.csdn.net/Awt_FuDongLai/article/details/115523552

20240709每日后端--------最優解決Invalid bound statement (not found)

目標 最優解決Invalid bound statement (not found) 步驟 1、打包 2、查看target下是否成雙成對出現 3、核對無誤后,即可解決問題。

軟考高級里《系統架構設計師》容易考嗎?

我還是22年通過的架構考試。系統架構設計師屬于軟考高級科目,難度比初級和中級都要大,往年的通過率也比較低,一般在10-20%左右。從總體來說,這門科目確實是不好過的,大家如果想要備考系統架構設計師的話,還…

Kithara和OpenCV (一)

Kithara使用 OpenCV 目錄 Kithara使用 OpenCV簡介需求和支持的環境構建 OpenCV 庫使用 CMake 進行配置以與 Kithara 一起工作 使用 OpenCV 庫設置項目運行 OpenCV 代碼圖像采集和 OpenCV自動并行化限制和局限性1.系統建議2.實時限制3.不支持的功能和缺失的功能4.顯示 OpenCV 對…

【技術選型】FastDFS、OSS如何選擇

【技術選型】FastDFS、OSS如何選擇 開篇詞:干貨篇:FastDFS:OSS(如阿里云OSS): 總結篇:我是杰叔叔,一名滬漂的碼農,下期再會! 開篇詞: 文件存儲該選…

簡談設計模式之原型模式

原型模式是一種創建型設計模式, 用于創建對象, 而不必指定它們所屬的具體類. 它通過復制現有對象 (即原型) 來創建新對象. 原型模式適用于當創建新對象的過程代價較高或復雜時, 通過克隆現有對象來提高性能 原型模式結構 原型接口. 聲明一個克隆自身的接口具體原型. 實現克隆…

【鴻蒙學習筆記】屬性學習迭代筆記

這里寫目錄標題 TextImageColumnRow Text Entry Component struct PracExample {build() {Row() {Text(文本描述).fontSize(40)// 字體大小.fontWeight(FontWeight.Bold)// 加粗.fontColor(Color.Blue)// 字體顏色.backgroundColor(Color.Red)// 背景顏色.width(50%)// 組件寬…

展開說說:Android服務之實現AIDL跨應用通信

前面幾篇總結了Service的使用和源碼執行流程,這里再簡單分析一下如果需要Service跨進程通信該怎樣做。AIDL(Android Interface Definition Language)Android接口定義語言,用于實現 Android 兩個進程之間進行進程間通信&#xff08…

Clickhouse的聯合索引

Clickhouse 有了單獨的鍵索引,為什么還需要有聯合索引呢?了解過mysql的兄弟們應該都知道這個事。 對sql比較熟悉的兄弟們估計看見這個聯合索引心里大概有點數了,不過clickhouse的聯合索引相比mysql的又有些不一樣了,mysql 很遵循最…

深入解析Spring Boot的application.yml配置文件

目錄 引言Spring Boot配置文件簡介 application.yml的優點 基本結構與語法 YAML語法基礎Spring Boot中application.yml的基本結構 常見配置項詳解 服務器配置數據源配置日志配置其他常見配置 環境配置與Profile 多環境配置激活Profile 高級配置與技巧 屬性的占位符替換自定義配…

Spring源碼二十:Bean實例化流程三

上一篇Spring源碼十九:Bean實例化流程二中,我們主要討論了單例Bean創建對象的主要方法getSingleton了解到了他的核心流程無非是:通過一個簡單工廠的getObject方法來實例化bean,當然spring在實例化前后提供了擴展如:bef…

第5章-組合序列類型

#全部是重點知識,必須會。 了解序列和索引|的相關概念 掌握序列的相關操作 掌握列表的相關操作 掌握元組的相關操作 掌握字典的相關操作 掌握集合的相關操作1,序列和索引 1,序列是一個用于存儲多個值的連續空間,每一個值都對應一…

升級之道:精通Conda的自我升級藝術

升級之道:精通Conda的自我升級藝術 引言 Conda是Python和其他科學計算語言的強大包管理器,它不僅管理著包的安裝和依賴,還負責自身的更新。隨著開源社區的不斷發展,Conda定期發布新版本以修復已知問題、增加新功能和提高性能。本…

[面試愛問] https 的s是什么意思,有什么作用?

HTTPS 中的 "S" 代表 "Secure",即安全的意思。HTTPS(全稱是 HyperText Transfer Protocol Secure)是HTTP(HyperText Transfer Protocol)的安全版本,主要作用是為互聯網通信提供安全保護…

靈活多變的對象創建——工廠方法模式(Python實現)

1. 引言 大家好,又見面了!在上一篇文章中,我們聊了聊簡單工廠模式,今天,我們要進一步探討一種更加靈活的工廠設計模式——工廠方法模式。如果說簡單工廠模式是“萬能鑰匙”,那工廠方法模式就是“變形金剛”…

生成式人工智能:助攻開發者還是取代開發者?

引言 近年來,生成式人工智能(AIGC)在軟件開發領域掀起了一場革命,為開發者帶來了全新的工具和可能性。從代碼生成、錯誤檢測到自動化測試,AI正在以各種方式改變著開發者的工作方式。然而,這也引發了人們對開…

Python采集京東標題,店鋪,銷量,價格,SKU,評論,圖片

京東的許多數據是通過 JavaScript 動態加載的,包括銷量、價格、評論和評論時間等信息。我們無法僅通過傳統的靜態網頁爬取方法獲取到這些數據。需要使用到如 Selenium 或 Pyppeteer 等能夠模擬瀏覽器行為的工具。 另外,京東的評論系統是獨立的一個系統&a…

offer題目33:判斷是否是二叉搜索樹的后序遍歷序列

題目描述:輸入一個整數數組,判斷該數組是不是某二叉搜索樹的后序遍歷結果。如果是則返回true,否則返回false。假設輸入的數組的任意兩個數字都互不相同。例如,輸入數組{5,7,6,9,11,10,8},則返回true,,因為這個整數是下圖二叉搜索樹…

c++內存管理(上)

目錄 引入 分析 說明 C語言中動態內存管理方式 C內存管理方式 new/delete操作內置類型 new和delete操作自定義類型 引入 我們先來看下面的一段代碼和相關問題 int globalVar 1; static int staticGlobalVar 1; void Test() { static int staticVar 1; int localVar 1…

集訓day3:并查集

一、目錄 1.并查集模版 2.并查集的理解和應用 二、正文 1.并查集模版 P3367 【模板】并查集 - 洛谷 | 計算機科學教育新生態 (luogu.com.cn) 2.并查集的理解與應用 (1).并查集與聯通塊數量 P1197 [JSOI2008] 星球大戰 - 洛谷 | 計算機科學教育新生態 (luogu.com.cn) P1656 炸…