下载
中文
注册

beta)torch_npu.npu_multi_head_attention

接口原型

torch_npu.npu_multi_head_attention(Tensor query, Tensor key, Tensor value, Tensor query_weight, Tensor key_weight, Tensor value_weight, Tensor attn_mask, Tensor out_proj_weight, Tensor query_bias, Tensor key_bia, Tensor value_bias, Tensor out_proj_bias, Tensor dropout_mask, int attn_head_num, int attn_dim_per_head, int src_len, int tgt_len, float dropout_prob, bool softmax_use_float) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)

功能描述

实现Transformer模块中的MultiHeadAttention计算逻辑。

参数说明

  • query: Tensor类型,仅支持float16
  • key: Tensor类型,仅支持float16
  • value: Tensor类型,仅支持float16
  • query_weight: Tensor类型,仅支持float16
  • key_weight: Tensor类型,仅支持float16
  • value_weight: Tensor类型,仅支持float16
  • attn_mask: Tensor类型,仅支持float16
  • out_proj_weight: Tensor类型,仅支持float16
  • query_bias: Tensor类型,仅支持float16
  • key_bias: Tensor类型,仅支持float16
  • value_bias: Tensor类型,仅支持float16
  • out_proj _bias: Tensor类型,仅支持float16
  • dropout_mask_input: Tensor类型,仅支持float16
  • attn_head_num: Attention Head numbers,Int型
  • attn_dim_per_head:Attention dim of a Head,Int型
  • src_len:source length,Int型
  • tgt_len:target length,Int型
  • keep_prob:dropout keep probability,Float型
  • softmax_use_float:SoftMax Use Float32 to keep precision,Bool型

输出说明

  • y: Tensor类型,仅支持float16
  • dropout_mask: Tensor类型,仅支持float16
  • query_res: Tensor类型,仅支持float16
  • key_res: Tensor类型,仅支持float16
  • value_res: Tensor类型,仅支持float16
  • attn_scores: Tensor类型,仅支持float16
  • attn_res: Tensor类型,仅支持float16
  • context: Tensor类型,仅支持float16

约束说明

Attr attn_head_num:需16对齐

Attr attn_dim_per_head:需16对齐

Attr src_len:需16对齐

tgt_len:需16对齐

支持的型号

  • Atlas 训练系列产品
  • Atlas A2 训练系列产品
  • Atlas A3 训练系列产品
  • Atlas 推理系列产品

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch_npu
import numpy as np
 
batch = 8
attn_head_num = 16
attn_dim_per_head = 64
src_len = 64
tgt_len = 64
dropout_prob = 0.0
softmax_use_float = True
 
weight_col = attn_head_num * attn_dim_per_head
query = torch.from_numpy(np.random.uniform(-1, 1, (batch * tgt_len, weight_col)).astype("float16")).npu()
key = torch.from_numpy(np.random.uniform(-1, 1, (batch * src_len, weight_col)).astype("float16")).npu()
value = torch.from_numpy(np.random.uniform(-1, 1, (batch * tgt_len, weight_col)).astype("float16")).npu()
query_weight = torch.from_numpy(np.random.uniform(-1, 1, (weight_col, weight_col)).astype("float16")).npu()
key_weight = torch.from_numpy(np.random.uniform(-1, 1, (weight_col, weight_col)).astype("float16")).npu()
value_weight = torch.from_numpy(np.random.uniform(-1, 1, (weight_col, weight_col)).astype("float16")).npu()
out_proj_weight = torch.from_numpy(np.random.uniform(-1, 1, (weight_col, weight_col)).astype("float16")).npu()
attn_mask = torch.from_numpy(np.random.uniform(-1, 1, (batch, attn_head_num, tgt_len, src_len)).astype("float16")).npu()
query_bias = torch.from_numpy(np.random.uniform(-1, 1, (weight_col,)).astype("float16")).npu()
key_bias = torch.from_numpy(np.random.uniform(-1, 1, (weight_col,)).astype("float16")).npu()
value_bias = torch.from_numpy(np.random.uniform(-1, 1, (weight_col,)).astype("float16")).npu()
out_proj_bias = torch.from_numpy(np.random.uniform(-1, 1, (weight_col,)).astype("float16")).npu()
dropout_mask_input = torch.from_numpy(np.random.uniform(-1, 1, (weight_col,)).astype("float16")).npu()
            
npu_result, npu_dropout_mask, npu_query_res, npu_key_res, npu_value_res, npu_attn_scores, npu_attn_res, npu_context = torch_npu.npu_multi_head_attention (query, key, value, query_weight, key_weight, value_weight, attn_mask, out_proj_weight, query_bias, key_bias, value_bias, out_proj_bias,  dropout_mask_input, attn_head_num, attn_dim_per_head, src_len, tgt_len, dropout_prob, softmax_use_float)
 
print(npu_result)
 
 
 
tensor([[ 623.5000,   75.5000,  307.0000,  ...,   25.3125, -418.7500,
           35.9688],
        [-254.2500, -165.6250,  176.2500,  ...,   87.3750,   78.0000,
           65.2500],
        [ 233.2500,  207.3750,  324.7500,  ...,   38.6250, -264.2500,
          153.7500],
        ...,
        [-110.2500,  -92.5000,  -74.0625,  ...,  -68.0625,  195.6250,
         -157.6250],
        [ 300.0000, -184.6250,   -6.0039,  ...,  -15.7969, -299.0000,
          -93.1875],
        [  -2.5996,   36.8750,  100.0625,  ...,  112.7500,  202.0000,
         -166.3750]], device='npu:0', dtype=torch.float16)