定义数据节点(Const)
权值、偏置等信息为常量Tensor,可以使用Const算子实现。
Const算子原型定义:
1 2 3 4 5 |
REG_OP(Const) .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16, DT_INT4, DT_INT8, DT_INT16, DT_UINT16, \ DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) .ATTR(value, Tensor, Tensor()) .OP_END_FACTORY_REG(Const) |
直接构造权重数据
根据Const算子原型定义创建Const算子实例,初始值(即属性value的值)为weighttensor1。
1 2 3 4 5 6 7 8 9 10 11 12 |
//构造weighttensor1 TensorDesc weight_desc(ge::Shape({1,3,3,3}), FORMAT_NCHW, DT_INT8); int bs_size_weight=27; int8_t * bs_inputData_weight = nullptr; bs_inputData_weight = new int8_t[bs_size_weight]; for (int i=0; i<bs_size_weight; ++i) { *(bs_inputData_weight+i) = 1; } Tensor weighttensor1(weight_desc, (uint8_t*)bs_inputData_weight, bs_size_weight*sizeof(int8_t)); //创建Const算子,初始值(即属性value的值)为weighttensor1 auto weight1 = op::Const().set_attr_value(weighttensor1); |
如果某个算子的原型输入和输出同名,代表这个是inplace操作算子,即算子的输出会更新输入。该场景下这种输入不能连接Const节点。
从文件读入权重数据
除了直接构造权重数据外,也可以直接从bin文件读入权重数据。
1 2 3 4 5 6 7 8 9 10 |
//构造weight_tensor auto weight_shape = ge::Shape({ 5,17,1,1 }); TensorDesc desc_weight_1(weight_shape, FORMAT_NCHW, DT_INT8); Tensor weight_tensor(desc_weight_1); uint32_t weight_1_len = weight_shape.GetShapeSize(); // PATH+"const_0.bin" 为常量文件的路径 bool res = GetConstTensorFromBin(PATH+"const_0.bin", weight_tensor, weight_l_len*sizeof(int8_t)); //创建Const算子,初始值(即属性value的值)为weight_tensor auto conv_weight = op::Const("const_0").set_attr_value(weight_tensor); |
GetConstTensorFromBin函数实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
bool GetConstTensorFromBin(string path, Tensor &weight, uint32_t len) { ifstream in_file(path.c_str(), std::ios::in | std::ios::binary); if(!in_file.is_open()) { std::cout << "failed to open" << path.c_str() << '\n'; return false; } in_file.seekg(0, ios_base::end); istream::pos_type file_size = in_file.tellg(); in_file.seekg(0, ios_base::beg); if(len != file_size) { cout << "Invalid Param.len:" << len << " is not equal with binary size(" << file_size << ")\n"; in_file.close(); return false; } char* pdata = new(std::nothrow) char[len]; if(pdata == nullptr) { cout << "Invalid Param.len:" << len << " is not equal with binary size(" << file_size << ")\n"; in_file.close(); return false; } in_file.read(reinterpret_cast<char*>(pdata), len); auto status = weight.SetData(reinterpret_cast<uint8_t*>(pdata), len); if(status != ge::GRAPH_SUCCESS) { cout << "Set Tensor Data Failed"<< "\n"; in_file.close(); return false; } in_file.close(); return true; } |
GetConstTensorFromBin函数参数说明:
- path:入参,指定权重文件路径,用于到固定目录例如“../data/weight/”下查找权重文件xx.bin,用户需要自行将权重文件解析为bin文件。
- weight:出参,从权重文件中读取的Tensor类型的权重数据。
- len:入参,指定权重数据大小。
父主题: 各类算子表达