(beta)torch_npu.npu_dropout_with_add_softmax
接口原型
torch_npu.npu_dropout_with_add_softmax(Tensor self, Tensor x1, Scalar alpha, float prob, int dim) -> (Tensor, Tensor, Tensor)
功能描述
实现axpy_v2、softmax_v2、drop_out_domask_v3功能。即:
y=x1+ self *alpha
Softmax(xi)= exp(xi)/∑jexp(xj)
output = 根据mask舍弃x中的元素,留下来的元素乘(1/prob)
参数说明
- Tensor self:4维张量,shape为(N, C, H, W)。
- Tensor x1:4维张量,shape为(N, C, H, W)。
约束说明
- self和x1的shape相同;
- H和W是[128, 256, 384, 512]其中之一;
- (N * C)%32结果为0;
- dim为-1。
调用示例
self = torch.rand(16, 16, 128, 128).npu() tensor([[[[7.2556e-02, 3.0909e-01, 7.9734e-01, ..., 6.1179e-01, 6.2624e-03, 8.5186e-01], [8.9196e-02, 3.3319e-01, 4.0780e-01, ..., 1.9144e-01, 2.2701e-01, 6.4018e-01], [4.7275e-01, 7.4895e-01, 4.6215e-01, ..., 9.3753e-01, 6.6048e-02, 8.1877e-02], ..., [7.9366e-01, 5.1516e-01, 5.6594e-01, ..., 1.6457e-01, 1.0640e-01, 3.4322e-03], [1.5743e-02, 1.2893e-01, 5.8990e-01, ..., 4.1721e-01, 8.7816e-02, 6.8886e-01], [4.2980e-01, 5.5447e-01, 3.1894e-01, ..., 9.2638e-01, 9.9324e-01, 4.6225e-01]], [[6.2426e-01, 4.5948e-01, 1.0837e-01, ..., 8.9386e-01, 3.6932e-01, 1.2406e-01], [9.1823e-01, 6.2311e-01, 5.1474e-01, ..., 2.1042e-01, 6.5943e-01, 3.1797e-01], [5.2891e-01, 2.0183e-01, 2.1452e-01, ..., 9.1638e-01, 6.4109e-01, 9.4484e-01], ..., [3.7783e-02, 1.3218e-01, 3.1192e-01, ..., 2.4931e-01, 4.8809e-01, 9.6085e-01], [3.3197e-01, 9.1186e-02, 2.4839e-01, ..., 2.1156e-03, 6.4952e-01, 8.5996e-01], [1.7941e-01, 5.1532e-01, 7.8133e-01, ..., 3.5526e-01, 5.3576e-01, 6.0538e-01]], [[2.6743e-01, 7.4942e-01, 1.9146e-01, ..., 4.9179e-01, 6.3319e-01, 9.9269e-01], [1.5163e-01, 3.7388e-01, 8.0604e-02, ..., 8.1193e-01, 1.7922e-01, 8.6578e-01], [8.2558e-01, 9.5139e-01, 2.1313e-01, ..., 2.1722e-01, 2.8402e-01, 8.8888e-01], ..., [1.8222e-01, 2.7645e-01, 6.7305e-01, ..., 6.8003e-01, 4.0917e-01, 7.6655e-01], [3.1234e-01, 7.8519e-01, 8.8509e-01, ..., 7.2574e-01, 9.6134e-01, 2.2267e-01], [4.9233e-01, 8.8407e-01, 7.4390e-01, ..., 5.2253e-02, 5.5150e-02, 4.4108e-02]], ..., [[4.3370e-01, 2.1176e-01, 4.7512e-01, ..., 5.7611e-01, 3.2619e-01, 1.1523e-01], [6.1469e-01, 7.4528e-01, 7.9559e-02, ..., 9.7112e-01, 1.8391e-01, 8.9883e-01], [8.6677e-02, 3.5051e-02, 1.6875e-01, ..., 3.9833e-01, 6.7967e-01, 4.7062e-01], ..., [7.1648e-01, 1.8378e-01, 5.3054e-01, ..., 8.4282e-01, 9.1972e-01, 7.0031e-01], [5.9876e-01, 6.7868e-01, 6.4128e-01, ..., 4.9516e-02, 7.2571e-01, 5.8792e-01], [7.6723e-01, 6.9527e-01, 9.3573e-01, ..., 6.3490e-02, 6.6129e-01, 2.4517e-01]], [[5.0158e-01, 8.2565e-01, 7.5532e-01, ..., 6.9342e-01, 3.3244e-01, 5.3913e-01], [2.3347e-01, 9.7822e-02, 1.5009e-01, ..., 5.5090e-01, 9.1813e-01, 7.9857e-01], [7.2416e-02, 5.9086e-01, 1.2243e-01, ..., 7.8511e-01, 2.4803e-01, 5.3717e-01], ..., [7.4899e-01, 1.5467e-02, 4.9711e-01, ..., 2.2938e-02, 1.6099e-01, 3.1928e-01], [3.9111e-01, 1.2422e-01, 6.1795e-02, ..., 8.4212e-01, 6.1346e-01, 1.0957e-01], [3.6311e-02, 8.9652e-01, 7.7428e-01, ..., 9.2212e-01, 4.9290e-01, 4.5609e-01]], [[2.2052e-01, 4.4260e-01, 8.8627e-01, ..., 9.2381e-01, 7.7046e-01, 9.2057e-01], [5.5775e-01, 8.8951e-01, 7.9238e-01, ..., 3.9209e-01, 9.6636e-01, 8.1876e-01], [3.4709e-01, 7.8678e-01, 1.4396e-01, ..., 7.9073e-01, 3.9021e-01, 8.5285e-01], ..., [1.4238e-01, 9.8432e-01, 2.7802e-01, ..., 5.1720e-01, 1.6290e-01, 8.2036e-01], [2.0184e-01, 1.0635e-01, 1.9612e-01, ..., 9.7101e-01, 9.6679e-01, 7.0811e-01], [5.8240e-01, 3.1642e-01, 9.6549e-01, ..., 5.1130e-02, 5.6725e-01, 3.5238e-01]]]], device='npu:0') x1 = torch.rand(16, 16, 128, 128).npu() tensor([[[[2.4353e-01, 8.5665e-01, 5.3571e-01, ..., 5.9101e-01, 4.0872e-01, 6.3873e-01], [1.4489e-01, 8.7982e-01, 3.3114e-01, ..., 2.5155e-01, 8.4987e-01, 8.7096e-01], [6.5837e-02, 2.2677e-02, 7.2063e-01, ..., 2.3542e-01, 9.3041e-01, 8.9596e-01], ..., [5.1450e-01, 7.9412e-01, 8.9288e-01, ..., 3.3639e-01, 5.6086e-01, 4.8770e-02], [4.7557e-01, 1.4793e-01, 4.9800e-01, ..., 3.9479e-01, 5.6052e-01, 9.8271e-01], [7.4438e-01, 7.5646e-01, 2.7942e-02, ..., 3.0381e-01, 4.3703e-01, 1.4037e-02]], [[4.0232e-01, 9.4407e-01, 6.4969e-01, ..., 3.4524e-01, 8.2647e-01, 5.4792e-01], [1.1801e-01, 1.8281e-01, 6.1723e-01, ..., 1.9393e-01, 4.5877e-01, 8.9990e-01], [2.6244e-01, 6.9614e-01, 3.6008e-01, ..., 5.0258e-01, 8.1919e-01, 4.6943e-01], ..., [7.4710e-01, 5.8911e-01, 1.5292e-01, ..., 6.6590e-01, 4.0754e-01, 3.6944e-01], [9.0501e-01, 2.7943e-01, 3.7068e-01, ..., 1.5053e-01, 7.3413e-01, 7.9626e-01], [9.5200e-01, 7.8327e-01, 3.4033e-01, ..., 8.0892e-01, 4.0480e-01, 3.8717e-01]], [[7.5938e-01, 2.9089e-01, 5.9916e-01, ..., 6.2526e-01, 3.9670e-01, 3.3548e-01], [7.0733e-01, 8.1400e-01, 4.9259e-01, ..., 1.6607e-02, 6.5331e-01, 7.3150e-02], [5.2770e-01, 7.8141e-01, 4.1904e-01, ..., 3.8917e-01, 4.1405e-01, 9.9596e-01], ..., [4.8669e-01, 9.9948e-01, 1.2023e-01, ..., 7.0420e-01, 2.8522e-01, 6.6192e-01], [4.9718e-01, 7.5792e-01, 6.6748e-01, ..., 9.7302e-01, 3.3443e-01, 3.6536e-01], [7.7033e-01, 6.0550e-01, 8.2024e-01, ..., 2.9711e-01, 1.9410e-01, 6.6304e-01]], ..., [[1.0284e-01, 6.5712e-01, 6.0831e-01, ..., 6.2622e-01, 2.0355e-01, 9.4250e-01], [4.9053e-01, 2.0148e-01, 2.4974e-01, ..., 9.2521e-01, 1.9919e-01, 4.4700e-01], [7.6515e-01, 8.7755e-01, 1.3500e-01, ..., 8.2136e-01, 2.0848e-01, 5.6432e-01], ..., [3.3618e-01, 1.8585e-01, 5.3475e-01, ..., 4.9333e-01, 9.1018e-01, 9.5052e-01], [2.1400e-01, 1.7407e-01, 5.8925e-01, ..., 7.5722e-01, 2.9850e-01, 3.9298e-01], [6.3625e-01, 1.7168e-01, 2.9183e-01, ..., 9.9674e-01, 2.1718e-01, 5.2626e-01]], [[1.8651e-01, 2.5385e-01, 2.0384e-01, ..., 3.4462e-01, 8.4150e-01, 4.7431e-01], [2.4992e-01, 1.1788e-01, 1.9730e-01, ..., 4.3722e-02, 7.8943e-01, 9.9097e-01], [1.4493e-02, 6.4856e-01, 8.3344e-01, ..., 8.6623e-01, 1.5456e-01, 7.8423e-01], ..., [6.1458e-01, 4.4260e-01, 7.4133e-01, ..., 2.5126e-01, 2.7251e-01, 6.9784e-01], [2.2419e-01, 3.4159e-01, 2.3232e-01, ..., 8.2850e-01, 8.2644e-02, 4.8390e-01], [1.0171e-01, 8.7662e-01, 2.0457e-01, ..., 7.6868e-01, 7.6592e-01, 3.1254e-01]], [[1.8866e-01, 1.5755e-01, 3.1025e-02, ..., 6.5044e-01, 7.8293e-01, 9.8030e-01], [3.7703e-01, 5.3198e-01, 1.8633e-01, ..., 4.7398e-01, 8.3618e-01, 8.7283e-01], [5.7119e-01, 4.3620e-01, 8.2536e-01, ..., 2.5390e-01, 5.6144e-01, 4.4044e-01], ..., [1.3243e-01, 6.2002e-02, 7.5278e-01, ..., 7.5907e-01, 4.2472e-01, 1.7624e-01], [4.7985e-01, 7.9769e-01, 8.1433e-01, ..., 7.3780e-01, 2.2877e-02, 4.8816e-01], [4.5100e-01, 9.9698e-02, 7.0776e-01, ..., 9.8046e-01, 2.2372e-01, 8.6304e-01]]]], device='npu:0') _, _, out = torch_npu.npu_dropout_with_add_softmax(self, x1, 2, 0.9, -1) tensor([[[[0.0000, 0.0639, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0632, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0794, ..., 0.0000, 0.0000, 0.1571], [0.0000, 0.0000, 0.0000, ..., 0.1270, 0.0000, 0.0000]], [[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.1030, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.2134, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0342, 0.0000, 0.0633, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.1578, 0.1334, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], ..., [[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.2316, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0237, 0.0000, ..., 0.0000, 0.2128, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.1421, 0.0000, 0.0000, ..., 0.0499, 0.0000, 0.0000]], [[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0218, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.1461, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.1130, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.1976, ..., 0.0000, 0.0000, 0.0000]]]], device='npu:0')
父主题: torch_npu