定义复杂计算算子(Conv2D)
下面以一个较复杂的Conv2D为例,介绍如何进行算子定义。
Conv2D算子原型定义:
1 2 3 4 5 6 7 8 9 10 11 12 13 | REG_OP(Conv2D) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_BF16})) .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_BF16})) .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_BF16})) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(pads, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1}) .ATTR(groups, Int, 1) .ATTR(data_format, String, "NHWC") .ATTR(offset_x, Int, 0) .OP_END_FACTORY_REG(Conv2D) |
从Conv2D算子原型定义可以看到,Conv2D算子包括:两个必选输入x和filter,两个可选输入bias和offset_w,两个必选属性strides、pads,四个可选属性dilations、groups、data_format、offset_x。则Conv2D算子定义的代码为:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | auto conv2d = op::Conv2D("Conv2d") // quant, conv_weight, conv_bias为三个输入节点 .set_input_x(quant) .set_input_filter(conv_weight) .set_input_bias(conv_bias) .set_attr_strides({ 1, 1, 1, 1 }) .set_attr_pads({ 0, 0, 0, 0 }) .set_attr_dilations({ 1, 1, 1, 1 }); TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NCHW, DT_INT8); TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_NCHW, DT_INT8); TensorDesc conv2d_input_desc_bias(ge::Shape(), FORMAT_NCHW, DT_INT32); TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NCHW, DT_INT32); conv2d.update_input_desc_x(conv2d_input_desc_x); conv2d.update_input_desc_filter(conv2d_input_desc_filter); conv2d.update_input_desc_bias(conv2d_input_desc_bias); conv2d.update_output_desc_y(conv2d_output_desc_y); |
主要过程为:
- 调用算子类型构造函数,例如“Conv2D(const char* name)”创建算子实例,并传算子名称(例如Conv2d)作为入参。
1
auto conv2d1 = op::Conv2D("Conv2d")
- 调用“set_input_输入名称”接口设置算子的输入。
1 2 3
.set_input_x(data) .set_input_filter(conv_weight) .set_input_bias(conv_bias)
data为整个graph的输入节点,通过Data算子构造,具体请参考定义数据节点(Data)。
conv_weight为常量数据,通过Const算子构造,具体请参考定义数据节点(Const)。
conv_bias为常量数据,通过Const算子构造,具体请参考定义数据节点(Const)。
- 调用“set_attr_属性名称”接口设置算子的属性。
1 2 3
.set_attr_strides({1, 1, 1, 1}) //设置strides属性值 .set_attr_pads({0, 0, 0, 0}) //设置pads属性值 .set_attr_dilations({1, 1, 1, 1}); //设置dilations属性值
- 对于Conv2D等卷积类或对C轴处理敏感的算子,建议通过“update_input_desc_输入名称”接口将Format信息设置为NCHW或者NHWC等,具体和用户需要处理的Format格式保持一致。
1 2 3 4 5 6 7 8
TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NCHW, DT_INT8); TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_NCHW, DT_INT8); TensorDesc conv2d_input_desc_bias(ge::Shape(), FORMAT_NCHW, DT_INT32); TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NCHW, DT_INT32); conv2d.update_input_desc_x(conv2d_input_desc_x); conv2d.update_input_desc_filter(conv2d_input_desc_filter); conv2d.update_input_desc_bias(conv2d_input_desc_bias); conv2d.update_output_desc_y(conv2d_output_desc_y);
IR构图不支持输入以下FORMAT:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
FORMAT_NC1HWC0 FORMAT_FRACTAL_Z FORMAT_NC1C0HWPAD FORMAT_NHWC1C0 FORMAT_FRACTAL_DECONV FORMAT_C1HWNC0 FORMAT_FRACTAL_DECONV_TRANSPOSE FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS FORMAT_NC1HWC0_C04 FORMAT_FRACTAL_Z_C04 FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS FORMAT_NC1KHKWHWC0 FORMAT_C1HWNCoC0 FORMAT_FRACTAL_ZZ FORMAT_FRACTAL_NZ FORMAT_NDC1HWC0 FORMAT_FRACTAL_Z_3D FORMAT_FRACTAL_Z_3D_TRANSPOSE FORMAT_FRACTAL_ZN_LSTM FORMAT_FRACTAL_Z_G FORMAT_ND_RNN_BIAS FORMAT_FRACTAL_ZN_RNN FORMAT_NYUV FORMAT_NYUV_A
父主题: 各类算子表达