文档
注册

torch_npu.npu_scatter_nd_update_

功能描述

将updates中的值按指定的索引indices更新self中的值,并将结果保存到输出tensor,self中的数据被改变。

接口原型

torch_npu.npu_scatter_nd_update_(Tensor(a!) self, Tensor indices, Tensor updates) -> Tensor(a!)

参数说明

  • self:Device侧的Tensor类型,必选输入,源数据张量,数据类型支持FLOAT32、FLOAT16、BOOL、BFLOAT16(仅Atlas A2 训练系列产品支持)、INT64(仅Atlas A2 训练系列产品支持),数据格式支持ND,支持非连续的Tensor,数据类型需要与updates一致,维数只能是1~8维。
  • indices:Device侧的Tensor类型,必选输入,索引张量,数据类型支持INT32、INT64,数据格式支持ND,支持非连续的Tensor,indices中的索引数据不支持越界。
  • updates:Device侧的Tensor类型,必选输入,更新数据张量,数据类型支持FLOAT32、FLOAT16、BOOL、BFLOAT16(仅Atlas A2 训练系列产品支持)、INT64(仅Atlas A2 训练系列产品支持),数据格式支持ND,支持非连续的Tensor,数据类型需要与self一致。

输出说明

返回被更新后的self。

约束说明

  • indices至少是2维,其最后1维的大小不能超过self的维度大小。
  • 假设indices最后1维的大小是a,则updates的shape等于indices除最后1维外的shape加上self除前a维外的shape。举例:self的shape是(4, 5, 6),indices的shape是(3, 2),则updates的shape必须是(3, 6)。

支持的PyTorch版本

  • PyTorch 2.2
  • PyTorch 2.1
  • PyTorch 1.11.0

支持的型号

  • Atlas A2 训练系列产品
  • Atlas 训练系列产品
  • Atlas 推理系列产品(Ascend 310P处理器)

调用示例

import torch
import torch_npu
import numpy as np

data_var = np.random.uniform(0, 1, [24, 128]).astype(np.float16)
var = torch.from_numpy(data_var).to(torch.float16).npu()

data_indices = np.random.uniform(0, 12, [12, 1]).astype(np.int32)
indices = torch.from_numpy(data_indices).to(torch.int32).npu()

data_updates = np.random.uniform(1, 2, [12, 128]).astype(np.float16)
updates = torch.from_numpy(data_updates).to(torch.float16).npu()

torch_npu.npu_scatter_nd_update_(var, indices, updates)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词