代码结构介绍
TBE DSL方式实现的算子代码结构如下所示:
# 导入依赖的Python模块 from tbe import dsl from tbe import tvm from tbe.common.utils import para_check from tbe.common.utils import shape_util # 若有其他的python依赖,请自行导入 # 算子计算函数 # 装饰器函数tbe.common.register.register_op_compute可选,若算子实现逻辑中涉及reshape操作,不可使用此装饰器函数 @tbe.common.register.register_op_compute("add",op_mode="static") def add_compute(input_x, input_y, output_z, kernel_name="add"): """ 算子计算逻辑实现 """ # 算子定义函数 @para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT,para_check.KERNEL_NAME) def add(input_x, input_y, output_z, kernel_name="add"): """ 算子校验(可选) 为输入tensor占位 """ res = add_compute(data_x, data_y, output_z, kernel_name) # 调用算子计算函数 # 自动调度 with tvm.target.cce(): schedule = dsl.auto_schedule(res) # 算子编译 config = {"name": kernel_name, "tensor_list": (data_x, data_y, res)} dsl.build(schedule, config) # 可选,若实现此函数,且算子信息库中的input的dtype与format的dynamicFormat.flag配置为true,则可在算子融合阶段调用此函数实现dtype与format的推导。 def op_select_format(input_x, input_y, output_y, kernel_name="add"): ... ... # 可选,若实现此函数,且算子信息库中的needCheckSupport的flag参数配置为true,则可在算子融合阶段调用此函数实现算子的dtype与shape的校验。 def check_supported(input_x, input_y, output_y, kernel_name="add"): ... ...
算子实现代码总体结构包含依赖Python模块的导入,算子定义函数实现,算子计算函数实现。
其中:
- 算子定义函数包含算子的校验,计算函数的调用以及调度与编译。
- 算子计算函数是对算子计算逻辑的实现。
下面详细介绍每个代码块的实现。
父主题: 算子代码实现(TBE DSL)