下载
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助
昇腾小AI

torch_npu.npu_rotary_mul

接口原型

torch_npu.npu_rotary_mul(Tensor x, Tensor r1, Tensor r2): -> Tensor

功能描述

实现RotaryEmbedding旋转位置编码。支持FakeTensor模式。
x1, x2 = torch.chunk(x, 2, -1)
x_new = torch.cat((-x2, x1), dim=-1)
output = r1 * x + r2 * x_new

参数说明

  • Tensor x:q, k,shape要求输入为4维,一般为[B, N, S, D]或[B, S, N, D]或[S, B, N, D],dtype支持bf16,fp16,fp32。
  • Tensor r1: cos值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D],dtype支持bf16,fp16,fp32。
  • Tensor r2: sin值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D],dtype支持bf16,fp16,fp32。

输出说明

输出为Tensor,shape和dtype同输入Tensor x。

约束说明

shape要求输入为4维,其中B维度和N维度数值需小于等于1000,D维度数值为128。

广播场景下,广播轴的总数据量不能超过1024。

调用示例

import torch
import torch_npu
x = torch.rand(2, 2, 5, 128).npu()
r1 = torch.rand(1, 2, 1, 128).npu()
r2 = torch.rand(1, 2, 1, 128).npu()
out = torch_npu.npu_rotary_mul(x, r1, r2)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词