PyTorch擴展自定義PyThonC++(CUDA)算子的若干方法總結

PyTorch擴展自定義PyThon/C++(CUDA)算子的若干方法總結

轉自:https://zhuanlan.zhihu.com/p/158643792

作者:奔騰的黑貓

在做畢設的時候需要實現一個PyTorch原生代碼中沒有的并行算子,所以用到了這部分的知識,再不總結就要忘光了= =,本文內容主要是PyTorch的官方教程的各種傳送門,這些官方教程寫的都很好,以后就可以不用再浪費時間在百度上了。由于圖神經網絡計算框架PyG的代碼實現也是采用了擴展的方法,因此也可以當成下面總結PyG源碼文章的前導知識吧 。

第一種情況:使用PyThon擴展PyTorch

使用PyThon擴展PyTorch準確的來說是在PyTorch的Python前端實現自定義算子或者模型,不涉及底層C++的實現。這種擴展方式是所有擴展方式中最簡單的,也是官方首先推薦的,這是因為PyTorch在NVIDIA cuDNN,Intel MKL或NNPACK之類的庫的支持下已經對可能出現的CPU和GPU操作進行了高度優化,因此用Python擴展的代碼通常足夠快。

比如要擴展一個新的PyThon算子(torch.nn)只需要繼承torch**.nn.Module并實現其forward方法即可。詳細的過程請參考**官方教程傳送門:

Extending PyTorchpytorch.org/docs/master/notes/extending.html

第二種情況:使用**pybind11**構建共享庫形式的C++和CUDA擴展

但是如果我們想對代碼進行進一步優化,比如對自己的算子添加并行的CUDA實現或者連接個OpenCV的庫什么的,那么僅僅使用Python進行擴展就不能滿足需求;其次如果我們想序列化模型,在一個沒有Python環境的生產環境下部署,也需要我們使用C++重寫算法;最后考慮到考慮到多線程執行和性能原因,一般Python代碼也并不適合做部署。因此在對性能有要求或者需要序列化模型的場景下我們還是會用到C++擴展。

下面我先把官方教程傳送門放在這里:

CUSTOM C++ AND CUDA EXTENSIONSpytorch.org/tutorials/advanced/cpp_extension.html

對于一種典型的擴展情況,比如我們要設計一個全新的C++底層算子,其過程其實就三步:

第一步:使用C++編寫算子的forward函數和backward函數

第二步:將該算子的forward函數和backward函數使用**pybind11**綁定到python上

第三步:使用setuptools/JIT/CMake編譯打包C++工程為so文件

注意到在第一步中,我們不僅僅要實現forward函數也要實現backward函數,這是因為在C++端PyTorch目前不支持自動根據forward函數推導出backward函數,所以我們必須要對自己算子的反向傳播過程完全清楚。一個需要注意的地方是,你可以選擇直接在C++中繼承torch::autograd類進行擴展;也可以像官方教程中那樣在C++代碼中實現forward和backward的核心過程,而在python端繼承PyTorch的torch**.autograd.**Function類。

在C++端擴展forward函數和backward函數的需要注意以下規則:

(1)首先無論是forward函數還是backward函數都需要聲明為靜態函數

(2)forward函數可以接受任意多的參數并且應該返回一個 variable list或者variable;forward函數需要將torch::autograd::AutogradContext 作為自己的第一個參數。Variables可以被使用ctx->save_for_backward保存,而其他數據類型可以使用ctx->saved_data以<std::string,at::IValue>pairs的形式保存在一個map中。

(3)backward函數第一個參數同樣需要為torch::autograd::AutogradContext,其余的參數是一個variable_list,包含的變量數量與forward輸出的變量數量相等。它應該返回和forward輸入一樣多的變量。保存在forward中的Variable變量可以通過ctx->get_saved_variables而其他的數據類型可以通過ctx->saved_data獲取。

請注意,backward的輸入參數是自動微分系統反傳回來的參數梯度值,其需要和forward函數的返回值位置一一對應的;而backward的返回值是對各參數根據自動微分規則求導后的梯度值,其需要和forward函數的輸入參數位置一一對應,對于不需要求導的參數也需要使用空Variable占位。

