aclnnLayerNormBackward
支持的产品型号
- Atlas 推理系列产品。
- Atlas 训练系列产品。
- Atlas A2训练系列产品/Atlas 800I A2推理产品。
接口原型
每个算子分为两段式接口,必须先调用“aclnnLayerNormBackwardGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnLayerNormBackward”接口执行计算。
aclnnStatus aclnnLayerNormBackwardGetWorkspaceSize(const aclTensor *gradOut, const aclTensor *input, const aclIntArray *normalizedShape, const aclTensor *mean, const aclTensor *rstd, const aclTensor *weightOptional, const aclTensor *biasOptional, const aclBoolArray *outputMask, aclTensor *gradInputOut, aclTensor *gradWeightOut, aclTensor *gradBiasOut, uint64_t *workspaceSize, aclOpExecutor **executor)
aclnnStatus aclnnLayerNormBackward(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
功能描述
算子功能:native_layer_norm的反向传播。
aclnnLayerNormBackwardGetWorkspaceSize
参数说明:
gradOut(aclTensor*,计算输入):反向计算的梯度tensor,与输入input的数据类型相同。shape与input的shape相等,为[A1,...,Ai,R1,...,Rj], shape长度大于等于normalizedShape的长度。支持非连续的Tensor,数据格式支持ND。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:FLOAT、FLOAT16、BFLOAT16
- Atlas 推理系列产品、Atlas 训练系列产品:FLOAT、FLOAT16
input(aclTensor*,计算输入):正向计算的首个输入,与输入gradOut的数据类型相同。shape与gradOut的shape相等,为[A1,...,Ai,R1,...,Rj], shape长度大于等于normalizedShape的长度。支持非连续的Tensor,数据格式支持ND。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:FLOAT、FLOAT16、BFLOAT16
- Atlas 推理系列产品、Atlas 训练系列产品:FLOAT、FLOAT16
normalizedShape(aclIntArray*,计算输入):表示需要进行norm计算的维度,数据类型支持INT64,shape为[R1,...,Rj], 长度小于等于输入input的长度,不支持为空。
mean(aclTensor*,计算输入):正向计算的第二个输出,表示input的均值,与输入rstd的数据类型相同且位宽不低于输入input的数据类型。shape与rstd的shape相等,为[A1,...,Ai,1,...,1],Ai后共有j个1,与需要norm的轴长度保持相同。支持非连续的Tensor,数据格式支持ND。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:FLOAT、FLOAT16、BFLOAT16
- Atlas 推理系列产品、Atlas 训练系列产品:FLOAT、FLOAT16
rstd(aclTensor*,计算输入):正向计算的第三个输出,表示input的标准差的倒数,与输入mean的数据类型相同且位宽不低于输入input的数据类型。shape与mean的shape相等,为[A1,...,Ai,1,...,1],Ai后共有j个1,与需要norm的轴长度保持相同。支持非连续的Tensor,数据格式支持ND。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:FLOAT、FLOAT16、BFLOAT16
- Atlas 推理系列产品、Atlas 训练系列产品:FLOAT、FLOAT16
weightOptional(aclTensor*,计算输入):权重tensor,可选参数。weightOptional非空时,数据类型与输入input一致或为FLOAT类型,且当biasOptional存在时与biasOptional的数据类型相同。shape与normalizedShape相等,为[R1,...,Rj]。支持非连续的Tensor,数据格式支持ND。weightOptional为空时,需要构造一个shape为[R1,...,Rj],数据类型与输入input相同,数据全为1的tensor。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:FLOAT、FLOAT16、BFLOAT16
- Atlas 推理系列产品、Atlas 训练系列产品:FLOAT、FLOAT16
biasOptional(aclTensor*,计算输入):偏置tensor,可选参数。biasOptional非空时,数据类型与输入input一致或为FLOAT类型,且当weightOptional存在时与weightOptional的数据类型相同。shape与normalizedShape相等,为[R1,...,Rj]。支持非连续的Tensor,数据格式支持ND。biasOptional为空时,不做任何处理。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:FLOAT、FLOAT16、BFLOAT16
- Atlas 推理系列产品、Atlas 训练系列产品:FLOAT、FLOAT16
outputMask(aclBoolArray*,计算输入):数据类型支持BOOL,长度固定为3,取值为True时表示对应位置的输出非空。
gradInputOut(aclTensor*,可选输出):由outputMask的第0个元素控制是否输出,outputMask第0个元素为True时会进行输出,此,与输入input的数据类型相同。shape与input的shape相等,为[A1,...,Ai,R1,...,Rj]。支持非连续的Tensor,数据格式支持ND。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:FLOAT、FLOAT16、BFLOAT16
- Atlas 推理系列产品、Atlas 训练系列产品:FLOAT、FLOAT16
gradWeightOut(aclTensor*,可选输出):由outputMask的第1个元素控制是否输出,outputMask第1个元素为True时会进行输出,此,与输入weightOptional的数据类型相同。shape与gradBiasOut的shape相等,为[R1,...,Rj]。支持非连续的Tensor,数据格式支持ND。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:FLOAT、FLOAT16、BFLOAT16
- Atlas 推理系列产品、Atlas 训练系列产品:FLOAT、FLOAT16
gradBiasOut(aclTensor*,可选输出):由outputMask的第2个元素控制是否输出,outputMask第2个元素为True时会进行输出,此,与输入weightOptional的数据类型相同。shape与gradWeightOut的shape相等,为[R1,...,Rj]。支持非连续的Tensor,数据格式支持ND。
- Atlas A2训练系列产品/Atlas 800I A2推理产品:FLOAT、FLOAT16、BFLOAT16
- Atlas 推理系列产品、Atlas 训练系列产品:FLOAT、FLOAT16
workspaceSize(uint64_t*, 出参):返回需要在Device侧申请的workspace大小。
executor(aclOpExecutor**, 出参):返回op执行器,包含了算子计算流程。
返回值:
aclnnStatus:返回状态码,具体参见aclnn返回码。
161001 (ACLNN_ERR_PARAM_NULLPTR): 1. 传入的gradOut、input、normalizedShape、mean、rstd、outputMask为空指针。
2. outputMask[0]为True且gradInputOut为空指针。
3. outputMask[1]为True且gradWeightOut为空指针。
4. outputMask[2]为True且gradBiasOut为空指针。
161002 (ACLNN_ERR_PARAM_INVALID): 1. gradOut、input、mean、rstd、weightOptional非空或biasOptional非空时的数据类型不在支持范围内。
2. gradOut的shape与input的shape不相等。
3. normalizedShape维度小于1维。
4. mean的shape乘积与input从第0根轴到第len(input) - len(normalizedShape)轴的乘积不相等。
5. rstd的shape乘积与input从第0根轴到第len(input) - len(normalizedShape)轴的乘积不相等。
6. weightOptional非空且shape与normalizedShape不相等。
7. biasOptional非空且shape与normalizedShape不相等。
8. input的维度小于normalizedShape的维度。
9. input的shape与normalizedShape右对齐时对应维度shape不相等。
10.outputMask的长度不为3。
aclnnLayerNormBackward
参数说明:
workspace(void*, 入参):在Device侧申请的workspace内存地址。
workspaceSize(uint64_t, 入参):在Device侧申请的workspace大小,由第一段接口aclnnLayerNormBackwardGetWorkspaceSize获取。
executor(aclOpExecutor*, 入参):op执行器,包含了算子计算流程。
stream(aclrtStream, 入参):指定执行任务的AscendCL Stream流。
返回值:
aclnnStatus:返回状态码,具体参见aclnn返回码。
约束与限制
无。
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。
#include <iostream>
#include <vector>
#include "acl/acl.h"
#include "aclnnop/aclnn_layer_norm_backward.h"
#define CHECK_RET(cond, return_expr) \
do { \
if (!(cond)) { \
return_expr; \
} \
} while (0)
#define LOG_PRINT(message, ...) \
do { \
printf(message, ##__VA_ARGS__); \
} while (0)
int64_t GetShapeSize(const std::vector<int64_t>& shape) {
int64_t shapeSize = 1;
for (auto i : shape) {
shapeSize *= i;
}
return shapeSize;
}
int Init(int32_t deviceId, aclrtStream* stream) {
// 固定写法,AscendCL初始化
auto ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
ret = aclrtSetDevice(deviceId);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
ret = aclrtCreateStream(stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
return 0;
}
template <typename T>
int CreateAclTensor(const std::vector<T>& hostData, const std::vector<int64_t>& shape, void** deviceAddr,
aclDataType dataType, aclTensor** tensor) {
auto size = GetShapeSize(shape) * sizeof(T);
// 调用aclrtMalloc申请device侧内存
auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
// 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
// 计算连续tensor的strides
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; i--) {
strides[i] = shape[i + 1] * strides[i + 1];
}
// 调用aclCreateTensor接口创建aclTensor
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
shape.data(), shape.size(), *deviceAddr);
return 0;
}
int main() {
// 1.(固定写法)device/stream初始化,参考AscendCL对外接口列表
// 根据自己的实际device填写deviceId
int32_t deviceId = 0;
aclrtStream stream;
auto ret = Init(deviceId, &stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret);
// 2. 构造输入与输出,需要根据API的接口自定义构造
std::vector<int64_t> xShape = {2, 2};
std::vector<int64_t> meanShape = {2, 1};
std::vector<int64_t> normShape = {2};
void* dyDeviceAddr = nullptr;
void* xDeviceAddr = nullptr;
void* meanDeviceAddr = nullptr;
void* rstdDeviceAddr = nullptr;
void* weightDeviceAddr = nullptr;
void* biasDeviceAddr = nullptr;
void* outDeviceAddr = nullptr;
void* dwDeviceAddr = nullptr;
void* dbDeviceAddr = nullptr;
aclTensor* dy = nullptr;
aclTensor* x = nullptr;
aclIntArray* norm = nullptr;
aclTensor* mean = nullptr;
aclTensor* rstd = nullptr;
aclTensor* weight = nullptr;
aclTensor* bias = nullptr;
aclBoolArray* mask = nullptr;
aclTensor* out = nullptr;
aclTensor* dw = nullptr;
aclTensor* db = nullptr;
std::vector<float> dyHostData = {2, 3, 4, 5};
std::vector<float> xHostData = {2, 3, 4, 5};
std::vector<int64_t> normData = {2};
std::vector<float> meanHostData = {2, 3};
std::vector<float> rstdHostData = {4, 5};
std::vector<float> weightHostData = {1, 1};
std::vector<float> biasHostData = {0, 0};
std::vector<float> outHostData(4, 0);
std::vector<float> dwHostData(2, 0);
std::vector<float> dbHostData(2, 0);
// 创建dy aclTensor
ret = CreateAclTensor(dyHostData, xShape, &dyDeviceAddr, aclDataType::ACL_FLOAT, &dy);
CHECK_RET(ret == ACL_SUCCESS, return ret);
// 创建x aclTensor
ret = CreateAclTensor(xHostData, xShape, &xDeviceAddr, aclDataType::ACL_FLOAT, &x);
CHECK_RET(ret == ACL_SUCCESS, return ret);
// 创建normalizedShape aclIntArray
norm = aclCreateIntArray(normData.data(), 1);
CHECK_RET(ret == ACL_SUCCESS, return false);
// 创建mean aclTensor
ret = CreateAclTensor(meanHostData, meanShape, &meanDeviceAddr, aclDataType::ACL_FLOAT, &mean);
CHECK_RET(ret == ACL_SUCCESS, return ret);
// 创建rstd aclTensor
ret = CreateAclTensor(rstdHostData, meanShape, &rstdDeviceAddr, aclDataType::ACL_FLOAT, &rstd);
CHECK_RET(ret == ACL_SUCCESS, return ret);
// 创建weight aclTensor
ret = CreateAclTensor(weightHostData, normShape, &weightDeviceAddr, aclDataType::ACL_FLOAT, &weight);
CHECK_RET(ret == ACL_SUCCESS, return ret);
// 创建bias aclTensor
ret = CreateAclTensor(biasHostData, normShape, &biasDeviceAddr, aclDataType::ACL_FLOAT, &bias);
CHECK_RET(ret == ACL_SUCCESS, return ret);
// 创建outputMask aclBoolArray
bool maskData[3] = {true, true, true};
mask = aclCreateBoolArray(&(maskData[0]), 3);
// 创建out aclTensor
ret = CreateAclTensor(outHostData, xShape, &outDeviceAddr, aclDataType::ACL_FLOAT, &out);
CHECK_RET(ret == ACL_SUCCESS, return ret);
// 创建dw aclTensor
ret = CreateAclTensor(dwHostData, normShape, &dwDeviceAddr, aclDataType::ACL_FLOAT, &dw);
CHECK_RET(ret == ACL_SUCCESS, return ret);
// 创建db aclTensor
ret = CreateAclTensor(dbHostData, normShape, &dbDeviceAddr, aclDataType::ACL_FLOAT, &db);
CHECK_RET(ret == ACL_SUCCESS, return ret);
// 3. 调用CANN算子库API,需要修改为具体的Api名称
uint64_t workspaceSize = 0;
aclOpExecutor* executor;
// 调用aclnnLayerNormBackward第一段接口
ret = aclnnLayerNormBackwardGetWorkspaceSize(dy, x, norm, mean, rstd, weight, bias, mask, out, dw, db, &workspaceSize, &executor);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnLayerNormBackwardGetWorkspaceSize failed. ERROR: %d\n", ret); return ret);
// 根据第一段接口计算出的workspaceSize申请device内存
void* workspaceAddr = nullptr;
if (workspaceSize > 0) {
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret);
}
// 调用aclnnLayerNormBackward第二段接口
ret = aclnnLayerNormBackward(workspaceAddr, workspaceSize, executor, stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnLayerNormBackward failed. ERROR: %d\n", ret); return ret);
// 4.(固定写法)同步等待任务执行结束
ret = aclrtSynchronizeStream(stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret);
// 5. 获取输出的值,将device侧内存上的结果拷贝至host侧,需要根据具体API的接口定义修改
auto size = GetShapeSize(xShape);
std::vector<float> resultData(size, 0);
ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), outDeviceAddr,
size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret);
for (int64_t i = 0; i < size; i++) {
LOG_PRINT("out result[%ld] is: %f\n", i, resultData[i]);
}
auto size1 = GetShapeSize(normShape);
std::vector<float> resultData1(size1, 0);
ret = aclrtMemcpy(resultData1.data(), resultData1.size() * sizeof(resultData1[0]), dwDeviceAddr,
size1 * sizeof(resultData1[0]), ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret);
for (int64_t i = 0; i < size1; i++) {
LOG_PRINT("dw result[%ld] is: %f\n", i, resultData1[i]);
}
auto size2 = GetShapeSize(normShape);
std::vector<float> resultData2(size2, 0);
ret = aclrtMemcpy(resultData2.data(), resultData2.size() * sizeof(resultData2[0]), dbDeviceAddr,
size2 * sizeof(resultData2[0]), ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret);
for (int64_t i = 0; i < size2; i++) {
LOG_PRINT("db result[%ld] is: %f\n", i, resultData2[i]);
}
// 6. 释放aclTensor、aclIntArray和aclBoolArray,需要根据具体API的接口定义修改
aclDestroyTensor(dy);
aclDestroyTensor(x);
aclDestroyIntArray(norm);
aclDestroyTensor(mean);
aclDestroyTensor(rstd);
aclDestroyTensor(weight);
aclDestroyTensor(bias);
aclDestroyBoolArray(mask);
aclDestroyTensor(out);
aclDestroyTensor(dw);
aclDestroyTensor(db);
// 7. 释放device 资源
aclrtFree(dyDeviceAddr);
aclrtFree(xDeviceAddr);
aclrtFree(meanDeviceAddr);
aclrtFree(rstdDeviceAddr);
aclrtFree(weightDeviceAddr);
aclrtFree(biasDeviceAddr);
aclrtFree(outDeviceAddr);
aclrtFree(dwDeviceAddr);
aclrtFree(dbDeviceAddr);
if (workspaceSize > 0) {
aclrtFree(workspaceAddr);
}
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
return 0;
}