封装成AscendCL接口的GEMM算子
基本原理
目前,AscendCL已将矩阵-向量乘、矩阵-矩阵乘相关的GEMM算子封装成AscendCL接口,详细接口请参见CBLAS接口,本章以aclblasGemmEx接口为例,该示例中矩阵乘的计算公式为:C = αAB + βC。
示例代码
示例代码如下,您可以从acl_execute_gemm样例中查看完整样例代码。
调用接口后,需增加异常处理的分支,并记录报错日志、提示日志,此处不一一列举。以下是关键步骤的代码示例,不可以直接拷贝编译运行,仅供参考。
//1. AscendCL初始化 //2. 运行管理资源申请,使用默认Context、Stream //3. 设置单算子模型文件所在的目录 //该目录相对可执行文件所在的目录,例如,编译出来的可执行文件存放在run/out目录下,此处就表示run/out/op_models目录 aclopSetModelDir("op_models"); //4. 申请内存 //申请Device上的内存存放执行算子的输入数据 //对于该矩阵乘示例,依次申请存放矩阵A数据、矩阵B数据、矩阵C数据、标量α数据、标量β数据的内存 aclrtMalloc((void **) &devMatrixA_, sizeA_, ACL_MEM_MALLOC_NORMAL_ONLY); aclrtMalloc((void **) &devMatrixB_, sizeB_, ACL_MEM_MALLOC_NORMAL_ONLY); aclrtMalloc((void **) &devMatrixC_, sizeC_, ACL_MEM_MALLOC_NORMAL_ONLY); aclrtMalloc((void **) &devAlpha_, sizeAlphaBeta_, ACL_MEM_MALLOC_NORMAL_ONLY); aclrtMalloc((void **) &devBeta_, sizeAlphaBeta_, ACL_MEM_MALLOC_NORMAL_ONLY); //申请Host上的内存,此处根据软件栈的运行模式判断是否需要申请Host上的内存 //在申请运行管理资源时调用aclrtGetRunMode接口获取软件栈的运行模式 //如果运行模式为ACL_DEVICE,则g_isDevice参数值为true,表示软件栈运行在Device侧,无需申请Host内存,无需传输图片数据或在Device内传输数据 //如果运行模式为ACL_HOST,则g_isDevice参数值为false,表示软件栈运行在Host侧,需要申请Host内存,涉及Host和Device之间的数据传输 if (g_isDevice) { hostMatrixA_ = devMatrixA_; hostMatrixB_ = devMatrixB_; hostMatrixC_ = devMatrixC_; } else { aclrtMallocHost((void **) &hostMatrixA_, sizeA_); aclrtMallocHost((void **) &hostMatrixB_, sizeB_); aclrtMallocHost((void **) &hostMatrixC_, sizeC_); } //5. 准备输入数据,ReadFile为自定义函数,由用户自行管理,从文件中读入数据到内存中 size_t fileSize; // Read matrix A char *fileData = ReadFile("test_data/data/matrix_a.bin", fileSize, hostMatrixA_, sizeA_); // Read matrix B fileData = ReadFile("test_data/data/matrix_b.bin", fileSize, hostMatrixB_, sizeB_); // Read matrix C fileData = ReadFile("test_data/data/matrix_c.bin", fileSize, hostMatrixC_, sizeC_); //根据软件栈的运行模式判断是否涉及Host与Device之间的数据传输 if (!g_isDevice) { aclError ret = aclrtMemcpy(devMatrixA_, sizeA_, hostMatrixA_, sizeA_, ACL_MEMCPY_HOST_TO_DEVICE); ret = aclrtMemcpy(devMatrixB_, sizeB_, hostMatrixB_, sizeB_, ACL_MEMCPY_HOST_TO_DEVICE); ret = aclrtMemcpy(devMatrixC_, sizeC_, hostMatrixC_, sizeC_, ACL_MEMCPY_HOST_TO_DEVICE); } aclrtMemcpyKind kind = g_isDevice ? ACL_MEMCPY_DEVICE_TO_DEVICE : ACL_MEMCPY_HOST_TO_DEVICE; ret = aclrtMemcpy(devAlpha_, sizeAlphaBeta_, hostAlpha_, sizeAlphaBeta_, kind); ret = aclrtMemcpy(devBeta_, sizeAlphaBeta_, hostBeta_, sizeAlphaBeta_, kind); //6. 执行单算子 //对于该示例,调用aclblasGemmEx接口(异步接口)实现矩阵-矩阵的乘法 aclblasGemmEx(ACL_TRANS_N, ACL_TRANS_N, ACL_TRANS_N, m_, n_, k_, devAlpha_, devMatrixA_, k_, inputType_, devMatrixB_, n_, inputType_, devBeta_, devMatrixC_, n_, outputType_, ACL_COMPUTE_HIGH_PRECISION, stream); //调用aclrtSynchronizeStream接口阻塞Host运行,直到指定Stream中的所有任务都完成 aclrtSynchronizeStream(stream); //7. 传输算子执行结果,根据软件栈的运行模式判断是否涉及Host与Device之间的数据传输 if (!g_isDevice) { auto ret = aclrtMemcpy(hostMatrixC_, sizeC_, devMatrixC_, sizeC_, ACL_MEMCPY_DEVICE_TO_HOST); } //8. 是否直接在终端屏幕上显示算子执行结果,由用户自行管理代码逻辑 //9. 释放运行管理资源 //10. AscendCL去初始化 //......
父主题: 单算子调用