文档
注册
评分
提单
论坛
小AI

torch_npu.npu_incre_flash_attention

功能描述

增量FA实现,实现对应公式:

接口原型

torch_npu.npu_incre_flash_attention(Tensor query, Tensor key, Tensor value, *, Tensor? padding_mask=None, Tensor? atten_mask=None, Tensor? pse_shift=None, SymInt[]? actual_seq_lengths=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? block_table=None, Tensor? dequant_scale1=None, Tensor? quant_scale1=None, Tensor? dequant_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, Tensor? kv_padding_size=None, int num_heads=1, float scale_value=1.0, str input_layout="BSH", int num_key_value_heads=0, int block_size=0, int inner_precise=1) -> Tensor

参数说明

  • query:Device侧的Tensor;数据类型支持FLOAT16、BFLOAT16、INT8,数据格式支持ND。Atlas 推理系列加速卡产品仅支持FLOAT16。
  • key:Device侧的Tensor;数据类型支持FLOAT16、BFLOAT16、INT8,数据格式支持ND。Atlas 推理系列加速卡产品仅支持FLOAT16。
  • value:Device侧的Tensor;数据类型支持FLOAT16、BFLOAT16、INT8,数据格式支持ND。Atlas 推理系列加速卡产品仅支持FLOAT16。
  • *:代表其之前的变量是位置相关,需要按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
  • padding_mask:预留参数,暂未使用,默认值为None。
  • atten_mask:Device侧的Tensor,取值为1代表该位不参与计算(不生效),为0代表该位参与计算,默认值为None,即全部参与计算;数据类型支持BOOL,INT8,UINT8,数据格式支持ND。
  • pse_shift:Device侧的Input Tensor;表示在attention结构内部的位置编码参数,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND。如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持None。
  • actual_seq_lengths:Host侧的Input数组,其shape为(B)或(1),形如[1, 2, 3],代表key、value中有效的S序列长度,默认值为None,即全部有效,类型为List int;数据类型为INT64,数据格式支持ND。
  • antiquant_scale:Device侧的Tensor,数据类型支持:FLOAT16、BFLOAT16。数据格式支持ND,表示量化因子,支持per-channel(list),由shape决定,BNSD场景下shape为(2, N, 1, D),BSH场景下shape为(2, H),BSND场景下shape为(2, N, D)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持FLOAT16。
  • antiquant_offset:Device侧的Input Tensor,数据类型支持:FLOAT16、BFLOAT16。数据格式支持ND,表示量化偏移,支持per-channel(list),由shape决定,BNSD场景下shape为(2, N, 1, D),BSH场景下shape为(2, H),BSND场景下shape为(2, N, D)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持FLOAT16。
  • block_table:Device侧的Tensor,数据类型支持:INT32。数据格式支持ND,表示PageAttention中KV存储使用的block映射表。 如不使用该功能时可不传或传入None。
  • dequant_scale1:Device侧的Tensor,数据类型支持:FLOAT32。数据格式支持ND,表示BMM1后面反量化的量化因子,支持per-tensor(scalar)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持None。
  • quant_scale1:Device侧的Tensor,数据类型支持:FLOAT32。数据格式支持ND,表示BMM2前面量化的量化因子,支持per-tensor(scalar)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持None。
  • dequant_scale2:Device侧的Tensor,数据类型支持:FLOAT32。数据格式支持ND,表示BMM2后面反量化的量化因子,支持per-tensor(scalar)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持None。
  • quant_scale2:Device侧的Tensor,数据类型支持:FLOAT32、BFLOAT16。数据格式支持ND,表示输出量化的量化因子,支持per-tensor(scalar)和per-channel(list)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持None。
  • quant_offset2:Device侧的Tensor,数据类型支持:FLOAT32、BFLOAT16。数据格式支持ND,表示输出量化的量化偏移,支持per-tensor(scalar)和per-channel(list)。 如不使用该功能时可不传或传入None。Atlas 推理系列加速卡产品仅支持None。
  • kv_padding_size:Device侧的Tensor,数据类型支持:INT64。数据格式支持ND,表示kv左padding场景使能时,最后一个有效token到S的距离。 如不使用该功能时可传入None。
  • num_heads:Host侧的attribute,代表query的头数,即query的N,默认值为1;数据类型为INT64。
  • scale_value:Host侧的attribute,代表缩放系数,用来约束梯度,其默认值为1.0,典型值为;数据类型为FLOAT32。
  • input_layout:Host侧的attribute,代表query、key、value的布局,根据输入的query、key、value的shape确定,三维Tensor是BSH,四维Tensor是BNSD或BSND,默认值为BSH,不支持其他值;数据类型为string。

    query、key、value数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。

  • num_key_value_heads:Host侧的attribute,代表key、value的头数,用于支持GQA(Grouped-Query Attention,分组查询注意力)场景,默认值为0,表示与query的头数相同,否则表示key、value的头数,需要能被query的头数(num_heads)整除;num_heads与num_key_value_heads的比值不能大于64。数据类型为INT64。
  • block_size (int64_t,计算输入):Host侧的int64_t,PageAttention中KV存储每个block中最大的token个数,默认为0,数据类型支持INT64。
  • inner_precise (int64_t,计算输入):Host侧的int64_t,代表高精度/高性能选择,0代表高精度,1代表高性能,默认值为1(高性能), 数据类型支持INT64。

输出说明

  • 共一个输出,为计算的最终结果atten_out,类型为Tensor,shape与query保持一致。
  • 非量化场景下,输出数据类型与query的数据类型保持一致。
  • 量化场景下,若传入quant_scale2,则输出数据类型为int8;若不传入quant_scale2,且query、key、value类型为int8,则输出数据类型为FLOAT16。

