下载
中文
注册

计算实现

计算实现包含依赖python模块导入、算子函数声明、算子入参校验、计算逻辑实现以及调度与编译。计算实现时,可以通过一定的方法进行精度与性能上的提升。

导入Python模块

进行TBE DSL算子开发时,首先需要在算子实现文件中导入昇腾AI软件栈提供的Python模块,代码示例如下所示,算子实现文件的命名请参见算子定义命名规则

from tbe import dsl
from tbe import tvm
from tbe.common.utils import para_check
from tbe.common.utils import shape_util

其中:

  • “tbe.dsl”:引入TBE支持的特定域语言接口,包括常见的计算接口、调度接口以及编译接口。DSL接口的详细使用方法请参见TBE DSL API
  • “tbe.tvm”:引入TVM后端代码生成机制。TVM接口的详细使用方法请参见Link
  • “tbe.common.utils.para_check”:提供了通用的算子参数校验接口,详细的接口定义可参见算子参数校验
  • “tbe.common.utils.shape_util”:提供了一些通用的处理算子shape的接口,详细的接口定义可参见shape相关工具

    如果您进行代码实现时依赖了其他自行引入的Python模块,请自行进行依赖导入。

算子函数声明

算子的代码实现中包括两个函数:算子定义函数与算子compute函数,算子的compute函数会在算子定义函数中被调用。

下面详细介绍这两个函数的声明规则。

  • 算子定义函数声明

    如下所示,一个算子的定义函数中包含了算子输入信息、算子输出信息以及内核名称,函数的声明信息需要与算子原型定义文件中的信息对应。

    def operationname(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="KernelName", impl_mode="high_performance")
    • 算子定义函数名称operationname当前版本请与算子实现文件名称保持一致,命名规则请参见算子定义命名规则
    • input_x1, input_x2:算子的输入tensor,每个tensor需要采用字典的形式进行定义,包含shape、ori_shape、format、ori_format与dtype信息,例如:

      dict input_x1 = {'shape' : (2,2), 'ori_shape' : (2,2), 'format': 'ND', 'ori_format':'ND', 'dtype' : 'float16'}

      输入tensor的顺序及个数需要与算子原型定义保持一致,可选输入也需要在此处定义,在计算逻辑中去判断是否有数据传入,并进行相应处理。

    • output_y:算子的输出tensor,包含shape和dtype等信息,字典格式,此字段为预留位。

      输出tensor的顺序及个数也需要与算子原型定义保持一致,可选输出也需要在此处定义。

    • attribute1attribute2...:算子的属性,算子属性的顺序与个数需要与算子原型定义保持一致。

      若算子无相关属性信息,此参数忽略;若算子的属性为可选值,此处需要为算子的属性赋默认值。

    • kernel_name:算子在内核中的名称(即生成的二进制文件与算子描述文件的名称),用户自定义,保持唯一,只能是大小写字母、数字、“_”的组合,且必须是字母或者“_”开头,长度小于或等于200个字符。
    • impl_mode(可选):String类型,算子运行时选择精度优先还是性能优先模式,该字段仅影响输入数据为float32数据类型时的精度与性能。

      有“high_precision”与“high_performance”两种取值,默认值为“high_performance”。

      当输入数据值的范围不超过float16的最大值65504时,不存在精度问题;当输入数据值的范围超过float16最大值时,会出现exp计算溢出,从而导致计算结果不正确,此时需要选择high_precision模式,但此种场景下性能会下降较多。

      算子在网络中运行时,使用哪种模型的配置规则请参见ATC工具使用指南中的“op_select_implmode”参数。

    不带属性的sqrt算子的定义函数声明如下:

    def sqrt(input_x, output_y, kernel_name="sqrt"):

    带属性的reduce_sum算子的定义函数声明如下:

    def reduce_sum(x, y, axis=None, keep_dims=None, kernel_name="reduce_sum")

    开发者在进行算子定义函数声明时可使用装饰器函数check_op_params或者check_input_type对算子参数进行基本的校验。

    其中check_op_params校验算子输入输出是否满足必选与可选的要求,check_input_type校验算子的参数类型是否合法。

    例如:

    @para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT, para_check.KERNEL_NAME)
    def sqrt(input_x, output_y, kernel_name="sqrt")
    @para_check.check_input_type(dict, dict, dict, int, bool, str)
    def sort(x, y1, y2, axis=-1, descending=False, kernel_name="sort")
  • compute函数声明
    @tbe.common.register.register_op_compute("KernelName",op_mode="static")
    def operationname_compute(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="KernelName")
    • 装饰器@tbe.common.register.register_op_compute,其作用是整网运行时支持算子做UB自动融合,使得当前自定义算子可以在UB中根据UB融合规则自动与其他算子的计算进行组装,提升算子运行效率,此接口的详细说明可参见register_op_compute

      若算子实现逻辑中涉及reshape操作,则不支持UB自动融合,算子compute函数声明时不可使用此装饰器函数。

    • input_x1, input_x2:compute函数的入参,为在 算子接口函数声明中声明的输入tensor对应的placeholder,包含shape和dtype等信息。
    • output_y,attribute1=None,xxx等参数,都是从 算子接口函数声明中的算子定义函数中透传过来的,与算子定义函数的声明保持一致即可。

    例如,对于sqrt算子,compute函数定义如下:

    @tbe.common.register.register_op_compute("sqrt",op_mode="static")
    def sqrt_compute(input_data, output_data, kernel_name="sqrt"):

    对于reduce_sum算子,算子接口和计算函数定义如下:

    @tbe.common.register.register_op_compute("reduce_sum",op_mode="static")
    def reduce_sum_compute(x, y, axis=None, keep_dims=None, kernel_name="reduce_sum")

