torch_npu.npu_fused_infer_attention_score

功能描述

接口原型

npu_fused_infer_attention_score(Tensor query, Tensor key, Tensor value, *, Tensor? pse_shift=None, Tensor? atten_mask=None, SymInt[]? actual_seq_lengths=None, SymInt[]? actual_seq_lengths_kv=None, Tensor? dequant_scale1=None, Tensor? quant_scale1=None, Tensor? dequant_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? block_table=None, Tensor? query_padding_size=None, Tensor? kv_padding_size=None, Tensor? key_antiquant_scale=None, Tensor? key_antiquant_offset=None, Tensor? value_antiquant_scale=None, Tensor? value_antiquant_offset=None,  Tensor? key_shared_prefix=None, Tensor? value_shared_prefix=None, Tensor? actual_shared_prefix_len=None, int num_heads=1, float scale_value=1.0, int pre_tokens=2147483647, int next_tokens=2147483647, str input_layout="BSH", int num_key_value_heads=0, int sparse_mode=0, int inner_precise=0, int block_size=0, int antiquant_mode=0, bool softmax_lse_flag=False, int key_antiquant_mode=0, int value_antiquant_mode=0) -> (Tensor, Tensor)

参数说明

输出说明

约束说明

支持的型号

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

调用示例