文档
注册

FlashAttentionScore

算子基础信息

FlashAttentionScore算子新增torch_npu接口,支持torch_npu接口调用。

表1 算子信息

算子名称

FlashAttentionScore

torch_npu api接口

torch_npu.npu_fusion_attention

支持的torch_npu版本

1.11.0, 2.1.0, 2.2.0, 2.3.1

支持的芯片类型

Atlas A2 训练系列产品

支持的数据类型

float16, bfloat16

FlashAttentionScore算子支持原生sdpa接口调用。
表2 原生接口调用

算子名称

FlashAttentionScore

torch_npu api接口

torch.nn.functional.scaled_dot_product_attention

支持的torch_npu版本

2.1.0, 2.2.0, 2.3.1

支持的芯片类型

Atlas A2 训练系列产品

支持的数据类型

float16, bfloat16

算子IR及torch_npu接口参数

算子IR:
REG_OP(FlashAttentionScore)
    .INPUT(query, TensorType({DT_FLOAT16, DT_BF16}))
    .INPUT(key, TensorType({DT_FLOAT16, DT_BF16}))
    .INPUT(value, TensorType({DT_FLOAT16, DT_BF16}))
    .OPTIONAL_INPUT(real_shift, TensorType({DT_FLOAT16, DT_BF16}))
    .OPTIONAL_INPUT(drop_mask, TensorType({DT_UINT8}))
    .OPTIONAL_INPUT(padding_mask, TensorType({DT_FLOAT16, DT_BF16}))
    .OPTIONAL_INPUT(atten_mask, TensorType({DT_BOOL, DT_UINT8}))
    .OPTIONAL_INPUT(prefix, TensorType({DT_INT64}))
    .OPTIONAL_INPUT(actual_seq_qlen, TensorType({DT_INT64}))
    .OPTIONAL_INPUT(actual_seq_kvlen, TensorType({DT_INT64}))
    .OUTPUT(softmax_max, TensorType({DT_FLOAT32}))
    .OUTPUT(softmax_sum, TensorType({DT_FLOAT32}))
    .OUTPUT(softmax_out, TensorType({DT_FLOAT16, DT_BF16}))
    .OUTPUT(attention_out, TensorType({DT_FLOAT16, DT_BF16}))
    .ATTR(scale_value, Float, 1.0)
    .ATTR(keep_prob, Float, 1.0)
    .ATTR(pre_tockens, Int, 2147483647)
    .ATTR(next_tockens, Int, 2147483647)
    .REQUIRED_ATTR(head_num, Int)
    .REQUIRED_ATTR(input_layout, String)
    .ATTR(inner_precise, Int, 0)
    .ATTR(sparse_mode, Int, 0)
    .OP_END_FACTORY_REG(FlashAttentionScore)
torch_npu接口:
torch_npu.npu_fusion_attention(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False) -> (Tensor, Tensor, Tensor, Tensor, int, int, int)

torch_npu接口中的问号表示这个输入参数是可选的。

实现“Transformer Attention Score”的融合计算,实现的计算公式如下:

