下载
中文
注册

封装成AscendCL接口的GEMM算子

单算子执行的基本流程依次参见主要接口调用流程算子调用

基本原理

目前,AscendCL已将矩阵-向量乘、矩阵-矩阵乘相关的GEMM算子封装成AscendCL接口,详细接口请参见CBLAS接口,本章以aclblasGemmEx接口为例,该示例中矩阵乘的计算公式为:C = αAB + βC。

示例代码

示例代码如下,您可以从acl_execute_gemm样例中查看完整样例代码。

调用接口后,需增加异常处理的分支,并记录报错日志、提示日志,此处不一一列举。以下是关键步骤的代码示例,不可以直接拷贝编译运行,仅供参考。

//1. AscendCL初始化
//2. 运行管理资源申请,使用默认Context、Stream

//3. 设置单算子模型文件所在的目录
//该目录相对可执行文件所在的目录,例如,编译出来的可执行文件存放在run/out目录下,此处就表示run/out/op_models目录
aclopSetModelDir("op_models");

//4. 申请内存
//申请Device上的内存存放执行算子的输入数据
//对于该矩阵乘示例,依次申请存放矩阵A数据、矩阵B数据、矩阵C数据、标量α数据、标量β数据的内存
aclrtMalloc((void **) &devMatrixA_, sizeA_, ACL_MEM_MALLOC_NORMAL_ONLY);
aclrtMalloc((void **) &devMatrixB_, sizeB_, ACL_MEM_MALLOC_NORMAL_ONLY);
aclrtMalloc((void **) &devMatrixC_, sizeC_, ACL_MEM_MALLOC_NORMAL_ONLY);
aclrtMalloc((void **) &devAlpha_, sizeAlphaBeta_, ACL_MEM_MALLOC_NORMAL_ONLY);
aclrtMalloc((void **) &devBeta_, sizeAlphaBeta_, ACL_MEM_MALLOC_NORMAL_ONLY);

//申请Host上的内存,此处根据软件栈的运行模式判断是否需要申请Host上的内存
//在申请运行管理资源时调用aclrtGetRunMode接口获取软件栈的运行模式
//如果运行模式为ACL_DEVICE,则g_isDevice参数值为true,表示软件栈运行在Device侧,无需申请Host内存,无需传输图片数据或在Device内传输数据 
//如果运行模式为ACL_HOST,则g_isDevice参数值为false,表示软件栈运行在Host侧,需要申请Host内存,涉及Host和Device之间的数据传输
if (g_isDevice) {
        hostMatrixA_ = devMatrixA_;
        hostMatrixB_ = devMatrixB_;
        hostMatrixC_ = devMatrixC_;
    } else {
        aclrtMallocHost((void **) &hostMatrixA_, sizeA_);
        aclrtMallocHost((void **) &hostMatrixB_, sizeB_);
        aclrtMallocHost((void **) &hostMatrixC_, sizeC_);
    }

//5. 准备输入数据,ReadFile为自定义函数,由用户自行管理,从文件中读入数据到内存中
size_t fileSize;
// Read matrix A
char *fileData = ReadFile("test_data/data/matrix_a.bin", fileSize, hostMatrixA_, sizeA_);
// Read matrix B
fileData = ReadFile("test_data/data/matrix_b.bin", fileSize, hostMatrixB_, sizeB_);
// Read matrix C
fileData = ReadFile("test_data/data/matrix_c.bin", fileSize, hostMatrixC_, sizeC_);
//根据软件栈的运行模式判断是否涉及Host与Device之间的数据传输
if (!g_isDevice) {
    aclError ret = aclrtMemcpy(devMatrixA_, sizeA_, hostMatrixA_, sizeA_, ACL_MEMCPY_HOST_TO_DEVICE);
    ret = aclrtMemcpy(devMatrixB_, sizeB_, hostMatrixB_, sizeB_, ACL_MEMCPY_HOST_TO_DEVICE);
    ret = aclrtMemcpy(devMatrixC_, sizeC_, hostMatrixC_, sizeC_, ACL_MEMCPY_HOST_TO_DEVICE);
}

aclrtMemcpyKind kind = g_isDevice ? ACL_MEMCPY_DEVICE_TO_DEVICE : ACL_MEMCPY_HOST_TO_DEVICE;
ret = aclrtMemcpy(devAlpha_, sizeAlphaBeta_, hostAlpha_, sizeAlphaBeta_, kind);
ret = aclrtMemcpy(devBeta_, sizeAlphaBeta_, hostBeta_, sizeAlphaBeta_, kind);

//6. 执行单算子
//对于该示例,调用aclblasGemmEx接口(异步接口)实现矩阵-矩阵的乘法
aclblasGemmEx(ACL_TRANS_N, ACL_TRANS_N, ACL_TRANS_N, m_, n_, k_,
                    devAlpha_, devMatrixA_, k_, inputType_, devMatrixB_, n_, inputType_,
                    devBeta_, devMatrixC_, n_, outputType_, ACL_COMPUTE_HIGH_PRECISION,
                    stream);
//调用aclrtSynchronizeStream接口阻塞Host运行,直到指定Stream中的所有任务都完成
aclrtSynchronizeStream(stream);

//7. 传输算子执行结果,根据软件栈的运行模式判断是否涉及Host与Device之间的数据传输
if (!g_isDevice) {
        auto ret = aclrtMemcpy(hostMatrixC_, sizeC_, devMatrixC_, sizeC_, ACL_MEMCPY_DEVICE_TO_HOST);
}

//8. 是否直接在终端屏幕上显示算子执行结果,由用户自行管理代码逻辑

//9. 释放运行管理资源
//10. AscendCL去初始化

//......