torch_npu.npu_scaled_masked_softmax
接口原型
torch_npu.npu_scaled_masked_softmax(x, mask, scale=1.0, fixed_triu_mask=False) -> Tensor
功能描述
计算输入张量x缩放并按照mask遮蔽后的Softmax结果。
参数说明
- x(Tensor)- 输入的logits。支持数据类型:float16、float32、bfloat16。支持格式:[ND,FRACTAL_NZ]。
- mask(Tensor)- 输入的掩码。支持数据类型:bool。支持格式:[ND,FRACTAL_NZ]。
- scale(float,默认值为1.0)- x的缩放系数。
- fixed_triu_mask(bool,默认值为False)- 为预留参数,功能未完成,当前只支持False。该功能完成后可支持自动生成上三角bool掩码。
输出说明
一个Tensor类型的输出,输入x经过mask后在最后一维的Softmax结果,输出shape与x一致。支持数据类型:float16、float32、bfloat16。支持格式:[ND,FRACTAL_NZ]。
约束说明
- 当前输入x的shape,只支持转为[NCHW]格式后,H和W轴长度大于等于32、小于等于4096、且能被32整除的场景。
- 输入mask的shape,必须能被broadcast成x的shape。
支持的型号
Atlas 训练系列产品 Atlas A2 训练系列产品 Atlas A3 训练系列产品 Atlas 推理系列产品
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 | >>> import torch >>> import torch_npu >>> >>> shape = [4, 4, 2048, 2048] >>> x = torch.rand(shape).npu() >>> mask = torch.zeros_like(x).bool() >>> scale = 1.0 >>> fixed_triu_mask = False >>> >>> output = torch_npu.npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask) >>> output.shape torch.size([4, 4, 2048, 2048]) |
父主题: torch_npu