torch_npu.contrib.module.MultiheadAttention(nn.Module):
Multi-headed attention。
1 2 3 4 5 6 7 8 9 10 11 12 13 | >>> from torch_npu.testing.common_utils import create_common_tensor >>> from torch_npu.contrib.module import MultiheadAttention >>> import numpy as np >>> model = MultiheadAttention(embed_dim=1024,num_heads=16,dropout=0.1,kdim=1024,vdim=1024,self_attention=True,encoder_decoder_attention=True) >>> _, query = create_common_tensor([np.float16, 29, (1024,1024)], -1, 1) >>> _, key = create_common_tensor([np.float16, 29, (1024,1024)], -1, 1) >>> _, value = create_common_tensor([np.float16, 29, (1024,1024)], -1, 1) >>> _, key_padding_mask = create_common_tensor([np.float16, 29, (1024,1024)], -1, 1) >>> bsz = 16 >>> tgt_len = 64 >>> s_len=64 >>> model = model.to("npu") >>> output = model(query, key, value, bsz, tgt_len, s_len, key_padding_mask) |