下载
中文
注册

beta)torch_npu.contrib.npu_fused_attention

接口原型

torch_npu.contrib.npu_fused_attention(hidden_states, attention_mask, query_kernel, key_kernel, value_kernel, query_bias, key_bias, value_bias, scale=1, keep_prob=0)

功能描述

bert自我注意的融合实现。

参数说明

  • hidden_states (Tensor):the hidden_states of the last layer
  • attention_mask (Tensor):attention mask
  • query_kernel (Tensor): the weight of query
  • key_kernel (Tensor):the weight of key
  • value_kernel (Tensor): the weight of value
  • query_bias (Tensor):the weight of query
  • key_bias (Tensor):the weight of key
  • value_bias (Tensor): the weight of value
  • scale=1 (double):scaling coefficient of the calculated score
  • keep_prob=0: probability of "keeping items", should equal to 1 - drop date.

输出说明

torch.Tensor:The result of self attention

支持的型号

  • Atlas 训练系列产品
  • Atlas A2 训练系列产品
  • Atlas 推理系列产品