下载
中文
注册

torch_npu.npu_incre_flash_attention

功能描述

增量FA实现,实现对应公式:

接口原型

1
torch_npu.npu_incre_flash_attention(Tensor query, Tensor key, Tensor value, *, Tensor? padding_mask=None, Tensor? pse_shift=None, Tensor? atten_mask=None,, SymInt[]? actual_seq_lengths=None, Tensor? dequant_scale1=None, Tensor? quant_scale1=None, Tensor? dequant_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? block_table=None, Tensor? kv_padding_size=None, int num_heads=1, float scale_value=1.0, str input_layout="BSH", int num_key_value_heads=0, int block_size=0, int inner_precise=1) -> Tensor

参数说明

  • query:Device侧的Tensor;数据格式支持ND。
    • Atlas 推理系列加速卡产品 :数据类型支持FLOAT16。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、BFLOAT16。
    • Atlas A3 训练系列产品 :数据类型支持FLOAT16、BFLOAT16。
  • key:Device侧的Tensor;数据格式支持ND。
    • Atlas 推理系列加速卡产品 :数据类型支持FLOAT16、INT8。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、BFLOAT16、INT8。
  • value:Device侧的Tensor;数据格式支持ND。
    • Atlas 推理系列加速卡产品 :数据类型支持FLOAT16、INT8。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、BFLOAT16、INT8。
  • *:代表其之前的变量是位置相关,需要按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
  • padding_mask:预留参数,暂未使用,默认值为None。
  • pse_shift:Device侧的Tensor;表示在attention结构内部的位置编码参数,数据格式支持ND。如不使用该功能时可不传或传入None。 Atlas 推理系列加速卡产品 仅支持None。
    • Atlas 推理系列加速卡产品 :数据类型支持FLOAT16。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、BFLOAT16。
  • atten_mask:Device侧的Tensor,取值为1代表该位不参与计算(不生效),为0代表该位参与计算,默认值为None,即全部参与计算;数据类型支持BOOL,INT8,UINT8,数据格式支持ND。
  • actual_seq_lengths:Host侧的Input数组,其shape为(B)或(1),形如[1, 2, 3],代表key、value中有效的S序列长度,默认值为None,即全部有效,类型为List int;数据类型为INT64,数据格式支持ND。
  • dequant_scale1:Device侧的Tensor,数据类型支持FLOAT32,数据格式支持ND,表示BMM1后面反量化的量化因子,支持per-tensor(scalar)。 如不使用该功能时可不传或传入None。 Atlas 推理系列加速卡产品 仅支持None。
  • quant_scale1:Device侧的Tensor,数据类型支持FLOAT32,数据格式支持ND,表示BMM2前面量化的量化因子,支持per-tensor(scalar)。 如不使用该功能时可不传或传入None。 Atlas 推理系列加速卡产品 仅支持None。
  • dequant_scale2:Device侧的Tensor,数据类型支持FLOAT32,数据格式支持ND,表示BMM2后面反量化的量化因子,支持per-tensor(scalar)。 如不使用该功能时可不传或传入None。 Atlas 推理系列加速卡产品 仅支持None。
  • quant_scale2:Device侧的Tensor,数据格式支持ND,表示输出量化的量化因子,支持per-tensor(scalar)和per-channel(list)。 如不使用该功能时可不传或传入None。 Atlas 推理系列加速卡产品 仅支持None。
    • Atlas 推理系列加速卡产品 :数据类型支持FLOAT32。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT32、BFLOAT16。
  • quant_offset2:Device侧的Tensor,数据格式支持ND,表示输出量化的量化偏移,支持per-tensor(scalar)和per-channel(list)。 如不使用该功能时可不传或传入None。 Atlas 推理系列加速卡产品 仅支持None。
    • Atlas 推理系列加速卡产品 :数据类型支持FLOAT32。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT32、BFLOAT16。
  • antiquant_scale:Device侧的Tensor,数据格式支持ND,表示量化因子,支持per-channel(list),由shape决定,BNSD场景下shape为(2, N, 1, D),BSH场景下shape为(2, H),BSND场景下shape为(2, N, D)。 如不使用该功能时可不传或传入None。
    • Atlas 推理系列加速卡产品 :数据类型支持FLOAT16。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、BFLOAT16。
  • antiquant_offset:Device侧的Input Tensor,数据格式支持ND,表示量化偏移,支持per-channel(list),由shape决定,BNSD场景下shape为(2, N, 1, D),BSH场景下shape为(2, H),BSND场景下shape为(2, N, D)。 如不使用该功能时可不传或传入None。
    • Atlas 推理系列加速卡产品 :数据类型支持FLOAT16。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、BFLOAT16。
  • block_table:Device侧的Tensor,数据类型支持INT32,数据格式支持ND。block_table为2维Tensor,表示PageAttention中KV存储使用的block映射表,具体约束和使用方法可见约束说明。 如不使用该功能时可不传或传入None。
  • kv_padding_size:Device侧的Tensor,数据类型支持INT64,数据格式支持ND,表示kv左padding场景使能时,最后一个有效token到S的距离。 如不使用该功能时可传入None。
  • num_heads:Host侧的attribute,代表query的头数,即query的N,默认值为1;数据类型为INT64。
  • scale_value:Host侧的attribute,代表缩放系数,用来约束梯度,其默认值为1.0,典型值为;数据类型为FLOAT32。
  • input_layout:Host侧的attribute,代表query、key、value的布局,根据输入的query、key、value的shape确定,三维Tensor是BSH,四维Tensor是BNSD或BSND,默认值为BSH,不支持其他值;数据类型为string。

    query、key、value数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。

  • num_key_value_heads:Host侧的attribute,代表key、value的头数,用于支持GQA(Grouped-Query Attention,分组查询注意力)场景,默认值为0,表示与query的头数相同,否则表示key、value的头数,需要能被query的头数(num_heads)整除;num_heads与num_key_value_heads的比值不能大于64。数据类型为INT64。
  • block_size (int64_t,计算输入):Host侧的int64_t,PageAttention中KV存储每个block中最大的token个数,默认为0,通常为128、256等值,数据类型支持INT64。
  • inner_precise (int64_t,计算输入):Host侧的int64_t,代表高精度/高性能选择,0代表高精度,1代表高性能,默认值为1(高性能), 数据类型支持INT64。

