文档
注册

beta)torch_npu.npu_scatter

该接口计划废弃,可以参考小算子拼接方案进行替换。

接口原型

torch_npu.npu_scatter(self, indices, updates, dim) -> Tensor

功能描述

使用dim对scatter结果进行计数。类似于torch.scatter,优化NPU设备实现。

参数说明

  • self (Tensor) - 输入张量。
  • indices (Tensor) - 待scatter的元素index,可以为空,也可以与src有相同的维数。当为空时,操作返回“self unchanged”。
  • updates (Tensor) - 待scatter的源元素。
  • dim (Int) - 要进行index的轴。

调用示例

>>> input    = torch.tensor([[1.6279, 0.1226], [0.9041, 1.0980]]).npu()
>>> input
tensor([[1.6279, 0.1226],
        [0.9041, 1.0980]], device='npu:0')
>>> indices  = torch.tensor([0, 1],dtype=torch.int32).npu()
>>> indices
tensor([0, 1], device='npu:0', dtype=torch.int32)
>>> updates  = torch.tensor([-1.1993, -1.5247]).npu()
>>> updates
tensor([-1.1993, -1.5247], device='npu:0')
>>> dim = 0
>>> output = torch_npu.npu_scatter(input, indices, updates, dim)
>>> output
tensor([[-1.1993,  0.1226],
        [ 0.9041, -1.5247]], device='npu:0')
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词