(可选)Format推导及参数校验
op_select_format函数实现
开发者可以在算子实现文件中实现op_select_format函数,推导出算子的输入输出支持的dtype与format,后续进行算子信息库定义时将输入输出的dtype与format的dynamicFormat.flag配置为true即可,算子融合时会自动调用算子实现文件中的op_select_format函数进行dtype与format的设置,无需配置固定的dtype与format;若算子实现文件中不实现此函数,则后续进行算子信息库定义时需要配置输入输出支持的dtype与format列表。
算子信息库定义的配置可参见TBE算子信息库。
op_select_format函数的声明如下所示:
def op_select_format(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="xx"):
op_select_format函数的入参和算子接口函数保持一致(即算子的输入、输出、属性及kernel_name),出参为包含了当前算子输入输出支持的format和data type列表的字符串,字符串格式如下所示:
{ "input0": { "name": "x", "dtype": "float16,float16,int8,int8", "format": "NC1HWC0_C04,NC1HWC0,NC1HWC0_C04,NC1HWC0" }, "input1": { "name": "y", "dtype": "float16,float16,int8,int8", "format": "FRACTAL_Z_C04,FRACTAL_Z,FRACTAL_Z_C04,FRACTAL_Z" }, "output0": { "name": "z", "dtype": "float16,float16,int32,int32", "format": "NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0" } }
例如,conv2d算子的op_select_format函数实现如下:
import json def op_select_format(inputs, weights, bias, offset_w, outputs, strides, pads, dilations, groups=1, data_format='NHWC', offset_x=0, kernel_name="conv2d"): shape_x = inputs.get("ori_shape") format_x = inputs.get("ori_format") shape_y = weights.get("ori_shape") format_y = weights.get("ori_format") x_dict = dict(zip(list(format_x), shape_x)) y_dict = dict(zip(list(format_y), shape_y)) use_c04 = False if x_dict["C"] <= 4 and (y_dict["W"] != 1 or y_dict["H"] != 1): use_c04 = True res = {} if use_c04: res["input0"] = { "name":"x", "dtype":"float16, float16, int8, int8", "format": "NC1HWC0, NC1HWC0_C04, NC1HWC0, NC1HWC0_C04" } res["input1"] = { "name":"filter", "dtype":"float16, float16, int8, int8", "format": "FRACTAL_Z, FRACTAL_Z_C04, FRACTAL_Z, FRACTAL_Z_C04" } res["input2"] = { "name":"bias", "dtype":"float16, float16, int32, int32", "format": "ND, ND, ND, ND" } res["input3"] = { "name":"offset_w", "dtype":"int8, int8, int8, int8", "format": "ND, ND, ND, ND" } res["output0"] = { "name":"filter", "dtype":"float16, float16, int8, int8", "format": "NC1HWC0, NC1HWC0, NC1HWC0, NC1HWC0" } else: res["input0"] = { "name":"x", "dtype":"float16, int8", "format": "NC1HWC0, NC1HWC0" } res["input1"] = { "name":"filter", "dtype":"float16, int8", "format": "FRACTAL_Z, FRACTAL_Z" } res["input2"] = { "name":"bias", "dtype":"float16, int32", "format": "ND, ND" } res["input3"] = { "name":"offset_w", "dtype":"int8, int8", "format": "ND, ND" } res["output0"] = { "name":"filter", "dtype":"float16, int8", "format": "NC1HWC0, NC1HWC0" } return json.dumps(res, indent=4)
check_supported函数实现
若开发者需要在算子融合阶段进行算子参数校验,则可在算子实现文件中实现check_supported函数,并在算子信息库定义文件中将配置项needCheckSupport的flag参数配置为true,算子信息库定义的配置可参见TBE算子信息库。
若check_supported函数校验通过,则代表AI Core支持此算子参数,则会选择AI Core上相应的算子(即TBE算子)执行,否则,会选择AI CPU算子执行。
check_supported函数中可自定义实现算子输入输出dtype的校验以及shape的校验,函数声明如下所示:
def check_supported(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="xx"):
check_supported函数的入参和算子接口函数保持一致(即算子的输入、输出、属性及kernel_name)。
若校验成功,则返回True;若校验失败,则返回False。
例如,InTopK算子的check_supported函数实现如下,实现对输入参数的数据类型的校验。
def check_supported(predictions,targets,precision,k,kernel_name='in_top_k'): prediction_dtype = predictions.get("dtype").lower() target_dtype = targets.get("dtype").lower() if prediction_dtype != "float32": return False if target_dtype != "int32": return False return True
InplaceUpdate算子的check_supported函数实现如下,实现对输入参数的数据类型以及shape的校验。
def check_supported(x, indices, v, y, kernel_name="inplace_update"): shape_indices = indices.get("shape") shape_v = v.get("shape") dtype_v = v.get("dtype").lower() reg_v_len = 1 for i in range(1, len(shape_v)): reg_v_len = reg_v_len * shape_v[i] if dtype_v in ("float32", "int32"): dtype_size = 4 else: dtype_size = 2 reg_v_size = reg_v_len * dtype_size try: if len(shape_indices) != 1 or (reg_v_size % 32 != 0): return False except RuntimeError: return False return True
对于可选类型的输入与输出,应首先判断是否为None,然后再进行shape与dtype的判断等操作。