// PyG的C++擴展就選擇的是直接繼承PyTorch的C++端的torch::autograd類進行擴展
// 下面是PyG的一個ScatterSum算子的擴展示例
// 不用糾結這個算子的具體內容,對擴展的算子的結構有一個大致了解即可
class ScatterSum : public torch::autograd::Function<ScatterSum> {
public:// AutogradContext *ctx指針可以操作static variable_list forward(AutogradContext *ctx, Variable src,Variable index, int64_t dim,torch::optional<Variable> optional_out,torch::optional<int64_t> dim_size) {dim = dim < 0 ? src.dim() + dim : dim;ctx->saved_data["dim"] = dim;ctx->saved_data["src_shape"] = src.sizes();index = broadcast(index, src, dim);auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");auto out = std::get<0>(result);ctx->save_for_backward({index});// 如果在擴展的C++代碼中使用非Aten內建操作修改了tensor的值,需要對其進行臟標記if (optional_out.has_value())ctx->mark_dirty({optional_out.value()});  return {out};}// grad_outs是out參數反傳回來的梯度值static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {auto grad_out = grad_outs[0];auto saved = ctx->get_saved_variables();auto index = saved[0];auto dim = ctx->saved_data["dim"].toInt();auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());auto grad_in = torch::gather(grad_out, dim, index, false);// 不需要求導的參數需要空Variable占位return {grad_in, Variable(), Variable(), Variable(), Variable()};}
};

由于涉及到在C++環境下操作張量和反向傳播等操作,因此我們需要對PyTorch的C++后端的庫有所了解,主要就是Torch和Aten這兩個庫,下面我簡要介紹一下這兩兄弟。

img

其中Torch是PyTorch的C++底層實現(PS:其實是先有的Torch后有的PyTorch,從名字也能看出來),FB在編碼PyTorch的時候就有意將PyTorch的接口和Torch的接口設計的十分類似,因此如果你對PyTorch很熟悉的話那么你也會很快的對Torch上手。

Torch官方文檔傳送門:

The C++ Frontendpytorch.org/cppdocs/frontend.html

安裝PyTorch的C++前端的官方教程:

INSTALLING C++ DISTRIBUTIONS OF PYTORCHpytorch.org/cppdocs/installing.html

而Aten是ATen從根本上講是一個張量庫,在PyTorch中幾乎所有其他Python和C ++接口都在其上構建。它提供了一個核心Tensor類,在其上定義了數百種操作。這些操作大多數都具有CPU和GPU實現,Tensor該類將根據其類型向其動態調度。和Torch相比Aten更接近底層和核心邏輯。

Aten源代碼傳送門:

https://github.com/zdevito/ATen/tree/master/aten/srcgithub.com/zdevito/ATen/tree/master/aten/src

使用Aten聲明和操作張量的教程:

TENSOR BASICSpytorch.org/cppdocs/notes/tensor_basics.html

由于Pyorch的C++后端文檔比較少,因此要多參考官方的例子,嘗試去模仿官方教程的代碼,同時可以通過Python前端的接口猜測后端接口的功能,如果沒有文檔了就讀一讀源碼,還是有不少注釋的,還能理解實現的邏輯。

第三種情況:為TORCHSCRIPT添加C++和CUDA擴展

首先簡單解釋一下TorchScript是什么,如果用官方的定義來說:“TorchScript是一種從PyTorch代碼創建可序列化和可優化模型的方法。任何TorchScript程序都可以從一個Python進程中保存并可以在一個沒有Python環境的進程中被加載。”通俗來說TorchScript就是一個序列化模型(即Inference)的工具,它可以讓你的PyTorch代碼方便的在生產環境中部署,同時在將PyTorch代碼轉化TorchScript代碼時還會對你的模型進行一些性能上的優化。使用TorchScript完成模型的部署要比我們之前提到的使用C++重寫要簡單的多,因為是自動生成的。

TorchScript包含兩種序列化模型的方法:tracingscript,兩種方法各有其適用場景,由于和本文關系不大就不詳細展開了,具體的官方教程傳送門在此:

INTRODUCTION TO TORCHSCRIPTpytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

但是,TorchScript只能自動化的構造PyTorch的原生代碼,如果我們需要序列化自定義的C++擴展算子,則需要我們顯式的將這些自定義算子注冊到TorchScript中,所幸的是,這一過程其實非常簡單,整個過程和第二小節中使用pybind11構建共享庫的形式的C++和CUDA擴展十分類似。官方教程傳送門如下:

EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORSpytorch.org/tutorials/advanced/torch_script_custom_ops.html

而對于自定義的C++類,如果要注冊到TorchScript要稍微復雜一些,官方教程傳送門如下:

EXTENDING TORCHSCRIPT WITH CUSTOM C++ CLASSESpytorch.org/tutorials/advanced/torch_script_custom_classes.html?highlight=registeroperators

另外需要注意的是,如果想要編寫能夠被TorchScript編譯器理解的代碼,需要注意在C++自定義擴展算子參數中的數據類型,目前被TorchScript支持的參數數據類型有torch::Tensortorch::Scalar(標量類型),doubleint64_tstd::vector,而像float,int,short這些是不能作為自定義擴展算子的參數數據類型的。

