增量FA实现,实现对应公式:
torch_npu.npu_incre_flash_attention(Tensor query, Tensor key, Tensor value, *, Tensor? padding_mask=None, Tensor? atten_mask=None, Tensor? pse_shift=None, SymInt[]? actual_seq_lengths=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? block_table=None, Tensor? dequant_scale1=None, Tensor? quant_scale1=None, Tensor? dequant_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, Tensor? kv_padding_size=None, int num_heads=1, float scale_value=1.0, str input_layout="BSH", int num_key_value_heads=0, int block_size=0, int inner_precise=1) -> Tensor
query、key、value数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。
query、key、value输入为FLOAT16,输出为INT8的场景:入参quant_scale2必填,quant_offset2可选,不能传入dequant_scale1、quant_scale1、dequant_scale2(即为None)参数。
# 单算子调用方式 import torch import torch_npu import math # 生成随机数据,并发送到npu q = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu() k = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu() v = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu() scale = 1/math.sqrt(128.0) # 调用IFA算子 out = torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale) # 执行上述代码的输出类似如下 tensor([[[ 0.3149, -0.2460, 0.7939, ..., 0.5737, -0.4929, -0.1500]], [[ 0.8115, 1.3789, 0.6484, ..., -0.9092, -0.6206, -0.7412]]], device='npu:0', dtype=torch.float16) # 入图方式 import torch import torch_npu import math import torchair as tng from torchair.ge_concrete_graph import ge_apis as ge from torchair.configs.compiler_config import CompilerConfig import torch._dynamo TORCHDYNAMO_VERBOSE=1 TORCH_LOGS="+dynamo" # 支持入图的打印宏 import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) from torch.library import Library, impl # 数据生成 q = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu() k = torch.randn(2, 2048, 40 * 128, dtype=torch.float16).npu() v = torch.randn(2, 2048, 40 * 128, dtype=torch.float16).npu() atten = torch.randn(2, 1, 1, 2048, dtype=torch.float16).npu() atten = torch.where(atten < 0, 1, 0).to(torch.bool).to(torch.float16) scale_value = 1/math.sqrt(128.0) class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self): return torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale_value, atten_mask=atten) def MetaInfershape(): with torch.no_grad(): model = Model() model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True) graph_output = model() single_op = torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale_value, atten_mask=atten) print("single op output with mask:", single_op, single_op.shape) print("graph output with mask:", graph_output, graph_output.shape) if __name__ == "__main__": MetaInfershape() # 执行上述代码的输出类似如下 single op output with mask: tensor([[[ 0.2488, -0.6572, 1.0928, ..., 0.1694, 0.1142, -2.2266]], [[-0.9595, -0.9609, -0.6602, ..., 0.7959, 1.7920, 0.0783]]], device='npu:0', dtype=torch.float16) torch.Size([2, 1, 5120]) graph output with mask: tensor([[[ 0.2488, -0.6572, 1.0928, ..., 0.1694, 0.1142, -2.2266]], [[-0.9595, -0.9609, -0.6602, ..., 0.7959, 1.7920, 0.0783]]], device='npu:0', dtype=torch.float16) torch.Size([2, 1, 5120])