参数说明:

  • query:Device侧的Tensor,数据类型支持FLOAT16、BFLOAT16FLOAT32,数据格式支持ND。综合约束请见约束说明
  • key:Device侧的Tensor,数据类型支持FLOAT16、BFLOAT16FLOAT32,数据格式支持ND。综合约束请见约束说明
  • value:Device侧的Tensor,数据类型支持FLOAT16、BFLOAT16FLOAT32,数据格式支持ND。综合约束请见约束说明
  • head_num:Host侧的int,代表head个数,数据类型支持INT64。综合约束请见约束说明
  • input_layout:Host侧的string,代表输入query、key、value的数据排布格式,支持BSH、SBH、BSND、BNSD、TND(actual_seq_qlen/actual_seq_kvlen需传值);后续章节如无特殊说明,S表示query或key、value的sequence length,Sq表示query的sequence length,Skv表示key、value的sequence length,SS表示Sq*Skv。
  • pse:Device侧的Tensor,可选参数,表示位置编码。数据类型支持FLOAT16、BFLOAT16FLOAT32,数据格式支持ND。非varlen场景支持四维输入,包含BNSS格式、BN1Skv格式、1NSS格式。如果非varlen场景Sq大于1024或varlen场景、每个batch的Sq与Skv等长且是sparse_mode为0、2、3的下三角掩码场景,可使能alibi位置编码压缩,此时只需要输入原始PSE最后1024行进行内存优化,即alibi_compress = ori_pse[:, :, -1024:, :],参数每个batch不相同时,输入BNHSkv(H=1024),每个batch相同时,输入1NHSkv(H=1024)。
  • padding_mask:Device侧的Tensor,暂不支持该传参
  • atten_mask:Device侧的Tensor,可选参数,取值为1代表该位不参与计算(不生效),为0代表该位参与计算,数据类型支持BOOL、UINT8,数据格式支持ND格式,输入shape类型支持BNSS格式、B1SS格式、11SS格式、SS格式。varlen场景只支持SS格式,SS分别是maxSq和maxSkv。综合约束请见约束说明
  • scale:Host侧的double,可选参数,代表缩放系数,作为计算流中Muls的scalar值,数据类型支持DOUBLE,默认值为1。
  • keep_prob:Host侧的double,可选参数,代表Dropout中1的比例,数据类型支持DOUBLE,默认值为1,表示全部保留。
  • pre_tockens:Host侧的int,用于稀疏计算的参数,可选参数,数据类型支持INT64,默认值为2147483647。综合约束请见约束说明
  • next_tockens:Host侧的int,用于稀疏计算的参数,可选参数,数据类型支持INT64,默认值为2147483647。next_tockens和pre_tockens取值与atten_mask的关系请参见sparse_mode参数,参数取值与atten_mask分布不一致会导致精度问题。综合约束请见约束说明
  • inner_precise:Host侧的int,用于提升精度,数据类型支持INT64,默认值为0。

    当前0、1为保留配置值,2为使能无效行计算,其功能是避免在计算过程中存在整行mask进而导致精度有损失,但是该配置会导致性能下降。

    如果算子可判断出存在无效行场景,会自动使能无效行计算,例如sparse_mode为3,Sq > Skv场景。

  • prefix:Host侧的int array,可选参数,代表prefix稀疏计算场景每个Batch的N值。数据类型支持INT64,数据格式支持ND。综合约束请见约束说明
  • actual_seq_qlen:Host侧的int array,可选参数,varlen场景时需要传入此参数。表示query每个S的累加和长度,数据类型支持INT64,数据格式支持ND。综合约束请见约束说明

    比如真正的S长度列表为:2 2 2 2 2,则actual_seq_qlen传:2 4 6 8 10。

  • actual_seq_kvlen:Host侧的int array,可选参数,varlen场景时需要传入此参数。表示key/value每个S的累加和长度。数据类型支持INT64,数据格式支持ND。综合约束请见约束说明

    比如真正的S长度列表为:2 2 2 2 2,则actual_seq_kvlen传:2 4 6 8 10。

  • sparse_mode:Host侧的int,表示sparse的模式,可选参数。数据类型支持:INT64,默认值为0,支持配置值为0、1、2、3、4、5、6、7、8。当整网的atten_mask都相同且shape小于2048*2048时,建议使用defaultMask模式,来减少内存使用量。综合约束请见约束说明
    表3 sparse_mode不同取值场景说明

    sparse_mode

    含义

    备注

    0

    defaultMask模式。

    -

    1

    allMask模式。

    -

    2

    leftUpCausal模式。

    -

    3

    rightDownCausal模式。

    -

    4

    band模式。

    -

    5

    prefix非压缩模式。

    varlen场景不支持。

    6

    prefix压缩模式。

    -

    7

    varlen外切场景,rightDownCausal模式。

    仅varlen场景支持。

    8

    varlen外切场景,leftUpCausal模式。

    仅varlen场景支持。

    atten_mask的工作原理为,在Mask为True的位置遮蔽query(Q)与key(K)的转置矩阵乘积的值,示意如下:

    QKT矩阵在atten_mask为True的位置会被遮蔽,效果如下:

    说明:下图中的蓝色表示保留该值,atten_mask中,应该配置为False;阴影表示遮蔽该值,atten_mask中应配置为True。
    • sparse_mode为0时,代表defaultMask模式。
      • 不传mask:如果atten_mask未传入则不做mask操作,atten_mask取值为None,忽略pre_tockens和next_tockens取值。Masked QKT矩阵示意如下:

      • next_tockens取值为0,pre_tockens大于等于Sq,表示causal场景sparse,atten_mask应传入下三角矩阵,此时pre_tockens和next_tockens之间的部分需要计算,Masked QKT矩阵示意如下:

        atten_mask应传入下三角矩阵,示意如下:

      • pre_tockens小于Sq,next_tockens小于Skv,且都大于等于0,表示band场景,此时pre_tockens和next_tockens之间的部分需要计算。Masked QKT矩阵示意如下:

        atten_mask应传入band形状矩阵,示意如下:

      • next_tockens为负数,以pre_tockens=9,next_tockens=-3为例,pre_tockens和next_tockens之间的部分需要计算。Masked QKT示意如下:

        说明:next_tockens为负数时,pre_tockens取值必须大于等于next_tockens的绝对值,且next_tockens的绝对值小于Skv。

      • pre_tockens为负数,以next_tockens=7,pre_tockens=-3为例,pre_tockens和next_tockens之间的部分需要计算。Masked QKT示意如下:

        说明:pre_tockens为负数时,next_tockens取值必须大于等于pre_tockens的绝对值,且pre_tockens的绝对值小于Sq。

    • sparse_mode为1时,代表allMask,即传入完整的atten_mask矩阵。

      该场景下忽略next_tockens、pre_tockens取值,Masked QKT矩阵示意如下:

    • sparse_mode为2时,代表leftUpCausal模式的mask,对应以左上顶点划分的下三角场景(参数起点为左上角)。该场景下忽略pre_tockens、next_tockens取值,Masked QKT矩阵示意如下:

      传入的atten_mask为优化后的压缩下三角矩阵(2048*2048),压缩下三角矩阵示意(下同):

    • sparse_mode为3时,代表rightDownCausal模式的mask,对应以右下顶点划分的下三角场景(参数起点为右下角)。该场景下忽略pre_tockens、next_tockens取值。atten_mask为优化后的压缩下三角矩阵(2048*2048),Masked QKT矩阵示意如下:

    • sparse_mode为4时,代表band场景,即计算pre_tockens和next_tockens之间的部分,参数起点为右下角,pre_tockens和next_tockens之间需要有交集。atten_mask为优化后的压缩下三角矩阵(2048*2048)。Masked QKT矩阵示意如下:

    • sparse_mode为5时,代表prefix非压缩场景,即在rightDownCasual的基础上,左侧加上一个长为Sq,宽为N的矩阵,N的值由可选输入prefix获取,例如下图中表示batch=2场景下prefix传入数组[4,5],每个batch轴的N值可以不一样,参数起点为左上角。

      该场景下忽略pre_tockens、next_tockens取值,atten_mask矩阵数据格式须为BNSS或B1SS,Masked QKT矩阵示意如下:

      atten_mask应传入矩阵示意如下:

    • sparse_mode为6时,代表prefix压缩场景,即prefix场景时,attenMask为优化后的压缩下三角+矩形的矩阵(3072*2048):其中上半部分[2048,2048]的下三角矩阵,下半部分为[1024,2048]的矩形矩阵,矩形矩阵左半部分全0,右半部分全1,atten_mask应传入矩阵示意如下。该场景下忽略pre_tockens、next_tockens取值。

    • sparse_mode为7时,表示varlen且为长序列外切场景(即长序列在模型脚本中进行多卡切query的sequence length);用户需要确保外切前为使用sparse_mode 3的场景;当前mode下用户需要设置pre_tockens和next_tockens(起点为右下顶点),且需要保证参数正确,否则会存在精度问题。

      Masked QKT矩阵示意如下,在第二个batch对query进行切分,key和value不切分,4x6的mask矩阵被切分成2x6和2x6的mask,分别在卡1和卡2上计算:

      • 卡1的最后一块mask为band类型的mask,配置pre_tockens=6(保证大于等于最后一个Skv),next_tockens=-2,actual_seq_qlen应传入{3,5},actual_seq_kvlen应传入{3,9}。
      • 卡2的mask类型切分后不变,sparse_mode为3,actual_seq_qlen应传入{2,7,11},actual_seq_kvlen应传入{6,11,15}。

      • 如果配置sparse_mode=7,但实际只存在一个batch,用户需按照band模式的要求来配置参数;sparse_mode=7时,用户需要输入2048x2048的下三角mask作为该融合算子的输入。
      • 基于sparse_mode=3进行外切产生的band模式的sparse的参数应符合以下条件:
        • pre_tockens >= last_Skv。
        • next_tockens <= 0。
        • 当前模式下不支持可选输入pse。
    • sparse_mode为8时,表示varlen且为长序列外切场景;用户需要确保外切前为使用sparse_mode 2的场景;当前mode下用户需要设置pre_tockens和next_tockens(起点为右下顶点),且需要保证参数正确,否则会存在精度问题。

      Masked QKT矩阵示意如下,在第二个batch对query进行切分,key和value不切分,5x4的mask矩阵被切分成2x4和3x4的mask,分别在卡1和卡2上计算:

      • 卡1的mask类型切分后不变,sparse_mode为2,actual_seq_qlen应传入{3,5},actual_seq_kvlen应传入{3,7}。
      • 卡2的第一块mask为band类型的mask,配置pre_tockens=4(保证大于等于第一个Skv),next_tockens=1,actual_seq_qlen应传入{3,8,12},actual_seq_kvlen应传入{4,9,13}。

      • 如果配置sparse_mode=8,但实际只存在一个batch,用户需按照band模式的要求来配置参数;sparse_mode=8时,用户需要输入2048x2048的下三角mask作为该融合算子的输入。
      • 基于sparse_mode=2进行外切产生的band模式的sparse的参数应符合以下条件:
        • pre_tockens >= first_Skv。
        • next_tockens范围无约束,根据实际情况进行配置。
        • 当前模式下不支持可选输入pse。
  • gen_mask_parallel:debug参数,DSA生成dropout随机数向量mask的控制开关,默认值为True:同AI Core计算并行,False:同AI Core计算串行。
  • sync:debug参数,DSA生成dropout随机数向量mask的控制开关,默认值为False:dropout mask异步生成,True:dropout mask同步生成。

