Nonzero算子替换
将mask转化为index,对于所有值大于0的tensor在某些计算中可以利用乘法进行替代。比如要对mask的tensor求和,tensor_a[mask].sum()就相当于(tensor_a * mask).sum()。
例如:
shape = (1024, ) mask= torch.randint(-1, 2, shape).npu() tensor_a = torch.ones(shape).float().npu() mask_inds = torch.nonzero(gt_inds > 0, as_tuple=False).squeeze(1) tensor_sum = tensor_a[mask_inds].sum()
替换代码如下:
shape = (1024, ) mask= torch.randint(-1, 2, shape).npu() tensor_a = torch.ones(shape).float().npu() mask_inds = torch.nonzero( gt_inds > 0, as_tuple=False).squeeze(1) tensor_sum = (tensor_a * mask_inds).sum()
父主题: 亲和算子替换