下载
中文
注册

torch_npu.npu_scaled_masked_softmax

接口原型

torch_npu.npu_scaled_masked_softmax(x, mask, scale=1.0, fixed_triu_mask=False) -> Tensor

功能描述

计算输入张量x缩放并按照mask遮蔽后的Softmax结果。

参数说明

  • x(Tensor)- 输入的logits。支持数据类型:float16、float32、bfloat16。支持格式:[ND,FRACTAL_NZ]。
  • mask(Tensor)- 输入的掩码。支持数据类型:bool。支持格式:[ND,FRACTAL_NZ]。
  • scale(float,默认值为1.0)- x的缩放系数。
  • fixed_triu_mask(bool,默认值为False)- 为预留参数,功能未完成,当前只支持False。该功能完成后可支持自动生成上三角bool掩码。

输出说明

一个Tensor类型的输出,输入x经过mask后在最后一维的Softmax结果,输出shape与x一致。支持数据类型:float16、float32、bfloat16。支持格式:[ND,FRACTAL_NZ]。

约束说明

  • 当前输入x的shape,只支持转为[NCHW]格式后,H和W轴长度大于等于32、小于等于4096、且能被32整除的场景。
  • 输入mask的shape,必须能被broadcast成x的shape。

支持的型号

  • Atlas 训练系列产品
  • Atlas A2 训练系列产品
  • Atlas A3 训练系列产品
  • Atlas 推理系列产品

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
>>> import torch
>>> import torch_npu
>>> 
>>> shape = [4, 4, 2048, 2048]
>>> x = torch.rand(shape).npu()
>>> mask = torch.zeros_like(x).bool()
>>> scale = 1.0
>>> fixed_triu_mask = False
>>> 
>>> output = torch_npu.npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask)
>>> output.shape
torch.size([4, 4, 2048, 2048])