目前就先總結這么多吧,這點東西居然寫了一天,好累啊(*  ̄︿ ̄)。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532457.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532457.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532457.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

給 Python 算法插上性能的翅膀——pybind11 落地實踐

給 Python 算法插上性能的翅膀——pybind11 落地實踐 轉自&#xff1a;https://zhuanlan.zhihu.com/p/444805518 作者&#xff1a;jesonxiang&#xff08;向乾彪&#xff09;&#xff0c;騰訊 TEG 后臺開發工程師 1. 背景 目前 AI 算法開發特別是訓練基本都以 Python 為主&…

chrome自動提交文件_收集文檔及提交名單統計

知乎文章若有排版問題請見諒&#xff0c;原文放在個人博客中【歡迎互踩&#xff01;】文叔叔文檔收集使用動機在我們的學習工作中&#xff0c;少不了要讓大家集體提交文件的情況&#xff0c;舉個最簡單的例子&#xff1a;收作業。 傳統的文件收集流程大致是&#xff1a;群內發出…

Pytorch自定義C++/CUDA擴展

Pytorch自定義C/CUDA擴展 翻譯自&#xff1a;官方文檔 PyTorch 提供了大量與神經網絡、張量代數、數據整理和其他操作。但是&#xff0c;我們有時會需要更加定制化的操作。例如&#xff0c;想要使用論文中找到的一種新型的激活函數&#xff0c;或者實現自己設計的算子。 在 Py…

惠普800g1支持什么內存_惠普黑白激光打印機哪種好 惠普黑白激光打印機推薦【圖文詳解】...

打印機的出現讓我們在生活和日常工作中變得越來越方便&#xff0c;不過隨著科技的發展&#xff0c;打印機的類型也變得非常多&#xff0c;其中就有黑白激光打印機&#xff0c;而黑白激光打印機的品牌也有很多&#xff0c;比如我們的惠普黑白激光打印機&#xff0c;今天小編就給…

控制臺輸出顏色控制

控制臺輸出顏色控制 轉自&#xff1a;https://cloud.tencent.com/developer/article/1142372 前端時間&#xff0c;寫了一篇 PHP 在 Console 模式下的進度顯示 &#xff0c;正好最近的一個數據合并項目需要用到控制臺顏色輸出&#xff0c;所以就把相關的信息整理下&#xff0c;…

idea連接跳板機_跳板機服務(jumpserver)

一、跳板機服務作用介紹1、有效管理用戶權限信息2、有效記錄用戶登錄情況3、有效記錄用戶操作行為二、跳板機服務架構原理三、跳板機服務安裝過程第一步&#xff1a;安裝跳板機依賴軟件yum -y install git python-pip mariadb-devel gcc automake autoconf python-devel readl…

【詳細圖解】再次理解im2col

【詳細圖解】再次理解im2col 轉自&#xff1a;https://mp.weixin.qq.com/s/GPDYKQlIOq6Su0Ta9ipzig 一句話&#xff1a;im2col是將一個[C,H,W]矩陣變成一個[H,W]矩陣的一個方法&#xff0c;其原理是利用了行列式進行等價轉換。 為什么要做im2col? 減少調用gemm的次數。 重要…

反思 大班 快樂的機器人_幼兒園大班教案《快樂的桌椅》含反思

大班教案《快樂的桌椅》含反思適用于大班的體育主題教學活動當中&#xff0c;讓幼兒提高協調性和靈敏性&#xff0c;創新桌椅的玩法&#xff0c;正確爬的方法&#xff0c;學會匍匐前進&#xff0c;快來看看幼兒園大班《快樂的桌椅》含反思教案吧。幼兒園大班教案《快樂的桌椅》…

DCN可形變卷積實現1:Python實現

DCN可形變卷積實現1&#xff1a;Python實現 我們會先用純 Python 實現一個 Pytorch 版本的 DCN &#xff0c;然后實現其 C/CUDA 版本。 本文主要關注 DCN 可形變卷積的代碼實現&#xff0c;不會過多的介紹其思想&#xff0c;如有興趣&#xff0c;請參考論文原文&#xff1a; …

藍牙耳機聲音一頓一頓的_線控耳機黨陣地轉移成功,OPPO這款TWS耳機體驗滿分...

“你看到我手機里3.5mm的耳機孔了嗎”&#xff0c;這可能是許多線控耳機黨最想說的話了。確實&#xff0c;如今手機在做“減法”&#xff0c;而廠商們首先就拿3.5mm耳機孔“開刀”&#xff0c;我們也喪失了半夜邊充電邊戴耳機打游戲的樂趣。竟然如此&#xff0c;那如何在耳機、…