算子函数实现

完成算子函数声明后,就要具体实现算子定义函数和compute函数。

  1. 首先在算子定义函数operatorname( )中,获取算子输入tensor的shape以及dtype,并可自行实现基本的校验功能。
    1. 获取算子输入tensor的shape以及dtype,为后续定义输入tensor的张量占位符做准备。
      def add(input_x, input_y, output_z, kernel_name="add"):
          shape_x = input_x.get("shape")
          shape_y = input_y.get("shape")
          input_data_type = input_x.get("dtype").lower()
          input_data_type_y = input_y.get("dtype").lower()
    2. (可选)在算子实现函数中添加算子输入/输出及属性基本信息校验,有助于在算子编译阶段,提前发现问题。

      例如,对于Add算子,首先校验两个输入的dtype是否一致,然后校验输入的数据类型是否在允许的数据类型列表中,代码实现如下所示。

      1
      2
      3
      4
      5
          if input_data_type != input_data_type_y:
              raise RuntimeError(
                  "the input_x and input_y should have the same data type.")
          check_tuple = ("float16", "float32", "int32")
          para_check.check_dtype(input_data_type, check_tuple, param_name="input_x")
      

      开发者可使用TBE Utils API中的通用校验函数进行算子参数的基本校验。

  2. 然后根据shape与dtype定义好输入tensor的张量占位符。
    例如:
    data_input = tvm.placeholder(shape, name="data_input", dtype=dtype)

    使用TVM的placeholder接口对输入tensor进行占位,返回一个tensor对象,此位置中的数据在程序运行时才被指定。

    若算子中的某个输入为可选输入,在进行输入tensor占位时需要对此输入进行判断,若输入为空,则无需占位。

    调度与编译中的tensor_list的输入tensor需要是tvm.placeholder接口返回的tensor对象,所以此对象在后续计算过程实现中不能被替换。

    如下所示:

        #返回占位的data_input
        data_input = tvm.placeholder(shape, name='data', dtype=dtype)      
         if dtype == "float16":
             #将data_input的类型转换为float32,然后重新赋值给data_input,此时data_input的内容已经被改变
             data_input = dsl.cast_to(data_input, "float32")        
            ......
        with tvm.target.cce():
            schedule = dsl.auto_schedule(res)
        config = {"need_build":need_build,
                  "name":kernel_name, 
                  "tensor_list":[data_input,res]} 
        dsl.build(schedule,config)

    以上代码中,通过data_input = dsl.cast_to(data_input, "float32")转换数据类型后,placeholder返回的data_input对象已经被覆盖,编译配置tensor_list中的data_input已经不是原placeholder接口返回的tensor,此时算子实现代码编译时会出现以下错误:

    所以可重新定义一个tensor用于存储转换数据类型后的输入进行计算,如下所示:

    data_input1 = dsl.cast_to(data_input, "float32")

    或者如步骤3所示,计算过程在compute函数中进行,将placeholder返回的输入tensor通过形参传入compute函数进行计算,会生成新的地址用于计算,也可避免placeholder返回的tensor对象被覆盖的情况。

  3. 在算子接口定义函数中调用compute函数进行计算过程的描述。

    例如:

    res = add_compute(data_x, data_y, output_z, kernel_name)

    输入tensor为使用tvm.placeholder定义的占位tensor,其他为算子定义函数透传的参数。

  4. 算子compute函数的实现。

    在compute函数中,完成算子的计算过程,计算过程的实现主要根据算子分析中的TBE DSL API进行代码开发。

    入门示例中介绍了简单的add算子的实现过程,下面我们以计算公式较复杂的relu算子为例,讲解算子的计算过程的实现以及部分DSL接口在使用过程中的注意事项。

    假设通过进行算子分析,得到relu算子的计算公式如下:

    计算实现代码如下所示:

    @fusion_manager.register("relu")
    def relu_compute(x, y, kernel_name="relu"):
        inp_dtype = x.dtype     # 获取输入数据的数据类型
        shape = x.shape         # 获取输入数据的形状
    
        # 若数据类型为float32、int32,使用vmax操作,避免精度损失
        if inp_dtype in ("float32", "int32"):
            tensor_zero = dsl.broadcast(tvm.const(CONST_ZERO, inp_dtype),shape)     # 返回形状与输入数据相同,每一个元素都为0,每一个元素的数据类型都为输入数据的数据类型的tensor
            data_res = dsl.vmax(x, tensor_zero)    # 取x与tensor_zero中的大值
        else:
            data_res = dsl.vrelu(x)      # 若数据类型为int8、float16,直接做relu操作。
    
        data_res = dsl.cast_to(data_res, inp_dtype)
    
        return data_res
    • 由于tbe.dsl.vrelu( )接口会将int8、uint8、int32、float32的数据类型转换为float16,而int32、float32进行数据类型转换时会造成精度损失,所以为了避免精度损失,对于这两种数据类型的输入,采用tbe.dsl.vmax( )接口取输入数据与0之间的大值。
    • TBE DSL中vmax接口要求两个输入tensor的shape相同,一般使用tbe.dsl.broadcast接口将输入tensor的shape广播到相同shape,一般取两个输入tensor的shape中每个维度的大值组成的shape。

    算子计算函数实现中的其他小技巧:

    • 若输入tensor数据类型不是float32,可以将其转换为float32进行计算,可以提高中间计算结果的精度,最后的结果输出时需要将数据类型转换成原数据类型。
    • 当算子的计算过程比较繁琐时,可以通过抽调内部函数的方法保持每个模块的简洁性和可读性。

