(beta)torch_npu.npu_scatter
接口原型
torch_npu.npu_scatter(self, indices, updates, dim) -> Tensor
功能描述
使用dim对scatter结果进行计数。类似于torch.scatter,优化NPU设备实现。
参数说明
- self (Tensor) - 输入张量。
- indices (Tensor) - 待scatter的元素index,可以为空,也可以与src有相同的维数。当为空时,操作返回“self unchanged”。
- updates (Tensor) - 待scatter的源元素。
- dim (Int) - 要进行index的轴。
调用示例
>>> input = torch.tensor([[1.6279, 0.1226], [0.9041, 1.0980]]).npu() >>> input tensor([[1.6279, 0.1226], [0.9041, 1.0980]], device='npu:0') >>> indices = torch.tensor([0, 1],dtype=torch.int32).npu() >>> indices tensor([0, 1], device='npu:0', dtype=torch.int32) >>> updates = torch.tensor([-1.1993, -1.5247]).npu() >>> updates tensor([-1.1993, -1.5247], device='npu:0') >>> dim = 0 >>> output = torch_npu.npu_scatter(input, indices, updates, dim) >>> output tensor([[-1.1993, 0.1226], [ 0.9041, -1.5247]], device='npu:0')
父主题: torch_npu