输出说明

  • 共一个输出,为计算的最终结果atten_out,类型为Tensor,shape与query保持一致。
  • 非量化场景下,输出数据类型与query的数据类型保持一致。
  • 量化场景下,若传入quantScale2,则输出数据类型为int8

约束说明

  • 该接口仅在推理场景下使用。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
  • query、key、value的维度必须保持一致,key、value的shape必须保持一致。
  • num_heads的值要等于query的N。
  • input_layout的值与query的shape相关,三维是BSH,四维是BNSD或BSND。
  • num_key_value_heads的值要等于key、value的N,需要能被query的头数(num_heads)整除。
  • query,key,value输入,功能使用限制如下:
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品 支持B轴小于等于65535,支持N轴小于等于256,支持S轴小于等于262144,支持D轴小于等于512。
    • Atlas 推理系列加速卡产品 支持B轴小于等于256,支持N轴小于等于256,支持S轴小于等于65536,支持D轴小于等于512。
    • query、key、value输入均为INT8的场景暂不支持。
  • INT8量化相关入参数量与输入、输出数据格式的综合限制:

    query、key、value输入为FLOAT16,输出为INT8的场景:入参quant_scale2必填,quant_offset2可选,不能传入dequant_scale1、quant_scale1、dequant_scale2(即为None)参数。

  • pse_shift功能使用限制如下:
    • pse_shift数据类型需与query数据类型保持一致。
    • 仅支持D轴对齐,即D轴可以被16整除。
  • page attention使用限制:
    • page attention使能必要条件是blocktable存在且有效,且传入每个batch对应的actualSeqLengths。page attention使能场景下,key、value是按照blocktable中的索引在一片连续内存中排布,支持key、value数据类型为FLOAT16、BFLOAT16、INT8。
    • page attention使能场景下,输入kv cache排布格式为(blocknum, numKvHeads, blocksize, headDims)或(blocknum, blocksize, H),blocknum不应小于每个batch所需block个数的总和。通常情况下,kv cache排布格式为(blocknum, numKvHeads, blocksize, headDims)时,性能比kv cache排布格式为(blocknum, blocksize, H)时更好。
    • page attention使能场景下,支持kv cache排布格式为(blocknum, numKvHeads, blocksize, headDims),但此时query layout仅支持BNSD。
    • page attention使能场景下,当输入kv cache排布格式为(blocknum, blocksize, H),且H(H=numKvHeads * headDims)超过64k时,受硬件指令约束,会被拦截报错。
    • page attention场景下,必须传入输入actualSeqLengths,每个batch的actualSeqLength表示每个batch对sequence真实长度,该值除以属性输入blocksize即表示每个batch所需block数量。
    • page attention场景下,blockTable必须为二维Tensor,第一维长度需等于batch数,第二维长度不能小于maxBlockNumPerSeq(maxBlockNumPerSeq为每个batch中最大actualSeqLengths对应的block数量)。例如,batch数为2,属性blocksize=128,当每个batch的actualSeqLength为512时,表明每个batch至少需要4个block,因此blockTable的排布可以为(2,4)
    • page attention使能场景下,blockSize是用户自定义的参数,该参数的取值会影响page attention的性能,通常为128或256。key、value输入类型为FLOAT16、BFLOAT16时blockSize需要16对齐;key、value 输入类型为INT8时blockSize需要32对齐。通常情况下,page attention可以提高吞吐量,但会带来性能上的下降。
  • quant_scale2、quant_offset2为一组参数,其中quant_offset2可选,传入该组参数后算子输出数据类型会推导为int8,若不期望int8输出,请勿传入该组参数。
  • kv左padding场景使用限制:
    • kvCache的搬运起点计算公式为:Smax - kv_padding_size- actual_seq_lengths。kvCache的搬运终点计算公式为:Smax - kv_padding_size。其中kvCache的搬运起点或终点小于0时,返回数据结果为全0。
    • kv左padding场景kv_padding_size小于0时将被置为0。
    • kv左padding场景使能需要同时存在kv_padding_size和actual_seq_lengths参数,否则默认为kv右padding场景。

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品
  • Atlas 推理系列加速卡产品

