下载
中文
注册

Nonzero算子替换

将mask转化为index,对于所有值大于0的tensor在某些计算中可以利用乘法进行替代。比如要对mask的tensor求和,tensor_a[mask].sum()就相当于(tensor_a * mask).sum()。

例如:

shape = (1024, ) 
mask= torch.randint(-1, 2, shape).npu() 
tensor_a = torch.ones(shape).float().npu() 
mask_inds = torch.nonzero(gt_inds > 0, as_tuple=False).squeeze(1)  
tensor_sum = tensor_a[mask_inds].sum()

替换代码如下:

shape = (1024, )  
mask= torch.randint(-1, 2, shape).npu()  
tensor_a = torch.ones(shape).float().npu()  
mask_inds = torch.nonzero( gt_inds > 0, as_tuple=False).squeeze(1) 
tensor_sum = (tensor_a * mask_inds).sum()