下载
中文
注册

定义数据节点(Data)

Graph的输入节点,也就是数据节点,使用Data算子实现。

Data算子原型定义:

1
2
3
4
5
REG_OP(Data)
    .INPUT(x, TensorType::ALL())
    .OUTPUT(y, TensorType::ALL())
    .ATTR(index, Int, 0)
    .OP_END_FACTORY_REG(Data)

根据Data算子原型定义创建Data算子实例,名称为data,初始化参数为desc_data。同时通过“update_input_desc_输入名称”和“update_output_desc_输出名称”接口设置Shape、Format和Dtype,和用户需要处理的数据信息保持一致。

1
2
3
4
5
auto shape_data = vector<int64_t>({1,17,2,2});
TensorDesc desc_data(ge::Shape(shape_data), FORMAT_ND, DT_FLOAT);
auto data = op::Data("data");                //创建data算子
data.update_input_desc_x(desc_data);         //设置算子输入描述
data.update_output_desc_y(desc_data);        //设置算子输出描述

需要注意的是,定义数据节点时,必须通过“update_input_desc_输入名称”和“update_output_desc_输出名称”接口设置Shape、Format和Dtype。