torch_npu.npu_incre_flash_attention
功能描述
增量FA实现,实现对应公式:
接口原型
torch_npu.npu_incre_flash_attention(Tensor query, Tensor key, Tensor value, *, Tensor? padding_mask=None, Tensor? atten_mask=None, Tensor? pse_shift=None, SymInt[]? actual_seq_lengths=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? block_table=None, Tensor? dequant_scale1=None, Tensor? quant_scale1=None, Tensor? dequant_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, Tensor? kv_padding_size=None, int num_heads=1, float scale_value=1.0, str input_layout="BSH", int num_key_value_heads=0, int block_size=0, int inner_precise=1) -> Tensor
参数说明
- query:Device侧的Tensor;数据类型支持FLOAT16、BFLOAT16,数据格式支持ND。Atlas 推理系列加速卡产品仅支持FLOAT16。
- key:Device侧的Tensor;数据类型支持FLOAT16、BFLOAT16、INT8,数据格式支持ND。Atlas 推理系列加速卡产品仅支持FLOAT16。
- value:Device侧的Tensor;数据类型支持FLOAT16、BFLOAT16、INT8,数据格式支持ND。Atlas 推理系列加速卡产品仅支持FLOAT16。
- *:代表其之前的变量是位置相关,需要按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
- padding_mask:预留参数,暂未使用,默认值为None。
- atten_mask:Device侧的Tensor,取值为1代表该位不参与计算(不生效),为0代表该位参与计算,默认值为None,即全部参与计算;数据类型支持BOOL,INT8,UINT8,数据格式支持ND。
- pse_shift: Device侧的Tensor;表示在attention结构内部的位置编码参数,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND。如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持None。
- actual_seq_lengths:Host侧的Input数组,其shape为(B)或(1),形如[1, 2, 3],代表key、value中有效的S序列长度,默认值为None,即全部有效,类型为List int;数据类型为INT64,数据格式支持ND。
- antiquant_scale:Device侧的Tensor,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND,表示量化因子,支持per-channel(list),由shape决定,BNSD场景下shape为(2, N, 1, D),BSH场景下shape为(2, H),BSND场景下shape为(2, N, D)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持FLOAT16。
- antiquant_offset:Device侧的Input Tensor,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND,表示量化偏移,支持per-channel(list),由shape决定,BNSD场景下shape为(2, N, 1, D),BSH场景下shape为(2, H),BSND场景下shape为(2, N, D)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持FLOAT16。
- block_table:Device侧的Tensor,数据类型支持INT32,数据格式支持ND。block_table为2维Tensor,表示PageAttention中KV存储使用的block映射表,具体约束和使用方法可见约束说明。 如不使用该功能时可不传或传入None。
- dequant_scale1:Device侧的Tensor,数据类型支持FLOAT32,数据格式支持ND,表示BMM1后面反量化的量化因子,支持per-tensor(scalar)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持None。
- quant_scale1:Device侧的Tensor,数据类型支持FLOAT32,数据格式支持ND,表示BMM2前面量化的量化因子,支持per-tensor(scalar)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持None。
- dequant_scale2:Device侧的Tensor,数据类型支持FLOAT32,数据格式支持ND,表示BMM2后面反量化的量化因子,支持per-tensor(scalar)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持None。
- quant_scale2:Device侧的Tensor,数据类型支持FLOAT32、BFLOAT16,数据格式支持ND,表示输出量化的量化因子,支持per-tensor(scalar)和per-channel(list)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持None。
- quant_offset2:Device侧的Tensor,数据类型支持FLOAT32、BFLOAT16,数据格式支持ND,表示输出量化的量化偏移,支持per-tensor(scalar)和per-channel(list)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持None。
- kv_padding_size:Device侧的Tensor,数据类型支持INT64,数据格式支持ND,表示kv左padding场景使能时,最后一个有效token到S的距离。 如不使用该功能时可传入None。
- num_heads:Host侧的attribute,代表query的头数,即query的N,默认值为1;数据类型为INT64。
- scale_value:Host侧的attribute,代表缩放系数,用来约束梯度,其默认值为1.0,典型值为;数据类型为FLOAT32。
- input_layout:Host侧的attribute,代表query、key、value的布局,根据输入的query、key、value的shape确定,三维Tensor是BSH,四维Tensor是BNSD或BSND,默认值为BSH,不支持其他值;数据类型为string。
query、key、value数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。
- num_key_value_heads:Host侧的attribute,代表key、value的头数,用于支持GQA(Grouped-Query Attention,分组查询注意力)场景,默认值为0,表示与query的头数相同,否则表示key、value的头数,需要能被query的头数(num_heads)整除;num_heads与num_key_value_heads的比值不能大于64。数据类型为INT64。
- block_size (int64_t,计算输入):Host侧的int64_t,PageAttention中KV存储每个block中最大的token个数,默认为0,通常为128、256等值,数据类型支持INT64。
- inner_precise (int64_t,计算输入):Host侧的int64_t,代表高精度/高性能选择,0代表高精度,1代表高性能,默认值为1(高性能), 数据类型支持INT64。
输出说明
- 共一个输出,为计算的最终结果atten_out,类型为Tensor,shape与query保持一致。
- 非量化场景下,输出数据类型与query的数据类型保持一致。
- 量化场景下,若传入quantScale2,则输出数据类型为int8。
约束说明
- query、key、value的维度必须保持一致,key、value的shape必须保持一致。
- num_heads的值要等于query的N。
- input_layout的值与query的shape相关,三维是BSH,四维是BNSD或BSND。
- num_key_value_heads的值要等于key、value的N,需要能被query的头数(num_heads)整除。
- query,key,value输入,功能使用限制如下:
- Atlas A2 训练系列产品支持B轴小于等于65535,支持N轴小于等于256,支持S轴小于等于262144,支持D轴小于等于512。
- Atlas 推理系列加速卡产品支持B轴小于等于256,支持N轴小于等于256,支持S轴小于等于65536,支持D轴小于等于512。
- query、key、value输入均为INT8的场景暂不支持。
- INT8量化相关入参数量与输入、输出数据格式的综合限制:
query、key、value输入为FLOAT16,输出为INT8的场景:入参quant_scale2必填,quant_offset2可选,不能传入dequant_scale1、quant_scale1、dequant_scale2(即为None)参数。
- pse_shift功能使用限制如下:
- pse_shift数据类型需与query数据类型保持一致。
- 仅支持D轴对齐,即D轴可以被16整除。
- page attention使用限制:
- page attention使能必要条件是blocktable存在且有效,且传入每个batch对应的actualSeqLengths。page attention使能场景下,key、value是按照blocktable中的索引在一片连续内存中排布,支持key、value数据类型为FLOAT16、BFLOAT16、INT8。
- page attention使能场景下,输入kv cache排布格式为(blocknum, numKvHeads, blocksize, headDims)或(blocknum, blocksize, H),blocknum不应小于每个batch所需block个数的总和。通常情况下,kv cache排布格式为(blocknum, numKvHeads, blocksize, headDims)时,性能比kv cache排布格式为(blocknum, blocksize, H)时更好。
- page attention使能场景下,支持kv cache排布格式为(blocknum, numKvHeads, blocksize, headDims),但此时query layout仅支持BNSD。
- page attention使能场景下,当输入kv cache排布格式为(blocknum, blocksize, H),且H(H=numKvHeads * headDims)超过64k时,受硬件指令约束,会被拦截报错。
- page attention场景下,必须传入输入actualSeqLengths,每个batch的actualSeqLength表示每个batch对sequence真实长度,该值除以属性输入blocksize即表示每个batch所需block数量。
- page attention场景下,blockTable必须为二维Tensor,第一维长度需等于batch数,第二维长度不能小于maxBlockNumPerSeq(maxBlockNumPerSeq为每个batch中最大actualSeqLengths对应的block数量)。例如,batch数为2,属性blocksize=128,当每个batch的actualSeqLength为512时,表明每个batch至少需要4个block,因此blockTable的排布可以为(2,4)
- page attention使能场景下,blockSize是用户自定义的参数,该参数的取值会影响page attention的性能,通常为128或256。key、value输入类型为FLOAT16、BFLOAT16时blockSize需要16对齐;key、value 输入类型为INT8时blockSize需要32对齐。通常情况下,page attention可以提高吞吐量,但会带来性能上的下降。
- quant_scale2、quant_offset2为一组参数,其中quant_offset2可选,传入该组参数后算子输出数据类型会推导为int8,若不期望int8输出,请勿传入该组参数。
- kv左padding场景使用限制:
- kvCache的搬运起点计算公式为:Smax - kv_padding_size- actual_seq_lengths。kvCache的搬运终点计算公式为:Smax - kv_padding_size。其中kvCache的搬运起点或终点小于0时,返回数据结果为全0。
- kv左padding场景kv_padding_size小于0时将被置为0。
- kv左padding场景使能需要同时存在kv_padding_size和actual_seq_lengths参数,否则默认为kv右padding场景。
支持的PyTorch版本
- PyTorch 2.3
- PyTorch 2.2
- PyTorch 2.1
支持的型号
- Atlas A2 训练系列产品
- Atlas 推理系列加速卡产品
调用示例
- 单算子调用
# 单算子调用方式 import torch import torch_npu import math # 生成随机数据,并发送到npu q = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu() k = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu() v = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu() scale = 1/math.sqrt(128.0) # 调用IFA算子 out = torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale) # 执行上述代码的输出类似如下 tensor([[[ 0.3149, -0.2460, 0.7939, ..., 0.5737, -0.4929, -0.1500]], [[ 0.8115, 1.3789, 0.6484, ..., -0.9092, -0.6206, -0.7412]]], device='npu:0', dtype=torch.float16)
- 图模式调用(图模式目前仅支持PyTorch 2.1版本)
# 入图方式 import torch import torch_npu import math import torchair as tng from torchair.configs.compiler_config import CompilerConfig import torch._dynamo TORCHDYNAMO_VERBOSE=1 TORCH_LOGS="+dynamo" # 支持入图的打印宏 import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) from torch.library import Library, impl # 数据生成 q = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu() k = torch.randn(2, 2048, 40 * 128, dtype=torch.float16).npu() v = torch.randn(2, 2048, 40 * 128, dtype=torch.float16).npu() atten = torch.randn(2, 1, 1, 2048).bool().npu() scale_value = 1/math.sqrt(128.0) class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self): return torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale_value, atten_mask=atten) def MetaInfershape(): with torch.no_grad(): model = Model() model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True) graph_output = model() single_op = torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale_value, atten_mask=atten) print("single op output with mask:", single_op, single_op.shape) print("graph output with mask:", graph_output, graph_output.shape) if __name__ == "__main__": MetaInfershape() # 执行上述代码的输出类似如下 single op output with mask: tensor([[[ 0.2488, -0.6572, 1.0928, ..., 0.1694, 0.1142, -2.2266]], [[-0.9595, -0.9609, -0.6602, ..., 0.7959, 1.7920, 0.0783]]], device='npu:0', dtype=torch.float16) torch.Size([2, 1, 5120]) graph output with mask: tensor([[[ 0.2488, -0.6572, 1.0928, ..., 0.1694, 0.1142, -2.2266]], [[-0.9595, -0.9609, -0.6602, ..., 0.7959, 1.7920, 0.0783]]], device='npu:0', dtype=torch.float16) torch.Size([2, 1, 5120])
父主题: torch_npu