输出说明:

共7个输出

(Tensor, Tensor, Tensor, Tensor, int, int, int)
  • 第1个输出为Tensor,计算公式的最终输出y,数据类型支持:FLOAT16、BFLOAT16。
  • 第2个输出为Tensor,Softmax 计算的Max中间结果,用于反向计算,数据类型支持:FLOAT。
  • 第3个输出为Tensor,Softmax计算的Sum中间结果,用于反向计算,数据类型支持:FLOAT。
  • 第4个输出为Tensor,保留参数,暂未使用。
  • 第5个输出为int,DSA生成dropoutmask中,Philox算法的seed。
  • 第6个输出为int,DSA生成dropoutmask中,Philox算法的offset。
  • 第7个输出为int,DSA生成dropoutmask的长度。

约束说明:

  • 输入query、key、value的B:batchsize必须相等;非varlen场景B取值范围1~2M;varlen场景B取值范围1~2K。
  • 输入query、key、value、pse的数据类型必须一致。
  • 输入query、key、value的input_layout必须一致。
  • 支持输入query的N和key/value的N不相等,但必须成比例关系,即Nq/Nkv必须是非0整数,Nq取值范围1~256。当Nq/Nkv > 1时,即为GQA(grouped-query attention);当Nq/Nkv=1时,即为MQA(multi-query attention)。本文如无特殊说明,N表示的是Nq。
  • 输入key/value的shape必须一致。
  • 输入query、key、value的S:sequence length,取值范围1~1M。
  • 部分场景下,如果计算量过大可能会导致算子执行超时(aicore error类型报错,errorStr为:timeout or trap error),此时建议做轴切分处理,注:这里的计算量会受B、S、N、D等参数的影响,值越大计算量越大。
  • 输入query、key、value的D:head dim必须相等,取值范围1~512。
  • varlen场景T(B*S)取值范围1~1M。
  • keep_prob的取值范围为(0, 1] 。
  • sparse_mode为1、2、3、4、5、6、7、8时,应传入对应正确的atten_mask,否则将导致计算结果错误。当atten_mask输入为None时,sparse_mode,pre_tockens,next_tockens参数不生效,固定为全计算。
  • sparse_mode配置为1、2、3、5、6时,用户配置的pre_tockens、next_tockens不会生效。
  • sparse_mode配置为0、4时,须保证atten_mask与pre_tockens、next_tockens的范围一致。
  • prefix稀疏计算场景B不大于32,varlen场景不支持非压缩prefix,即不支持sparse_mode=5;当Sq>Skv时,prefix的N值取值范围[0, Skv],当Sq<=Skv时,prefix的N值取值范围[Skv-Sq, Skv]。
  • sparse_mode=7或者8时,不支持可选输入pse。
  • varlen场景下,atten_mask输入不支持补pad,即atten_mask中不能存在某一行全1的场景。
  • 支持actual_seq_qlen中某个Batch上的S长度为0;如果存在S为0的情况,不支持pse输入, 假设真实的S长度为[2,2,0,2,2],则传入的actual_seq_qlen为[2,4,4,6,8];不支持某个batch中Sq不为0,但是Skv为0的场景。

