适配插件开发(PyTorch框架)
简介
PyTorch1.8.1及以上版本官方提供的native_functions.yaml文件定义了PyTorch Native Functions的具体算子定义和分发细节。因此在NPU设备上适配官方已定义算子时,我们不需要重新定义,只需要注册NPU分发即可。因为我们可以根据已支持的算子(supported,autograd)对应解析官方yaml文件得到每个函数的具体格式,所以对应的函数声明和注册分发可以自动化完成,算子迁移和开发的时候只需要关注对应的实现细节即可。
对于自定义算子,由于没有具体的算子定义,我们需要在npu_native_functions.yaml文件中给出定义,以便对算子进行结构化解析从而实现自动化注册和Python接口绑定。
npu_native_functions.yaml文件介绍:
backend: NPU # Backend类型 cpp_namespace: at_npu::native # 插件中开发算子的命名空间 supported: # 已支持的和PyTorch Native Functions对齐的算子 - add.Tensor - add.Scalar - slow_conv3d.out - slow_conv3d_forward.output - slow_conv3d_forward - convolution - _convolution - _convolution_nogroup - addcdiv - addcdiv_ - addcdiv.out autograd: # 已支持的和PyTorch Native Functions对齐的继承自Function的具有前反向操作的算子 - maxpool2d custom: # 自定义算子,需要提供算子格式定义 - func: npu_dtype_cast(Tensor self, ScalarType dtype) -> Tensor variants: function, method - func: npu_dtype_cast_(Tensor(a!) self, Tensor src) -> Tensor(a!) variants: method - func: npu_alloc_float_status(Tensor self) -> Tensor variants: function, method - func: npu_get_float_status(Tensor self) -> Tensor variants: function, method custom_autograd: # 自定义继承自Function的自定义算子 - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, ...) -> Tensor
注册算子
本步骤目的是将算子分发到昇腾AI处理器。当前制定的NPU适配派发原则是:NPU算子的派发不经过框架公共函数,直接派发成NPU适配的函数;即算子执行调用栈中只包含NPU适配的函数调用,不包含框架公共函数。PyTorch框架在编译时,会根据native_functions.yaml的定义,按框架中定义的类型和设备分发原则,生成相应的新算子的中间层的调用说明。
对于NPU,PyTorch1.8.1及以上会生成在torch_npu/csrc/aten/RegisterNPU.cpp。该目录为在Ascend PyTorch源码目录下获取原生PyTorch代码并生成适配NPU的PyTorch全量代码后的路径。
- 打开native_functions.yaml文件。
native_functions.yaml文件中,定义了所有算子函数原型,包括函数名称和参数等信息。每个算子函数支持不同硬件平台的派发信息。该文件所在路径为:
- 1.8.1及以上:scripts/codegen/native_functions.yaml
- 确定需要派发函数。
算子适配插件开发
本步骤目的是将基于第三方框架的算子属性映射成适配昇腾AI处理器的算子属性。用户通过开发算子适配插件,实现PyTorch原生算子的输入参数、输出参数和属性的格式转换,使转换后的格式与TBE算子的输入参数、输出参数和属性的格式相同。适配昇腾AI处理器的PyTorch源代码中提供了适配关联、类型转换和判别、处理动态shape等相关的方法供用户使用。
- 在npu_native_functions.yaml中添加算子信息。
- 1.8.1及以上:参考npu_native_functions.yaml文件介绍,在文件中添加对应类型需要适配的算子信息。
- 创建适配插件文件。
- 1.8.1及以上:NPU TBE算子适配文件保存在torch_npu/csrc/aten/ops目录下,命名风格采用大驼峰,命名格式:<算子名> + <KernelNpu>.cpp,如:AddKernelNpu.cpp。
- 引入依赖头文件。
适配昇腾AI处理器的PyTorch源代码在torch_npu/csrc/framework/utils中提供适配常用的工具供用户使用。
工具的功能和使用方法,请查看头文件和.cpp源码。
- 定义算子适配主体函数。
根据注册算子开发中的分发函数确定自定义算子适配主体函数。
- 实现适配主体函数。
实现算子适配主题函数,根据TBE算子原型构造得到对应的input、output、attr。
- 重编译PyTorch框架或插件。
- 1.8.1及以上:请重新编译生成torch_npu插件安装包并安装
参考样例
以torch.add()算子在PyTorch1.8.1框架下为例,介绍注册算子开发过程以及算子适配开发过程。
- 打开native_functions.yaml文件,搜索相关函数。
在yaml中搜索add,找到与add算子相关的函数描述func。由于add是PyTorch内置算子,不需要手动添加func。若是自定义算子,需要手动添加func。
- 确定算子相关函数名称及其类型的func描述,此处只列出一个函数的描述。add.Tensor的函数分发描述:
- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor structured_delegate: add.out variants: function, method dispatch: SparseCPU, SparseCUDA: add_sparse MkldnnCPU: mkldnn_add
- 在npu_native_functions.yaml中添加算子信息,以add.Tensor为例。
supported: # 已支持的和PyTorch Native Functions对齐的算子 - add.Tensor #添加add.Tensor
- 进入Ascend PyTorch源码目录并创建适配插件文件,创建AddKernelNpu.cpp。
cd ${pytorch_code_path}/torch_npu/csrc/aten/ops
- 引入依赖头文件。
#include <ATen/Tensor.h> #include <c10/util/SmallVector.h> #include "torch_npu/csrc/core/npu/register/OptionsManager.h" #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" #include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h"
- 定义算子适配主体函数。根据注册算子开发中的分发函数确定自定义算子适配主体函数。
at::Tensor NPUNativeFunctions::add(const at::Tensor &self, const at::Tensor &other, at::Scalar alpha)
- 实现适配主体函数。
// 输入参数为Tensor和Tensor时 at::Tensor NPUNativeFunctions::add(const at::Tensor &self, const at::Tensor &other, at::Scalar alpha) { alpha_check_npu(self.scalar_type(), alpha); if ((!(self.is_contiguous() && other.is_contiguous())) && (NpuUtils::check_5d_5d_match(self) || NpuUtils::check_5d_5d_match(other)) && check_size(self, other)) { int64_t c0_len = 16; at::Tensor self_use = stride_add_tensor_get(self); at::Scalar self_c1_offset( self.storage_offset() / (self.size(2) * self.size(3) * c0_len)); at::Tensor other_use = stride_add_tensor_get(other); at::Scalar other_c1_offset( other.storage_offset() / (other.size(2) * other.size(3) * c0_len)); at::Scalar stride_len(self.size(1) / c0_len); at::Tensor result = NPUNativeFunctions::npu_stride_add( self_use, other_use, self_c1_offset, other_c1_offset, stride_len); return result; } // calculate the output size at::Tensor outputTensor = add_dest_output(self, other); auto outputSize = broadcast_ops_npu_output_size(self, other); // construct the output tensor of the NPU at::Tensor result = OpPreparation::ApplyTensorWithFormat( outputSize, outputTensor.options(), CalcuOpUtil::get_tensor_npu_format(outputTensor)); // calculate the output result of the NPU add_out_npu_nocheck(result, self, other, alpha); return result; }
AddKernelNpu.cpp的详细实现代码请参见PyTorch源代码pytorch/torch_npu/csrc/aten/ops中的AddKernelNpu.cpp文档。