功能描述
实现“Transformer Attention Score”的融合计算,实现的计算公式如下:

接口原型
torch_npu.npu_fusion_attention(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=1, int[]? prefix=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False ) -> (Tensor, Tensor, Tensor, Tensor, int, int, int)
参数说明
- query:Device侧的aclTensor,公式中输入Q,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
- key:Device侧的aclTensor,公式中输入K,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
- value:Device侧的aclTensor,公式中输入V,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
- pse:Device侧的aclTensor,公式中输入pse,可选参数,表示位置编码。数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
- dropMask:Device侧的aclTensor,可选属性,数据类型支持UINT8(标识8个1bit BOOL),数据格式支持ND。
- paddingMask:Device侧的aclTensor,暂不支持该传参。
- attenMask:Device侧的aclTensor,可选属性,代表下三角全为0上三角全为负无穷的倒三角mask矩阵,数据类型支持BOOL(8bit的BOOL)、UINT8,数据格式支持ND。
- prefix:Device侧的aclTensor,可选属性,代表prefix稀疏计算场景每个Batch的N值。数据类型支持INT64,数据格式支持ND。
- scale:Host侧的double,公式中d开根号的倒数,代表缩放系数,作为计算流中Muls的scalar值,数据类型支持DOUBLE。
- keepProb:Host侧的double,可选参数,代表dropMask中1的比例,数据类型支持DOUBLE。
- preTokens:Host侧的int64_t,用于稀疏计算的参数,可选参数,数据类型支持INT64。
- nextTokens:Host侧的int64_t,用于稀疏计算的参数,可选参数,数据类型支持INT64。
- headNum:Host侧的int64_t,代表head个数,数据类型支持INT64。
- inputLayout:Host侧的string,代表输入query、key、value的数据排布格式,支持BSH、SBH、BSND、BNSD。
- innerPrecise:Host侧的int32_t,数据类型支持INT32,内部计算精度模式,其中0表示为高精度,1表示为高性能。
- sparseMode:Host侧的int,表示sparse的模式。数据类型支持:INT64。
- sparseMode为0时,代表defaultMask模式,如果attenmask未传入则不做mask操作,忽略preTokens和nextTokens(内部赋值为INT_MAX);如果传入,则需要传入完整的attenmask矩阵(S1 * S2),表示preTokens和nextTokens之间的部分需要计算。
- sparseMode为为1时,代表allMask,即传入完整的attenmask矩阵。。
- sparseMode为2时,代表leftUpCausal模式的mask,对应以左顶点为划分的下三角场景,需要传入优化后的attenmask矩阵(2048*2048)。
- sparseMode为3时,代表rightDownCausal模式的mask,对应以右下顶点为划分的下三角场景,需要传入优化后的attenmask矩阵(2048*2048)。
- sparseMode为为4时,代表band场景,即计算preTokens和nextTokens之间的部分。
- sparseMode为为5时,代表prefix场景,即在rightDownCasual的基础上,左侧加上一个长为S1,宽为N的矩阵,N的值由新增的输入prefix获取,且每个Batch轴的N值不一样。
- sparseMode为为6、7、8时,分别代表global、dilated、block_local,均暂不支持。
- gen_mask_parallel:debug参数,DSA生成dropout随机数向量mask的控制开关,默认值为True:同AICORE计算并行,False:同AICORE计算串行
- sync:debug参数,DSA生成dropout随机数向量mask的控制开关,默认值为False:dropout mask异步生成,True:dropout mask同步生成
输出说明
共7个输出
(Tensor, Tensor, Tensor, Tensor, int, int, int)
- 第1个输出为Tensor,计算公式的最终输出y。
- 第2个输出为Tensor,Softmax 计算的Max中间结果,用于反向计算。
- 第3个输出为Tensor,Softmax计算的Sum中间结果,用于反向计算。
- 第4个输出为Tensor,保留参数,暂未使用。
- 第5个输出为int,DSA生成dropoutmask中,Philox算法的seed。
- 第6个输出为int,DSA生成dropoutmask中,Philox算法的offset。
- 第7个输出为int,DSA生成dropoutmask的长度。
约束说明
- 输入query、key、value的B:batchsize必须相等。
- 输入query的N和key/value的N 必须成比例关系,即Nq/Nkv必须是非0整数,当Nq/Nkv > 1时,即为GQA,当Nkv=1时,即为MQA。
- 输入key/value的shape必须一致。
- 输入query、key、value的S:sequence length,取值范围1~32K,且为16的倍数。
- 输入query、key、value的D:head dim,取值范围64、80、96、120、128、256。
- 当pre_tockens<Sq 的时候, 使能band sparse计算,pre_tockens不能小于0。
- 当next_tockens<Skv的时候,使能bandsparse计算,next_tokens不能小于0。
- 当pre_tokens >= Sq,同时next_tokens=0时,使能causal计算。
- 在使能band sparse、causal计算时,必须输入atten_mask。
- 当所有的attenmask的shape小于2048且相同的时候,建议使用default模式,即sparse_mode配置为0,来减少内存使用量;sparse_mode配置为2或3时,禁止配置preTokens、nextTokens。
支持的PyTorch版本
- PyTorch 2.1
- PyTorch 2.0
- PyTorch 1.11.0
支持的型号
- Atlas 训练系列产品
- Atlas A2训练系列产品
调用示例
import math
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import get_npu_device
DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10]
class TestNPUFlashAttention(TestCase):
def supported_op_exec(self, query, key, value):
qk = torch.matmul(query, key.transpose(2, 3)).mul(0.08838)
softmax_res = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32).to(torch.float16)
output = torch.matmul(softmax_res, value)
output = output.transpose(1, 2)
output = output.reshape(output.shape[0], output.shape[1], -1)
return output
def custom_op_exec(self, query, key, value):
scale = 0.08838
return torch_npu.npu_fusion_attention(
query, key, value, head_num=32, input_layout="BSH", scale=scale)
def trans_BNSD2BSH(self, tensor: torch.Tensor):
tensor = torch.transpose(tensor, 1, 2)
tensor = torch.reshape(tensor, (tensor.shape[0], tensor.shape[1], -1))
return tensor
@unittest.skipIf(DEVICE_NAME != '<soc_version>',
"OP `FlashAttention` is only supported on <soc_version>, skip this ut for this device type!")
def test_npu_flash_attention(self, device="npu"):
query = torch.randn(1, 32, 128, 128, dtype=torch.float16)
key = torch.randn(1, 32, 128, 128, dtype=torch.float16)
value = torch.randn(1, 32, 128, 128, dtype=torch.float16)
q_FA = self.trans_BNSD2BSH(query).npu()
k_FA = self.trans_BNSD2BSH(key).npu()
v_FA = self.trans_BNSD2BSH(value).npu()
output = self.supported_op_exec(query, key, value)
attention_score, softmax_max, softmax_sum, softmax_out, seed, offset, numels = self.custom_op_exec(q_FA, k_FA, v_FA)
self.assertRtolEqual(output, attention_score, prec=0.005, prec16=0.005)
if __name__ == "__main__":
run_tests()