AI移動端優化之Im2Col+Pack+Sgemm

AI移動端優化之Im2ColPackSgemm 轉自&#xff1a;https://blog.csdn.net/just_sort/article/details/108412760 這篇文章是基于NCNN的Sgemm卷積為大家介紹Im2ColPackSgemm的原理以及算法實現&#xff0c;希望對算法優化感興趣或者做深度學習模型部署的讀者帶來幫助。 1. 前言 …

elementui的upload組件怎么獲取上傳的文本流、_抖音feed流直播間引流你還不會玩?實操講解...

本文由艾奇在線明星優化師寫作計劃出品在這個全民驚恐多災多難且帶有魔幻的2020&#xff0c;一場突如其來的疫情改變了人們很多消費習慣&#xff0c;同時加速了直播電商的發展&#xff0c;現在直播已經成為商家必爭的營銷之地&#xff0c;直播雖然很火&#xff0c;但如果沒有流…

FFmpeg 視頻處理入門教程

FFmpeg 視頻處理入門教程 轉自&#xff1a;https://www.ruanyifeng.com/blog/2020/01/ffmpeg.html 作者&#xff1a; 阮一峰 日期&#xff1a; 2020年1月14日 FFmpeg 是視頻處理最常用的開源軟件。 它功能強大&#xff0c;用途廣泛&#xff0c;大量用于視頻網站和商業軟件&…

checkbox wpf 改變框的大小_【論文閱讀】傾斜目標范圍框(標注)的終極方案

前言最常用的斜框標注方式是在正框的基礎上加一個旋轉角度θ&#xff0c;其代數表示為(x_c,y_c,w,h,θ)&#xff0c;其中(x_c,y_c )表示范圍框中心點坐標&#xff0c;(w,h)表示范圍框的寬和高[1,2,7]。對于該標注方式&#xff0c;如果將w和h的值互換&#xff0c;再將θ加上或者…

徹底理解BP之手寫BP圖像分類你也行

徹底理解BP之手寫BP圖像分類你也行 轉自&#xff1a;https://zhuanlan.zhihu.com/p/397963213 第一節&#xff1a;用矩陣的視角&#xff0c;看懂BP的網絡圖 1.1、什么是BP反向傳播算法 BP(Back Propagation)誤差反向傳播算法&#xff0c;使用反向傳播算法的多層感知器又稱為B…

h5頁面禁止復制_H5移動端頁面禁止復制技巧

前言&#xff1a;業務需要&#xff0c;需要對整個頁面禁止彈出復制菜單。在禁止的頁面中加入以下css樣式定義* {-webkit-touch-callout:none;/*系統默認菜單被禁用*/-webkit-user-select:none;/*webkit瀏覽器*/-khtml-user-select:none;/*早起瀏覽器*/-moz-user-select:none;/*…

梯度下降法和牛頓法計算開根號

梯度下降法和牛頓法計算開根號 本文將介紹如何不調包&#xff0c;只能使用加減乘除法實現對根號x的求解。主要介紹梯度下降和牛頓法者兩種方法&#xff0c;并給出 C 實現。 梯度下降法 思路/步驟 轉化問題&#xff0c;將 x\sqrt{x}x? 的求解轉化為最小化目標函數&#xff…

匯博工業機器人碼垛機怎么寫_全自動碼垛機器人在企業生產中的地位越來越重要...

全自動碼垛機器人在企業生產中的地位越來越重要在智能化的各種全自動生產線中&#xff0c;全自動碼垛機器人成了全自動生產線的重要機械設備&#xff0c;在各種生產中發揮著不可忽視的作用。全自動碼垛機器人主要用于生產線上的包裝過程中&#xff0c;不僅能夠提高企業的生產率…

kmeans手寫實現與sklearn接口

kmeans手寫實現與sklearn接口 kmeans簡介 K 均值聚類是最基礎的一種聚類方法。它是一種迭代求解的聚類分析算法。 kmeans的迭代步驟 給各個簇中心 μ1,…,μc\mu_1,\dots,\mu_cμ1?,…,μc? 以適當的初值&#xff1b; 更新樣本 x1,…,xnx_1,\dots,x_nx1?,…,xn? 對應的…

小說中場景的功能_《流浪地球》:從小說到電影

2019年春節賀歲檔冒出一匹黑馬&#xff1a;國產科幻片《流浪地球》大年初一上映后口碑、票房雙豐收&#xff1a;截至9日下午&#xff0c;票房已破15億&#xff0c;并獲得9.2的高評分。著名導演詹姆斯卡梅隆通過社交媒體對我國春節期間上映的科幻影片《流浪地球》發出的祝愿&…