下载
中文
注册

(可选)Format推导及参数校验

op_select_format函数实现

开发者可以在算子实现文件中实现op_select_format函数,推导出算子的输入输出支持的dtype与format,后续进行算子信息库定义时将输入输出的dtype与format的dynamicFormat.flag配置为true即可,算子融合时会自动调用算子实现文件中的op_select_format函数进行dtype与format的设置,无需配置固定的dtype与format;若算子实现文件中不实现此函数,则后续进行算子信息库定义时需要配置输入输出支持的dtype与format列表。

算子信息库定义的配置可参见TBE算子信息库

op_select_format函数的声明如下所示:

def op_select_format(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="xx"):

op_select_format函数的入参和算子接口函数保持一致(即算子的输入、输出、属性及kernel_name),出参为包含了当前算子输入输出支持的format和data type列表的字符串,字符串格式如下所示:

{
"input0": {
"name": "x",
"dtype": "float16,float16,int8,int8",
"format": "NC1HWC0_C04,NC1HWC0,NC1HWC0_C04,NC1HWC0"
},
"input1": {
"name": "y",
"dtype": "float16,float16,int8,int8",
"format": "FRACTAL_Z_C04,FRACTAL_Z,FRACTAL_Z_C04,FRACTAL_Z"
},
"output0": {
"name": "z",
"dtype": "float16,float16,int32,int32",
"format": "NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0"
}
}

例如,conv2d算子的op_select_format函数实现如下:

import json
def op_select_format(inputs, weights, bias, offset_w, outputs, strides,
                     pads, dilations, groups=1, data_format='NHWC',
                     offset_x=0, kernel_name="conv2d"):
    shape_x = inputs.get("ori_shape")
    format_x = inputs.get("ori_format")
    shape_y = weights.get("ori_shape")
    format_y = weights.get("ori_format")
    x_dict = dict(zip(list(format_x), shape_x))
    y_dict = dict(zip(list(format_y), shape_y))

    use_c04 = False
    if x_dict["C"] <= 4 and (y_dict["W"] != 1 or y_dict["H"] != 1):
        use_c04 = True
    res = {}
    if use_c04:
        res["input0"] = {
                "name":"x",
                "dtype":"float16, float16, int8, int8",
                "format": "NC1HWC0, NC1HWC0_C04, NC1HWC0, NC1HWC0_C04"
            }
        res["input1"] = {
            "name":"filter",
            "dtype":"float16, float16, int8, int8",
            "format": "FRACTAL_Z, FRACTAL_Z_C04, FRACTAL_Z, FRACTAL_Z_C04"
        }
        res["input2"] = {
            "name":"bias",
            "dtype":"float16, float16, int32, int32",
            "format": "ND, ND, ND, ND"
        }
        res["input3"] = {
            "name":"offset_w",
            "dtype":"int8, int8, int8, int8",
            "format": "ND, ND, ND, ND"
        }
        res["output0"] = {
            "name":"filter",
            "dtype":"float16, float16, int8, int8",
            "format": "NC1HWC0, NC1HWC0, NC1HWC0, NC1HWC0"
        }
    else:
        res["input0"] = {
            "name":"x",
            "dtype":"float16, int8",
            "format": "NC1HWC0, NC1HWC0"
        }
        res["input1"] = {
            "name":"filter",
            "dtype":"float16, int8",
            "format": "FRACTAL_Z, FRACTAL_Z"
        }
        res["input2"] = {
            "name":"bias",
            "dtype":"float16, int32",
            "format": "ND, ND"
        }
        res["input3"] = {
            "name":"offset_w",
            "dtype":"int8, int8",
            "format": "ND, ND"
        }
        res["output0"] = {
            "name":"filter",
            "dtype":"float16, int8",
            "format": "NC1HWC0, NC1HWC0"
        }

    return json.dumps(res, indent=4)

check_supported函数实现

若开发者需要在算子融合阶段进行算子参数校验,则可在算子实现文件中实现check_supported函数,并在算子信息库定义文件中将配置项needCheckSupport的flag参数配置为true,算子信息库定义的配置可参见TBE算子信息库

若check_supported函数校验通过,则代表AI Core支持此算子参数,则会选择AI Core上相应的算子(即TBE算子)执行,否则,会选择AI CPU算子执行。

check_supported函数中可自定义实现算子输入输出dtype的校验以及shape的校验,函数声明如下所示:

def check_supported(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="xx"):

check_supported函数的入参和算子接口函数保持一致(即算子的输入、输出、属性及kernel_name)。

若校验成功,则返回True;若校验失败,则返回False。

例如,InTopK算子的check_supported函数实现如下,实现对输入参数的数据类型的校验。

def check_supported(predictions,targets,precision,k,kernel_name='in_top_k'):
    prediction_dtype = predictions.get("dtype").lower()
    target_dtype = targets.get("dtype").lower()
    if prediction_dtype != "float32":
        return False
    if target_dtype != "int32":
        return False

    return True

InplaceUpdate算子的check_supported函数实现如下,实现对输入参数的数据类型以及shape的校验。

def check_supported(x, indices, v, y, kernel_name="inplace_update"):
    shape_indices = indices.get("shape")
    shape_v = v.get("shape")
    dtype_v = v.get("dtype").lower()
    reg_v_len = 1
    for i in range(1, len(shape_v)):
        reg_v_len = reg_v_len * shape_v[i]

    if dtype_v in ("float32", "int32"):
        dtype_size = 4
    else:
        dtype_size = 2
    reg_v_size = reg_v_len * dtype_size

    try:
        if len(shape_indices) != 1 or (reg_v_size % 32 != 0):
            return False

    except RuntimeError:
        return False

    return True

对于可选类型的输入与输出,应首先判断是否为None,然后再进行shape与dtype的判断等操作。