计算实现
计算实现包含依赖python模块导入、算子函数声明、算子入参校验、计算逻辑实现以及调度与编译。计算实现时,可以通过一定的方法进行精度与性能上的提升。
导入Python模块
进行TBE DSL算子开发时,首先需要在算子实现文件中导入昇腾AI软件栈提供的Python模块,代码示例如下所示,算子实现文件的命名请参见算子定义命名规则。
from tbe import dsl from tbe import tvm from tbe.common.utils import para_check from tbe.common.utils import shape_util
其中:
- “tbe.dsl”:引入TBE支持的特定域语言接口,包括常见的计算接口、调度接口以及编译接口。DSL接口的详细使用方法请参见TBE DSL API。
- “tbe.tvm”:引入TVM后端代码生成机制。TVM接口的详细使用方法请参见Link。
- “tbe.common.utils.para_check”:提供了通用的算子参数校验接口,详细的接口定义可参见算子参数校验。
- “tbe.common.utils.shape_util”:提供了一些通用的处理算子shape的接口,详细的接口定义可参见shape相关工具。
如果您进行代码实现时依赖了其他自行引入的Python模块,请自行进行依赖导入。
算子函数声明
算子的代码实现中包括两个函数:算子定义函数与算子compute函数,算子的compute函数会在算子定义函数中被调用。
下面详细介绍这两个函数的声明规则。
- 算子定义函数声明
如下所示,一个算子的定义函数中包含了算子输入信息、算子输出信息以及内核名称,函数的声明信息需要与算子原型定义文件中的信息对应。
def operationname(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="KernelName", impl_mode="high_performance")
- 算子定义函数名称operationname当前版本请与算子实现文件名称保持一致,命名规则请参见算子定义命名规则。
- input_x1, input_x2:算子的输入tensor,每个tensor需要采用字典的形式进行定义,包含shape、ori_shape、format、ori_format与dtype信息,例如:
dict input_x1 = {'shape' : (2,2), 'ori_shape' : (2,2), 'format': 'ND', 'ori_format':'ND', 'dtype' : 'float16'}
输入tensor的顺序及个数需要与算子原型定义保持一致,可选输入也需要在此处定义,在计算逻辑中去判断是否有数据传入,并进行相应处理。
- output_y:算子的输出tensor,包含shape和dtype等信息,字典格式,此字段为预留位。
输出tensor的顺序及个数也需要与算子原型定义保持一致,可选输出也需要在此处定义。
- attribute1,attribute2...:算子的属性,算子属性的顺序与个数需要与算子原型定义保持一致。
若算子无相关属性信息,此参数忽略;若算子的属性为可选值,此处需要为算子的属性赋默认值。
- kernel_name:算子在内核中的名称(即生成的二进制文件与算子描述文件的名称),用户自定义,保持唯一,只能是大小写字母、数字、“_”的组合,且必须是字母或者“_”开头,长度小于或等于200个字符。
- impl_mode(可选):String类型,算子运行时选择精度优先还是性能优先模式,该字段仅影响输入数据为float32数据类型时的精度与性能。
有“high_precision”与“high_performance”两种取值,默认值为“high_performance”。
当输入数据值的范围不超过float16的最大值65504时,不存在精度问题;当输入数据值的范围超过float16最大值时,会出现exp计算溢出,从而导致计算结果不正确,此时需要选择high_precision模式,但此种场景下性能会下降较多。
算子在网络中运行时,使用哪种模型的配置规则请参见《ATC工具使用指南》中的“op_select_implmode”参数。
不带属性的sqrt算子的定义函数声明如下:
def sqrt(input_x, output_y, kernel_name="sqrt"):
带属性的reduce_sum算子的定义函数声明如下:
def reduce_sum(x, y, axis=None, keep_dims=None, kernel_name="reduce_sum")
开发者在进行算子定义函数声明时可使用装饰器函数check_op_params或者check_input_type对算子参数进行基本的校验。
其中check_op_params校验算子输入输出是否满足必选与可选的要求,check_input_type校验算子的参数类型是否合法。
例如:
@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT, para_check.KERNEL_NAME) def sqrt(input_x, output_y, kernel_name="sqrt")
@para_check.check_input_type(dict, dict, dict, int, bool, str) def sort(x, y1, y2, axis=-1, descending=False, kernel_name="sort")
- compute函数声明
@tbe.common.register.register_op_compute("KernelName",op_mode="static") def operationname_compute(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="KernelName")
- 装饰器@tbe.common.register.register_op_compute,其作用是整网运行时支持算子做UB自动融合,使得当前自定义算子可以在UB中根据UB融合规则自动与其他算子的计算进行组装,提升算子运行效率,此接口的详细说明可参见register_op_compute。
若算子实现逻辑中涉及reshape操作,则不支持UB自动融合,算子compute函数声明时不可使用此装饰器函数。
- input_x1, input_x2:compute函数的入参,为在 算子接口函数声明中声明的输入tensor对应的placeholder,包含shape和dtype等信息。
- output_y,attribute1=None,xxx等参数,都是从 算子接口函数声明中的算子定义函数中透传过来的,与算子定义函数的声明保持一致即可。
例如,对于sqrt算子,compute函数定义如下:
@tbe.common.register.register_op_compute("sqrt",op_mode="static") def sqrt_compute(input_data, output_data, kernel_name="sqrt"):
对于reduce_sum算子,算子接口和计算函数定义如下:
@tbe.common.register.register_op_compute("reduce_sum",op_mode="static") def reduce_sum_compute(x, y, axis=None, keep_dims=None, kernel_name="reduce_sum")
- 装饰器@tbe.common.register.register_op_compute,其作用是整网运行时支持算子做UB自动融合,使得当前自定义算子可以在UB中根据UB融合规则自动与其他算子的计算进行组装,提升算子运行效率,此接口的详细说明可参见register_op_compute。
算子函数实现
完成算子函数声明后,就要具体实现算子定义函数和compute函数。
- 首先在算子定义函数operatorname( )中,获取算子输入tensor的shape以及dtype,并可自行实现基本的校验功能。
- 获取算子输入tensor的shape以及dtype,为后续定义输入tensor的张量占位符做准备。
def add(input_x, input_y, output_z, kernel_name="add"): shape_x = input_x.get("shape") shape_y = input_y.get("shape") input_data_type = input_x.get("dtype").lower() input_data_type_y = input_y.get("dtype").lower()
- (可选)在算子实现函数中添加算子输入/输出及属性基本信息校验,有助于在算子编译阶段,提前发现问题。
例如,对于Add算子,首先校验两个输入的dtype是否一致,然后校验输入的数据类型是否在允许的数据类型列表中,代码实现如下所示。
1 2 3 4 5
if input_data_type != input_data_type_y: raise RuntimeError( "the input_x and input_y should have the same data type.") check_tuple = ("float16", "float32", "int32") para_check.check_dtype(input_data_type, check_tuple, param_name="input_x")
开发者可使用TBE Utils API中的通用校验函数进行算子参数的基本校验。
- 获取算子输入tensor的shape以及dtype,为后续定义输入tensor的张量占位符做准备。
- 然后根据shape与dtype定义好输入tensor的张量占位符。例如:
data_input = tvm.placeholder(shape, name="data_input", dtype=dtype)
使用TVM的placeholder接口对输入tensor进行占位,返回一个tensor对象,此位置中的数据在程序运行时才被指定。
若算子中的某个输入为可选输入,在进行输入tensor占位时需要对此输入进行判断,若输入为空,则无需占位。
调度与编译中的tensor_list的输入tensor需要是tvm.placeholder接口返回的tensor对象,所以此对象在后续计算过程实现中不能被替换。
如下所示:
#返回占位的data_input data_input = tvm.placeholder(shape, name='data', dtype=dtype) if dtype == "float16": #将data_input的类型转换为float32,然后重新赋值给data_input,此时data_input的内容已经被改变 data_input = dsl.cast_to(data_input, "float32") ...... with tvm.target.cce(): schedule = dsl.auto_schedule(res) config = {"need_build":need_build, "name":kernel_name, "tensor_list":[data_input,res]} dsl.build(schedule,config)
以上代码中,通过data_input = dsl.cast_to(data_input, "float32")转换数据类型后,placeholder返回的data_input对象已经被覆盖,编译配置tensor_list中的data_input已经不是原placeholder接口返回的tensor,此时算子实现代码编译时会出现以下错误:
所以可重新定义一个tensor用于存储转换数据类型后的输入进行计算,如下所示:
data_input1 = dsl.cast_to(data_input, "float32")
或者如步骤3所示,计算过程在compute函数中进行,将placeholder返回的输入tensor通过形参传入compute函数进行计算,会生成新的地址用于计算,也可避免placeholder返回的tensor对象被覆盖的情况。
- 在算子接口定义函数中调用compute函数进行计算过程的描述。
例如:
res = add_compute(data_x, data_y, output_z, kernel_name)
输入tensor为使用tvm.placeholder定义的占位tensor,其他为算子定义函数透传的参数。
- 算子compute函数的实现。
在compute函数中,完成算子的计算过程,计算过程的实现主要根据算子分析中的TBE DSL API进行代码开发。
入门示例中介绍了简单的add算子的实现过程,下面我们以计算公式较复杂的relu算子为例,讲解算子的计算过程的实现以及部分DSL接口在使用过程中的注意事项。
假设通过进行算子分析,得到relu算子的计算公式如下:
计算实现代码如下所示:
@fusion_manager.register("relu") def relu_compute(x, y, kernel_name="relu"): inp_dtype = x.dtype # 获取输入数据的数据类型 shape = x.shape # 获取输入数据的形状 # 若数据类型为float32、int32,使用vmax操作,避免精度损失 if inp_dtype in ("float32", "int32"): tensor_zero = dsl.broadcast(tvm.const(CONST_ZERO, inp_dtype),shape) # 返回形状与输入数据相同,每一个元素都为0,每一个元素的数据类型都为输入数据的数据类型的tensor data_res = dsl.vmax(x, tensor_zero) # 取x与tensor_zero中的大值 else: data_res = dsl.vrelu(x) # 若数据类型为int8、float16,直接做relu操作。 data_res = dsl.cast_to(data_res, inp_dtype) return data_res
- 由于tbe.dsl.vrelu( )接口会将int8、uint8、int32、float32的数据类型转换为float16,而int32、float32进行数据类型转换时会造成精度损失,所以为了避免精度损失,对于这两种数据类型的输入,采用tbe.dsl.vmax( )接口取输入数据与0之间的大值。
- TBE DSL中vmax接口要求两个输入tensor的shape相同,一般使用tbe.dsl.broadcast接口将输入tensor的shape广播到相同shape,一般取两个输入tensor的shape中每个维度的大值组成的shape。
算子计算函数实现中的其他小技巧:
- 若输入tensor数据类型不是float32,可以将其转换为float32进行计算,可以提高中间计算结果的精度,最后的结果输出时需要将数据类型转换成原数据类型。
- 当算子的计算过程比较繁琐时,可以通过抽调内部函数的方法保持每个模块的简洁性和可读性。
broadcast类算子
broadcast类算子指算子对多个输入的shape要求一致,所以涉及broadcast操作。此类算子在进行计算函数实现时,需要首先调用tbe.common.utils.shape_util下的broadcast_shapes或unify_broadcast_shapes接口,计算出输出shape大小,然后再调用tbe.dsl.broadcast接口将参与broadcast的每个输入广播到输出shape的大小。否则可能会出现类似如下错误:
'Compile operator failed, cause: Parameters check failed, detailed information: The lhs shape[(dim 0 0, dim 1 0)] must be equal to the rhs[(dim 0 0, dim 1 1)].'
vcmp/vsel接口使用注意事项
在使用tbe.dsl.vcmp与tbe.dsl.vsel接口时,会出现某些shape的情况下输出结果不符合预期,这是因为vcmp接口中mode的默认值为bool,表示按照8bit进行存储。而config配置中若不配置"bool_storage_as_1bit"参数,此参数默认值为True,表示按照1bit进行存储,与mode为bool不匹配,所以要在算子接口实现函数的编译配置config中加入配置项:"bool_storage_as_1bit": False,例如:
with tvm.target.cce(): schedule = dsl.auto_schedule(res) config = {"name": kernel_name, "tensor_list": [data_x, data_y, res], "bool_storage_as_1bit": False} dsl.build(schedule, config)