torch_npu.npu_scatter_nd_update_

功能描述

将updates中的值按指定的索引indices更新input中的值,并将结果保存到输出tensor,input中的数据被改变。

接口原型

torch_npu.npu_scatter_nd_update_(Tensor(a!) input, Tensor indices, Tensor updates) -> Tensor(a!)

参数说明

输出说明

返回被更新后的input。

约束说明

支持的PyTorch版本

支持的型号

调用示例

import torch
import torch_npu
import numpy as np

data_var = np.random.uniform(0, 1, [24, 128]).astype(np.float16)
var = torch.from_numpy(data_var).to(torch.float16).npu()

data_indices = np.random.uniform(0, 12, [12, 1]).astype(np.int32)
indices = torch.from_numpy(data_indices).to(torch.int32).npu()

data_updates = np.random.uniform(1, 2, [12, 128]).astype(np.float16)
updates = torch.from_numpy(data_updates).to(torch.float16).npu()

torch_npu.npu_scatter_nd_update_(var, indices, updates)