torch_npu.npu_scaled_masked_softmax
接口原型
torch_npu.npu_scaled_masked_softmax(x, mask, scale=1.0, fixed_triu_mask=False) -> Tensor
功能描述
计算输入张量x缩放并按照mask遮蔽后的Softmax结果。
参数说明
- x(Tensor)- 输入的logits。支持数据类型:float16、float32、bfloat16。支持格式:[ND,FRACTAL_NZ]。
- mask(Tensor)- 输入的掩码。支持数据类型:bool。支持格式:[ND,FRACTAL_NZ]。
- scale(float,默认值为1.0)- x的缩放系数。
- fixed_triu_mask(bool,默认值为False)- 为预留参数,功能未完成,当前只支持False。该功能完成后可支持自动生成上三角bool掩码。
约束说明
- 当前输入x的shape,只支持转为[NCHW]格式后,H和W轴长度大于等于32、小于等于4096、且能被32整除的场景。
- 输入mask的shape,必须能被broadcast成x的shape。
调用示例
>>> import torch >>> import torch_npu >>> >>> x = torch.rand(4, 4, 2048, 2048).npu() >>> mask = torch.rand(1, 1, 2048, 2048).npu() >>> out = torch_npu.npu_scaled_masked_softmax(x, mask, 1.0, False) >>> out.shape torch.size([4, 4, 2048, 2048])
父主题: torch_npu