torch_npu.npu_rotary_mul
接口原型
torch_npu.npu_rotary_mul(Tensor x, Tensor r1, Tensor r2): -> Tensor
功能描述
实现RotaryEmbedding旋转位置编码。支持FakeTensor模式。
x1, x2 = torch.chunk(x, 2, -1) x_new = torch.cat((-x2, x1), dim=-1) output = r1 * x + r2 * x_new
参数说明
- Tensor x:q, k,shape要求输入为4维,一般为[B, N, S, D]或[B, S, N, D]或[S, B, N, D],dtype支持bf16,fp16,fp32。
- Tensor r1: cos值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D],dtype支持bf16,fp16,fp32。
- Tensor r2: sin值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D],dtype支持bf16,fp16,fp32。
约束说明
shape要求输入为4维,其中B维度和N维度数值需小于等于1000,D维度数值为128。
调用示例
>>>x = torch.rand(2, 2, 5, 128).npu() >>>tensor([[[[0.8594, 0.4914, 0.9075, ..., 0.2126, 0.6520, 0.2206], [0.5515, 0.3353, 0.6568, ..., 0.3686, 0.1457, 0.8528], [0.0504, 0.2687, 0.4036, ..., 0.3032, 0.8262, 0.6302], [0.0537, 0.5141, 0.7016, ..., 0.4948, 0.9778, 0.8535], [0.3602, 0.7874, 0.9913, ..., 0.1474, 0.3422, 0.6830]], [[0.4641, 0.6254, 0.7415, ..., 0.1834, 0.1067, 0.7171], [0.8084, 0.7570, 0.4728, ..., 0.4603, 0.4991, 0.1723], [0.0483, 0.6931, 0.0935, ..., 0.7522, 0.0054, 0.1736], [0.6196, 0.1028, 0.7076, ..., 0.2745, 0.9943, 0.6971], [0.3267, 0.3748, 0.1232, ..., 0.0507, 0.4302, 0.6249]]], [[[0.2783, 0.8262, 0.6014, ..., 0.8040, 0.7986, 0.2831], [0.6035, 0.2955, 0.7711, ..., 0.7464, 0.3739, 0.6637], [0.6282, 0.7243, 0.5445, ..., 0.3755, 0.0533, 0.9468], [0.5179, 0.3967, 0.6558, ..., 0.0267, 0.5549, 0.9707], [0.4388, 0.7458, 0.2065, ..., 0.6080, 0.4242, 0.8879]], [[0.3428, 0.6976, 0.0970, ..., 0.9552, 0.3663, 0.2139], [0.2019, 0.2452, 0.1142, ..., 0.3651, 0.6993, 0.5257], [0.9636, 0.1691, 0.4807, ..., 0.9137, 0.3510, 0.0905], [0.0177, 0.9496, 0.1560, ..., 0.7437, 0.9043, 0.0131], [0.9699, 0.5352, 0.9763, ..., 0.1850, 0.2056, 0.0368]]]], device='npu:0') >>>r1 = torch.rand(1, 2, 1, 128).npu() tensor([[[[0.8433, 0.5262, 0.2608, 0.8501, 0.7187, 0.6944, 0.0193, 0.1507, 0.0450, 0.2257, 0.4679, 0.8309, 0.4740, 0.8715, 0.7443, 0.3354, 0.5533, 0.9151, 0.4215, 0.4631, 0.9076, 0.3093, 0.0270, 0.7681, 0.1800, 0.0847, 0.6965, 0.2059, 0.8806, 0.3987, 0.8446, 0.6225, 0.1375, 0.8765, 0.5965, 0.3092, 0.0193, 0.9220, 0.4997, 0.8170, 0.8575, 0.5525, 0.8528, 0.7262, 0.4026, 0.5704, 0.0390, 0.9240, 0.9780, 0.3927, 0.7343, 0.3922, 0.5004, 0.8561, 0.6021, 0.6530, 0.6565, 0.9988, 0.4238, 0.0092, 0.5131, 0.5257, 0.1649, 0.0272, 0.9103, 0.2476, 0.7573, 0.8500, 0.9348, 0.4306, 0.3612, 0.5378, 0.7141, 0.3559, 0.6620, 0.3335, 0.4000, 0.2479, 0.3490, 0.7000, 0.5321, 0.3485, 0.9162, 0.9207, 0.3262, 0.7929, 0.1258, 0.6689, 0.1023, 0.1938, 0.3887, 0.6893, 0.0849, 0.3700, 0.5747, 0.9674, 0.4520, 0.5313, 0.0377, 0.1202, 0.9326, 0.0442, 0.4651, 0.7036, 0.3994, 0.9332, 0.5104, 0.0930, 0.4481, 0.8753, 0.5597, 0.6068, 0.9895, 0.5833, 0.6771, 0.4255, 0.4513, 0.6330, 0.9070, 0.3103, 0.0609, 0.8202, 0.6031, 0.3628, 0.1118, 0.2747, 0.4521, 0.8347]], [[0.6759, 0.8744, 0.3595, 0.2361, 0.4899, 0.3769, 0.6809, 0.0101, 0.0730, 0.0576, 0.5242, 0.5510, 0.9780, 0.4704, 0.9607, 0.1699, 0.3613, 0.6096, 0.0246, 0.6088, 0.4984, 0.9788, 0.2026, 0.1484, 0.3086, 0.9697, 0.8166, 0.9566, 0.9874, 0.4547, 0.5250, 0.2041, 0.7784, 0.4269, 0.0110, 0.6878, 0.6575, 0.3382, 0.1889, 0.8344, 0.9608, 0.6153, 0.4812, 0.0547, 0.2978, 0.3610, 0.5285, 0.6162, 0.2123, 0.1364, 0.6027, 0.7450, 0.2485, 0.2149, 0.7849, 0.8886, 0.0514, 0.9511, 0.4865, 0.8380, 0.6947, 0.2378, 0.5839, 0.8434, 0.0871, 0.4179, 0.1669, 0.8703, 0.1946, 0.0302, 0.9516, 0.1208, 0.5780, 0.6859, 0.2405, 0.5083, 0.3872, 0.7649, 0.1329, 0.0252, 0.2404, 0.5456, 0.7009, 0.6524, 0.7623, 0.5965, 0.0437, 0.4080, 0.8390, 0.4172, 0.4781, 0.2405, 0.1502, 0.2020, 0.4192, 0.8185, 0.0899, 0.1961, 0.7368, 0.4798, 0.4303, 0.9281, 0.5410, 0.0620, 0.8945, 0.3589, 0.5637, 0.4875, 0.1523, 0.9478, 0.9040, 0.3410, 0.3591, 0.2702, 0.5949, 0.3337, 0.3578, 0.8890, 0.6608, 0.6578, 0.4953, 0.7975, 0.2891, 0.9552, 0.0092, 0.1293, 0.2362, 0.7821]]]], device='npu:0') >>>r2 = torch.rand(1, 2, 1, 128).npu() tensor([[[[6.4270e-01, 1.3050e-01, 9.6509e-01, 1.4090e-01, 1.8660e-01, 8.7512e-01, 3.8567e-01, 4.1776e-01, 9.7718e-01, 5.6305e-01, 6.3091e-01, 4.6385e-01, 1.8142e-01, 3.7779e-01, 3.8012e-01, 8.1807e-01, 3.3292e-01, 5.8488e-01, 5.8188e-01, 5.7776e-01, 5.1828e-01, 9.6087e-01, 7.2219e-01, 8.5045e-02, 3.6623e-01, 3.3758e-01, 7.9666e-01, 6.9932e-01, 9.9202e-01, 2.5493e-01, 2.3017e-01, 7.9396e-01, 5.0109e-01, 6.5580e-01, 3.2200e-01, 7.8023e-01, 4.4098e-01, 1.0576e-01, 8.0548e-01, 2.2453e-01, 1.4705e-01, 8.7682e-02, 4.7264e-01, 8.9034e-02, 8.5720e-01, 4.7576e-01, 2.8438e-01, 8.6523e-01, 8.1707e-02, 3.0075e-01, 4.9069e-01, 9.7404e-01, 9.3865e-01, 5.7160e-01, 1.6332e-01, 4.3868e-01, 5.8658e-01, 5.3993e-01, 3.8271e-02, 9.9662e-01, 2.2892e-01, 7.8558e-01, 9.4502e-01, 9.7633e-01, 1.7877e-01, 2.6446e-02, 2.3411e-01, 6.7531e-01, 1.5023e-01, 4.4280e-02, 1.4457e-01, 3.6683e-01, 4.3424e-01, 7.4145e-01, 8.2433e-01, 6.8660e-01, 6.7477e-01, 5.5000e-02, 5.1344e-01, 9.3115e-01, 3.8280e-01, 9.2177e-01, 4.5470e-01, 2.5540e-01, 4.6632e-01, 8.3960e-01, 4.4320e-01, 1.0808e-01, 7.5544e-01, 4.6372e-01, 1.4322e-01, 1.9141e-01, 5.5918e-02, 7.0804e-01, 1.8789e-01, 9.4276e-01, 9.1742e-01, 9.1980e-01, 6.2728e-01, 4.1787e-01, 7.9545e-01, 9.0569e-01, 7.9123e-01, 9.7596e-01, 7.2507e-01, 2.3772e-01, 8.2560e-01, 5.9359e-01, 7.1134e-01, 5.1029e-01, 6.1601e-01, 2.9094e-01, 3.4174e-01, 9.0532e-01, 5.0960e-01, 3.4441e-01, 7.0498e-01, 4.2729e-01, 7.6714e-01, 6.3755e-01, 8.4604e-01, 5.9109e-01, 7.9137e-01, 7.5149e-01, 2.2092e-01, 9.5235e-01, 3.6915e-01, 6.4961e-01]], [[8.7862e-01, 1.1325e-01, 2.4575e-01, 9.7429e-01, 1.9362e-01, 8.2297e-01, 3.5184e-02, 5.2755e-01, 7.6429e-01, 2.4700e-01, 6.2860e-01, 2.4555e-01, 4.4557e-01, 7.0955e-03, 2.0326e-01, 8.6354e-02, 3.5959e-01, 3.4059e-01, 8.6852e-01, 1.3858e-01, 6.8500e-01, 1.3601e-01, 7.3152e-01, 8.3474e-01, 2.7017e-01, 9.8078e-01, 6.1084e-01, 1.6540e-01, 4.3081e-01, 8.5738e-01, 4.1890e-01, 6.6872e-01, 3.1698e-01, 4.2576e-02, 1.5236e-01, 2.0526e-01, 1.9493e-01, 6.6122e-03, 1.8332e-01, 5.6981e-01, 5.4090e-01, 6.0783e-01, 5.8742e-01, 9.1761e-04, 2.0904e-01, 6.6419e-01, 9.9559e-01, 5.8233e-01, 6.8562e-01, 8.6456e-01, 9.9931e-01, 3.5637e-01, 2.4642e-01, 2.3428e-02, 6.9037e-01, 1.7560e-01, 1.8703e-01, 3.5244e-01, 6.3031e-01, 1.8450e-01, 9.2194e-01, 9.3016e-02, 9.0488e-01, 2.4294e-02, 5.1122e-01, 5.0793e-01, 7.9585e-01, 7.9594e-02, 5.2137e-01, 9.8359e-01, 7.5022e-01, 4.1925e-01, 3.3284e-01, 4.7939e-01, 9.9081e-01, 3.3931e-01, 2.6461e-01, 5.3063e-01, 1.0328e-01, 8.0720e-01, 9.9480e-01, 3.0833e-01, 5.6780e-01, 3.9551e-01, 6.7176e-01, 4.8049e-01, 1.5653e-01, 1.7595e-02, 6.6493e-02, 5.1989e-01, 8.2691e-01, 7.3295e-01, 5.7169e-01, 4.9911e-01, 1.0260e-01, 5.2307e-01, 7.4247e-01, 1.1682e-01, 5.8123e-01, 7.3496e-02, 6.4274e-02, 2.4704e-01, 6.0424e-02, 2.6161e-01, 7.7966e-01, 7.1244e-01, 2.2077e-01, 5.0723e-01, 9.6665e-01, 6.0933e-01, 8.1332e-01, 3.0749e-01, 2.1297e-02, 3.6734e-01, 9.2542e-01, 1.3554e-01, 9.7240e-01, 4.4344e-01, 4.2534e-01, 4.6205e-01, 6.1811e-01, 5.8800e-01, 5.4673e-01, 1.2535e-01, 2.9959e-01, 4.4890e-01, 2.7185e-01, 5.0243e-01]]]], device='npu:0') >>>out = torch_npu.npu_rotary_mul(x, r1, r2) tensor([[[[ 0.1142, 0.1891, -0.4689, ..., 0.5704, 0.5375, 0.6079], [ 0.2264, 0.1155, -0.7678, ..., 0.9857, 0.3382, 0.9441], [-0.1902, 0.1329, -0.3613, ..., 0.9793, 0.5628, 0.8669], [-0.3349, 0.1532, 0.1124, ..., 0.3125, 0.6741, 1.1248], [-0.0473, 0.2978, -0.6940, ..., 0.2753, 0.2604, 1.0379]], [[ 0.0136, 0.4723, 0.1371, ..., 0.1081, 0.2462, 0.6316], [ 0.0769, 0.6558, -0.0734, ..., 0.2714, 0.2221, 0.2195], [-0.3755, 0.5364, -0.1131, ..., 0.3105, 0.1225, 0.6166], [ 0.3535, 0.0164, 0.0095, ..., 0.1361, 0.2570, 0.5811], [-0.2992, 0.2981, 0.0242, ..., 0.2881, 0.2367, 0.9582]]], [[[ 0.1699, 0.3589, -0.7443, ..., 0.4751, 0.7291, 0.2717], [ 0.3657, 0.0397, 0.1818, ..., 0.9113, 0.4130, 0.8279], [-0.0657, 0.2528, -0.6658, ..., 0.8184, 0.2057, 1.2864], [-0.1058, 0.1859, -0.0998, ..., 0.0662, 0.5590, 1.0525], [ 0.2651, 0.3719, -0.8170, ..., 0.2789, 0.3916, 1.0407]], [[-0.5998, 0.5740, -0.0154, ..., 0.1746, 0.1982, 0.6338], [ 0.0766, 0.1790, -0.1490, ..., 0.4387, 0.2592, 0.4924], [ 0.4765, 0.0485, -0.0226, ..., 0.2219, 0.3445, 0.2265], [-0.1006, 0.8073, -0.1540, ..., 0.1045, 0.2633, 0.2194], [ 0.0157, 0.3997, 0.3131, ..., 0.0538, 0.0647, 0.4821]]]], device='npu:0')
父主题: torch_npu