下载
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助
昇腾小AI

torch_npu.npu_prompt_flash_attention

功能描述

全量FA实现,实现对应公式:

接口原型

torch_npu.npu_prompt_flash_attention(Tensor query, Tensor key, Tensor value, *, Tensor? pse_shift=None, padding_mask=None, Tensor? atten_mask=None, int[]? actual_seq_lengths=None, Tensor? deq_scale1=None, Tensor? quant_scale1=None, Tensor? deq_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, int num_heads=1, float scale_value=1.0, int pre_tokens=2147473647, int next_tokens=0, str input_layout="BSH", int num_key_value_heads=0, int[]? actual_seq_lengths_kv=None, int sparse_mode=0) -> Tensor

参数说明

  • query(Tensor*,计算输入):Device侧的Tensor,公式中的输入Q,数据类型支持FLOAT16、BFLOAT16、INT8,数据类型与key的数据类型需满足数据类型推导规则,即保持与key、value的数据类型一致。不支持非连续的Tensor,数据格式支持ND,Atlas 推理系列加速卡产品仅支持FLOAT16。
  • key(Tensor*,计算输入):Device侧的Tensor,公式中的输入K,数据类型支持FLOAT16、BFLOAT16、INT8,数据类型与query的数据类型需满足数据类型推导规则,即保持与query、value的数据类型一致。不支持非连续的Tensor,数据格式支持ND,Atlas 推理系列加速卡产品仅支持FLOAT16。
  • value(Tensor*,计算输入):Device侧的Tensor,公式中的输入V,数据类型支持FLOAT16、BFLOAT16、INT8,数据类型与query的数据类型需满足数据类型推导规则,即保持与query、key的数据类型一致。不支持非连续的Tensor,数据格式支持ND,Atlas 推理系列加速卡产品仅支持FLOAT16。
  • *:代表其之前的变量是位置相关,需要按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
  • pse_shift(Tensor*,计算输入):Device侧的Tensor,可选参数,数据类型支持FLOAT16、BFLOAT16,且在pse_shift为FLOAT16类型时,要求此时的query为FLOAT16或INT8类型,而在pse_shift为BFLOAT16类型时,要求此时的query为BFLOAT16类型。在query、key、value为FLOAT16且pse_shift存在的情况下,默认走高精度模式。不支持非连续的Tensor,数据格式支持ND。输入shape类型需为(B,N,Q_S,KV_S)或(1,N,Q_S,KV_S),其中Q_S为query的shape中的S,KV_S为key和value的shape中的S。对于pse_shift的KV_S为非32字节对齐的场景,建议padding到32字节来提高性能,多余部分的填充值不做要求。如不使用该功能时可传入nullptr。综合约束请见约束说明Atlas 推理系列加速卡产品仅支持nullptr。
  • padding_mask:预留参数,暂未使用,默认值为None。
  • atten_mask(Tensor*,计算输入):Device侧的Tensor,代表下三角全为0上三角全为负无穷的倒三角mask矩阵,不支持非连续的Tensor,数据类型支持BOOL、INT8和UINT8。数据格式支持ND。如果不使用该功能可传入nullptr。通常建议shape输入Q_S,KV_S;B,Q_S,KV_S;1,Q_S,KV_S;B,1,Q_S,KV_S;1,1,Q_S,KV_S,其中Q_S为query的shape中的S,KV_S为key和value的shape中的S,对于attenMask的KV_S为非32字节对齐的场景,建议padding到32字节对齐来提高性能,多余部分填充成1。综合约束请见约束说明
  • actual_seq_lengths(IntArray*,计算输入):Host侧的IntArray,代表不同Batch中query的有效Sequence Length,数据类型支持INT64。如果不指定seqlen可以传入nullptr,表示和query的shape的s长度相同。限制:该入参中每个batch的有效Sequence Length应该不大于query中对应batch的Sequence Length。seqlen的传入长度为1时,每个Batch使用相同seqlen;传入长度大于等于Batch数时取seqlen的前Batch个数。其它长度不支持。Atlas 推理系列加速卡产品仅支持nullptr。
  • deq_scale1(Tensor*,计算输入):Device侧的Tensor,数据类型支持UINT64、FLOAT32。数据格式支持ND(参考),表示BMM1后面的反量化因子,支持per-tensor。 如不使用该功能时可传入nullptr,Atlas 推理系列加速卡产品仅支持nullptr。
  • quant_scale1(Tensor*,计算输入):Device侧的Tensor,数据类型支持FLOAT32。数据格式支持ND(参考),表示BMM2前面的量化因子,支持per-tensor。 如不使用该功能时可传入nullptr,Atlas 推理系列加速卡产品仅支持nullptr。
  • deq_scale2(Tensor*,计算输入):Device侧的Tensor,数据类型支持UINT64、FLOAT32。数据格式支持ND(参考),表示BMM2后面的反量化因子,支持per-tensor。 如不使用该功能时可传入nullptr,Atlas 推理系列加速卡产品仅支持nullptr。
  • quant_scale2(Tensor*,计算输入):Device侧的Tensor,数据类型支持FLOAT32、BFLOAT16。数据格式支持ND(参考),表示输出的量化因子,支持per-tensor,per-channel。

    当输入为BFLOAT16时,同时支持 FLOAT32和BFLOAT16 ,否则仅支持 FLOAT32 。per-channel 格式,当输出layout为BSH时,要求 quant_scale2所有维度的乘积等于H;其他layout要求乘积等于N*D(建议输出layout为BSH时,quant_scale2 shape传入[1,1,H]或[H];输出为BNSD时,建议传入[1,N,1,D]或[N,D];输出为BSND时,建议传入[1,1,N,D]或[N,D])。

  • 如不使用该功能时可传入nullptr,Atlas 推理系列加速卡产品仅支持nullptr。
  • quant_offset2(Tensor*,计算输入):Device侧的Tensor,数据类型支持FLOAT32、BFLOAT16。数据格式支持ND(参考),表示输出的量化偏移,支持per-tensor,per-channel。若传入 quant_offset2,需保证其类型和shape信息与 quant_scale2一致。
  • 如不使用该功能时可传入nullptr,Atlas 推理系列加速卡产品仅支持nullptr。
  • num_heads(int64_t,计算输入):Host侧的int,代表query的head个数,数据类型支持INT64。
  • scale_value(double,计算输入):Host侧的double,公式中d开根号的倒数,代表缩放系数,作为计算流中Muls的scalar值,数据类型支持DOUBLE。数据类型与query的数据类型需满足数据类型推导规则。用户不特意指定时可传入默认值1.0。
  • pre_tokens(int64_t,计算输入):Host侧的int,用于稀疏计算,表示attention需要和前几个Token计算关联,数据类型支持INT64。用户不特意指定时可传入默认值2147483647,Atlas 推理系列加速卡产品仅支持默认值2147483647。
  • next_tokens(int64_t,计算输入):Host侧的int,用于稀疏计算,表示attention需要和后几个Token计算关联。数据类型支持INT64。用户不特意指定时可传入默认值0,Atlas 推理系列加速卡产品仅支持0和2147483647。
  • input_layout(char*,计算输入):Host侧的字符指针CHAR*,用于标识输入query、key、value的数据排布格式,当前支持BSH、BSND、BNSD、BNSD、BNSD_BSND(输入为BNSD时,输出格式为BSND)。用户不特意指定时可传入默认值"BSH"。
  • num_key_value_heads(int64_t,计算输入):Host侧的int,代表key、value中head个数,用于支持GQA(Grouped-Query Attention,分组查询注意力)场景,数据类型支持INT64。用户不特意指定时可传入默认值0,表示key/value和query的head个数相等。限制:需要满足num_heads整除num_key_value_heads,num_heads与num_key_value_heads的比值不能大于64,且在BSND、BNSD、BNSD_BSND场景下,需要与shape中的key/value的N轴shape值相同,否则报错,Atlas 推理系列加速卡产品仅支持默认值0。
  • actual_seq_lengths_kv(aclIntArray*,计算输入):Host侧的aclIntArray,可传入nullptr,代表不同batch中key/value的有效Sequence Length。数据类型支持INT64。限制:该入参中每个batch的有效Sequence Length应该不大于key/value中对应batch的Sequence Length,Atlas 推理系列加速卡产品仅支持nullptr。seqlenKV的传入长度为1时,每个Batch使用相同seqlenKV;传入长度大于等于Batch数时取seqlenKV的前Batch个数,其它长度不支持。
  • sparse_mode(int64_t,计算输入):Host侧的int,表示sparse的模式。数据类型支持INT64,Atlas 推理系列加速卡产品仅支持默认值0。
    • sparse_mode为0时,代表defaultMask模式,如果atten_mask未传入则不做mask操作,忽略preTokens和nextTokens(内部赋值为INT_MAX);如果传入,则需要传入完整的atten_mask矩阵(S1 * S2),表示pre_tokens和next_tokens之间的部分需要计算。
    • sparse_mode为1时,代表allMask。
    • sparse_mode为2时,代表leftUpCausal模式的mask,需要传入优化后的atten_mask矩阵(2048*2048)。
    • sparse_mode为3时,代表rightDownCausal模式的mask,均对应以左顶点为划分的下三角场景,需要传入优化后的atten_mask矩阵(2048*2048)。
    • sparse_mode为4时,代表band模式的mask,需要传入优化后的atten_mask矩阵(2048*2048)。
    • sparse_mode为5、6、7、8时,分别代表prefix、global、dilated、block_local,均暂不支持。用户不特意指定时可传入默认值0。综合约束请见约束说明

