torch_npu.npu_incre_flash_attention

功能描述

增量FA实现,实现对应公式:

接口原型

torch_npu.npu_incre_flash_attention(Tensor query, Tensor key, Tensor value, *, Tensor? padding_mask=None, Tensor? atten_mask=None, Tensor? pse_shift=None, SymInt[]? actual_seq_lengths=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? block_table=None, Tensor? dequant_scale1=None, Tensor? quant_scale1=None, Tensor? dequant_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, Tensor? kv_padding_size=None, int num_heads=1, float scale_value=1.0, str input_layout="BSH", int num_key_value_heads=0, int block_size=0, int inner_precise=1) -> Tensor

参数说明

输出说明

约束说明

支持的PyTorch版本

支持的型号

调用示例

# 单算子调用方式
import torch
import torch_npu
import math

# 生成随机数据,并发送到npu
q = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
k = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
v = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
scale = 1/math.sqrt(128.0)

# 调用IFA算子
out = torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale)

# 执行上述代码的输出类似如下
tensor([[[ 0.3149, -0.2460,  0.7939,  ...,  0.5737, -0.4929, -0.1500]],
        [[ 0.8115,  1.3789,  0.6484,  ..., -0.9092, -0.6206, -0.7412]]],
       device='npu:0', dtype=torch.float16)


# 入图方式

import torch
import torch_npu
import math

import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"

# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl

# 数据生成
q = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
k = torch.randn(2, 2048, 40 * 128, dtype=torch.float16).npu()
v = torch.randn(2, 2048, 40 * 128, dtype=torch.float16).npu()
atten = torch.randn(2, 1, 1, 2048, dtype=torch.float16).npu()
atten = torch.where(atten < 0, 1, 0).to(torch.bool).to(torch.float16)
scale_value = 1/math.sqrt(128.0)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self):
        return torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale_value, atten_mask=atten)
def MetaInfershape():
    with torch.no_grad():
        model = Model()
        model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
        graph_output = model()
    single_op = torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale_value, atten_mask=atten)
    print("single op output with mask:", single_op, single_op.shape)
    print("graph output with mask:", graph_output, graph_output.shape)
if __name__ == "__main__":
    MetaInfershape()

# 执行上述代码的输出类似如下
single op output with mask: tensor([[[ 0.2488, -0.6572,  1.0928,  ...,  0.1694,  0.1142, -2.2266]],
        [[-0.9595, -0.9609, -0.6602,  ...,  0.7959,  1.7920,  0.0783]]],
       device='npu:0', dtype=torch.float16) torch.Size([2, 1, 5120])
graph output with mask: tensor([[[ 0.2488, -0.6572,  1.0928,  ...,  0.1694,  0.1142, -2.2266]],
        [[-0.9595, -0.9609, -0.6602,  ...,  0.7959,  1.7920,  0.0783]]],
       device='npu:0', dtype=torch.float16) torch.Size([2, 1, 5120])