将mask转化为index,对于所有值大于0的tensor在某些计算中可以利用乘法进行替代。比如要对mask的tensor求和,tensor_a[mask].sum()就相当于(tensor_a * mask).sum()。
例如:
import torch import torch_npu shape = (1024, ) mask= torch.randint(-1, 2, shape).npu() gt_inds = torch.randint(-1, 2, shape).npu() tensor_a = torch.ones(shape).float().npu() mask_inds = torch.nonzero(gt_inds > 0, as_tuple=False).squeeze(1) tensor_sum = tensor_a[mask_inds].sum()
替换代码如下:
import torch import torch_npu shape = (1024, ) mask= torch.randint(-1, 2, shape).npu() gt_inds = torch.randint(-1, 2, shape).npu() tensor_a = torch.ones(shape).float().npu() mask_inds = torch.nonzero( gt_inds > 0, as_tuple=False).squeeze(1) tensor_sum = (tensor_a * mask_inds).sum()