【pytorch擴展】CUDA自定義pytorch算子(簡單demo入手)

Pytorch作為一款優秀的AI開發平臺,提供了完備的自定義算子的規范。我們用torch開發時,經常會因為現有算子的不足限制我們idea的迸發。于是,CUDA/C++自定義pytorch算子是不得不磕了。

今天通過一個小實驗來梳理自定義pytorch算子都需要做哪些準備。比如,我們做一個張量加法。
vim test_add.py

from add import sum_double_op
import torch
import timeclass Timer:def __init__(self, op_name):self.begin_time = 0self.end_time = 0self.op_name = op_namedef __enter__(self):torch.cuda.synchronize()self.begin_time = time.time()def __exit__(self, exc_type, exc_val, exc_tb):torch.cuda.synchronize()self.end_time = time.time()print(f"Average time cost of {self.op_name} is {(self.end_time - self.begin_time) * 1000:.4f} ms")if __name__ == '__main__':n = 1000000device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")tensor1 = torch.ones(n, dtype=torch.float32, device=device, requires_grad=True)tensor2 = torch.ones(n, dtype=torch.float32, device=device, requires_grad=True)with Timer("sum_double"):ans = sum_double_op(tensor1, tensor2)

這里的"sum_double_op"就是我們用CUDA寫的算子。那這個可以直接調用,并且可以傳遞梯度的算子,需要怎么做呢?


眾所周知,CUDA/C++都是編譯性語言,編譯以后再調用會比python這種解釋性語言更快。所以,我們需要對CUDA有一個編譯過程。這個編譯過程用setuptools來實現(可以pip安裝)。
先vim setup.py

from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtensionsetup(name='myAdd',packages=find_packages(),version='0.1.0',author='muzhan',ext_modules=[CUDAExtension('sum_double',['./add/add.cpp','./add/add_cuda.cu',]),],cmdclass={'build_ext': BuildExtension}
)

直接“python setup.py install”即可完成cuda算子的編譯和安裝。等等,你的add.cpp和add_cuda.cu還沒呢?
vim add_cuda.cu

#include <cstdio>
#define THREADS_PER_BLOCK 256
#define WARP_SIZE 32
#define DIVUP(m, n) ((m + n - 1) / n)__global__ void two_sum_kernel(const float* a, const float* b, float * c, int n){int idx = blockIdx.x * blockDim.x + threadIdx.x;if (idx < n){c[idx] = a[idx] + b[idx];}
}void two_sum_launcher(const float* a, const float* b, float* c, int n){dim3 blockSize(DIVUP(n, THREADS_PER_BLOCK));dim3 threadSize(THREADS_PER_BLOCK);two_sum_kernel<<<blockSize, threadSize>>>(a, b, c, n);
}

vim add.cpp

#include <torch/extension.h>
#include <torch/serialize/tensor.h>#define CHECK_CUDA(x) \TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \CHECK_CUDA(x);       \CHECK_CONTIGUOUS(x)void two_sum_launcher(const float* a, const float* b, float* c, int n);void two_sum_gpu(at::Tensor a_tensor, at::Tensor b_tensor, at::Tensor c_tensor){CHECK_INPUT(a_tensor);CHECK_INPUT(b_tensor);CHECK_INPUT(c_tensor);const float* a = a_tensor.data_ptr<float>();const float* b = b_tensor.data_ptr<float>();float* c = c_tensor.data_ptr<float>();int n = a_tensor.size(0);two_sum_launcher(a, b, c, n);
}PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {m.def("forward", &two_sum_gpu, "sum two arrays (CUDA)");
}

我們看一下文件結構:

.
├── add
│   ├── add.cpp
│   ├── add_cuda.cu
│   ├── __init__.py
│   └── sum.py
├── README.md
├── setup.py
└── test_add.py

有了add.cpp和add_cuda.cu以后,我們就可以用"python setup.py install"來進行編譯和安裝了。編譯和安裝以后,我們需要用python類封裝一下:

vim __init__.py
from .sum import *

vim sum.py

