计算实现

计算实现包含依赖python模块导入、算子函数声明、算子入参校验、计算逻辑实现以及调度与编译。计算实现时,可以通过一定的方法进行精度与性能上的提升。

导入Python模块

进行TBE DSL算子开发时,首先需要在算子实现文件中导入昇腾AI软件栈提供的Python模块,代码示例如下所示,算子实现文件的命名请参见算子定义命名规则

from tbe import dsl
from tbe import tvm
from tbe.common.utils import para_check
from tbe.common.utils import shape_util

其中:

算子函数声明

算子的代码实现中包括两个函数:算子定义函数与算子compute函数,算子的compute函数会在算子定义函数中被调用。

下面详细介绍这两个函数的声明规则。

算子函数实现

完成算子函数声明后,就要具体实现算子定义函数和compute函数。

  1. 首先在算子定义函数operatorname( )中,获取算子输入tensor的shape以及dtype,并可自行实现基本的校验功能。

    1. 获取算子输入tensor的shape以及dtype,为后续定义输入tensor的张量占位符做准备。
      def add(input_x, input_y, output_z, kernel_name="add"):
          shape_x = input_x.get("shape")
          shape_y = input_y.get("shape")
          input_data_type = input_x.get("dtype").lower()
          input_data_type_y = input_y.get("dtype").lower()
    2. (可选)在算子实现函数中添加算子输入/输出及属性基本信息校验,有助于在算子编译阶段,提前发现问题。

      例如,对于Add算子,首先校验两个输入的dtype是否一致,然后校验输入的数据类型是否在允许的数据类型列表中,代码实现如下所示。

      1
      2
      3
      4
      5
          if input_data_type != input_data_type_y:
              raise RuntimeError(
                  "the input_x and input_y should have the same data type.")
          check_tuple = ("float16", "float32", "int32")
          para_check.check_dtype(input_data_type, check_tuple, param_name="input_x")
      

      开发者可使用TBE Utils API中的通用校验函数进行算子参数的基本校验。

  2. 然后根据shape与dtype定义好输入tensor的张量占位符。

    例如:
    data_input = tvm.placeholder(shape, name="data_input", dtype=dtype)

    使用TVM的placeholder接口对输入tensor进行占位,返回一个tensor对象,此位置中的数据在程序运行时才被指定。

    若算子中的某个输入为可选输入,在进行输入tensor占位时需要对此输入进行判断,若输入为空,则无需占位。

    调度与编译中的tensor_list的输入tensor需要是tvm.placeholder接口返回的tensor对象,所以此对象在后续计算过程实现中不能被替换。

    如下所示:

        #返回占位的data_input
        data_input = tvm.placeholder(shape, name='data', dtype=dtype)      
         if dtype == "float16":
             #将data_input的类型转换为float32,然后重新赋值给data_input,此时data_input的内容已经被改变
             data_input = dsl.cast_to(data_input, "float32")        
            ......
        with tvm.target.cce():
            schedule = dsl.auto_schedule(res)
        config = {"need_build":need_build,
                  "name":kernel_name, 
                  "tensor_list":[data_input,res]} 
        dsl.build(schedule,config)

    以上代码中,通过data_input = dsl.cast_to(data_input, "float32")转换数据类型后,placeholder返回的data_input对象已经被覆盖,编译配置tensor_list中的data_input已经不是原placeholder接口返回的tensor,此时算子实现代码编译时会出现以下错误:

    所以可重新定义一个tensor用于存储转换数据类型后的输入进行计算,如下所示:

    data_input1 = dsl.cast_to(data_input, "float32")

    或者如步骤3所示,计算过程在compute函数中进行,将placeholder返回的输入tensor通过形参传入compute函数进行计算,会生成新的地址用于计算,也可避免placeholder返回的tensor对象被覆盖的情况。

  3. 在算子接口定义函数中调用compute函数进行计算过程的描述。

    例如:

    res = add_compute(data_x, data_y, output_z, kernel_name)

    输入tensor为使用tvm.placeholder定义的占位tensor,其他为算子定义函数透传的参数。

  4. 算子compute函数的实现。

    在compute函数中,完成算子的计算过程,计算过程的实现主要根据算子分析中的TBE DSL API进行代码开发。

    入门示例中介绍了简单的add算子的实现过程,下面我们以计算公式较复杂的relu算子为例,讲解算子的计算过程的实现以及部分DSL接口在使用过程中的注意事项。

    假设通过进行算子分析,得到relu算子的计算公式如下:

    计算实现代码如下所示:

    @fusion_manager.register("relu")
    def relu_compute(x, y, kernel_name="relu"):
        inp_dtype = x.dtype     # 获取输入数据的数据类型
        shape = x.shape         # 获取输入数据的形状
    
        # 若数据类型为float32、int32,使用vmax操作,避免精度损失
        if inp_dtype in ("float32", "int32"):
            tensor_zero = dsl.broadcast(tvm.const(CONST_ZERO, inp_dtype),shape)     # 返回形状与输入数据相同,每一个元素都为0,每一个元素的数据类型都为输入数据的数据类型的tensor
            data_res = dsl.vmax(x, tensor_zero)    # 取x与tensor_zero中的大值
        else:
            data_res = dsl.vrelu(x)      # 若数据类型为int8、float16,直接做relu操作。
    
        data_res = dsl.cast_to(data_res, inp_dtype)
    
        return data_res
    • 由于tbe.dsl.vrelu( )接口会将int8、uint8、int32、float32的数据类型转换为float16,而int32、float32进行数据类型转换时会造成精度损失,所以为了避免精度损失,对于这两种数据类型的输入,采用tbe.dsl.vmax( )接口取输入数据与0之间的大值。
    • TBE DSL中vmax接口要求两个输入tensor的shape相同,一般使用tbe.dsl.broadcast接口将输入tensor的shape广播到相同shape,一般取两个输入tensor的shape中每个维度的大值组成的shape。

    算子计算函数实现中的其他小技巧:

    • 若输入tensor数据类型不是float32,可以将其转换为float32进行计算,可以提高中间计算结果的精度,最后的结果输出时需要将数据类型转换成原数据类型。
    • 当算子的计算过程比较繁琐时,可以通过抽调内部函数的方法保持每个模块的简洁性和可读性。

broadcast类算子

broadcast类算子指算子对多个输入的shape要求一致,所以涉及broadcast操作。此类算子在进行计算函数实现时,需要首先调用tbe.common.utils.shape_util下的broadcast_shapes或unify_broadcast_shapes接口,计算出输出shape大小,然后再调用tbe.dsl.broadcast接口将参与broadcast的每个输入广播到输出shape的大小。否则可能会出现类似如下错误:

'Compile operator failed, cause: Parameters check failed, detailed information: The lhs shape[(dim 0 0, dim 1 0)] must be equal to the rhs[(dim 0 0, dim 1 1)].'

vcmp/vsel接口使用注意事项

在使用tbe.dsl.vcmp与tbe.dsl.vsel接口时,会出现某些shape的情况下输出结果不符合预期,这是因为vcmp接口中mode的默认值为bool,表示按照8bit进行存储。而config配置中若不配置"bool_storage_as_1bit"参数,此参数默认值为True,表示按照1bit进行存储,与mode为bool不匹配,所以要在算子接口实现函数的编译配置config中加入配置项:"bool_storage_as_1bit": False,例如:

with tvm.target.cce():
    schedule = dsl.auto_schedule(res)
config = {"name": kernel_name,
          "tensor_list": [data_x, data_y, res],
          "bool_storage_as_1bit": False}
dsl.build(schedule, config)