文档
注册

torch_npu.npu_rotary_mul

接口原型

torch_npu.npu_rotary_mul(Tensor x, Tensor r1, Tensor r2): -> Tensor

功能描述

实现RotaryEmbedding旋转位置编码。支持FakeTensor模式。
x1, x2 = torch.chunk(x, 2, -1)
x_new = torch.cat((-x2, x1), dim=-1)
output = r1 * x + r2 * x_new

参数说明

  • Tensor x:q, k,shape要求输入为4维,一般为[B, N, S, D]或[B, S, N, D]或[S, B, N, D],dtype支持bf16,fp16,fp32。
  • Tensor r1: cos值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D],dtype支持bf16,fp16,fp32。
  • Tensor r2: sin值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D],dtype支持bf16,fp16,fp32。

约束说明

shape要求输入为4维,其中B维度和N维度数值需小于等于1000,D维度数值为128。

调用示例

    >>>x = torch.rand(2, 2, 5, 128).npu()
    >>>tensor([[[[0.8594, 0.4914, 0.9075,  ..., 0.2126, 0.6520, 0.2206],
          [0.5515, 0.3353, 0.6568,  ..., 0.3686, 0.1457, 0.8528],
          [0.0504, 0.2687, 0.4036,  ..., 0.3032, 0.8262, 0.6302],
          [0.0537, 0.5141, 0.7016,  ..., 0.4948, 0.9778, 0.8535],
          [0.3602, 0.7874, 0.9913,  ..., 0.1474, 0.3422, 0.6830]],

        [[0.4641, 0.6254, 0.7415,  ..., 0.1834, 0.1067, 0.7171],
          [0.8084, 0.7570, 0.4728,  ..., 0.4603, 0.4991, 0.1723],
          [0.0483, 0.6931, 0.0935,  ..., 0.7522, 0.0054, 0.1736],
          [0.6196, 0.1028, 0.7076,  ..., 0.2745, 0.9943, 0.6971],
          [0.3267, 0.3748, 0.1232,  ..., 0.0507, 0.4302, 0.6249]]],
 
 
         [[[0.2783, 0.8262, 0.6014,  ..., 0.8040, 0.7986, 0.2831],
          [0.6035, 0.2955, 0.7711,  ..., 0.7464, 0.3739, 0.6637],
          [0.6282, 0.7243, 0.5445,  ..., 0.3755, 0.0533, 0.9468],
          [0.5179, 0.3967, 0.6558,  ..., 0.0267, 0.5549, 0.9707],
          [0.4388, 0.7458, 0.2065,  ..., 0.6080, 0.4242, 0.8879]],
 
         [[0.3428, 0.6976, 0.0970,  ..., 0.9552, 0.3663, 0.2139],
          [0.2019, 0.2452, 0.1142,  ..., 0.3651, 0.6993, 0.5257],
          [0.9636, 0.1691, 0.4807,  ..., 0.9137, 0.3510, 0.0905],
          [0.0177, 0.9496, 0.1560,  ..., 0.7437, 0.9043, 0.0131],
          [0.9699, 0.5352, 0.9763,  ..., 0.1850, 0.2056, 0.0368]]]],
       device='npu:0')
    >>>r1 = torch.rand(1, 2, 1, 128).npu()
       tensor([[[[0.8433, 0.5262, 0.2608, 0.8501, 0.7187, 0.6944, 0.0193, 0.1507,
           0.0450, 0.2257, 0.4679, 0.8309, 0.4740, 0.8715, 0.7443, 0.3354,
           0.5533, 0.9151, 0.4215, 0.4631, 0.9076, 0.3093, 0.0270, 0.7681,
           0.1800, 0.0847, 0.6965, 0.2059, 0.8806, 0.3987, 0.8446, 0.6225,
           0.1375, 0.8765, 0.5965, 0.3092, 0.0193, 0.9220, 0.4997, 0.8170,
           0.8575, 0.5525, 0.8528, 0.7262, 0.4026, 0.5704, 0.0390, 0.9240,
           0.9780, 0.3927, 0.7343, 0.3922, 0.5004, 0.8561, 0.6021, 0.6530,
           0.6565, 0.9988, 0.4238, 0.0092, 0.5131, 0.5257, 0.1649, 0.0272,
           0.9103, 0.2476, 0.7573, 0.8500, 0.9348, 0.4306, 0.3612, 0.5378,
           0.7141, 0.3559, 0.6620, 0.3335, 0.4000, 0.2479, 0.3490, 0.7000,
           0.5321, 0.3485, 0.9162, 0.9207, 0.3262, 0.7929, 0.1258, 0.6689,
           0.1023, 0.1938, 0.3887, 0.6893, 0.0849, 0.3700, 0.5747, 0.9674,
           0.4520, 0.5313, 0.0377, 0.1202, 0.9326, 0.0442, 0.4651, 0.7036,
           0.3994, 0.9332, 0.5104, 0.0930, 0.4481, 0.8753, 0.5597, 0.6068,
           0.9895, 0.5833, 0.6771, 0.4255, 0.4513, 0.6330, 0.9070, 0.3103,
           0.0609, 0.8202, 0.6031, 0.3628, 0.1118, 0.2747, 0.4521, 0.8347]],
 
         [[0.6759, 0.8744, 0.3595, 0.2361, 0.4899, 0.3769, 0.6809, 0.0101,
           0.0730, 0.0576, 0.5242, 0.5510, 0.9780, 0.4704, 0.9607, 0.1699,
           0.3613, 0.6096, 0.0246, 0.6088, 0.4984, 0.9788, 0.2026, 0.1484,
           0.3086, 0.9697, 0.8166, 0.9566, 0.9874, 0.4547, 0.5250, 0.2041,
           0.7784, 0.4269, 0.0110, 0.6878, 0.6575, 0.3382, 0.1889, 0.8344,
           0.9608, 0.6153, 0.4812, 0.0547, 0.2978, 0.3610, 0.5285, 0.6162,
           0.2123, 0.1364, 0.6027, 0.7450, 0.2485, 0.2149, 0.7849, 0.8886,
           0.0514, 0.9511, 0.4865, 0.8380, 0.6947, 0.2378, 0.5839, 0.8434,
           0.0871, 0.4179, 0.1669, 0.8703, 0.1946, 0.0302, 0.9516, 0.1208,
           0.5780, 0.6859, 0.2405, 0.5083, 0.3872, 0.7649, 0.1329, 0.0252,
           0.2404, 0.5456, 0.7009, 0.6524, 0.7623, 0.5965, 0.0437, 0.4080,
           0.8390, 0.4172, 0.4781, 0.2405, 0.1502, 0.2020, 0.4192, 0.8185,
           0.0899, 0.1961, 0.7368, 0.4798, 0.4303, 0.9281, 0.5410, 0.0620,
           0.8945, 0.3589, 0.5637, 0.4875, 0.1523, 0.9478, 0.9040, 0.3410,
           0.3591, 0.2702, 0.5949, 0.3337, 0.3578, 0.8890, 0.6608, 0.6578,
           0.4953, 0.7975, 0.2891, 0.9552, 0.0092, 0.1293, 0.2362, 0.7821]]]],
       device='npu:0')
 
   >>>r2 = torch.rand(1, 2, 1, 128).npu()
 
 
 
       tensor([[[[6.4270e-01, 1.3050e-01, 9.6509e-01, 1.4090e-01, 1.8660e-01,
           8.7512e-01, 3.8567e-01, 4.1776e-01, 9.7718e-01, 5.6305e-01,
           6.3091e-01, 4.6385e-01, 1.8142e-01, 3.7779e-01, 3.8012e-01,
           8.1807e-01, 3.3292e-01, 5.8488e-01, 5.8188e-01, 5.7776e-01,
           5.1828e-01, 9.6087e-01, 7.2219e-01, 8.5045e-02, 3.6623e-01,
           3.3758e-01, 7.9666e-01, 6.9932e-01, 9.9202e-01, 2.5493e-01,
           2.3017e-01, 7.9396e-01, 5.0109e-01, 6.5580e-01, 3.2200e-01,
           7.8023e-01, 4.4098e-01, 1.0576e-01, 8.0548e-01, 2.2453e-01,
           1.4705e-01, 8.7682e-02, 4.7264e-01, 8.9034e-02, 8.5720e-01,
           4.7576e-01, 2.8438e-01, 8.6523e-01, 8.1707e-02, 3.0075e-01,
           4.9069e-01, 9.7404e-01, 9.3865e-01, 5.7160e-01, 1.6332e-01,
           4.3868e-01, 5.8658e-01, 5.3993e-01, 3.8271e-02, 9.9662e-01,
           2.2892e-01, 7.8558e-01, 9.4502e-01, 9.7633e-01, 1.7877e-01,
           2.6446e-02, 2.3411e-01, 6.7531e-01, 1.5023e-01, 4.4280e-02,
           1.4457e-01, 3.6683e-01, 4.3424e-01, 7.4145e-01, 8.2433e-01,
           6.8660e-01, 6.7477e-01, 5.5000e-02, 5.1344e-01, 9.3115e-01,
           3.8280e-01, 9.2177e-01, 4.5470e-01, 2.5540e-01, 4.6632e-01,
           8.3960e-01, 4.4320e-01, 1.0808e-01, 7.5544e-01, 4.6372e-01,
           1.4322e-01, 1.9141e-01, 5.5918e-02, 7.0804e-01, 1.8789e-01,
           9.4276e-01, 9.1742e-01, 9.1980e-01, 6.2728e-01, 4.1787e-01,
           7.9545e-01, 9.0569e-01, 7.9123e-01, 9.7596e-01, 7.2507e-01,
           2.3772e-01, 8.2560e-01, 5.9359e-01, 7.1134e-01, 5.1029e-01,
           6.1601e-01, 2.9094e-01, 3.4174e-01, 9.0532e-01, 5.0960e-01,
           3.4441e-01, 7.0498e-01, 4.2729e-01, 7.6714e-01, 6.3755e-01,
           8.4604e-01, 5.9109e-01, 7.9137e-01, 7.5149e-01, 2.2092e-01,
           9.5235e-01, 3.6915e-01, 6.4961e-01]],
 
         [[8.7862e-01, 1.1325e-01, 2.4575e-01, 9.7429e-01, 1.9362e-01,
           8.2297e-01, 3.5184e-02, 5.2755e-01, 7.6429e-01, 2.4700e-01,
           6.2860e-01, 2.4555e-01, 4.4557e-01, 7.0955e-03, 2.0326e-01,
           8.6354e-02, 3.5959e-01, 3.4059e-01, 8.6852e-01, 1.3858e-01,
           6.8500e-01, 1.3601e-01, 7.3152e-01, 8.3474e-01, 2.7017e-01,
           9.8078e-01, 6.1084e-01, 1.6540e-01, 4.3081e-01, 8.5738e-01,
           4.1890e-01, 6.6872e-01, 3.1698e-01, 4.2576e-02, 1.5236e-01,
           2.0526e-01, 1.9493e-01, 6.6122e-03, 1.8332e-01, 5.6981e-01,
           5.4090e-01, 6.0783e-01, 5.8742e-01, 9.1761e-04, 2.0904e-01,
           6.6419e-01, 9.9559e-01, 5.8233e-01, 6.8562e-01, 8.6456e-01,
           9.9931e-01, 3.5637e-01, 2.4642e-01, 2.3428e-02, 6.9037e-01,
           1.7560e-01, 1.8703e-01, 3.5244e-01, 6.3031e-01, 1.8450e-01,
           9.2194e-01, 9.3016e-02, 9.0488e-01, 2.4294e-02, 5.1122e-01,
           5.0793e-01, 7.9585e-01, 7.9594e-02, 5.2137e-01, 9.8359e-01,
           7.5022e-01, 4.1925e-01, 3.3284e-01, 4.7939e-01, 9.9081e-01,
           3.3931e-01, 2.6461e-01, 5.3063e-01, 1.0328e-01, 8.0720e-01,
           9.9480e-01, 3.0833e-01, 5.6780e-01, 3.9551e-01, 6.7176e-01,
           4.8049e-01, 1.5653e-01, 1.7595e-02, 6.6493e-02, 5.1989e-01,
           8.2691e-01, 7.3295e-01, 5.7169e-01, 4.9911e-01, 1.0260e-01,
           5.2307e-01, 7.4247e-01, 1.1682e-01, 5.8123e-01, 7.3496e-02,
           6.4274e-02, 2.4704e-01, 6.0424e-02, 2.6161e-01, 7.7966e-01,
           7.1244e-01, 2.2077e-01, 5.0723e-01, 9.6665e-01, 6.0933e-01,
           8.1332e-01, 3.0749e-01, 2.1297e-02, 3.6734e-01, 9.2542e-01,
           1.3554e-01, 9.7240e-01, 4.4344e-01, 4.2534e-01, 4.6205e-01,
           6.1811e-01, 5.8800e-01, 5.4673e-01, 1.2535e-01, 2.9959e-01,
           4.4890e-01, 2.7185e-01, 5.0243e-01]]]], device='npu:0')
 
    >>>out = torch_npu.npu_rotary_mul(x, r1, r2)
       tensor([[[[ 0.1142,  0.1891, -0.4689,  ...,  0.5704,  0.5375,  0.6079],
          [ 0.2264,  0.1155, -0.7678,  ...,  0.9857,  0.3382,  0.9441],
          [-0.1902,  0.1329, -0.3613,  ...,  0.9793,  0.5628,  0.8669],
          [-0.3349,  0.1532,  0.1124,  ...,  0.3125,  0.6741,  1.1248],
          [-0.0473,  0.2978, -0.6940,  ...,  0.2753,  0.2604,  1.0379]],
 
         [[ 0.0136,  0.4723,  0.1371,  ...,  0.1081,  0.2462,  0.6316],
          [ 0.0769,  0.6558, -0.0734,  ...,  0.2714,  0.2221,  0.2195],
          [-0.3755,  0.5364, -0.1131,  ...,  0.3105,  0.1225,  0.6166],
          [ 0.3535,  0.0164,  0.0095,  ...,  0.1361,  0.2570,  0.5811],
          [-0.2992,  0.2981,  0.0242,  ...,  0.2881,  0.2367,  0.9582]]],
 
 
        [[[ 0.1699,  0.3589, -0.7443,  ...,  0.4751,  0.7291,  0.2717],
          [ 0.3657,  0.0397,  0.1818,  ...,  0.9113,  0.4130,  0.8279],
          [-0.0657,  0.2528, -0.6658,  ...,  0.8184,  0.2057,  1.2864],
          [-0.1058,  0.1859, -0.0998,  ...,  0.0662,  0.5590,  1.0525],
          [ 0.2651,  0.3719, -0.8170,  ...,  0.2789,  0.3916,  1.0407]],
 
         [[-0.5998,  0.5740, -0.0154,  ...,  0.1746,  0.1982,  0.6338],
          [ 0.0766,  0.1790, -0.1490,  ...,  0.4387,  0.2592,  0.4924],
          [ 0.4765,  0.0485, -0.0226,  ...,  0.2219,  0.3445,  0.2265],
          [-0.1006,  0.8073, -0.1540,  ...,  0.1045,  0.2633,  0.2194],
          [ 0.0157,  0.3997,  0.3131,  ...,  0.0538,  0.0647,  0.4821]]]],
       device='npu:0')
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词