下载
中文
注册

定义数据节点(Const)

权值、偏置等信息为常量Tensor,可以使用Const算子实现。

Const算子原型定义:

1
2
3
4
5
REG_OP(Const)
    .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16, DT_INT4, DT_INT8, DT_INT16, DT_UINT16, \
        DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
    .ATTR(value, Tensor, Tensor())
    .OP_END_FACTORY_REG(Const)

直接构造权重数据

根据Const算子原型定义创建Const算子实例,初始值(即属性value的值)为weighttensor1。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
//构造weighttensor1
TensorDesc weight_desc(ge::Shape({1,3,3,3}), FORMAT_NCHW, DT_INT8);
int bs_size_weight=27;
int8_t * bs_inputData_weight = nullptr;
bs_inputData_weight = new int8_t[bs_size_weight];
for (int i=0; i<bs_size_weight; ++i) {
    *(bs_inputData_weight+i) = 1;
}
Tensor weighttensor1(weight_desc, (uint8_t*)bs_inputData_weight, bs_size_weight*sizeof(int8_t));

//创建Const算子,初始值(即属性value的值)为weighttensor1
auto weight1 = op::Const().set_attr_value(weighttensor1);

如果某个算子的原型输入和输出同名,代表这个是inplace操作算子,即算子的输出会更新输入。该场景下这种输入不能连接Const节点。

从文件读入权重数据

除了直接构造权重数据外,也可以直接从bin文件读入权重数据。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
//构造weight_tensor
auto weight_shape = ge::Shape({ 5,17,1,1 });
TensorDesc desc_weight_1(weight_shape, FORMAT_NCHW, DT_INT8);
Tensor weight_tensor(desc_weight_1);
uint32_t weight_1_len = weight_shape.GetShapeSize();
// PATH+"const_0.bin" 为常量文件的路径
bool res = GetConstTensorFromBin(PATH+"const_0.bin", weight_tensor, weight_l_len*sizeof(int8_t));

//创建Const算子,初始值(即属性value的值)为weight_tensor
auto conv_weight = op::Const("const_0").set_attr_value(weight_tensor);

GetConstTensorFromBin函数实现:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
bool GetConstTensorFromBin(string path, Tensor &weight, uint32_t len) {
    ifstream in_file(path.c_str(), std::ios::in | std::ios::binary);
    if(!in_file.is_open()) {
        std::cout << "failed to open" << path.c_str() << '\n';
        return false;
    }
    in_file.seekg(0, ios_base::end);
    istream::pos_type file_size = in_file.tellg();
    in_file.seekg(0, ios_base::beg);

    if(len != file_size) {
        cout << "Invalid Param.len:" << len << " is not equal with binary size(" << file_size << ")\n";
        in_file.close();
        return false;
    }
    char* pdata = new(std::nothrow) char[len];
    if(pdata == nullptr) {
        cout << "Invalid Param.len:" << len << " is not equal with binary size(" << file_size << ")\n";
        in_file.close();
        return false;
    }
    in_file.read(reinterpret_cast<char*>(pdata), len);
    auto status = weight.SetData(reinterpret_cast<uint8_t*>(pdata), len);
    if(status != ge::GRAPH_SUCCESS) {
        cout << "Set Tensor Data Failed"<< "\n";
        in_file.close();
        return false;
    }
    in_file.close();
    return true;
}
GetConstTensorFromBin函数参数说明:
  • path:入参,指定权重文件路径,用于到固定目录例如“../data/weight/”下查找权重文件xx.bin,用户需要自行将权重文件解析为bin文件。
  • weight:出参,从权重文件中读取的Tensor类型的权重数据。
  • len:入参,指定权重数据大小。