下载
中文
注册

算子converter注册和实现

完成算子注册PyTorch后,支持以Eager模式调用自定义算子,但不支持图模式调用。自定义算子还需向TorchAir注册并实现对应converter函数,完成Aten IR向GE IR转换,实现算子在NPU上入图。具体实现步骤如下:

  1. 获取自定义算子原型。

    获取算子开发和部署中定义的算子原型,假设AddCustom算子原型定义如下:

    缺省情况下,基于Ascend C开发的算子原型文件在${INSTALL_DIR}/opp/vendors/op_proto/inc/op_proto.h中。其中${INSTALL_DIR}请替换为CANN软件安装后文件存储路径。例如,$HOME/Ascend/latest$HOME/Ascend/ascend-toolkit/latest/x86_64-linux。例如,若安装的Ascend-cann-toolkit软件包,则安装后文件存储路径为:$HOME/Ascend/ascend-toolkit/latest。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    # 根据实际情况,替换为自定义的算子原型
    REG_OP(AddCustom)
        .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_BF16, DT_INT16,
                                   DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, DT_COMPLEX64, DT_STRING}))
        .INPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_BF16, DT_INT16,
                                   DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, DT_COMPLEX64, DT_STRING}))
        .OUTPUT(z, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_BF16, DT_INT16,
                                   DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, DT_COMPLEX64, DT_STRING}))        
        .OP_END_FACTORY_REG(AddCustom)
    
  2. 自定义算子converter注册和实现。

    converter功能是将PyTorch FX图的节点转换为NPU GE图的节点。对于自定义算子,需要先实现对应的converter,否则会出现报错。

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch_npu
    from torchair import register_fx_node_ge_converter
    from torchair.ge import Tensor
    
    # 将PyTorch aten算子注册到TorchAir
    # 装饰器固定,torch.ops.npu_define.custom_op.default为自定义算子生成的Python函数签名。
    @register_fx_node_ge_converter(torch.ops.npu.npu_add_custom.default)        
    # 实现converter
    def convert_npu_add_custom(x: Tensor, y: Tensor, z: Tensor = None, meta_outputs: Any = None):
        return torchair.ge.custom_op(                  # 根据算子原型定义设置变量名,顺序保持一致
            "AddCustom",
            inputs={
                "x": x,
                "y": y,
            },
            outputs=['z']
        )