ScaledMaskedSoftmax & ScaledMaskedSoftmaxGrad
算子基础信息
算子名称 |
ScaledMaskedSoftmax & ScaledMaskedSoftmaxGrad |
---|---|
torch_npu api接口 |
torch_npu.npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask) |
支持的torch_npu版本 |
2.1.0, 2.3.1, 2.4.0 |
支持的芯片类型 |
Atlas 训练系列产品,Atlas A2 训练系列产品 |
支持的数据类型 |
float16, bfloat16, float |
算子IR及torch_npu接口参数
算子IR:
REG_OP(ScaledMaskedSoftmax) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_BF16})) .OPTIONAL_INPUT(mask, TensorType({DT_BOOL, DT_UINT1})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_BF16})) .ATTR(scale, Float, 1.0) .ATTR(fixed_triu_mask, Bool, false) .OP_END_FACTORY_REG(ScaledMaskedSoftmax)
torch_npu接口:
torch_npu.npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask)
参数说明:
名称 |
类型 |
Dtype |
Shape要求 |
默认值 |
---|---|---|---|---|
x |
输入 |
bfloat16, float16, float32 |
必须为4维, 且后两维都需要在[32, 4096]范围内,且能被32整除 |
- |
mask |
输入 |
bool |
必须为4维,且后两维和x一致,且能被广播成x的shape |
- |
scale |
属性 |
float |
对输入x缩放 |
1.0 |
fixed_triu_mask |
属性 |
bool |
是否生成可用的上三角bool掩码 |
False |
模型中替换代码及算子计算逻辑
模型中替换代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 |
if self.input_in_float16 and self.softmax_in_fp32: input = input.float() if self.scale is not None: input = input * self.scale mask_output = self.mask_func(input, mask) if mask is not None else input probs = torch.nn.Softmax(dim=-1)(mask_output) if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_fp16: probs = probs.half() else: probs = probs.bfloat16() |
替换为:
probs = torch_npu.npu_scaled_masked_softmax(input , mask, self.scale, fixed_triu_mask )
算子的计算逻辑如下:
1 2 3 |
if fixed_triu_mask: mask = torch.triu(mask.shape, diagonal=1) y = torch.softmax((x * scale).masked_fill(mask, -inf), dim=-1) |
图1 计算流程图

算子替换的模型中小算子
使用限制
- 输入x的shape限制如下:
- 必须为4维
- 第三维的取值需要在[32, 4096]范围内
- 第四维的取值需要在[32, 4096]范围内
- 第三维的取值需要能被32整除
- 第四维的取值需要能被32整除
- 输入mask的shape限制如下:
- 必须为4维
- 后两维必须与x的后两维相等
- 前两维需要能被广播成x的前两维
已支持模型典型case
id |
x |
mask |
---|---|---|
1 |
[1, 8, 4096, 4096] |
[1, 1, 4096, 4096] |
2 |
[4, 32, 2048, 2048] |
[4, 1, 2048, 2048] |
3 |
[8, 16, 512, 2048] |
[8, 16, 512, 2048] |
4 |
[8, 16, 512, 1536] |
[8, 16, 512, 1536] |
5 |
[8, 16, 512, 1024] |
[8, 16, 512, 1024] |
6 |
[8, 16, 512, 512] |
[8, 16 512, 512] |
7 |
[8, 16, 512, 256] |
[8, 16, 512, 256] |
8 |
[4, 4, 2048, 2048] |
[4, 4, 2048, 2048] |
父主题: 融合算子替换