下载
中文
注册

定义复杂计算算子(Conv2D)

下面以一个较复杂的Conv2D为例,介绍如何进行算子定义。

Conv2D算子原型定义:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
REG_OP(Conv2D)
    .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_BF16}))
    .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_BF16}))
    .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
    .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
    .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_BF16}))
    .REQUIRED_ATTR(strides, ListInt)
    .REQUIRED_ATTR(pads, ListInt)
    .ATTR(dilations, ListInt, {1, 1, 1, 1})
    .ATTR(groups, Int, 1)
    .ATTR(data_format, String, "NHWC")
    .ATTR(offset_x, Int, 0)
    .OP_END_FACTORY_REG(Conv2D)
从Conv2D算子原型定义可以看到,Conv2D算子包括:两个必选输入x和filter,两个可选输入bias和offset_w,两个必选属性strides、pads,四个可选属性dilations、groups、data_format、offset_x。则Conv2D算子定义的代码为:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
auto conv2d = op::Conv2D("Conv2d")
    // quant, conv_weight, conv_bias为三个输入节点
    .set_input_x(quant)
    .set_input_filter(conv_weight)
    .set_input_bias(conv_bias)
    .set_attr_strides({ 1, 1, 1, 1 })
    .set_attr_pads({ 0, 0, 0, 0 })
    .set_attr_dilations({ 1, 1, 1, 1 });

TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NCHW, DT_INT8);
TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_NCHW, DT_INT8);
TensorDesc conv2d_input_desc_bias(ge::Shape(), FORMAT_NCHW, DT_INT32);
TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NCHW, DT_INT32);
conv2d.update_input_desc_x(conv2d_input_desc_x);
conv2d.update_input_desc_filter(conv2d_input_desc_filter);
conv2d.update_input_desc_bias(conv2d_input_desc_bias);
conv2d.update_output_desc_y(conv2d_output_desc_y);

主要过程为:

  1. 调用算子类型构造函数,例如“Conv2D(const char* name)”创建算子实例,并传算子名称(例如Conv2d)作为入参。
    1
    auto conv2d1 = op::Conv2D("Conv2d")
    
  2. 调用“set_input_输入名称”接口设置算子的输入。
    1
    2
    3
        .set_input_x(data)
        .set_input_filter(conv_weight)
        .set_input_bias(conv_bias)
    

    data为整个graph的输入节点,通过Data算子构造,具体请参考定义数据节点(Data)

    conv_weight为常量数据,通过Const算子构造,具体请参考定义数据节点(Const)

    conv_bias为常量数据,通过Const算子构造,具体请参考定义数据节点(Const)

  3. 调用“set_attr_属性名称”接口设置算子的属性。
    1
    2
    3
    .set_attr_strides({1, 1, 1, 1})       //设置strides属性值
    .set_attr_pads({0, 0, 0, 0})          //设置pads属性值
    .set_attr_dilations({1, 1, 1, 1});    //设置dilations属性值
    
  4. 对于Conv2D等卷积类或对C轴处理敏感的算子,建议通过“update_input_desc_输入名称”接口将Format信息设置为NCHW或者NHWC等,具体和用户需要处理的Format格式保持一致。
    1
    2
    3
    4
    5
    6
    7
    8
    TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NCHW, DT_INT8);
    TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_NCHW, DT_INT8);
    TensorDesc conv2d_input_desc_bias(ge::Shape(), FORMAT_NCHW, DT_INT32);
    TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NCHW, DT_INT32);
    conv2d.update_input_desc_x(conv2d_input_desc_x);
    conv2d.update_input_desc_filter(conv2d_input_desc_filter);
    conv2d.update_input_desc_bias(conv2d_input_desc_bias);
    conv2d.update_output_desc_y(conv2d_output_desc_y);
    

    IR构图不支持输入以下FORMAT:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    FORMAT_NC1HWC0
    FORMAT_FRACTAL_Z
    FORMAT_NC1C0HWPAD
    FORMAT_NHWC1C0
    FORMAT_FRACTAL_DECONV
    FORMAT_C1HWNC0
    FORMAT_FRACTAL_DECONV_TRANSPOSE
    FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS
    FORMAT_NC1HWC0_C04
    FORMAT_FRACTAL_Z_C04
    FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS
    FORMAT_NC1KHKWHWC0
    FORMAT_C1HWNCoC0
    FORMAT_FRACTAL_ZZ
    FORMAT_FRACTAL_NZ
    FORMAT_NDC1HWC0
    FORMAT_FRACTAL_Z_3D
    FORMAT_FRACTAL_Z_3D_TRANSPOSE
    FORMAT_FRACTAL_ZN_LSTM
    FORMAT_FRACTAL_Z_G
    FORMAT_ND_RNN_BIAS
    FORMAT_FRACTAL_ZN_RNN
    FORMAT_NYUV
    FORMAT_NYUV_A