torch_npu.npu_scaled_masked_softmax

接口原型

torch_npu.npu_scaled_masked_softmax(x, mask, scale=1.0, fixed_triu_mask=False) -> Tensor

功能描述

计算输入张量x缩放并按照mask遮蔽后的Softmax结果。

参数说明

输出说明

一个Tensor类型的输出,输入x经过mask后在最后一维的Softmax结果,输出shape与x一致。支持数据类型:float16、float32、bfloat16。支持格式:[ND,FRACTAL_NZ]。

约束说明

支持的型号

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
>>> import torch
>>> import torch_npu
>>> 
>>> shape = [4, 4, 2048, 2048]
>>> x = torch.rand(shape).npu()
>>> mask = torch.zeros_like(x).bool()
>>> scale = 1.0
>>> fixed_triu_mask = False
>>> 
>>> output = torch_npu.npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask)
>>> output.shape
torch.size([4, 4, 2048, 2048])