torch_npu.npu_scaled_masked_softmax

API接口

torch_npu.npu_scaled_masked_softmax(x, mask, scale=1.0, fixed_triu_mask=False) -> Tensor

功能描述

计算输入张量x缩放并按照mask遮蔽后的Softmax结果。

参数说明

约束说明

示例

>>> import torch
>>> import torch_npu
>>> 
>>> x = torch.rand(4, 4, 2048, 2048).npu()
>>> mask = torch.rand(1, 1, 2048, 2048).npu()
>>> out = torch_npu.npu_scaled_masked_softmax(x, mask, 1.0, False)
>>> out.shape
torch.size([4, 4, 2048, 2048])