torch_npu.npu_random_choice_with_mask(x, count=256, seed=0, seed2=0) -> (Tensor, Tensor)
混洗非零元素的index。
1 2 3 4 5 6 | >>> x = torch.tensor([1, 0, 1, 0], dtype=torch.bool).to("npu") >>> result, mask = torch_npu.npu_random_choice_with_mask(x, 2, 1, 0) >>> result tensor([[0],[2]], device='npu:0', dtype=torch.int32) >>> mask tensor([True, True], device='npu:0') |