torch_npu.npu_rotary_mul
接口原型
torch_npu.npu_rotary_mul(Tensor x, Tensor r1, Tensor r2): -> Tensor
在模型训练场景中,正向算子的输入Tensor x将被保留以供反向计算时使用。在r1,r2不需要计算反向梯度场景下(requires_grad=False),使用该接口相较融合前小算子使用的设备内存占用会有所增加。
功能描述
实现RotaryEmbedding旋转位置编码。支持FakeTensor模式。
x1, x2 = torch.chunk(x, 2, -1) x_new = torch.cat((-x2, x1), dim=-1) output = r1 * x + r2 * x_new
参数说明
- Tensor x:q, k,shape要求输入为4维,一般为[B, N, S, D]或[B, S, N, D]或[S, B, N, D],dtype支持bf16,fp16,fp32。
- Tensor r1: cos值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D],dtype支持bf16,fp16,fp32。
- Tensor r2: sin值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D],dtype支持bf16,fp16,fp32。
输出说明
输出为Tensor,shape和dtype同输入Tensor x。
约束说明
shape要求输入为4维,其中B维度和N维度数值需小于等于1000,D维度数值为128。
广播场景下,广播轴的总数据量不能超过1024。
当输入layout为BNSD时,不支持B*N > 8*S的场景计算。此时建议模型中使用小算子拼接计算。
支持的型号
Atlas 训练系列产品 Atlas A2 训练系列产品 Atlas A3 训练系列产品 Atlas 推理系列产品
调用示例
1 2 3 4 5 6 | import torch import torch_npu x = torch.rand(2, 2, 5, 128).npu() r1 = torch.rand(1, 2, 1, 128).npu() r2 = torch.rand(1, 2, 1, 128).npu() out = torch_npu.npu_rotary_mul(x, r1, r2) |
父主题: torch_npu