torch_npu.npu_fused_infer_attention_score
功能描述
接口原型
npu_fused_infer_attention_score(Tensor query, Tensor key, Tensor value, *, Tensor? pse_shift=None, Tensor? atten_mask=None, SymInt[]? actual_seq_lengths=None, SymInt[]? actual_seq_lengths_kv=None, Tensor? dequant_scale1=None, Tensor? quant_scale1=None, Tensor? dequant_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? block_table=None, Tensor? query_padding_size=None, Tensor? kv_padding_size=None, Tensor? key_antiquant_scale=None, Tensor? key_antiquant_offset=None, Tensor? value_antiquant_scale=None, Tensor? value_antiquant_offset=None, Tensor? key_shared_prefix=None, Tensor? value_shared_prefix=None, Tensor? actual_shared_prefix_len=None, int num_heads=1, float scale_value=1.0, int pre_tokens=2147483647, int next_tokens=2147483647, str input_layout="BSH", int num_key_value_heads=0, int sparse_mode=0, int inner_precise=0, int block_size=0, int antiquant_mode=0, bool softmax_lse_flag=False, int key_antiquant_mode=0, int value_antiquant_mode=0) -> (Tensor, Tensor)
参数说明
- query(Tensor*,计算输入):Device侧的Tensor,attention结构的Query输入,数据类型支持FLOAT16、BFLOAT16、INT8,不支持非连续的Tensor,数据格式支持ND。
- key(Tensor*,计算输入):Device侧的Tensor,attention结构的Key输入,数据类型支持FLOAT16、BFLOAT16、INT8,不支持非连续的Tensor,数据格式支持ND。
- value(Tensor*,计算输入):Device侧的Tensor,attention结构的Value输入,数据类型支持FLOAT16、BFLOAT16、INT8,不支持非连续的Tensor,数据格式支持ND。
- *:代表其之前的变量是位置相关,需要按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
- pse_shift(Tensor*,计算输入):Device侧的Tensor,在attention结构内部的位置编码参数,数据类型支持FLOAT16、BFLOAT16,数据类型与query的数据类型需满足数据类型推导规则。不支持非连续的Tensor,数据格式支持ND。如不使用该功能时可传入None。
- Q_S不为1,要求在pse_shift为FLOAT16类型时,此时的query为FLOAT16或INT8类型,而在pse_shift为BFLOAT16类型时,要求此时的query为BFLOAT16类型。输入shape类型需为(B,N,Q_S,KV_S)或(1,N,Q_S,KV_S),其中Q_S为query的shape中的S,KV_S为key和value的shape中的S。对于pse_shift的KV_S为非32对齐的场景,建议padding到32字节来提高性能,多余部分的填充值不做要求。
- Q_S为1,要求在pse_shift为FLOAT16类型时,此时的query为FLOAT16类型,而在pse_shift为BFLOAT16类型时,要求此时的query为BFLOAT16类型。输入shape类型需为(B,N,1,KV_S)或(1,N,1,KV_S),其中KV_S为key和value的shape中的S。对于pse_shift的KV_S为非32对齐的场景,建议padding到32字节来提高性能,多余部分的填充值不做要求。
- atten_mask(Tensor*,计算输入):Device侧的Tensor,对QK的结果进行mask,用于指示是否计算Token间的相关性,不支持非连续的Tensor,数据类型支持BOOL、INT8和UINT8。数据格式支持ND。如果不使用该功能可传入None。
- Q_S不为1时建议shape输入Q_S,KV_S;B,Q_S,KV_S;1,Q_S,KV_S;B,1,Q_S,KV_S;1,1,Q_S,KV_S。
- Q_S为1时建议shape输入B,KV_S;B,1,KV_S;B,1,1,KV_S。
其中Q_S为query的shape中的S,KV_S为key和value的shape中的S,但如果Q_S、KV_S非16或32对齐,可以向上取到对齐的S。综合约束请见约束说明。
- actual_seq_lengths(IntArray*,计算输入):Host侧的IntArray,代表不同Batch中query的有效Sequence Length,数据类型支持INT64。如果不指定seqlen可以转入None,表示和query的shape的s长度相同。限制:该入参中每个batch的有效Sequence Length应该不大于query中对应batch的Sequence Length,Q_S为1时该参数无效。seqlen的传入长度为1时,每个Batch使用相同seqlen;传入长度大于等于Batch数时取seqlen的前Batch个数。其它长度不支持。
- actual_seq_lengths_kv(IntArray*,计算输入):Host侧的IntArray,可传入nullptr,代表不同Batch中key/value的有效Sequence Length。数据类型支持INT64。如果不指定seqlen可以转入None,表示和key/value的shape的s长度相同。限制:该入参中每个batch的有效Sequence Length应该不大于key/value中对应batch的Sequence Length。seqlenKV的传入长度为1时,每个Batch使用相同seqlenKV;传入长度大于等于Batch数时取seqlenKV的前Batch个数。其它长度不支持。
- dequant_scale1(Tensor*,计算输入):Device侧的Tensor,数据类型支持UINT64、FLOAT32。数据格式支持ND(参考),表示BMM1后面的反量化因子,支持per-tensor。 如不使用该功能时可传入nullptr。
- quant_scale1(Tensor*,计算输入):Device侧的Tensor,数据类型支持FLOAT32。数据格式支持ND(参考),表示BMM2前面的量化因子,支持per-tensor。 如不使用该功能时可传入None,综合约束请见约束说明。
- dequant_scale2(Tensor*,计算输入):Device侧的Tensor,数据类型支持UINT64、FLOAT32。数据格式支持ND(参考),表示BMM2后面的反量化因子,支持per-tensor。 如不使用该功能时可传入nullptr。
- quant_scale2(Tensor*,计算输入):Device侧的Tensor,数据类型支持FLOAT32、BFLOAT16。数据格式支持ND(参考),表示输出的量化因子,支持per-tensor,per-channel。 当输入为BFLOAT16时,同时支持 FLOAT32和BFLOAT16 ,否则仅支持FLOAT32 。per-channel格式,当输出layout为BSH时,要求quant_scale2所有维度的乘积等于H;其他layout要求乘积等于N*D(建议输出layout为BSH时,quant_scale2shape传入[1,1,H]或[H];输出为BNSD时,建议传入[1,N,1,D]或[N,D];输出为BSND时,建议传入[1,1,N,D]或[N,D])。如不使用该功能时可传入None,综合约束请见约束说明。
- quant_offset2(Tensor*,计算输入):Device侧的Tensor,数据类型支持FLOAT32、BFLOAT16。数据格式支持ND(参考),表示输出的量化偏移,支持per-tensor,per-channel。 若传入 quant_offset2,需保证其类型和shape信息与quantScale2 一致。如不使用该功能时可传入None,综合约束请见约束说明。
- antiquant_scale(Tensor*,计算输入):Device侧的Tensor,数据类型支持FLOAT16、BFLOAT16。数据格式支持ND(参考),表示伪量化因子,支持per-tensor,per-channel,Q_S为1时只支持per-channel,Q_S大于等于2时只支持FLOAT16,如不使用该功能时可传入None,综合约束请见约束说明。
- antiquant_offset(Tensor*,计算输入):Device侧的Tensor,数据类型支持FLOAT16、BFLOAT16。数据格式支持ND(参考),表示伪量化偏移,支持per-tensor,per-channel,Q_S为1时只支持per-channel,Q_S大于等于2时只支持FLOAT16,如不使用该功能时可传入None,综合约束请见约束说明。
- block_table(Tensor*,计算输入):Device侧的Tensor,数据类型支持INT32。数据格式支持ND(参考)。表示PageAttention中KV存储使用的block映射表,如不使用该功能可传入None。
- query_padding_size(Tensor*,计算输入):Device侧的Tensor,数据类型支持INT64。数据格式支持ND(参考)。表示Query中每个batch的数据是否右对齐,且右对齐的个数是多少。仅支持Q_S大于1,其余场景该参数无效。用户不特意指定时可传入默认值None。
- kv_padding_size(Tensor*,计算输入):Device侧的Tensor,数据类型支持INT64。数据格式支持ND(参考)。表示key、value中每个batch的数据是否右对齐,且右对齐的个数是多少。表示key、value中每个batch的数据是否右对齐,且右对齐的个数是多少。用户不特意指定时可传入默认值None。
- key_antiquant_scale(aclTensor*,计算输入):Device侧的aclTensor,数据类型支持:FLOAT16、BFLOAT16、FLOAT32。数据格式支持ND(参考),kv伪量化参数分离时表示key的反量化因子,支持per-tensor,per-channel,per-token。Q_S大于等于2时仅支持per-token模式,如不使用该功能时可传入None,综合约束请见约束说明。
- key_antiquant_offset(aclTensor*,计算输入):Device侧的aclTensor,数据类型支持:FLOAT16、BFLOAT16、FLOAT32。数据格式支持ND(参考),kv伪量化参数分离时表示key的反量化偏移,支持per-tensor,per-channel,per-token。Q_S大于等于2时仅支持per-token模式,如不使用该功能时可传入None,综合约束请见约束说明。
- value_antiquant_scale(aclTensor*,计算输入):Device侧的aclTensor,数据类型支持:FLOAT16、BFLOAT16、FLOAT32。数据格式支持ND(参考),kv伪量化参数分离时表示value的反量化因子,支持per-tensor,per-channel,per-token。Q_S大于等于2时仅支持per-token模式,如不使用该功能时可传入None,综合约束请见约束说明。
- value_antiquant_offset(aclTensor*,计算输入):Device侧的aclTensor,数据类型支持:FLOAT16、BFLOAT16、FLOAT32。数据格式支持ND(参考),kv伪量化参数分离时表示value的反量化偏移,支持per-tensor,per-channel,per-token。Q_S大于等于2时仅支持per-token模式,如不使用该功能时可传入None,综合约束请见约束说明。
- keySharedPrefixOptional(aclTensor*,计算输入):Device侧的aclTensor,attention结构中Key的系统前缀部分的参数,数据类型支持FLOAT16、BFLOAT16、INT8,不支持非连续的Tensor,数据格式支持ND。如不使用该功能时可传入nullptr。综合约束请见约束说明。
- valueSharedPrefixOptional(aclTensor*,计算输入):Device侧的aclTensor,attention结构中Value的系统前缀部分的输入,数据类型支持FLOAT16、BFLOAT16、INT8,不支持非连续的Tensor,数据格式支持ND。如不使用该功能时可传入nullptr。综合约束请见约束说明。
- actualSharedPrefixLenOptional(aclIntArray*,计算输入):Host侧的aclIntArray,可传入nullptr,代表keySharedPrefix/valueSharedPrefix的有效Sequence Length。数据类型支持:INT64。如果不指定seqlen可以传入nullptr,表示和keySharedPrefix/valueSharedPrefix的s长度相同。限制:该入参中的有效Sequence Length应该不大于keySharedPrefix/valueSharedPrefix中的Sequence Length。
- num_heads(int64_t,计算输入):Host侧的int,代表query的head个数,数据类型支持INT64,在BNSD场景下,需要与shape中的query的N轴shape值相同,否则执行异常。
- scale(double,计算输入):Host侧的double,公式中d开根号的倒数,代表缩放系数,作为计算流中Muls的scalar值,数据类型支持DOUBLE。数据类型与query的数据类型需满足数据类型推导规则。用户不特意指定时可传入默认值1.0。
- pre_tokens(int64_t,计算输入):Host侧的int,用于稀疏计算,表示attention需要和前几个Token计算关联,数据类型支持INT64。用户不特意指定时可传入默认值2147483647,Q_S为1时该参数无效。
- next_tokens(int64_t,计算输入):Host侧的int,用于稀疏计算,表示attention需要和后几个Token计算关联。数据类型支持INT64。用户不特意指定时可传入默认值2147483647,Q_S为1时该参数无效。
- input_layout(char*,计算输入):Host侧的字符指针CHAR*,用于标识输入query、key、value的数据排布格式,当前支持BSH、BSND、BNSD、BNSD_BSND(输入为BNSD时,输出格式为BSND,仅支持Q_S大于1)。用户不特意指定时可传入默认值"BSH"。
query、key、value数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。
- num_key_value_heads(int64_t,计算输入):Host侧的int,代表key、value中head个数,用于支持GQA(Grouped-Query Attention,分组查询注意力)场景,数据类型支持INT64。用户不特意指定时可传入默认值0,表示key/value和query的head个数相等,需要满足num_heads整除num_key_value_heads,num_heads与num_key_value_heads的比值不能大于64。在BSND、BNSD、BNSD_BSND(仅支持Q_S大于1)场景下,还需要与shape中的key/value的N轴shape值相同,否则执行异常。
- sparse_mode(int64_t,计算输入):Host侧的int,表示sparse的模式。数据类型支持INT64。Q_S为1时该参数无效。
- sparse_mode为0时,代表defaultMask模式,如果atten_mask未传入则不做mask操作,忽略pre_tokens和next_tokens(内部赋值为INT_MAX);如果传入,则需要传入完整的atten_mask矩阵(S1 * S2),表示pre_tokens和next_tokens之间的部分需要计算。
- sparse_mode为1时,代表allMask,必须传入完整的attenmask矩阵(S1 * S2)。
- sparse_mode为2时,代表leftUpCausal模式的mask,需要传入优化后的atten_mask矩阵(2048*2048)。
- sparse_mode为3时,代表rightDownCausal模式的mask,均对应以左顶点为划分的下三角场景,需要传入优化后的atten_mask矩阵(2048*2048)。
- sparse_mode为4时,代表band模式的mask,需要传入优化后的atten_mask矩阵(2048*2048)。
- sparse_mode为5、6、7、8时,分别代表prefix、global、dilated、block_local,均暂不支持。用户不特意指定时可传入默认值0。综合约束请见约束说明。
- inner_precise(int64_t,计算输入):Host侧的int,一共4种模式:0、1、2、3。一共两位bit位,第0位(bit0)表示高精度或者高性能选择,第1位(bit1)表示是否做行无效修正。数据类型支持INT64。Q_S>1时,sparse_mode为0或1,并传入用户自定义mask的情况下,建议开启行无效;Q_S为1时该参数仅支持innerPrecise为0和1。综合约束请见约束说明。
- inner_precise为0时,代表开启高精度模式,且不做行无效修正。
- inner_precise为1时,代表高性能模式,且不做行无效修正。
- inner_precise为2时,代表开启高精度模式,且做行无效修正。
- inner_precise为3时,代表高性能模式,且做行无效修正。
BFLOAT16和INT8不区分高精度和高性能,行无效修正对FLOAT16、BFLOAT16和INT8均生效。当前0、1为保留配置值,当计算过程中“参与计算的mask部分”存在某整行全为1的情况时,精度可能会有损失。此时可以尝试将该参数配置为2或3来使能行无效功能以提升精度,但是该配置会导致性能下降。
- block_size(int64_t,计算输入):Host侧的int64_t,PageAttention中KV存储每个block中最大的token个数,默认为0,数据类型支持INT64。
- antiquant_mode(int64,计算输入):伪量化的方式,传入0时表示为per-channel(per-channel包含per-tensor),传入1时表示per-token。Q_S大于等于2时该参数无效,用户不特意指定时可传入默认值0,传入0和1之外的其他值会执行异常。
- softmax_lse_flag(bool,计算输入):是否输出softmax_lse,支持S轴外切(增加输出)。true表示输出softmax_lse,false表示不输出;用户不特意指定时可传入默认值false。
- key_antiquant_mode(int64,计算输入):key的伪量化的方式。Q_S大于等于2时仅支持传入值为1,用户不特意指定时可传入默认值0,传入0、1、2、3和4之外的其他值会执行异常,取值需要与value_antiquant_mode一致。综合约束请见约束说明。
- key_antiquant_mode为0时,代表per-channel模式(per-channel包含per-tensor)。
- key_antiquant_mode为1时,代表per-token模式。
- key_antiquant_mode为2时,代表per-tensor叠加per-head模式。
- key_antiquant_mode为3时,代表per-token叠加per-head模式。
- key_antiquant_mode为4时,代表per-token叠加使用page attention模式管理scale、offset模式。
- value_antiquant_mode(int64,计算输入):value的伪量化的方式,模式编号与key_antiquant_mode一致。Q_S大于等于2时仅支持传入值为1,用户不特意指定时可传入默认值0,传入0、1、2、3和4之外的其他值会执行异常,取值需要与key_antiquant_mode一致。综合约束请见约束说明。
输出说明
- attentionOut(Tensor*,计算输出):Device侧的Tensor,公式中的输出,数据类型支持FLOAT16、BFLOAT16、INT8。数据格式支持ND。限制:当input_layout为BNSD_BSND时,输入query的shape是BNSD,输出shape为BSND;其余情况该参数的shape需要与入参query的shape保持一致。
- softmaxLse(Tensor*,计算输出):ring attention算法对query乘key的结果,先取max得到softmax_max。query乘key的结果减去softmax_max,再取exp,最后取sum,得到softmax_sum,最后对softmax_sum取log,再加上softmax_max得到的结果。数据类型支持FLOAT32,softmax_lse_flag为True时,输出shape为[B,N,Q_S,1]的Tensor;softmax_lse_flag为False时,则输出shape为[1]的值为0的Tensor。
约束说明
- 该融合算子仅在推理场景使用。
- 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
- 入参为空的处理:算子内部需要判断参数query是否为空,如果是空则直接返回。参数query不为空Tensor,参数key、value为空tensor(即S2为0),则填充全零的对应shape的输出(填充attention_out)。attention_out为空Tensor时,AscendCLNN框架会处理。其余在上述参数说明中标注了"可传入nullptr"的入参为空指针时,不进行处理。
- 参数key、value中对应tensor的shape需要完全一致;非连续场景下 key、value的tensorlist中的batch只能为1,个数等于query的B,N和D需要相等。
- int8量化相关入参数量与输入、输出数据格式的综合限制:
- 输入为INT8,输出为INT8的场景:入参dequant_scale1、quant_scale1、dequant_scale2、quant_scale2需要同时存在,quant_offset2可选,不传时默认为0。
- 输入为INT8,输出为FLOAT16的场景:入参dequant_scale1、quant_scale1、dequant_scale2需要同时存在,若存在入参quant_offset2或 quantScale2(即不为nullptr),则报错并返回。
- 输入全为FLOAT16或BFLOAT16,输出为INT8的场景:入参quant_scale2需存在,quant_offset2可选,不传时默认为0,若存在入参dequant_scale1或 quant_scale1或 dequant_scale2(即不为nullptr),则报错并返回。
- 入参 quant_offset2和 quant_scale2支持 per-tensor 或 per-channel 格式和 FLOAT32、BFLOAT16 两种数据类型。
- antiquant_scale和antiquant_offset参数约束:
- 支持per-channel、per-tensor和per-token三种模式:
- per-channel模式:两个参数BNSD场景下shape为(2, N, 1, D),BSND场景下shape为(2, N, D),BSH场景下shape为(2, H),N为num_key_value_heads。参数数据类型和query数据类型相同,antiquant_mode置0。
- per-tensor模式:两个参数的shape均为(2),数据类型和query数据类型相同, antiquant_mode置0。
- per-token模式:两个参数的shape均为(2, B, S), 数据类型固定为FLOAT32, antiquant_mode置1。
算子运行在何种模式根据参数的shape进行判断,dim为1时运行per-tensor模式,否则运行per-channel模式。
- 支持对称量化和非对称量化:
- 非对称量化模式下, antiquant_scale和antiquant_offset参数需同时存在。
- 对称量化模式下,antiquant_offset可以为空(即nullptr);当antiquant_offset参数为空时,执行对称量化,否则执行非对称量化。
- 支持per-channel、per-tensor和per-token三种模式:
- 当Q_S大于1时:
- query,key,value输入,功能使用限制如下:
- 支持B轴小于等于65536,D轴32byte不对齐时仅支持到128。
- 支持N轴小于等于256,支持D轴小于等于512;inputLayout为BSH或者BSND时,要求N*D小于65535。
- S支持小于等于20971520(20M)。部分长序列场景下,如果计算量过大可能会导致pfa算子执行超时(aicore error类型报错,errorStr为:timeout or trap error),此场景下建议做S切分处理(注:这里计算量会受B、S、N、D等的影响,值越大计算量越大),典型的会超时的长序列(即B、S、N、D的乘积较大)场景包括但不限于:
- B=1,Q_N=20,Q_S=2097152,D = 256,KV_N=1,KV_S=2097152。
- B=1,Q_N=2,Q_S=20971520,D = 256,KV_N=2,KV_S=20971520。
- B=20,Q_N=1,Q_S=2097152,D = 256,KV_N=1,KV_S=2097152。
- B=1,Q_N=10,Q_S=2097152,D = 512,KV_N=1,KV_S=2097152。
- query、key、value输入类型包含INT8时,D轴需要32对齐;输入类型全为FLOAT16、BFLOAT16时,D轴需16对齐。
- 参数sparse_mode当前仅支持值为0、1、2、3、4的场景,取其它值时会报错。
- sparse_mode=0时,atten_mask如果为空指针,或者在左padding场景传入atten_mask,则忽略入参pre_tokens、next_tokens(内部赋值为INT_MAX)。
- sparse_mode=2、3、4时,atten_mask的shape需要为S,S或1,S,S或1,1,S,S,其中S的值需要固定为2048,且需要用户保证传入的atten_mask为下三角,不传入atten_mask或者传入的shape不正确报错。
- sparse_mode=1、2、3的场景忽略入参pre_tokens、next_tokens并按照相关规则赋值。
- kvCache反量化的合成参数场景仅支持INT8反量化到FLOAT16。入参key、value的datarange与入参antiquant_scale的datarange乘积范围在(-1,1)内,高性能模式可以保证精度,否则需要开启高精度模式来保证精度。
- page attention场景:
- page attention的使能必要条件是blockTable存在且有效,同时key、value是按照blockTable中的索引在一片连续内存中排布,支持key、value dtype为FLOAT16/BFLOAT16/INT8,在该场景下key、value的inputLayout参数无效。blockTable中填充的是blockid,当前不会对blockid的合法性进行校验,需用户自行保证。
blockSize是用户自定义的参数,该参数的取值会影响page attention的性能,在使能page attention场景下,blockSize最小为128,最大为512,且要求是128的倍数。通常情况下,page attention可以提高吞吐量,但会带来性能上的下降。
- page attention场景下,当输入kv cache排布格式为(blocknum, blocksize, H),且 KV_N * D 超过65535时,受硬件指令约束,会被拦截报错。可通过使能GQA(减小 KV_N)或调整kv cache排布格式为(blocknum, KV_N, blocksize, D)解决。当query的inputLayout为BNSD时,kv cache排布支持(blocknum, blocksize, H)和(blocknum, KV_N, blocksize, D)两种格式,当query的inputLayout为BSH、BSND时,kv cache排布只支持(blocknum, blocksize, H)一种格式。blocknum不能小于根据actualSeqLengthsKv和blockSize计算的每个batch的block数量之和。且key和value的shape需保证一致。
- page attention不支持伪量化场景,不支持tensorlist场景,不支持左padding场景。
- page attention场景下,必须传入actualSeqLengthsKv。
- page attention场景下,blockTable必须为二维,第一维长度需等于B,第二维长度不能小于maxBlockNumPerSeq(maxBlockNumPerSeq为不同batch中最大actualSeqLengthsKv对应的block数量)。
- page attention场景下,不支持输入query为int8的场景。
- page attention使能场景下,以下场景输入KV_S需要大于等于maxBlockNumPerSeq * blockSize:
- 传入attenMask时,如 mask shape为 (B, 1, Q_S, KV_S)。
- 传入pseShift时,如 pseShift shape为(B, N, Q_S, KV_S)。
- query左padding场景:
- query左padding场景query的搬运起点计算公式为:Q_S - query_padding_size- actual_seq_lengths。query的搬运终点计算公式为:Q_S - query_padding_size。其中query的搬运起点不能小于0,终点不能大于Q_S,否则结果将不符合预期。
- query左padding场景kv_padding_size小于0时将被置为0。
- query左padding场景需要与actual_seq_lengths参数一起使能,否则默认为query右padding场景。
- query左padding场景不支持PageAttention,不能与blocktable参数一起使能。
- kv左padding场景:
- kv左padding场景key和value的搬运起点计算公式为:KV_S - kv_padding_size- actual_seq_lengths_kv。key和value的搬运终点计算公式为:KV_S - kv_padding_size。其中key和value的搬运起点不能小于0,终点不能大于KV_S,否则结果将不符合预期。
- kv左padding场景kv_padding_size小于0时将被置为0。
- kv左padding场景需要与actual_seq_lengths_kv参数一起使能,否则默认为kv右padding场景。
- kv左padding场景不支持PageAttention,不能与blocktable参数一起使能。
- 入参quant_scale2和quant_offset2支持per-tensor/per-channel 两种格式和FLOAT32/BFLOAT16 两种数据类型。若传入quant_offset2,需保证其类型和shape信息与quant_scale2一致。当输入为BFLOAT16时,同时支持FLOAT32和BFLOAT16 ,否则仅支持FLOAT32。per-channel 格式,当输出layout为BSH时,要求quant_scale2所有维度的乘积等于H;其他layout要求乘积等于N*D。(建议输出layout为BSH时,quant_scale2shape传入[1,1,H]或[H];输出为BNSD时,建议传入[1,N,1,D]或[N,D];输出为BSND时,建议传入[1,1,N,D]或[N,D])。
- 输出为int8,quant_scale2和quant_offset2为per-channel 时,暂不支持左padding、Ring Attention或者D非32Byte对齐的场景。
- 输出为int8时,暂不支持sparse为band且preTokens/nextTokens为负数。
- pse_shift功能使用限制如下:
- 支持query数据类型为FLOAT16或BFLOAT16或INT8场景下使用该功能。
- query,key,value数据类型为FLOAT16且pse_shift存在时,强制走高精度模式,对应的限制继承自高精度模式的限制。
- Q_S需大于等于query的S长度,KV_S需大于等于key的S长度。prefix场景KV_S需大于等于actualSharedPrefixLen与key的S长度之和。
- 输出为INT8,入参quant_offset2传入非空指针和非空tensor值,并且sparse_mode、pre_tokens和next_tokens满足以下条件,矩阵会存在某几行不参与计算的情况,导致计算结果误差,该场景会拦截:
- sparse_mode= 0,atten_mask如果非空指针,每个batch actual_seq_lengths- actual_seq_lengths_kv-pre_tokens> 0 或 next_tokens< 0 时,满足拦截条件。
- sparse_mode= 1 或 2,不会出现满足拦截条件的情况。
- sparse_mode= 3,每个batch actual_seq_lengths_kv- actual_seq_lengths< 0,满足拦截条件。
- sparse_mode= 4,preTokens < 0 或 每个batch next_tokens+ actual_seq_lengths_kv- actual_seq_lengths< 0 时,满足拦截条件。
- prefix相关参数约束:
- keySharedPrefix和valueSharedPrefix要么都为空,要么都不为空。
- keySharedPrefix和valueSharedPrefix都不为空时,keySharedPrefix、valueSharedPrefix、key、value的维度相同、dtype保持一致。
- keySharedPrefix和valueSharedPrefix都不为空时,keySharedPrefix的shape第一维batch必须为1,layout为BNSD和BSND情况下N、D轴要与key一致、BSH情况下H要与key一致,valueSharedPrefix同理。keySharedPrefix和valueSharedPrefix的S应相等。
- 当actualSharedPrefixLen存在时,actualSharedPrefixLen的shape需要为[1],值不能大于keySharedPrefix和valueSharedPrefix的S。
- 公共前缀的S加上key或value的S的结果,要满足原先key或value的S的限制。
- prefix不支持PageAttention场景、不支持左padding场景、不支持tensorlist场景。
- prefix场景不支持query,key,value数据类型同时为int8。
- prefix场景,sparse为0或1时,如果传入attenmask,则S2需大于等于actualSharedPrefixLen与key的S长度之和。
- prefix场景,不支持输入qkv全部为int8的场景。
- kv伪量化参数分离:
- keyAntiquantMode和valueAntiquantMode需要保持一致。
- keyAntiquantScale和valueAntiquantScale要么都为空,要么都不为空;keyAntiquantOffset和valueAntiquantOffset要么都为空,要么都不为空。
- KeyAntiquantScale和valueAntiquantScale都不为空时,其shape需要保持一致;keyAntiquantOffset和valueAntiquantOffset都不为空时,其shape需要保持一致。
- 仅支持per-token模式,且该模式下要求两个参数的shape均为(B, S),数据类型固定为FLOAT32。
- 当伪量化参数和KV分离量化参数同时传入时,以KV分离量化参数为准。
- keyAntiquantScale与valueAntiquantScale非空场景,要求query的s小于等于16。
- keyAntiquantScale与valueAntiquantScale非空场景,要求query的dtype为BFLOAT16,key、value的dtype为INT8,输出的dtype为BFLOAT16。
- keyAntiquantScale与valueAntiquantScale非空场景,不支持tensolist、左padding、page attention、prefix特性。
- query,key,value输入,功能使用限制如下:
- 当Q_S等于1时:
- query,key,value输入,功能使用限制如下:
- 支持B轴小于等于65536,支持N轴小于等于256,支持S轴小于等于262144,支持D轴小于等于512。
- query、key、value输入类型均为INT8的场景暂不支持。
- page attention场景:
- 使能必要条件是blocktable存在且有效,同时key、value是按照blocktable中的索引在一片连续内存中排布,支持key、value dtype为FLOAT16、BFLOAT16、INT8,在该场景下key、value的inputLayout参数无效。
- 该场景下,blockSize是用户自定义的参数,该参数的取值会影响page attention的性能。key、value输入类型为FLOAT16、BFLOAT16时需要16对齐,key、value 输入类型为INT8时需要32对齐,推荐使用128。通常情况下,page attention可以提高吞吐量,但会带来性能上的下降。
- 参数key、value各自对应tensor的shape所有维度相乘不能超过int32的表示范围。
- kv左padding场景:
- kvCache的搬运起点计算公式为:Smax - kv_padding_size- actual_seq_lengths。kvCache的搬运终点计算公式为:Smax - kv_padding_size。其中kvCache的搬运起点或终点小于0时,返回数据结果为全0。
- kv_padding_size小于0时将被置为0。
- 使能需要同时存在actual_seq_lengths参数,否则默认为kv右padding场景。
- kv伪量化参数分离:
- key_antiquant_scale和value_antiquant_scale要么都为空,要么都不为空;key_antiquant_offset和value_antiquant_offset要么都为空,要么都不为空。
- key_antiquant_scale和value_antiquant_scale都不为空时,其shape需要保持一致;key_antiquant_offset和value_antiquant_offset都不为空时,其shape需要保持一致。
- 支持per-channel、per-tensor、per-token、per-tensor叠加per-head、per-token叠加per-head和per-token叠加使用page attention模式管理scale、offset六种模式,以下N均为numKeyValueHeads:
- per-channel模式:两个参数的shape可支持(1, N, 1, D),(1, N, D),(1, H)。参数数据类型和query数据类型相同。
- per-tensor模式:两个参数的shape均为(1),数据类型和query数据类型相同。
- per-token模式:两个参数的shape均为(1, B, S),数据类型固定为FLOAT32。
- per-tensor叠加per-head模式:两个参数的shape均为(N),数据类型和query数据类型相同。
- per-token叠加per-head模式:两个参数的shape均为(B, N, S),数据类型固定为FLOAT32。
- per-token叠加使用page attention模式管理scale、offset模式:两个参数的shape均为(blocknum, blocksize),数据类型固定为FLOAT32。
- pse_shift功能使用限制如下:
- pse_shift数据类型需与query数据类型保持一致。
- 仅支持D轴对齐,即D轴可以被16整除。
- query,key,value输入,功能使用限制如下:
支持的型号
Atlas A2训练系列产品/Atlas 800I A2推理产品
调用示例
- 单算子调用
# 单算子调用方式 import torch import torch_npu import math # 生成随机数据, 并发送到npu q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu() k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu() v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu() scale = 1/math.sqrt(128.0) actseqlen = [164] actseqlenkv = [1024] # 调用FIA算子 out, _ = torch_npu.npu_fused_infer_attention_score(q, k, v, actual_seq_lengths = actseqlen, actual_seq_lengths_kv = actseqlenkv, num_heads = 8, input_layout = "BNSD", scale = scale, pre_tokens=65535, next_tokens=65535) # 执行上述代码的输出out类似如下 tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140], [ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117], [ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062], ..., [ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039], [ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164], [ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]], device='npu:0', dtype=torch.float16)
- 图模式调用
# 入图方式 import torch import torch_npu import math import torchair as tng from torchair.configs.compiler_config import CompilerConfig import torch._dynamo TORCHDYNAMO_VERBOSE=1 TORCH_LOGS="+dynamo" # 支持入图的打印宏 import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) from torch.library import Library, impl # 数据生成 q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu() k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu() v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu() scale = 1/math.sqrt(128.0) class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self): return torch_npu.npu_fused_infer_attention_score(q, k, v, num_heads = 8, input_layout = "BNSD", scale=scale, pre_tokens=65535, next_tokens=65535) def MetaInfershape(): with torch.no_grad(): model = Model() model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True) graph_output = model() single_op = torch_npu.npu_fused_infer_attention_score(q, k, v, num_heads = 8, input_layout = "BNSD", scale=scale, pre_tokens=65535, next_tokens=65535) print("single op output with mask:", single_op[0], single_op[0].shape) print("graph output with mask:", graph_output[0], graph_output[0].shape) if __name__ == "__main__": MetaInfershape() # 执行上述代码的输出类似如下 single op output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140], [ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117], [ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062], ..., [ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039], [ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164], [ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]], device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128]) graph output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140], [ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117], [ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062], ..., [ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039], [ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164], [ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]], device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])
父主题: torch_npu