from torch.autograd import Function
import sum_doubleclass SumDouble(Function):@staticmethoddef forward(ctx, array1, array2):"""sum_double function forward.Args:array1 (torch.Tensor): [n,]array2 (torch.Tensor): [n,]Returns:ans (torch.Tensor): [n,]"""array1 = array1.float()array2 = array2.float()ans = array1.new_zeros(array1.shape)sum_double.forward(array1.contiguous(), array2.contiguous(), ans)# ctx.mark_non_differentiable(ans) # if the function is no need for backpropogationreturn ans@staticmethoddef backward(ctx, g_out):# return None, None   # if the function is no need for backpropogationg_in1 = g_out.clone()g_in2 = g_out.clone()return g_in1, g_in2sum_double_op = SumDouble.apply

最后,直接

python test_add.py

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/40774.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/40774.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/40774.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

軟設之類的繼承與泛化,多重繼承

在類中&#xff0c;假如父類已經寫好屬性或方法&#xff0c;子類想要實現相同的功能&#xff0c;不用專門寫代碼&#xff0c;直接用專門的繼承語言繼承就可以了。 比如說有一個動物類&#xff0c;有毛色和叫這兩個屬性和方法&#xff0c;又寫了一個子類是貓類&#xff0c;貓類…

騰訊云COS分布式對象存儲

騰訊云COS分布式對象存儲 騰訊云對象存儲&#xff08;Cloud Object Storage&#xff0c;COS&#xff09;是騰訊云提供的一種用于存儲海量文件的分布式存儲服務。 騰訊云 COS 適用于多種場景&#xff0c;如靜態網站托管、大規模數據備份和歸檔、多媒體存儲和處理、移動應用數據存…

Kafka搭建(單機版)

部署前提 VMware環境 : 兩臺centos系統 Jdk包:jdk-8u202-linux-x64.tar.gz Kafka包:kafka_2.12-3.5.0.tgz Zookeeper包:apache-zookeeper-3.7.2-bin.tar.gz 百度網盤自取: 鏈接: https://pan.baidu.com/s/11EWuhBoSmH3musd_3Rgodw?pwde32t 提取碼: e32t Kafka搭建&#xff08;…

Camtasia 2024新功能 Camtasia2024更新介紹:AI剪輯助力微課制作 Camtasia2024密鑰 Camtasia2023免費升級更新

Camtasia 是一款功能強大的屏幕錄制和視頻編輯軟件&#xff0c;廣泛應用于教育、商業和娛樂領域。無論是創建教學視頻、產品演示、教程還是營銷內容&#xff0c;Camtasia都能提供專業的工具和功能&#xff0c;幫助用戶制作高質量的視頻內容。 Camtasia 2024 中文免費安裝包百度…

暑假學習DevEco Studio第2天

學習目標&#xff1a; 掌握頁面跳轉 學習內容&#xff1a; 跳轉頁面 創建頁面&#xff1a; 在“project”窗口。打開“entry>src>main>ets”,右擊“pages”&#xff0c;選擇“New>ArkTS File”,命名“Second”&#xff0c;點擊回車鍵。 在頁面的路由&#xff0…

昇思25天學習打卡營第16天|文本解碼原理——以MindNLP為例

在大模型中&#xff0c;文本解碼通常是指在自然語言處理&#xff08;NLP&#xff09;任務中使用的大型神經網絡模型&#xff08;如Transformer架構的模型&#xff09;將編碼后的文本數據轉換回可讀的原始文本的過程。這些模型在處理自然語言時&#xff0c;首先將輸入文本&#…

【Unix/Linux】Unix/Linux如何查看系統版本

Unix和Linux查看系統版本的指令有些區別&#xff0c;下面分別介紹: 一.Unix查看系統版本 在Unix系統中&#xff0c;查看系統版本的方法可能會根據具體的Unix操作系統而有所不同。以下是一些通用的方法&#xff0c;適用于多種Unix系統&#xff0c;包括但不限于Solaris、AIX、H…

vienna整流器過零畸變原因分析

Vienna整流器是一種常見的三電平功率因數校正&#xff08;PFC&#xff09;整流器&#xff0c;廣泛應用于電源和電能質量控制領域。由于其高效率、高功率密度和低諧波失真的特點&#xff0c;Vienna整流器在工業和電力電子應用中具有重要地位。然而&#xff0c;在實際應用中&…

ssh:(xshell)遠程連接失敗

項目場景&#xff1a; 提示&#xff1a;這里簡述項目相關背景&#xff1a; 云服務器遠程連接失敗 xshell 遠程連接失敗 xshell (ssh客戶端&#xff09; ---------------------------------------------安全組----------防火墻-------黑白名單-----SSH服務 問題排查 1. 安全…

