簡單適配torch_npu不支持的ATen算子 一、背景說明 二、實現步驟詳解 2.1 實現前向、反向傳播算子 2.2 編譯生成動態庫 2.3 測試驗證程序 三、關鍵點解析 四、驗證結果
一、背景說明
1.1 PyTorch擴展機制
PrivateUse1
是PyTorch為第三方設備擴展設計的保留設備類型,允許開發者添加新硬件支持當算子在當前設備(如NPU)未實現時,PyTorch會自動回退(fallback)到CPU執行 本文以native_batch_norm
算子為例,演示如何為NPU設備添加自定義實現
1.2 核心概念
ATen :PyTorch的核心張量運算庫,提供超過2000個基礎算子內存格式 :描述張量在內存中的排布方式,如NCHW(批處理x通道x高度x寬度)自動微分 :PyTorch通過記錄計算圖實現反向傳播,需要同時實現前向和反向算子
二、實現步驟詳解
2.1 實現前向、反向傳播算子
cat > native_batch_norm_npu. cpp << - 'EOF'
# include <torch/library.h>
# include <ATen/EmptyTensor.h>
# include <ATen/Device.h>
# include <ATen/Utils.h>
# include <ATen/native/Resize.h>
# include <c10/core/DeviceType.h> std:: tuple< at:: Tensor, at:: Tensor, at:: Tensor> native_batch_norm_npu ( const at:: Tensor& input, const c10:: optional< at:: Tensor> & weight, const c10:: optional< at:: Tensor> & bias, const c10:: optional< at:: Tensor> & running_mean, const c10:: optional< at:: Tensor> & running_var, bool training, double momentum, double eps)
{ at:: Tensor output = at:: empty_like ( input) ; at:: Tensor dummy_mean = at:: empty