class MultiheadAttention(nn.Module):
Multi-headed attention.
from torch_npu.contrib.module import MultiheadAttention model = MultiheadAttention(embed_dim=1024,num_heads=16,dropout=0.1,kdim=1024,vdim=1024,self_attention=True,encoder_decoder_attention=True)
>>> model = MultiheadAttention(embed_dim=1024,num_heads=16,dropout=0.1,kdim=1024,vdim=1024,self_attention=True,encoder_decoder_attention=True) >>> _, query = create_common_tensor([np.float16, FORMAT_NZ, (1024,1024)], -1, 1) >>> _, key = create_common_tensor([np.float16, FORMAT_NZ, (1024, 1024)], -1, 1) >>> _, value = create_common_tensor([np.float16, FORMAT_NZ, (1024, 1024)], -1, 1) >>> _, key_padding_mask = create_common_tensor([np.float16, FORMAT_NZ, (16,16,64,64)], -65504, 65504) >>> bsz = 16 >>> tgt_len = 64 >>> s_len=64 >>> model = model.to("npu") >>> output = model(query, key, value, bsz, tgt_len, s_len, key_padding_mask)