Playwright之錄制腳本轉Page Object類

Playwright之錄制腳本轉Page Object類 設計思路 &#xff1a; 我們今天UI自動化設計的時候&#xff0c;通常會遵循一些設計模式&#xff0c;例如Page Object模式。但是自己找元素再去填寫有一些麻煩&#xff0c;所以我們可以通過拆解錄制的腳本&#xff0c;將其中的元素提取出來…

DALL-E、Stable Diffusion 等 20+ 圖像生成模型綜述

二、任務場景 2.1. 無條件生成 無條件生成是指生成模型在生成圖像時不受任何額外條件或約束的影響。模型從學習的數據分布中生成圖像&#xff0c;而不需要關注輸入條件。 2.2. 有條件生成 有條件生成是指生成模型在生成圖像時受到額外條件或上下文的影響。這些條件可以是類別…

Vscode 保存代碼,代碼自動格式化

我這里使用的插件是Prettier-Code formatter&#xff1a;自動縮進整理代碼的格式&#xff0c;使用方法如下&#xff1a; 先在vscode商店找到插件并安裝&#xff1a;安裝插件之后&#xff0c;隨便找到一個項目文件&#xff0c;右鍵選擇格式化文檔&#xff1a;選中我們安裝的插件…

掌握Vim的會話之道:深度解析會話管理功能

掌握Vim的會話之道&#xff1a;深度解析會話管理功能 在高效的文本編輯工作流中&#xff0c;能夠保存和恢復編輯會話是極其重要的。Vim&#xff0c;作為一個功能強大的文本編輯器&#xff0c;提供了會話管理功能&#xff0c;允許用戶保存當前的工作狀態&#xff0c;并在之后重…

spring6框架解析(by尚硅谷)

文章目錄 spring61. 一些基本的概念、優勢2. 入門案例實現maven聚合工程創建步驟分析實現過程 3. IoC&#xff08;Inversion of Control&#xff09;基于xml的bean環境搭建獲取bean獲取接口創建實現類依賴注入 setter注入 和 構造器注入原生方式的setter注入原生方式的構造器注…

Java 多線程stream流按行讀取文件

stream并行流快&#xff08;文件11g&#xff09; try (Stream<String> lines Files.lines(filePath)) {lines.parallel().forEach(str -> operatePartData(str, allDataList)); } catch (IOException e) {throw new RuntimeException(e); }線程池慢&#xff08;文件…

PyPDF2合并PDF文件的高級應用:指定合并方式

本文目錄 前言一、合并PDF的高級應用1、邏輯講解2、合并效果圖3、完整代碼二、異常校驗1、合并過程中的錯誤校驗前言 本文我們主要來講解一下PyPDF2合并PDF文件的高級應用,就是指定合并方式進行合并,構建函數支持模式選擇,主要不管咋折騰,其實就是不想去付費買那個PDF編輯…

PDF怎么分割成一頁一頁的?原來可以這么輕松

PDF怎么分割成一頁一頁的&#xff1f;PDF文檔因其跨平臺兼容性和可打印性而被廣泛使用&#xff0c;但有時為了便于發送電子郵件、管理文檔或保護敏感信息&#xff0c;我們需要將一個大型的PDF文件分割成多個小文件。幸運的是&#xff0c;分割PDF文件并不復雜。下文中就介紹了三…

webp2jpg網頁在線圖片格式轉換源碼

源碼介紹 webp2jpg-免費在線圖片格式轉化器, 可將jpeg、jpg、png、gif、 webp、svg、ico、bmp文件轉化為jpeg、png、webp、webp動畫、gif文件。 無需上傳文件&#xff0c;本地即可完成轉換! 源碼特點&#xff1a; 無需上傳&#xff0c;使用瀏覽器自身進行轉換批量轉換輸出we…

easyexcel使用小結-未完待續

官網&#xff1a;https://easyexcel.opensource.alibaba.com/docs/current/ <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>4.0.1</version></dependency>一、讀 1.1簡單讀 Getter…

系統安全體系架構規劃框架

安全技術體系架構是對組織機構信息技術系統的安全體系結構的整體描述。安全技術體系架構框架是擁有信息技術系統的組織機構根據其策略的要求和風險評估的結果&#xff0c;參考相關技術體系構架的標準和最佳實踐&#xff0c;結合組織機構信息技術系統的具體現狀和需求&#xff0…