(beta)torch_npu.npu_geglu
接口原型
torch_npu.npu_geglu(Tensor self, int dim=-1, int approximate=1) -> (Tensor, Tensor)
功能描述
对输入Tensor完成GeGlu运算。
其中,A和B是对输入self沿指定轴平均切分,dim默认值为-1。
参数说明
- Tensor self:待进行GeGlu计算的入参,npu device侧的aclTensor,数据类型支持FLOAT32、FLOAT16、BFLOAT16(Atlas A2 训练系列产品支持),支持非连续的Tensor,数据格式支持ND。
- int dim:可选入参,设定的slice轴,数据类型支持INT64。
- int approximate:可选入参,GeGlu计算使用的激活函数索引,0表示使用none,1表示使用tanh,数据类型支持INT64。
- out:GeGlu计算的出参,npu device侧的aclTensor,数据类型必须和self一致,支持非连续的Tensor,数据格式支持ND。
- outGelu:GeGlu计算的出参,npu device侧的aclTensor,数据类型必须和self一致,支持非连续的Tensor,数据格式支持ND。
约束说明
out、outGelu在dim维的size等于self在dim维size的1/2。
当self.dim()==0时,dim的取值在[-1, 0]范围内;当self.dim()>0时,取值在[-self.dim(), self.dim()-1]范围内。
调用示例
data_x = np.random.uniform(-2, 2, [24,9216,2560]).astype(np.float16) x_npu = torch.from_numpy(data_x).npu() x_npu: tensor([[[ 0.8750, 0.4766, -0.3535, ..., -1.4619, 0.3542, -1.8389], [ 0.9424, -0.0291, 0.9482, ..., 0.5640, -1.2959, 1.7666], [-0.4958, -0.6787, 0.0179, ..., 0.4365, -0.8311, -1.7676], ..., [-1.1611, 1.4766, -1.1934, ..., -0.5913, 1.1553, -0.4626], [ 0.4873, -1.8105, 0.5723, ..., 1.3193, -0.1558, -1.6191], [ 1.6816, -1.2080, -1.6953, ..., -1.3096, 0.4158, -1.2168]], [[ 1.4287, -1.9863, 1.4053, ..., -1.7676, -1.6709, -1.1582], [-1.3281, -1.9043, 0.7725, ..., -1.5596, 0.1632, -1.0732], [ 1.0254, -1.6650, 0.1318, ..., -0.8159, -0.7134, -0.4536], ..., [ 0.0327, -0.6206, -0.1492, ..., -1.2559, 0.3777, -1.2822], [-1.1904, 1.1260, -1.3369, ..., -1.4814, 0.4463, 1.0205], [-0.1192, 1.7783, 0.1040, ..., 1.0010, 1.5342, -0.5728]], [[-0.3296, 0.5703, 0.6338, ..., 0.2131, 1.1113, 0.9854], [ 1.4336, -1.7568, 1.8164, ..., -1.2012, -1.8721, 0.6904], [ 0.6934, 0.3743, -0.9448, ..., -0.9946, -1.6494, -1.3564], ..., [ 1.1855, -0.9663, -0.8252, ..., 0.2285, -1.5684, -0.4277], [ 1.1260, 1.2871, 1.2754, ..., -0.5171, -1.1064, 0.9624], [-1.4639, -0.0661, -1.7178, ..., 1.2656, -1.9023, -1.1641]], ..., [[-1.8350, 1.0625, 1.6172, ..., 1.4160, 1.2490, 1.9775], [-0.5615, -1.9990, -0.5996, ..., -1.9404, 0.5068, -0.9829], [-1.0771, -1.5537, -1.5654, ..., 0.4678, -1.5215, -1.7920], ..., [-1.3389, -0.3228, -1.1514, ..., 0.8882, -1.9971, 1.2432], [-1.5439, -1.8154, -1.9238, ..., 0.2556, 0.2131, -1.7471], [-1.1074, 1.0391, 0.1556, ..., 1.1689, 0.6470, 0.2463]], [[ 1.2617, -0.8911, 1.9160, ..., -0.3027, 1.7764, 0.3381], [-1.4160, 1.6201, -0.5396, ..., 1.8271, 1.3086, -1.8770], [ 1.8252, 1.3779, -0.3535, ..., -1.5215, -1.4727, -1.0420], ..., [-1.4600, -1.7617, -0.7754, ..., 0.4697, -0.4734, -0.3838], [ 1.8506, -0.3945, -0.0142, ..., -1.3447, -0.6587, 0.5728], [ 1.1523, -1.8027, 0.4731, ..., 0.5464, 1.4014, -1.8594]], [[-0.1467, -0.5752, 0.3298, ..., -1.9902, -1.8281, 1.8506], [ 0.2473, 1.0693, -1.8184, ..., 1.9277, 1.6543, 1.0088], [ 0.0804, -0.7939, 1.3486, ..., -1.1543, -0.4053, -0.0055], ..., [ 0.3672, 0.3274, -0.3369, ..., 1.4951, -1.9580, -0.7847], [ 1.3525, -0.4780, -0.5000, ..., -0.1610, -1.9209, 1.5498], [ 0.4905, -1.7832, 0.4243, ..., 0.9492, 0.3335, 0.9565]]], device='npu:0', dtype=torch.float16) y_npu, y_gelu_npu = torch_npu.npu_geglu(x_npu, dim=-1, approximate=1) y_npu: tensor([[[-9.2590e-02, -1.2054e-01, 1.6980e-01, ..., -6.8542e-02, -2.5254e+00, -6.9519e-02], [ 1.2405e-02, -1.4902e+00, 8.0750e-02, ..., 3.4570e-01, -1.5029e+00, 2.8442e-01], [-9.0271e-02, 4.3335e-01, -1.7402e+00, ..., 1.3574e-01, -5.5762e-01, -1.3123e-01], ..., [ 1.0004e-01, 1.5312e+00, 1.4189e+00, ..., -2.6172e-01, 1.6113e-01, -1.1887e-02], [-5.9845e-02, 2.0911e-01, -6.4735e-03, ..., 5.1422e-02, 2.6289e+00, 2.5977e-01], [ 1.3649e-02, -1.3329e-02, -6.9031e-02, ..., 3.5977e+00, -1.2178e+00, -2.3242e+00]], [[-3.1816e+00, -2.6719e+00, 1.4038e-01, ..., 2.6660e+00, 7.7820e-02, 2.3999e-01], [ 2.9297e+00, -1.7754e+00, 2.6703e-02, ..., -1.3318e-01, -6.2109e-01, -1.9072e+00], [ 1.1316e-01, 5.8887e-01, 8.2959e-01, ..., 1.1273e-01, 1.1481e-01, 4.2419e-02], ..., [-2.6831e-01, -1.7288e-02, 2.6343e-01, ..., 9.3750e-02, -2.2324e+00, 1.2894e-02], [-2.0630e-01, 5.9619e-01, -1.4210e-03, ..., -1.2598e-01, -6.5552e-02, 1.1115e-01], [-1.6143e+00, -1.6150e-01, -4.9774e-02, ..., 8.6426e-02, 1.1879e-02, -1.9795e+00]], [[ 4.3152e-02, 1.9250e-01, -4.7485e-02, ..., -5.8632e-03, 1.4551e-01, -2.1289e+00], [ 4.7951e-03, 2.0691e-01, 4.4458e-01, ..., 4.7485e-02, -4.8889e-02, 1.5684e+00], [-8.9404e-01, -8.0420e-01, -2.9248e-01, ..., 1.6205e-02, 3.5449e+00, 8.2397e-02], ..., [-1.9385e+00, -1.8838e+00, 6.0010e-01, ..., -8.5059e-01, 6.1829e-02, 1.0547e-01], [-5.1086e-02, -1.0760e-01, -7.1228e-02, ..., -9.2468e-02, 4.7900e-01, -3.5278e-01], [ 1.7078e-01, 1.6846e-01, 2.5528e-02, ..., 1.3708e-01, 1.4954e-01, -2.8418e-01]], ..., [[-6.3574e-01, -2.0156e+00, 9.3994e-02, ..., 2.2402e+00, -6.2218e-03, 8.7402e-01], [ 1.5010e+00, -1.8518e-01, -3.0930e-02, ..., 1.1511e-01, -3.8300e-02, -1.6150e-01], [-2.8442e-01, 4.4373e-02, -1.0022e-01, ..., 9.2468e-02, -1.2524e-01, -1.2115e-01], ..., [ 3.4760e-02, 1.9812e-01, -9.1431e-02, ..., -1.1650e+00, 2.4011e-01, -1.0919e-01], [-1.5283e-01, 1.8535e+00, 4.4360e-01, ..., 6.4844e-01, -2.8784e-01, -2.5938e+00], [-9.9915e-02, 4.6436e-01, 6.6528e-02, ..., -1.2817e-01, -1.5686e-01, -5.4962e-02]], [[-2.3279e-01, 4.5630e-01, -5.4834e-01, ..., 5.9013e-03, -4.7974e-02, -2.7617e+00], [-1.0760e-01, -2.0371e+00, 3.7915e-01, ..., 6.4551e-01, 2.6953e-01, -1.0910e-03], [ 4.9683e-01, 1.2402e+00, -1.0429e-02, ..., 3.4294e-03, -8.2959e-01, 1.2012e-01], ..., [ 1.6956e-01, 5.3027e-01, -1.6418e-01, ..., -2.1094e-01, -9.8267e-02, 2.3364e-01], [ 4.1687e-02, -1.1365e-01, 1.2598e+00, ..., -5.6299e-01, 1.5967e+00, 9.3445e-02], [ 9.7656e-02, -4.5410e-01, -2.9395e-01, ..., -1.6565e-01, -8.2153e-02, -7.0068e-01]], [[ 1.6345e-01, 2.5806e-01, -6.1951e-02, ..., -6.5857e-02, -6.0303e-02, -1.9080e-01], [ 1.9666e-01, 1.8262e+00, -1.1951e-01, ..., 1.0138e-01, -2.0911e-01, -6.0638e-02], [-6.9141e-01, -2.5234e+00, -1.2734e+00, ..., 1.0510e-01, -1.6504e+00, -9.7070e-01], ..., [-2.5406e-03, -3.1342e-02, -7.0862e-02, ..., 9.2041e-02, 7.7271e-02, 8.0518e-01], [-1.5161e-01, -6.8848e-02, 7.0801e-01, ..., 7.0166e-01, -3.3661e-02, -1.4319e-01], [-3.0899e-02, 1.4490e-01, 1.9763e-01, ..., -8.1116e-02, 7.8955e-01, 1.8347e-01]]], device='npu:0', dtype=torch.float16) y_gelu_npu: tensor([[[-1.5771e-01, -1.4331e-01, -1.0846e-01, ..., -1.1133e-01, 1.3818e+00, -1.5076e-01], [-1.8600e-02, 1.6904e+00, -6.9336e-02, ..., 3.6890e-01, 1.6768e+00, 2.5146e-01], [ 7.5342e-01, 6.0742e-01, 1.0820e+00, ..., 1.5063e-01, 1.1572e+00, -9.4482e-02], ..., [-1.5796e-01, 8.4082e-01, 9.2627e-01, ..., -1.6064e-01, -1.1096e-01, -1.6370e-01], [ 3.4814e-01, -1.6418e-01, -3.1982e-02, ..., -1.5186e-01, 1.3330e+00, -1.4111e-01], [-8.4778e-02, -1.1023e-01, -1.0669e-01, ..., 1.9521e+00, 9.5654e-01, 1.5635e+00]], [[ 1.7881e+00, 1.8359e+00, -1.6663e-01, ..., 1.4609e+00, -1.6760e-01, -1.6528e-01], [ 1.9434e+00, 1.7168e+00, -1.1615e-01, ..., -9.8816e-02, 9.4043e-01, 1.2344e+00], [-1.6064e-01, 5.7031e-01, 1.6475e+00, ..., -1.0809e-01, -1.6785e-01, -1.6345e-01], ..., [-1.6797e-01, -4.6326e-02, 2.6904e-01, ..., 6.9458e-02, 1.3174e+00, 1.3486e+00], [-1.0645e-01, 3.0249e-01, -9.9411e-03, ..., -1.3928e-01, -1.0974e-01, -7.1533e-02], [ 1.7012e+00, -1.0254e-01, -8.2825e-02, ..., -4.8492e-02, -1.1926e-01, 1.7490e+00]], [[-6.6650e-02, -1.0370e-01, -2.3788e-02, ..., -1.0706e-01, -1.6980e-01, 1.4209e+00], [-5.2986e-03, -1.1133e-01, 2.5439e-01, ..., -3.9459e-02, -6.8909e-02, 1.2119e+00], [ 6.1035e-01, 6.8506e-01, -1.5039e-01, ..., 5.8136e-02, 1.8232e+00, -6.7383e-02], ..., [ 1.4434e+00, 1.6787e+00, 1.2422e+00, ..., 7.5488e-01, -5.0720e-02, -6.8787e-02], [-1.4600e-01, -1.2213e-01, -1.6711e-01, ..., 3.7280e-01, 1.3125e+00, 2.2375e-01], [ 3.4985e-01, -1.2659e-01, -4.6722e-02, ..., -1.4685e-01, 1.4856e-01, -1.6406e-01]], ..., [[ 4.8730e-01, 1.6680e+00, -5.7098e-02, ..., 1.4189e+00, 7.1983e-03, 7.8857e-01], [ 1.1328e+00, -1.6931e-01, -1.1163e-01, ..., -1.6467e-01, 3.5309e-02, -1.5173e-01], [-1.6858e-01, -8.9111e-02, -1.4709e-01, ..., -8.1970e-02, 5.4248e-01, 5.0830e-01], ..., [ 2.1936e-01, 7.7197e-01, 4.8737e-02, ..., 8.7842e-01, -1.6406e-01, -7.1716e-02], [-1.2720e-01, 1.9404e+00, 1.0391e+00, ..., 7.3877e-01, -1.6199e-01, 1.5781e+00], [-1.6968e-01, 1.0664e+00, -1.6431e-01, ..., -7.5439e-02, -1.5332e-01, 2.1790e-01]], [[ 3.0981e-01, 6.0010e-01, 7.9346e-01, ..., 4.0169e-03, 5.8447e-01, 1.7109e+00], [-1.6699e-01, 1.7646e+00, 5.9326e-01, ..., 3.3813e-01, -1.5845e-01, -4.7699e-02], [ 3.7573e-01, 9.4580e-01, -9.5276e-02, ..., 2.4805e-01, 8.3350e-01, 1.2573e-01], ..., [-1.5369e-01, 1.2021e+00, -1.6626e-01, ..., -1.1108e-01, 1.6084e+00, -1.4807e-01], [-4.6234e-02, -6.4331e-02, 8.9844e-01, ..., 9.2871e-01, 7.9834e-01, -1.6992e-01], [-6.4941e-02, 1.1465e+00, -1.5161e-01, ..., -1.5076e-01, -8.6487e-02, 1.0137e+00]], [[-1.1731e-01, -1.4404e-01, -8.9050e-02, ..., -1.2128e-01, -1.0919e-01, -1.6943e-01], [ 1.5186e-01, 1.1396e+00, -6.5735e-02, ..., -7.4829e-02, -1.6455e-01, -8.9355e-02], [ 6.4404e-01, 1.5625e+00, 1.7725e+00, ..., -5.5176e-02, 1.7920e+00, 6.6504e-01], ..., [ 1.9083e-03, 3.8452e-01, -4.9011e-02, ..., -1.5405e-01, -1.6003e-01, 1.3975e+00], [ 1.0437e-01, -8.6182e-02, 5.5713e-01, ..., 1.0645e+00, -1.3818e-01, 5.1562e-01], [-1.0229e-01, -1.0529e-01, 2.6562e-01, ..., -5.6702e-02, 1.0830e+00, -1.6833e-01]]], device='npu:0', dtype=torch.float16)
父主题: torch_npu