broadcast类算子

broadcast类算子指算子对多个输入的shape要求一致,所以涉及broadcast操作。此类算子在进行计算函数实现时,需要首先调用tbe.common.utils.shape_util下的broadcast_shapes或unify_broadcast_shapes接口,计算出输出shape大小,然后再调用tbe.dsl.broadcast接口将参与broadcast的每个输入广播到输出shape的大小。否则可能会出现类似如下错误:

'Compile operator failed, cause: Parameters check failed, detailed information: The lhs shape[(dim 0 0, dim 1 0)] must be equal to the rhs[(dim 0 0, dim 1 1)].'

vcmp/vsel接口使用注意事项

在使用tbe.dsl.vcmp与tbe.dsl.vsel接口时,会出现某些shape的情况下输出结果不符合预期,这是因为vcmp接口中mode的默认值为bool,表示按照8bit进行存储。而config配置中若不配置"bool_storage_as_1bit"参数,此参数默认值为True,表示按照1bit进行存储,与mode为bool不匹配,所以要在算子接口实现函数的编译配置config中加入配置项:"bool_storage_as_1bit": False,例如:

with tvm.target.cce():
    schedule = dsl.auto_schedule(res)
config = {"name": kernel_name,
          "tensor_list": [data_x, data_y, res],
          "bool_storage_as_1bit": False}
dsl.build(schedule, config)