aclnnRmsNormGrad
支持的产品型号
- Atlas 推理系列产品。
- Atlas A2训练系列产品/Atlas 800I A2推理产品。
接口原型
每个算子分为两段式接口,必须先调用“aclnnRmsNormGradGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnRmsNormGrad”接口执行计算。
aclnnStatus aclnnRmsNormGradGetWorkspaceSize( const aclTensor *dy, const aclTensor *x, const aclTensor *rstd, const aclTensor *gamma, const aclTensor *dxOut, const aclTensor *dgammaOut, uint64_t *workspaceSize, aclOpExecutor **executor)
aclnnStatus aclnnRmsNormGrad( void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
功能描述
算子功能:aclnnRmsNorm的反向计算。
算子公式:
- 正向公式:
- 反向推导:
aclnnRmsNormGradGetWorkspaceSize
参数说明:
- dy(aclTensor*,计算输入):Device侧的aclTensor,表示反向传回的梯度。数据格式支持ND,shape支持1-8维度。
- Atlas 推理系列产品:数据类型支持FLOAT32,FLOAT16。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:数据类型支持FLOAT32,FLOAT16,BFLOAT16。
- x(aclTensor*,计算输入):Host侧的aclTensor,正向算子的输入,表示被标准化的数据。数据格式支持ND,shape支持1-8维度,且与入参
dy
的shape一致。- Atlas 推理系列产品:数据类型支持FLOAT32,FLOAT16。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:数据类型支持FLOAT32,FLOAT16,BFLOAT16。
- rstd(aclTensor*,计算输入):Host侧的aclTensor,正向算子的中间计算结果。数据类型支持FLOAT32。数据格式支持ND,shape支持1-8维度,shape需要满足rstd_shape = x_shape[0:n],n < x_shape.dims(),n与gamma一致。
- gamma(aclTensor*,计算输入):Host侧的aclTensor,正向算子的输入。数据格式支持ND,shape支持1-8维度,shape需要满足gamma_shape = x_shape[n:], n < x_shape.dims()。
- Atlas 推理系列产品:数据类型支持FLOAT32,FLOAT16。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:数据类型支持FLOAT32,FLOAT16,BFLOAT16。
- dxOut(aclTensor*,计算输出):Host侧的aclTensor,表示输入
x
的梯度。数据格式支持ND,shape支持1-8维度,shape与入参dy
的shape保持一致。- Atlas 推理系列产品:数据类型支持FLOAT32,FLOAT16。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:数据类型支持FLOAT32,FLOAT16,BFLOAT16。
- dgammaOut(aclTensor*,计算输出):Device侧的aclTensor,表示
gamma
的梯度。数据类型支持FLOAT32。数据格式支持ND,shape支持1-8维度,shape与入参gamma
的shape保持一致。 - workspaceSize(uint64_t*,出参):返回用户需要在Device侧申请的workspace大小。
- executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。
- dy(aclTensor*,计算输入):Device侧的aclTensor,表示反向传回的梯度。数据格式支持ND,shape支持1-8维度。
返回值:
aclnnStatus:返回状态码,具体参见aclnn返回码。
第一段接口完成入参校验,出现以下场景时报错:
返回161001 (ACLNN_ERR_PARAM_NULLPTR):如果传入参数是必选输入,输出或者必选属性,且是空指针,则返回161001。
aclnnRmsNormGrad
参数说明:
- workspace(void*,入参):在Device侧申请的workspace内存地址。
- workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnRmsNormGradGetWorkspaceSize获取。
- executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
- stream(aclrtStream,入参):指定执行任务的AscendCL Stream流。
返回值: aclnnStatus:返回状态码。(具体参见aclnn返回码)
约束与限制
- 功能维度
- 数据类型支持
- 入参
dy
、x
、gamma
支持:- Atlas 推理系列产品:数据类型支持FLOAT32、FLOAT16。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:数据类型支持FLOAT32、FLOAT16、BFLOAT16。
- 入参
rstd
支持:FLOAT32。
- 入参
- Atlas 推理系列产品:
x
、dy
、gamma
输入的尾轴长度必须大于等于 32 Bytes。 - 数据格式支持:ND。
- 数据类型支持
- 未支持类型说明
- DOUBLE:指令不支持DOUBLE。
- 是否支持空Tensor:不支持空进空出。
- 是否非连续的Tensor:不支持输入非连续,不支持数据非连续。
- 边界值场景说明
- Atlas 推理系列产品:输入不支持包含inf和nan 各产品支持数据类型说明
- Atlas A2训练系列产品/Atlas 800I A2推理产品:
dy
数据类型x
数据类型rstd
数据类型gamma
数据类型dx
数据类型dgamma
数据类型float16 float16 float32 float32 float16 float32 bfloat16 bfloat16 float32 float32 bfloat16 float32 float16 float16 float32 float16 float16 float32 float32 float32 float32 float32 float32 float32 bfloat16 bfloat16 float32 bfloat16 bfloat16 float32 - Atlas 推理系列产品:
dy
数据类型x
数据类型rstd
数据类型gamma
数据类型dx
数据类型dgamma
数据类型float16 float16 float32 float16 float16 float32 float32 float32 float32 float32 float32 float32
调用示例
#include <iostream>
#include <vector>
#include "acl/acl.h"
#include "aclnnop/aclnn_rms_norm_grad.h"
#define CHECK_RET(cond, return_expr) \
do { \
if (!(cond)) { \
return_expr; \
} \
} while (0)
#define LOG_PRINT(message, ...) \
do { \
printf(message, ##__VA_ARGS__); \
} while (0)
int64_t GetShapeSize(const std::vector<int64_t>& shape) {
int64_t shape_size = 1;
for (auto i : shape) {
shape_size *= i;
}
return shape_size;
}
int Init(int32_t deviceId, aclrtStream* stream) {
// 固定写法,AscendCL初始化
auto ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
ret = aclrtSetDevice(deviceId);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
ret = aclrtCreateStream(stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
return 0;
}
template <typename T>
int CreateAclTensor(const std::vector<T>& hostData, const std::vector<int64_t>& shape, void** deviceAddr,
aclDataType dataType, aclTensor** tensor) {
auto size = GetShapeSize(shape) * sizeof(T);
// 调用aclrtMalloc申请device侧内存
auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
// 调用aclrtMemcpy将host侧数据复制到device侧内存上
ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
// 计算连续tensor的strides
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; i--) {
strides[i] = shape[i + 1] * strides[i + 1];
}
// 调用aclCreateTensor接口创建aclTensor
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_NCDHW,
shape.data(), shape.size(), *deviceAddr);
return 0;
}
int main() {
// 1. (固定写法)device/stream初始化, 参考acl对外接口列表
// 根据自己的实际device填写deviceId
int32_t deviceId = 0;
aclrtStream stream;
auto ret = Init(deviceId, &stream);
// check根据自己的需要处理
CHECK_RET(ret == 0, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret);
// 2. 构造输入与输出,需要根据API的接口自定义构造
std::vector<int64_t> gradInputShape = {2, 16};
std::vector<int64_t> xInputShape = {2, 16};
std::vector<int64_t> rstdInputShape = {2};
std::vector<int64_t> gammaInputShape = {16};
std::vector<int64_t> dxOutputShape = {2, 16};
std::vector<int64_t> dgammaOutputShape = {16};
void* gradInputDeviceAddr = nullptr;
void* xInputDeviceAddr = nullptr;
void* rstdInputDeviceAddr = nullptr;
void* gammaInputDeviceAddr = nullptr;
void* dxOutDeviceAddr = nullptr;
void* dgammaOutDeviceAddr = nullptr;
aclTensor* gradInput = nullptr;
aclTensor* xInput = nullptr;
aclTensor* rstdInput = nullptr;
aclTensor* gammaInput = nullptr;
aclTensor* dxOut = nullptr;
aclTensor* dgammaOut = nullptr;
std::vector<float> gradInputHostData = {1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32};
std::vector<float> xInputHostData = {1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32};
std::vector<float> rstdInputHostData = {1,2};
std::vector<float> gammaInputHostData = {1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16};
std::vector<float> dxOutHostData = {1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32};
std::vector<float> dgammaOutHostData = {1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16};
std::vector<int64_t> output1SizeData = {2, 16};
std::vector<int64_t> output2SizeData = {16};
std::vector<int64_t> input1SizeData = {2, 16};
std::vector<int64_t> input2SizeData = {2};
std::vector<int64_t> input3SizeData = {16};
ret = CreateAclTensor(gradInputHostData, input1SizeData , &gradInputDeviceAddr, aclDataType::ACL_FLOAT, &gradInput);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(xInputHostData , input1SizeData , &xInputDeviceAddr, aclDataType::ACL_FLOAT, &xInput );
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(rstdInputHostData , input2SizeData , &rstdInputDeviceAddr, aclDataType::ACL_FLOAT, &rstdInput );
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gammaInputHostData , input3SizeData , &gammaInputDeviceAddr , aclDataType::ACL_FLOAT, &gammaInput );
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(dxOutHostData , output1SizeData , &dxOutDeviceAddr , aclDataType::ACL_FLOAT, &dxOut );
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(dgammaOutHostData , output2SizeData , &dgammaOutDeviceAddr , aclDataType::ACL_FLOAT, &dgammaOut );
CHECK_RET(ret == ACL_SUCCESS, return ret);
// 3. 调用CANN算子库API,需要修改为具体的API
uint64_t workspaceSize = 0;
aclOpExecutor* executor;
// 调用aclnnRmsNormGrad第一段接口
ret = aclnnRmsNormGradGetWorkspaceSize(gradInput, xInput ,rstdInput , gammaInput ,dxOut ,dgammaOut, &workspaceSize, &executor);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnRmsNormGradGetWorkspaceSize failed. ERROR: %d\n", ret); return ret);
// 根据第一段接口计算出的workspaceSize申请device内存
void* workspaceAddr = nullptr;
if (workspaceSize > 0) {
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret;);
}
// 调用aclnnRmsNormGrad第二段接口
ret = aclnnRmsNormGrad(workspaceAddr, workspaceSize, executor, stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnRmsNormGrad failed. ERROR: %d\n", ret); return ret);
// 4. (固定写法)同步等待任务执行结束
ret = aclrtSynchronizeStream(stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret);
// 5. 获取输出的值,将device侧内存上的结果复制至host侧,需要根据具体API的接口定义修改
auto size_dx = GetShapeSize(gradInputShape);
std::vector<float> resultData1(size_dx, 0);
ret = aclrtMemcpy(resultData1.data(), resultData1.size() * sizeof(resultData1[0]), dxOutDeviceAddr, size_dx * sizeof(float),
ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret);
for (int64_t i = 0; i < size_dx; i++) {
LOG_PRINT("result[%ld] is: %f\n", i, resultData1[i]);
}
auto size_dgamma = GetShapeSize(gammaInputShape);
std::vector<float> resultData2(size_dgamma, 1);
ret = aclrtMemcpy(resultData2.data(), resultData2.size() * sizeof(resultData2[0]), dgammaOutDeviceAddr , size_dgamma * sizeof(float),
ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret);
for (int64_t i = 0; i < size_dgamma; i++) {
LOG_PRINT("result[%ld] is: %f\n", i, resultData2[i]);
}
// 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改
aclDestroyTensor(gradInput);
aclDestroyTensor(xInput);
aclDestroyTensor(rstdInput);
aclDestroyTensor(gammaInput);
aclDestroyTensor(dxOut);
aclDestroyTensor(dgammaOut);
// 7. 释放device资源,需要根据具体API的接口定义修改
aclrtFree(gradInputDeviceAddr);
aclrtFree(xInputDeviceAddr);
aclrtFree(rstdInputDeviceAddr);
aclrtFree(gammaInputDeviceAddr);
aclrtFree(dxOutDeviceAddr);
aclrtFree(dgammaOutDeviceAddr);
if (workspaceSize > 0) {
aclrtFree(workspaceAddr);
}
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
return 0;
}