输出说明

共一个输出,为计算的最终结果atten_out,类型为Tensor,shape与query保持一致。·

约束说明

  • 该融合算子仅在推理场景使用。
  • 该接口与pytorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
  • 入参为空的处理:算子内部需要判断参数query是否为空,如果是空则直接返回。参数query不为空Tensor,参数key、value为空tensor(即S2为0),则填充全零的对应shape的输出(填充attention_out)。attention_out为空Tensor时,AscendCLNN框架会处理。其余在上述参数说明中标注了"可传入nullptr"的入参为空指针时,不进行处理。
  • query,key,value输入,功能使用限制如下:
    • Atlas A2训练系列产品/Atlas 800I A2推理产品
      • 支持B轴小于等于65536(64k),D轴32byte不对齐时仅支持到128。
      • 支持N轴小于等于256。
      • S支持小于等于20971520(20M)。长序列场景下,如果计算量过大可能会导致pfa算子执行超时(aicore error类型报错,errorStr为:timeout or trap error),此场景下建议做S切分处理,注:这里计算量会受B、S、N、D等的影响,值越大计算量越大。典型的会超时的长序列(即B、S、N、D的乘积较大)场景包括但不限于:

        (1)B=1,Q_N=20,Q_S=1048576,D = 256,KV_N=1,KV_S=1048576。

        (2)B=1,Q_N=2,Q_S=10485760,D = 256,KV_N=2,KV_S=10485760。

        (3)B=20,Q_N=1,Q_S=1048576,D = 256,KV_N=1,KV_S=1048576。

        (4)B=1,Q_N=10,Q_S=1048576,D = 512,KV_N=1,KV_S=1048576。

      • 支持D轴小于等于512。inputLayout为BSH或者BSND时,要求N*D小于65535。
    • Atlas 推理系列加速卡产品
      • 支持B轴小于等于128。
      • 支持N轴小于等于256。
      • 支持S轴小于等于65535(64k)。
      • 支持D轴小于等于512。
  • 参数sparse_mode当前仅支持值为0、1、2、3、4的场景,取其它值时会报错。
    • sparse_mode= 0时,atten_mask如果为空指针,则忽略入参pre_tokens、next_tokens(内部赋值为INT_MAX)。
    • sparse_mode= 2、3、4时,atten_mask的shape需要为S,S或1,S,S或1,1,S,S,其中S的值需要固定为2048,且需要用户保证传入的atten_mask为下三角,不传入atten_mask或者传入的shape不正确报错。
    • sparse_mode= 1、2、3的场景忽略入参pre_tokens、next_tokens并按照相关规则赋值。
    • sparse_mode= 4时,pre_tokens、next_tokens必须为非负数。
  • int8量化相关入参数量与输入、输出数据格式的综合限制:
    • 输入为INT8,输出为INT8的场景:入参deq_scale1、quant_scale1、deq_scale2、quant_scale2需要同时存在,quant_offset2可选,不传时默认为0。
    • 输入为INT8,输出为FLOAT16的场景:入参deq_scale1、quant_scale1、deq_scale2需要同时存在,若存在入参quant_offset2或 quant_scale2(即不为nullptr),则报错并返回。
    • 输入为FLOAT16或BFLOAT16,输出为INT8的场景:入参quant_scale2需存在,quant_offset2可选,不传时默认为0,若存在入参deq_scale1或 quant_scale1或 deq_scale2(即不为nullptr),则报错并返回。
    • 入参 quant_offset2和 quant_scale2支持per-tensor/per-channel两种格式和FLOAT32/BFLOAT16两种数据类型。若传入quant_offset2,需保证其类型和shape信息与quant_scale2一致。当输入为BFLOAT16时,同时支持FLOAT32和BFLOAT16,否则仅支持FLOAT32。per-channel 格式,当输出layout为BSH时,要求quant_scale2所有维度的乘积等于H;其他layout要求乘积等于N*D。(建议输出layout为BSH时,quant_scale2shape传入[1,1,H]或[H];输出为BNSD时,建议传入[1,N,1,D]或[N,D];输出为BSND时,建议传入[1,1,N,D]或[N,D])。per-tensor格式,建议D轴对齐到32Byte。
    • per-channel格式,入参 quant_scale2和 quant_offset2暂不支持左padding、Ring Attention或者D非32Byte对齐的场景。
    • 输出为int8时,暂不支持sparse为band且pre_tokens/next_tokens为负数。
  • pse_shift功能使用限制如下:
    • 支持query数据类型为FLOAT16或BFLOAT16或INT8场景下使用该功能。
    • query,key,value数据类型为FLOAT16且pse_shift存在时,强制走高精度模式,对应的限制继承自高精度模式的限制。
    • Q_S需大于等于query的S长度,KV_S需大于等于key的S长度。
  • 输出为INT8,入参quantOffset2传入非空指针和非空tensor值,并且sparse_mode、pre_tokens和next_tokens满足以下条件,矩阵会存在某几行不参与计算的情况,导致计算结果误差,该场景会拦截:
    • sparseMode = 0,atten_mask如果非空指针,每个batch actual_seq_lengths— actual_seq_lengths_kv - pre_tokens> 0 或 nextTokens < 0 时,满足拦截条件。
    • sparseMode = 1 或 2,不会出现满足拦截条件的情况。
    • sparseMode = 3,每个batch actual_seq_lengths_kv- actual_seq_lengths< 0,满足拦截条件。
    • sparseMode = 4,preTokens < 0 或 每个batch next_tokens + actual_seq_lengths_kv- actual_seq_lengths< 0 时,满足拦截条件。
  • kv伪量化参数分离当前暂不支持。
  • 暂不支持D不对齐场景。

