下载
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助
昇腾小AI

FlashAttentionScore

算子基础信息

FlashAttentionScore算子新增torch_npu接口,支持torch_npu接口调用。

表1 算子信息

算子名称

FlashAttentionScore

torch_npu api接口

torch_npu.npu_fusion_attention

支持的torch_npu版本

2.1.0, 2.3.1

支持的芯片类型

Atlas A2 训练系列产品

支持的数据类型

float16, bfloat16

FlashAttentionScore算子支持原生sdpa接口调用。

在当前版本中,sdpa(scaled_dot_product_attention)接口仅作为一项试用特性,此功能在后续版本中可能会有所调整或改进。请用户在使用过程中关注后续版本的迭代。

表2 原生接口调用

算子名称

FlashAttentionScore

torch_npu api接口

torch.nn.functional.scaled_dot_product_attention

支持的torch_npu版本

2.1.0,2.3.1

支持的芯片类型

Atlas A2 训练系列产品

支持的数据类型

float16, bfloat16

算子IR及torch_npu接口参数

算子IR:
REG_OP(FlashAttentionScore)
    .INPUT(query, TensorType({DT_FLOAT16, DT_BF16}))
    .INPUT(key, TensorType({DT_FLOAT16, DT_BF16}))
    .INPUT(value, TensorType({DT_FLOAT16, DT_BF16}))
    .OPTIONAL_INPUT(real_shift, TensorType({DT_FLOAT16, DT_BF16}))
    .OPTIONAL_INPUT(drop_mask, TensorType({DT_UINT8}))
    .OPTIONAL_INPUT(padding_mask, TensorType({DT_FLOAT16, DT_BF16}))
    .OPTIONAL_INPUT(atten_mask, TensorType({DT_BOOL, DT_UINT8}))
    .OPTIONAL_INPUT(prefix, TensorType({DT_INT64}))
    .OPTIONAL_INPUT(actual_seq_qlen, TensorType({DT_INT64}))
    .OPTIONAL_INPUT(actual_seq_kvlen, TensorType({DT_INT64}))
    .OUTPUT(softmax_max, TensorType({DT_FLOAT32}))
    .OUTPUT(softmax_sum, TensorType({DT_FLOAT32}))
    .OUTPUT(softmax_out, TensorType({DT_FLOAT16, DT_BF16}))
    .OUTPUT(attention_out, TensorType({DT_FLOAT16, DT_BF16}))
    .ATTR(scale_value, Float, 1.0)
    .ATTR(keep_prob, Float, 1.0)
    .ATTR(pre_tockens, Int, 2147483647)
    .ATTR(next_tockens, Int, 2147483647)
    .REQUIRED_ATTR(head_num, Int)
    .REQUIRED_ATTR(input_layout, String)
    .ATTR(inner_precise, Int, 0)
    .ATTR(sparse_mode, Int, 0)
    .OP_END_FACTORY_REG(FlashAttentionScore)
torch_npu接口:
torch_npu.npu_fusion_attention(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False) -> (Tensor, Tensor, Tensor, Tensor, int, int, int)

torch_npu接口中的问号表示这个输入参数是可选的。

实现“Transformer Attention Score”的融合计算,实现的计算公式如下:

参数说明、输出说明和约束说明具体请参考《API 参考》中的“torch_npu.npu_fusion_attention”章节

模型中替换代码

当前GPU模式下,调用FA算子的方式有多种,torch调用FA的接口scaled_dot_product_attention,通过flash-attention库中的flash_attn_func、flash_attn_varlen_func等接口调用。NPU模式下除了已经适配的sdpa接口,其余模式需要通过torch_npu接口——npu_fusion_attention接口实现调用。两者之间的适配可能涉及一些脚本迁移工作,以下通过范例说明接口适配方式。

torch原生接口

scaled_dot_product_attention:

当前已适配NPU,训练场景直接调用即可调用到FA,相关规格限制请参考torch原生接口调用FA算子使用限制。若依然需要使用NPU接口,可以按以下方式适配替换,但输入规格依然要满足torch原生接口调用FA算子使用限制要求。
  • 不使能is_causal时,原调用接口代码:
    1
    2
    res = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
                                         dropout_p=0.0, is_causal=False)
    

    替换为:

    if attn_mask.dtype == torch.bool:
        attn_mask_npu = torch.logical_not(attention_mask.bool()).to(device) // atten_mask需要取反
    else:
        attn_mask_npu = attention_mask.bool().to(device)
    head_num = query.shape[1]
    res = torch_npu.npu_fusion_attention(
                           query, key, value, head_num, input_layout="BNSD", 
                           pse=None,
                           atten_mask=attn_mask_npu,
                           scale=1.0 / math.sqrt(query.shape[-1]),
                           pre_tockens=2147483647,
                           next_tockens=2147483647,
                           keep_prob=1
                       )[0]
  • 使能is_causal时,原调用接口代码:
    res = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None,
                                         dropout_p=0.0, is_causal=True)

    替换为:

    attn_mask_npu= torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(device)
    head_num = query.shape[1]
    res = torch_npu.npu_fusion_attention(
                           query, key, value, head_num, input_layout="BNSD", 
                           pse=None,
                           atten_mask=attn_mask_npu,
                           scale=1.0 / math.sqrt(query.shape[-1]),
                           pre_tockens=2147483647,
                           next_tockens=2147483647,
                           keep_prob=1,
                           sparse_mode=2
                       )[0]
