文档
注册

torch_npu.npu_scaled_masked_softmax

接口原型

torch_npu.npu_scaled_masked_softmax(x, mask, scale=1.0, fixed_triu_mask=False) -> Tensor

功能描述

计算输入张量x缩放并按照mask遮蔽后的Softmax结果。

参数说明

  • x(Tensor)- 输入的logits。支持数据类型:float16、float32、bfloat16。支持格式:[ND,FRACTAL_NZ]。
  • mask(Tensor)- 输入的掩码。支持数据类型:bool。支持格式:[ND,FRACTAL_NZ]。
  • scale(float,默认值为1.0)- x的缩放系数。
  • fixed_triu_mask(bool,默认值为False)- 为预留参数,功能未完成,当前只支持False。该功能完成后可支持自动生成上三角bool掩码。

约束说明

  • 当前输入x的shape,只支持转为[NCHW]格式后,H和W轴长度大于等于32、小于等于4096、且能被32整除的场景。
  • 输入mask的shape,必须能被broadcast成x的shape。

调用示例

>>> import torch
>>> import torch_npu
>>> 
>>> x = torch.rand(4, 4, 2048, 2048).npu()
>>> mask = torch.rand(1, 1, 2048, 2048).npu()
>>> out = torch_npu.npu_scaled_masked_softmax(x, mask, 1.0, False)
>>> out.shape
torch.size([4, 4, 2048, 2048])
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词