调用示例

  • 单算子调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    # 单算子调用方式
    import torch
    import torch_npu
    import math
    
    # 生成随机数据,并发送到npu
    q = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
    k = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
    v = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
    scale = 1/math.sqrt(128.0)
    
    # 调用IFA算子
    out = torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale)
    
    # 执行上述代码的输出类似如下
    tensor([[[ 0.3149, -0.2460,  0.7939,  ...,  0.5737, -0.4929, -0.1500]],
            [[ 0.8115,  1.3789,  0.6484,  ..., -0.9092, -0.6206, -0.7412]]],
           device='npu:0', dtype=torch.float16)
    
  • 图模式调用(图模式目前仅支持PyTorch 2.1版本)
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    # 入图方式
    
    import torch
    import torch_npu
    import math
    
    import torchair as tng
    from torchair.configs.compiler_config import CompilerConfig
    import torch._dynamo
    TORCHDYNAMO_VERBOSE=1
    TORCH_LOGS="+dynamo"
    
    # 支持入图的打印宏
    import logging
    from torchair.core.utils import logger
    logger.setLevel(logging.DEBUG)
    config = CompilerConfig()
    config.debug.graph_dump.type = "pbtxt"
    npu_backend = tng.get_npu_backend(compiler_config=config)
    from torch.library import Library, impl
    
    # 数据生成
    q = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
    k = torch.randn(2, 2048, 40 * 128, dtype=torch.float16).npu()
    v = torch.randn(2, 2048, 40 * 128, dtype=torch.float16).npu()
    atten = torch.randn(2, 1, 1, 2048).bool().npu()
    scale_value = 1/math.sqrt(128.0)
    
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self):
            return torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale_value, atten_mask=atten)
    def MetaInfershape():
        with torch.no_grad():
            model = Model()
            model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
            graph_output = model()
        single_op = torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale_value, atten_mask=atten)
        print("single op output with mask:", single_op, single_op.shape)
        print("graph output with mask:", graph_output, graph_output.shape)
    if __name__ == "__main__":
        MetaInfershape()
    
    # 执行上述代码的输出类似如下
    single op output with mask: tensor([[[ 0.2488, -0.6572,  1.0928,  ...,  0.1694,  0.1142, -2.2266]],
            [[-0.9595, -0.9609, -0.6602,  ...,  0.7959,  1.7920,  0.0783]]],
           device='npu:0', dtype=torch.float16) torch.Size([2, 1, 5120])
    graph output with mask: tensor([[[ 0.2488, -0.6572,  1.0928,  ...,  0.1694,  0.1142, -2.2266]],
            [[-0.9595, -0.9609, -0.6602,  ...,  0.7959,  1.7920,  0.0783]]],
           device='npu:0', dtype=torch.float16) torch.Size([2, 1, 5120])