权重更新
接口调用流程
对于权重更新的场景,为便于用户一次编译模型后,在模型执行阶段能动态更新权重,可通过以下接口配合使用实现该功能:
- 基于Ascend Graph方式编译并保存模型,模型中包含推理图、权重初始化图、权重更新图三部分。
此处是调用aclgrphBundleBuildModel接口编译模型、调用aclgrphBundleSaveModel接口保存模型,接口详细描述参见《Ascend Graph开发指南》。
权重初始化是可选步骤,根据业务场景由用户判断是否需要包含权重初始化图,不包含的情况下,可节省模型加载所需的Device内存。
- 调用aclmdlBundleLoadFromFile或aclmdlBundleLoadFromMem接口加载模型。
- 调用aclmdlBundleGetModelId接口获取三个图的ID。
- 根据权重初始化图ID,调用模型执行接口(例如aclmdlExecute)执行权重初始化图。
- 若需更新权重,在执行权重更新图前,调用aclmdlSetDatasetTensorDesc接口设置图的tensor描述信息。
- 根据权重更新图ID,调用模型执行接口(例如aclmdlExecute)执行权重更新图。
- 根据推理图ID,调用模型执行接口(例如aclmdlExecute)执行推理图。
- 推理结束后,调用aclmdlBundleUnload接口卸载模型。
示例代码
本节中的示例重点介绍模型推理的代码逻辑,AscendCL初始化和去初始化请参见AscendCL初始化,运行管理资源申请与释放请参见运行管理资源申请与释放。
调用接口后,需增加异常处理的分支,并记录报错日志、提示日志,此处不一一列举。以下是关键步骤的代码示例,不可以直接拷贝编译运行,仅供参考。
// 1. 初始化资源 aclInit(nullptr); aclrtSetDevice(0); // 2. 加载基于Ascend Graph方式构建出来的模型,模型中包含推理图、权重初始化图、权重更新图,模型文件名以bundle.om为例 uint32_t bundle_id = 0; aclmdlBundleLoadFromFile("./bundle.om", &bundle_id); // 3. 获取模型中各个图的ID size_t modelNum = 0; aclmdlBundleGetModelNum(bundle_id, &modelNum); // 此处aclgrphBundleBuildModel接口入参是3张图,各个图的索引是固定的 uint32_t infer_id= 0; aclmdlBundleGetModelId(bundle_id, 0, &infer_id); uint32_t init_id= 0; aclmdlBundleGetModelId(bundle_id, 1, &init_id); uint32_t update_id= 0; aclmdlBundleGetModelId(bundle_id, 2, &update_id); // 若不需要更新权重,就执行执行权重初始化图和推理图 // 4.执行权重初始化图,准备模型输入、输出请参见模型推理下其它推理特性章节的示例代码 aclmdlExecute(init_id, init_mdl_input, init_mdl_output); // 5. 执行推理图,准备模型输入、输出请参见模型推理下其它推理特性章节的示例代码 aclmdlExecute(infer_id, infer_mdl_input, infer_mdl_output); // 若需要更新权重,则需要执行权重更新图之后,再执行推理图 // 6. 执行权重更新图 // 如果不需要更新某一个权重,比如第0个,shape可以传入空tensor,但device内存必须有效。 size_t no_need_refresh_index = 0; std::vector<int64_t> dims{0}; // dims数组中的元素为0,表示空tensor auto tensorDesc = aclCreateTensorDesc(ACL_FLOAT, dims.size(), dims.data(), ACL_FORMAT_ND); aclmdlSetDatasetTensorDesc(update_mdl_input, tensorDesc, no_need_refresh_index); // 若需要更新某一个权重,此处以更新第1个权重为例 size_t need_refresh_index = 1; std::vector<int64_t> dims{1, 3, 224, 224}; auto tensorDesc = aclCreateTensorDesc(ACL_FLOAT, dims.size(), dims.data(), ACL_FORMAT_ND); aclmdlSetDatasetTensorDesc(update_mdl_input, tensorDesc, need_refresh_index); // 执行权重更新图,准备模型输入、输出请参见模型推理下其它推理特性章节的示例代码 aclmdlExecute(update_id, update_mdl_input, update_mdl_output); // 8. 执行推理图,准备模型输入、输出请参见模型推理下其它推理特性章节的示例代码 aclmdlExecute(infer_id, infer_mdl_input, infer_mdl_output); // 9. 卸载捆绑模型 aclmdlBundleUnload(bundle_id); // 10. 释放资源 aclrtResetDevice(0); aclFinalize();
父主题: 模型推理