模型中替换代码

当前GPU模式下,调用FA算子的方式有多种,torch调用FA的接口scaled_dot_product_attention,通过flash-attention库中的flash_attn_func、flash_attn_varlen_func等接口调用。NPU模式下除了已经适配的sdpa接口,其余模式需要通过torch_npu接口——npu_fusion_attention接口实现调用。两者之间的适配可能涉及一些脚本迁移工作,以下通过范例说明接口适配方式。

torch原生接口

scaled_dot_product_attention:

当前已适配NPU,训练场景直接调用即可调用到FA,相关规格限制请参考torch原生接口调用FA算子使用限制。若依然需要使用NPU接口,可以按以下方式适配替换,但输入规格依然要满足torch原生接口调用FA算子使用限制要求。
  • 不使能is_causal时,原调用接口代码:
    res = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
                                         dropout_p=0.0, is_causal=False)

    替换为:

    if attn_mask.dtype == torch.bool:
        attn_mask_npu = torch.logical_not(attention_mask.bool()).to(device) // atten_mask需要取反
    else:
        attn_mask_npu = attention_mask.bool().to(device)
    head_num = query.shape[1]
    res = torch_npu.npu_fusion_attention(
                           query, key, value, head_num, input_layout="BNSD", 
                           pse=None,
                           atten_mask=attn_mask_npu,
                           scale=1.0 / math.sqrt(query.shape[-1]),
                           pre_tockens=2147483647,
                           next_tockens=2147483647,
                           keep_prob=1
                       )[0]
  • 使能is_causal时,原调用接口代码:
    res = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None,
                                         dropout_p=0.0, is_causal=True)

    替换为:

    atten_mask = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(device)
    head_num = query.shape[1]
    res = torch_npu.npu_fusion_attention(
                           query, key, value, head_num, input_layout="BNSD", 
                           pse=None,
                           atten_mask=attn_mask_npu,
                           scale=1.0 / math.sqrt(query.shape[-1]),
                           pre_tockens=2147483647,
                           next_tockens=2147483647,
                           keep_prob=1,
                           sparse_mode=2
                       )[0]
