custom_op
函数功能
基于算子原型(IR)实现算子conveter函数,完成PyTorch IR与GE IR的转换,方便自定义算子入图。
converter功能是将PyTorch FX图的节点转换为NPU GE图的节点。对于自定义算子,需要先实现对应的converter,否则会出现报错。
函数原型
def custom_op(op_type: str, *, inputs: Optional[Dict[str, Optional[Union['Tensor', List['Tensor']]]]], outputs: Optional[List[Union[str, Tuple[str, int]]]], attrs: Optional[Dict[str, '_Attr']] = None, node_name: Optional[str] = None):--> OUT: Tensor
参数说明
参数 |
输入/输出 |
说明 |
是否必选 |
备注 |
|
---|---|---|---|---|---|
op_type |
输入 |
自定义算子类型,字符串类型,取值与算子原型REG_OP保持一致,例如"MyOp"。 |
是 |
关于算子原型REG_OP的介绍,请参考 |
|
inputs |
输入 |
算子原型的输入参数,为Dict类型,取值和顺序需与算子原型REG_OP保持一致。 |
是 |
||
outputs |
输入 |
算子原型的输出参数,为List类型。取值和顺序需与算子原型REG_OP保持一致。 |
是 |
||
attrs |
输入 |
算子原型的属性参数,为Dict类型,缺省值为None。取值和顺序需与算子原型REG_OP保持一致。 attrs一般通过ge.attr.Xxx显式确定类型,当前支持如下类型:
|
否 |
||
node_name |
输入 |
常量节点名,支持字符串类型,如“a_1”;若不设置,框架会自动生成节点名。 同一张图中节点名不允许重复。 |
否 |
- |
|
* |
输入 |
预留参数项,用于后续功能拓展。 |
否 |
- |
返回值说明
正常情况下,返回输出Tensor,否则失败报错。
约束说明
无
调用示例
假设算子原型定义如下:
1 2 3 4 5 6 7 8 9 |
REG_OP(MyOp) .INPUT(x, TensorType::ALL()) //x为必选输入,类型为TensorType类。 .OPTIONAL_INPUT(y, TensorType::ALL()) //y为可选输入,类型为TensorType类。 .OPTIONAL_INPUT(z, TensorType::ALL()) //z为可选输入,类型为TensorType类。 .DYNAMIC_INPUT(dy, TensorType::ALL()) //dy为动态输入,类型为TensorType类。构图时该输入为一组Tensor,数量在构图时确定,数量可以为0。 .REQUIRED_ATTR(rattr, ListInt) //rattr为必选属性,类型为ListInt,属性值为整数数组,无缺省值。 .ATTR(oattr, Float) //oattr为可选属性,类型为Float,缺省值为1.0,构图时可不传入,则图上节点该属性值为1.0。 .OUTPUT(m, TensorType::ALL()) //m为单输出,类型为TensorType类,节点的输出为一个Tensor。 .DYNAMIC_OUTPUT(n, TensorType::ALL()) //n为动态数量输出,类型为TensorType类,节点的输出为一组Tensor,构图时这组Tensor数量可以确定。 |
假设构图时,y传入,z不传,rattr值为[1,2,3],oattr值为1(与缺省值相同),动态输出n的数量为3,接口返回的输出数量与outputs列表长度相同,接口调用示例如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch_npu, torchair def my_op(x, y, z, dy, *, rattr, oattr=1.0): num_of_n = ... # 对于动态输出,这里要根据输入和属性计算出n的输出个数 return torchair.ge.custom_op( "MyOp", # 自定义算子类型名称 inputs={ # 顺序必需与IR定义的输入顺序一致 "x" : x, "y" : y, "z" : z, # None表示该输入未传入 "dy": dy }, attrs={ "rattr" : ge.attr.ListInt(rattr), # 需要通过ge.attr.T显式确定类型 "oattr" : ge.attr.Float(oattr) # 即使与缺省值相同也要传入 }, outputs=[ "m", # 非动态输出写对应IR输出名称即可 ("n", num_of_n) # 动态输出需要(对应IR输出名称, 输出数量)的元组 ] ) t1, t2, t3, t4 = ...... # t1、t2、t3、t4为converter接收到的入参信息 m, n = my_op(t1, t2, None, [t3, t4], rattr=[1,2,3]) |