下载
中文
注册

代码结构介绍

TBE DSL方式实现的算子代码结构如下所示:

# 导入依赖的Python模块
from tbe import dsl
from tbe import tvm
from tbe.common.utils import para_check
from tbe.common.utils import shape_util
# 若有其他的python依赖,请自行导入

# 算子计算函数
# 装饰器函数tbe.common.register.register_op_compute可选,若算子实现逻辑中涉及reshape操作,不可使用此装饰器函数
@tbe.common.register.register_op_compute("add",op_mode="static")
def add_compute(input_x, input_y, output_z, kernel_name="add"):
    """
    算子计算逻辑实现
    """
# 算子定义函数
@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT,para_check.KERNEL_NAME)
def add(input_x, input_y, output_z, kernel_name="add"):

    """
    算子校验(可选)
    为输入tensor占位
    """

    res = add_compute(data_x, data_y, output_z, kernel_name) # 调用算子计算函数

    # 自动调度
    with tvm.target.cce(): 
        schedule = dsl.auto_schedule(res)        
    # 算子编译
    config = {"name": kernel_name,
              "tensor_list": (data_x, data_y, res)}
    dsl.build(schedule, config)

# 可选,若实现此函数,且算子信息库中的input的dtype与format的dynamicFormat.flag配置为true,则可在算子融合阶段调用此函数实现dtype与format的推导。
def op_select_format(input_x, input_y, output_y, kernel_name="add"):
    ...
    ...
# 可选,若实现此函数,且算子信息库中的needCheckSupport的flag参数配置为true,则可在算子融合阶段调用此函数实现算子的dtype与shape的校验。
def check_supported(input_x, input_y, output_y, kernel_name="add"):
    ...
    ...

算子实现代码总体结构包含依赖Python模块的导入,算子定义函数实现,算子计算函数实现。

其中:

  • 算子定义函数包含算子的校验,计算函数的调用以及调度与编译。
  • 算子计算函数是对算子计算逻辑的实现。

下面详细介绍每个代码块的实现。