支持的芯片型号

  • Atlas A2训练系列产品/Atlas 800I A2推理产品
  • Atlas 推理系列加速卡产品

调用示例

  • 单算子调用
    # 单算子调用方式
    import torch
    import torch_npu
    import math
    
    # 生成随机数据,并发送到npu
    q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
    k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
    v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
    scale = 1/math.sqrt(128.0)
    actseqlen = [164]
    actseqlenkv = [1024]
    
    # 调用PFA算子
    out = torch_npu.npu_prompt_flash_attention(q, k, v, 
    actual_seq_lengths = actseqlen, actual_seq_lengths_kv = actseqlenkv,
    num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535)
    
    # 执行上述代码的输出类似如下
    tensor([[[[ 0.0219,  0.0201,  0.0049,  ...,  0.0118, -0.0011, -0.0140],
            [ 0.0294,  0.0256, -0.0081,  ...,  0.0267,  0.0067, -0.0117],
            [ 0.0285,  0.0296,  0.0011,  ...,  0.0150,  0.0056, -0.0062],
            ...,
            [ 0.0177,  0.0194, -0.0060,  ...,  0.0226,  0.0029, -0.0039],
            [ 0.0180,  0.0186, -0.0067,  ...,  0.0204, -0.0045, -0.0164],
            [ 0.0176,  0.0288, -0.0091,  ...,  0.0304,  0.0033, -0.0173]]]],
            device='npu:0', dtype=torch.float16)
  • 图模式调用(图模式目前仅支持PyTorch 2.1版本)
    # 入图方式
    
    import torch
    import torch_npu
    import math
    
    import torchair as tng
    
    from torchair.configs.compiler_config import CompilerConfig
    import torch._dynamo
    TORCHDYNAMO_VERBOSE=1
    TORCH_LOGS="+dynamo"
    
    # 支持入图的打印宏
    import logging
    from torchair.core.utils import logger
    logger.setLevel(logging.DEBUG)
    config = CompilerConfig()
    config.aoe_config.aoe_mode = "2"
    config.debug.graph_dump.type = "pbtxt"
    npu_backend = tng.get_npu_backend(compiler_config=config)
    from torch.library import Library, impl
    
    # 数据生成
    q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
    k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
    v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
    scale = 1/math.sqrt(128.0)
    
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self):
            return torch_npu.npu_prompt_flash_attention(q, k, v, num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535)
    def MetaInfershape():
        with torch.no_grad():
            model = Model()
            model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
            graph_output = model()
        single_op = torch_npu.npu_prompt_flash_attention(q, k, v, num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535)
        print("single op output with mask:", single_op, single_op.shape)
        print("graph output with mask:", graph_output, graph_output.shape)
    if __name__ == "__main__":
        MetaInfershape()
    
    # 执行上述代码的输出类似如下
    single op output with mask: tensor([[[[ 0.0219,  0.0201,  0.0049,  ...,  0.0118, -0.0011, -0.0140],
            [ 0.0294,  0.0256, -0.0081,  ...,  0.0267,  0.0067, -0.0117],
            [ 0.0285,  0.0296,  0.0011,  ...,  0.0150,  0.0056, -0.0062],
            ...,
            [ 0.0177,  0.0194, -0.0060,  ...,  0.0226,  0.0029, -0.0039],
            [ 0.0180,  0.0186, -0.0067,  ...,  0.0204, -0.0045, -0.0164],
            [ 0.0176,  0.0288, -0.0091,  ...,  0.0304,  0.0033, -0.0173]]]],
            device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])
    
    graph output with mask: tensor([[[[ 0.0219,  0.0201,  0.0049,  ...,  0.0118, -0.0011, -0.0140],
            [ 0.0294,  0.0256, -0.0081,  ...,  0.0267,  0.0067, -0.0117],
            [ 0.0285,  0.0296,  0.0011,  ...,  0.0150,  0.0056, -0.0062],
            ...,
            [ 0.0177,  0.0194, -0.0060,  ...,  0.0226,  0.0029, -0.0039],
            [ 0.0180,  0.0186, -0.0067,  ...,  0.0204, -0.0045, -0.0164],
            [ 0.0176,  0.0288, -0.0091,  ...,  0.0304,  0.0033, -0.0173]]]],
            device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词