flash-attention库
  • flash_attn_func

    不使能causal时,模型中替换代码:

    out= flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=False)

    替换为:

    head_num = q.shape[2]
    out = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=1.0,
                                         scale=scale)

    使能causal时,模型中替换代码:

    out= flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=True)

    替换为:

    atten_mask_npu= torch.from_numpy(np.triu(np.ones(seqlen_q, seqlen_k), k=1)).bool().to(device)
    head_num = q.shape[2]
    out = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=1.0,
                                         scale=scale, atten_mask=attn_mask_npu,)
  • flash_attn_varlen_func

    不使能causal时,原调用接口代码:

    out = flash_attn_varlen_func(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p=0.0,
        softmax_scale=None,
        causal=False
    )

    替换为:

    head_num = q.shape[1]
    output = torch_npu.npu_fusion_attention(
                 q, k, v, head_num,
                 pse=None,
                 padding_mask=None,
                 atten_mask=None,
                 scale=1.0 / math.sqrt(query.shape[-1]),
                 keep_prob=1,
                 input_layout="TND",
                 actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
                 actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
                 pre_tockens=2147483647,
                 next_tockens=2147483647,
                 sparse_mode=0)[0]

    使能causal时,GPU调用接口代码:

    out = flash_attn_varlen_func(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p=0.0,
        softmax_scale=None,
        causal=True
    )

    替换为:

    atten_mask_npu= torch.from_numpy(np.triu(np.ones([max_seqlen_q, max_seqlen_k]), k=1))
    head_num = q.shape[1]
    output = torch_npu.npu_fusion_attention(
                 q, k, v, head_num,
                 pse=None,
                 padding_mask=None,
                 atten_mask=atten_mask_npu,
                 scale=1.0 / math.sqrt(query.shape[-1]),
                 keep_prob=1,
                 input_layout="TND",
                 actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
                 actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
                 pre_tockens=2147483647,
                 next_tockens=0)[0]

