torch_npu.npu_random_choice_with_mask

API接口

torch_npu.npu_random_choice_with_mask(x, count=256, seed=0, seed2=0) -> (Tensor, Tensor)

功能描述

混洗非零元素的index。

参数说明

输出说明

示例

>>> x = torch.tensor([1, 0, 1, 0], dtype=torch.bool).to("npu")
>>> result, mask = torch_npu.npu_random_choice_with_mask(x, 2, 1, 0)
>>> resulttensor([[0],
        [2]], device='npu:0', dtype=torch.int32)
>>> mask
tensor([True, True], device='npu:0')