定义数据节点(Data)
Graph的输入节点,也就是数据节点,使用Data算子实现。
Data算子原型定义:
1 2 3 4 5 |
REG_OP(Data) .INPUT(x, TensorType::ALL()) .OUTPUT(y, TensorType::ALL()) .ATTR(index, Int, 0) .OP_END_FACTORY_REG(Data) |
根据Data算子原型定义创建Data算子实例,名称为data,初始化参数为desc_data。同时通过“update_input_desc_输入名称”和“update_output_desc_输出名称”接口设置Shape、Format和Dtype,和用户需要处理的数据信息保持一致。
1 2 3 4 5 |
auto shape_data = vector<int64_t>({1,17,2,2}); TensorDesc desc_data(ge::Shape(shape_data), FORMAT_ND, DT_FLOAT); auto data = op::Data("data"); //创建data算子 data.update_input_desc_x(desc_data); //设置算子输入描述 data.update_output_desc_y(desc_data); //设置算子输出描述 |
需要注意的是,定义数据节点时,必须通过“update_input_desc_输入名称”和“update_output_desc_输出名称”接口设置Shape、Format和Dtype。
父主题: 各类算子表达