RotaryMul & RotaryMulGrad

算子基础信息

表1 算子信息

算子名称

RotaryMul & RotaryMulGrad

torch_npu api接口

torch_npu.npu_rotary_mul(x, r1, r2)

支持的torch_npu版本

2.1.0, 2.3.1, 2.4.0

支持的芯片类型

Atlas 训练系列产品 Atlas A2 训练系列产品 Atlas A3 训练系列产品

支持的数据类型

float16, bfloat16, float

算子IR及torch_npu接口参数

算子IR:
1
2
3
4
5
6
REG_OP(RotaryMul)
    .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_BFLOAT16}))
    .INPUT(r1, TensorType({DT_FLOAT16, DT_FLOAT, DT_BFLOAT16}))
    .INPUT(r2, TensorType({DT_FLOAT16, DT_FLOAT, DT_BFLOAT16}))
    .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_BFLOAT16}))
    .OP_END_FACTORY_REG(RotaryMul)

torch_npu接口:
torch_npu.npu_rotary_mul(x, r1, r2)

参数说明:
  • x:q,k,shape要求输入为4维,一般为[B, N, S, D]或[B, S, N, D]或[S, B, N, D]。
  • r1:cos值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。
  • r2:sin值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。

模型中替换代码及算子计算逻辑

模型中替换代码:

算子的计算逻辑如下:
1
2
3
x1, x2 = torch.chunk(x, 2, -1)
x_new = torch.cat((-x2, x1), dim=-1)
output = r1 * x + r2 * x_new
图3 计算流程图

算子替换的模型中小算子

使用限制

目前算子仅支持r1、r2需要broadcast为x的shape的情形,且算子shape中最后一维D必须是128的倍数。

已支持模型典型case