约束说明

  • query、key、value的维度必须保持一致,key、value的shape必须保持一致。
  • num_heads的值要等于query的N,key、value的N和num_key_value_heads值相等,并且num_heads是num_key_value_heads的倍数关系。
  • input_layout的值与query的shape相关,三维是BSH,四维是BNSD或BSND。
  • num_key_value_heads的值要等于key、value的N,需要能被query的头数(num_heads)整除。
  • query,key,value输入,功能使用限制如下:
    • Atlas A2 训练系列产品支持B轴小于等于65535,支持N轴小于等于256,支持S轴小于等于262144,支持D轴小于等于512。
    • Atlas 推理系列加速卡产品支持B轴小于等于256,支持N轴小于等于256,支持S轴小于等于65536,支持D轴小于等于512。page attention使能的场景下,仅支持S轴小于等于8192。
    • query、key、value输入均为INT8的场景暂不支持
  • INT8量化相关入参数量与输入、输出数据格式的综合限制:

    query、key、value输入为FLOAT16,输出为INT8的场景:入参quant_scale2必填,quant_offset2可选,不能传入dequant_scale1、quant_scale1、dequant_scale2(即为None)参数。

  • pse_shift功能使用限制如下:
    • pseShift数据类型需与query数据类型保持一致。
    • 仅支持D轴对齐,即D轴可以被16整除。
  • page attention场景:
    • 使能必要条件是blocktable存在且有效,同时key、value是按照blocktable中的索引在一片连续内存中排布,支持key、value dtype为FLOAT16/BFLOAT16/INT8。
    • 该场景下,blockSize是用户自定义的参数,该参数的取值会影响page attention的性能。key、value输入类型为FLOAT16/BFLOAT16时需要16对齐,key、value 输入类型为INT8时需要32对齐,推荐使用128。通常情况下,page attention可以提高吞吐量,但会带来性能上的下降。
    • 参数key、value各自对应tensor的shape所有维度相乘不能超过int32的表示范围。
    • 进行per-channel后量化时,不能同时使能page attention。
  • blockTable当前支持的maxBlockNumPerSeq最大为16k,超过16k会被拦截报错;如果遇到S超大导致maxBlockNumPerSeq超过16k,可以调大blockSize解决。
  • quantScale2、quantOffset2为一组参数,其中quantOffset2可选,传入该组参数后算子输出数据类型会推导为int8,若不期望int8输出,请勿传入该组参数。
  • kv左padding场景:
    • kvCache的搬运起点计算公式为:Smax - kv_padding_size- actual_seq_lengths。kvCache的搬运终点计算公式为:Smax - kv_padding_size。其中kvCache的搬运起点或终点小于0时,返回数据结果为全0。
    • kv_padding_size小于0时将被置为0。
    • 使能需要同时存在actual_seq_lengths参数,否则默认为kv右padding场景。

支持的PyTorch版本

  • PyTorch 2.2
  • PyTorch 2.1

支持的型号

  • Atlas A2 训练系列产品
  • Atlas 推理系列加速卡产品

调用示例

# 单算子调用方式
import torch
import torch_npu
import math

# 生成随机数据,并发送到npu
q = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
k = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
v = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
scale = 1/math.sqrt(128.0)

# 调用IFA算子
out = torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale)

# 执行上述代码的输出类似如下
tensor([[[ 0.3149, -0.2460,  0.7939,  ...,  0.5737, -0.4929, -0.1500]],
        [[ 0.8115,  1.3789,  0.6484,  ..., -0.9092, -0.6206, -0.7412]]],
       device='npu:0', dtype=torch.float16)


# 入图方式

import torch
import torch_npu
import math

import torchair as tng
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"

# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl

# 数据生成
q = torch.randn(2, 1, 40 * 128, dtype=torch.float16).npu()
k = torch.randn(2, 2048, 40 * 128, dtype=torch.float16).npu()
v = torch.randn(2, 2048, 40 * 128, dtype=torch.float16).npu()
atten = torch.randn(2, 1, 1, 2048, dtype=torch.float16).npu()
atten = torch.where(atten < 0, 1, 0).to(torch.bool).to(torch.float16)
scale_value = 1/math.sqrt(128.0)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self):
        return torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale_value, atten_mask=atten)
def MetaInfershape():
    with torch.no_grad():
        model = Model()
        model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
        graph_output = model()
    single_op = torch_npu.npu_incre_flash_attention(q, k, v, num_heads=40, input_layout="BSH", scale_value=scale_value, atten_mask=atten)
    print("single op output with mask:", single_op, single_op.shape)
    print("graph output with mask:", graph_output, graph_output.shape)
if __name__ == "__main__":
    MetaInfershape()

# 执行上述代码的输出类似如下
single op output with mask: tensor([[[ 0.2488, -0.6572,  1.0928,  ...,  0.1694,  0.1142, -2.2266]],
        [[-0.9595, -0.9609, -0.6602,  ...,  0.7959,  1.7920,  0.0783]]],
       device='npu:0', dtype=torch.float16) torch.Size([2, 1, 5120])
graph output with mask: tensor([[[ 0.2488, -0.6572,  1.0928,  ...,  0.1694,  0.1142, -2.2266]],
        [[-0.9595, -0.9609, -0.6602,  ...,  0.7959,  1.7920,  0.0783]]],
       device='npu:0', dtype=torch.float16) torch.Size([2, 1, 5120])
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词