xFormers库

memory_efficient_attention

模型中替换代码:
def __init__(self, attention_op: Optional[Callable] = None):
       self.attention_op = attention_op

def head_to_batch_dim(self, tensor):
       tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
       return tensor

...

query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
        )
替换为:
def head_to_batch_dim(self, tensor, out_dim=3):
       head_size = self.heads
       batch_size, seq_len, dim = tensor.shape
       tensor = torch_npu.npu_confusion_transpose(tensor, [0, 2, 1, 3], (batch_size, seq_len, head_size, dim // head_size), False)

       if out_dim == 3:
           tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)

       return tensor

...

query = attn.head_to_batch_dim(query, out_dim=4)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
hidden_states = torch_npu.npu_fusion_attention(
            query, key, value, heads, input_layout="BNSD",
            pse=None,
            atten_mask=attention_mask,
            scale=scale,
            pre_tockens=2147483647,
            next_tockens=2147483647,
            keep_prob=1.,
            sync=False
        )[0]

算子计算逻辑

算子的计算逻辑如下:
def forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob):
	if pse is None:
		qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale)
	else:
		qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scale)
	if atten_mask is None:
		qk = qk
	else:
		qk = qk + atten_mask * torch.finfo(torch.float32).min
	softmax_res, softmax_max, softmax_sum = softmax(qk)
	if drop_mask:
		drop_res = softmax_res
	else:
		drop_res = softmax_res * drop_mask * (1.0 / (keep_prob))
	attention_out = torch.matmul(drop_res, v)
	return attention_out
图1 计算流程图

已支持模型典型case

如下case均包含fp16、bf16。

表4 case列表

id

BNSD

1

[1, 8, 4096, 128]

2

[4, 32, 2048, 64]

3

[8, 16, 512, 128]

4

[8, 16, 512, 128]

5

[8, 16, 512, 128]

6

[8, 16, 512, 64]

7

[8, 16, 512, 64]

8

[4, 4, 2048, 64]

torch原生接口调用FA算子使用限制

接口和参数说明:

参考开源说明可单击Link

使用限制:

  • 参数输入符合规格:
    • 输入query、key、value的N:batch size,当前只支持[N,head_num, S(L), E(Ev)]的排布方式,取值范围1~2K
    • 输入query的head num和key/value的head num必须成比例关系,即Nq/Nkv必须是非0整数,取值范围1~256
    • 输入query的L:Target sequence length,取值范围1~512K
    • 输入key、value的S:Source sequence length,取值范围1~512K
    • 输入query、key、value的E:Embedding dimension of the query and key,取值范围1~512
    • 输入value的Ev:Embedding dimension of the value,必须与E相等
    • 输入attn_mask:当前支持[N, 1, L, S]、[N, head_num, L, S]、[1, 1, L, S]、[L, S]共4种排布方式
    • 在使能is_causal计算时,attn_mask必须为None;不使能is_causal时,若attn_mask输入有效数据,输入数据类型必须是Bool类型
  • 与原接口除了规格限制之外存在差异点
    • NPU用DSA硬件实现,算法在DSA引擎固化存在跟GPU算法实现差异,导致dropout功能和GPU结果不一致
    • 当前接口支持输入query的head num和key/value的head num不等长,而原生PyTorch接口不支持
    • 输入query、key、value的数据类型bf16、fp16、fp32并且使能requires_grad时,执行FA算子
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词