下载
中文
注册

custom_op

函数功能

基于算子原型(IR)实现算子conveter函数,完成PyTorch IR与GE IR的转换,方便自定义算子入图。

converter功能是将PyTorch FX图的节点转换为NPU GE图的节点。对于自定义算子,需要先实现对应的converter,否则会出现报错。

函数原型

def custom_op(op_type: str, *, inputs: Optional[Dict[str, Optional[Union['Tensor', List['Tensor']]]]], outputs: Optional[List[Union[str, Tuple[str, int]]]], attrs: Optional[Dict[str, '_Attr']] = None, node_name: Optional[str] = None):--> OUT: Tensor

参数说明

参数

输入/输出

说明

是否必选

备注

op_type

输入

自定义算子类型,字符串类型,取值与算子原型REG_OP保持一致,例如"MyOp"。

关于算子原型REG_OP的介绍,请参考

《CANN Ascend Graph开发指南》中“什么是算子原型”章节

inputs

输入

算子原型的输入参数,为Dict类型,取值和顺序需与算子原型REG_OP保持一致。

outputs

输入

算子原型的输出参数,为List类型。取值和顺序需与算子原型REG_OP保持一致。

attrs

输入

算子原型的属性参数,为Dict类型,缺省值为None。取值和顺序需与算子原型REG_OP保持一致。

attrs一般通过ge.attr.Xxx显式确定类型,当前支持如下类型:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
"torchair.ge.attr": [
"Bool",
"DataType",
"Float",
"Int",
"ListBool",
"ListDataType",
"ListFloat",
"ListInt",
"ListListFloat",
"ListListInt",
"ListStr",
"Str"]

node_name

输入

常量节点名,支持字符串类型,如“a_1”;若不设置,框架会自动生成节点名。

同一张图中节点名不允许重复。

-

*

输入

预留参数项,用于后续功能拓展。

-

返回值说明

正常情况下,返回输出Tensor,否则失败报错。

约束说明

调用示例

假设算子原型定义如下:

1
2
3
4
5
6
7
8
9
REG_OP(MyOp)
 .INPUT(x, TensorType::ALL())                    //x为必选输入,类型为TensorType类。
 .OPTIONAL_INPUT(y, TensorType::ALL())           //y为可选输入,类型为TensorType类。
 .OPTIONAL_INPUT(z, TensorType::ALL())          //z为可选输入,类型为TensorType类。
 .DYNAMIC_INPUT(dy, TensorType::ALL())          //dy为动态输入,类型为TensorType类。构图时该输入为一组Tensor,数量在构图时确定,数量可以为0。
 .REQUIRED_ATTR(rattr, ListInt)                //rattr为必选属性,类型为ListInt,属性值为整数数组,无缺省值。
 .ATTR(oattr, Float)                           //oattr为可选属性,类型为Float,缺省值为1.0,构图时可不传入,则图上节点该属性值为1.0。
 .OUTPUT(m, TensorType::ALL())                 //m为单输出,类型为TensorType类,节点的输出为一个Tensor。
 .DYNAMIC_OUTPUT(n, TensorType::ALL())         //n为动态数量输出,类型为TensorType类,节点的输出为一组Tensor,构图时这组Tensor数量可以确定。

假设构图时,y传入,z不传,rattr值为[1,2,3],oattr值为1(与缺省值相同),动态输出n的数量为3,接口返回的输出数量与outputs列表长度相同,接口调用示例如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch_npu, torchair
def my_op(x, y, z, dy, *, rattr, oattr=1.0):
    num_of_n = ...                                               # 对于动态输出,这里要根据输入和属性计算出n的输出个数
    return torchair.ge.custom_op(
        "MyOp",                                                  # 自定义算子类型名称
        inputs={                                                 # 顺序必需与IR定义的输入顺序一致
            "x" : x,
            "y" : y,
            "z" : z,                                             # None表示该输入未传入
            "dy": dy
        },
        attrs={
            "rattr" : ge.attr.ListInt(rattr),                    # 需要通过ge.attr.T显式确定类型
            "oattr" : ge.attr.Float(oattr)                       # 即使与缺省值相同也要传入
        },
        outputs=[
            "m",                                                  # 非动态输出写对应IR输出名称即可
            ("n", num_of_n)                                       # 动态输出需要(对应IR输出名称, 输出数量)的元组
        ]
    )
t1, t2, t3, t4 = ......                                           # t1、t2、t3、t4为converter接收到的入参信息
m, n = my_op(t1, t2, None, [t3, t4], rattr=[1,2,3])