flash-attention库
  • flash_attn_func

    不使能causal时,模型中替换代码:

    out= flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=False)

    替换为:

    head_num = q.shape[2]
    out = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=1.0,
                                         scale=scale)[0]

    使能causal时,模型中替换代码:

    out= flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=True)

    替换为:

    atten_mask_npu= torch.from_numpy(np.triu(np.ones(seqlen_q, seqlen_k), k=1)).bool().to(device)
    head_num = q.shape[2]
    out = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=1.0,
                                         scale=scale, atten_mask=attn_mask_npu)[0]
  • flash_attn_varlen_func

    不使能causal时,原调用接口代码:

    out = flash_attn_varlen_func(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p=0.0,
        softmax_scale=None,
        causal=False
    )

    替换为:

    head_num = q.shape[1]
    output = torch_npu.npu_fusion_attention(
                 q, k, v, head_num,
                 pse=None,
                 padding_mask=None,
                 atten_mask=None,
                 scale=1.0 / math.sqrt(q.shape[-1]),
                 keep_prob=1,
                 input_layout="TND",
                 actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
                 actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
                 pre_tockens=2147483647,
                 next_tockens=2147483647,
                 sparse_mode=0)[0]

    使能causal时,GPU调用接口代码:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    out = flash_attn_varlen_func(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p=0.0,
        softmax_scale=None,
        causal=True
    )
    

    替换为:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    atten_mask_npu= torch.from_numpy(np.triu(np.ones([max_seqlen_q, max_seqlen_k]), k=1)).bool().to(device)
    head_num = q.shape[1]
    output = torch_npu.npu_fusion_attention(
                 q, k, v, head_num,
                 pse=None,
                 padding_mask=None,
                 atten_mask=atten_mask_npu,
                 scale=1.0 / math.sqrt(q.shape[-1]),
                 keep_prob=1,
                 input_layout="TND",
                 actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
                 actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
                 pre_tockens=2147483647,
                 next_tockens=0)[0]
    

xFormers库

memory_efficient_attention

模型中替换代码:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
def __init__(self, attention_op: Optional[Callable] = None):
       self.attention_op = attention_op

def head_to_batch_dim(self, tensor):
       tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
       return tensor

...

query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
        )
替换为:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def head_to_batch_dim(self, tensor, out_dim=3):
       head_size = self.heads
       batch_size, seq_len, dim = tensor.shape
       tensor = torch_npu.npu_confusion_transpose(tensor, [0, 2, 1, 3], (batch_size, seq_len, head_size, dim // head_size), False)

       if out_dim == 3:
           tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)

       return tensor

...

query = attn.head_to_batch_dim(query, out_dim=4)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
hidden_states = torch_npu.npu_fusion_attention(
            query, key, value, heads, input_layout="BNSD",
            pse=None,
            atten_mask=attention_mask,
            scale=scale,
            pre_tockens=2147483647,
            next_tockens=2147483647,
            keep_prob=1.,
            sync=False
        )[0]

算子计算逻辑

算子的计算逻辑如下:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob):
	if pse is None:
		qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale)
	else:
		qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scale)
	if atten_mask is None:
		qk = qk
	else:
		qk = qk + atten_mask * torch.finfo(torch.float32).min
	softmax_res, softmax_max, softmax_sum = softmax(qk)
	if drop_mask:
		drop_res = softmax_res
	else:
		drop_res = softmax_res * drop_mask * (1.0 / (keep_prob))
	attention_out = torch.matmul(drop_res, v)
	return attention_out
图1 计算流程图

已支持模型典型case

如下case均包含fp16、bf16。

表3 case列表

id

BNSD

1

[1, 8, 4096, 128]

2

[4, 32, 2048, 64]

3

[8, 16, 512, 128]

4

[8, 16, 512, 128]

5

[8, 16, 512, 128]

6

[8, 16, 512, 64]

7

[8, 16, 512, 64]

8

[4, 4, 2048, 64]

torch原生接口调用FA算子使用限制

接口和参数说明:

参考开源说明可单击Link

使用限制:

  • 参数输入符合规格:
    • 输入query、key、value的N:batch size,当前只支持[N,head_num, S(L), E(Ev)]的排布方式,取值范围1~2K
    • 输入query的head num和key/value的head num必须成比例关系,即Nq/Nkv必须是非0整数,取值范围1~256
    • 输入query的L:Target sequence length,取值范围1~512K
    • 输入key、value的S:Source sequence length,取值范围1~512K
    • 输入query、key、value的E:Embedding dimension of the query and key,取值范围1~512
    • 输入value的Ev:Embedding dimension of the value,必须与E相等
    • 输入attn_mask:当前支持[N, 1, L, S]、[N, head_num, L, S]、[1, 1, L, S]、[L, S]共4种排布方式
    • 在使能is_causal计算时,attn_mask必须为None;不使能is_causal时,若attn_mask输入有效数据,输入数据类型必须是Bool类型
  • 与原接口除了规格限制之外存在差异点
    • NPU用DSA硬件实现,算法在DSA引擎固化存在跟GPU算法实现差异,导致dropout功能和GPU结果不一致
    • 当前接口支持输入query的head num和key/value的head num不等长,而原生PyTorch接口不支持
    • 输入query、key、value的数据类型bf16、fp16、fp32并且使能requires_grad时,执行FA算子
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词