torch_npu.npu_scatter_nd_update_
功能描述
将updates中的值按指定的索引indices更新self中的值,并将结果保存到输出tensor,self中的数据被改变。
接口原型
1 | torch_npu.npu_scatter_nd_update_(Tensor(a!) self, Tensor indices, Tensor updates) -> Tensor(a!) |
参数说明
- self:Device侧的Tensor类型,必选输入,源数据张量,数据格式支持ND,支持非连续的Tensor,数据类型需要与updates一致,维数只能是1~8维。
Atlas 推理系列加速卡产品 :数据类型支持FLOAT32、FLOAT16、BOOL。Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FFLOAT32、FLOAT16、BOOL、BFLOAT16、INT64。Atlas A3 训练系列产品 :数据类型支持FLOAT32、FLOAT16、BOOL、BFLOAT16、INT64。
- indices:Device侧的Tensor类型,必选输入,索引张量,数据类型支持INT32、INT64,数据格式支持ND,支持非连续的Tensor,indices中的索引数据不支持越界。
- updates:Device侧的Tensor类型,必选输入,更新数据张量,数据格式支持ND,支持非连续的Tensor,数据类型需要与self一致。
Atlas 推理系列加速卡产品 :数据类型支持FLOAT32、FLOAT16、BOOL。Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FFLOAT32、FLOAT16、BOOL、BFLOAT16、INT64。Atlas A3 训练系列产品 :数据类型支持FLOAT32、FLOAT16、BOOL、BFLOAT16、INT64。
输出说明
返回被更新后的self。
约束说明
- 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
- indices至少是2维,其最后1维的大小不能超过self的维度大小。
- 假设indices最后1维的大小是a,则updates的shape等于indices除最后1维外的shape加上self除前a维外的shape。举例:self的shape是(4, 5, 6),indices的shape是(3, 2),则updates的shape必须是(3, 6)。
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品 Atlas 训练系列产品 Atlas 推理系列产品
调用示例
- 单算子模式调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14
import torch import torch_npu import numpy as np data_var = np.random.uniform(0, 1, [24, 128]).astype(np.float16) var = torch.from_numpy(data_var).to(torch.float16).npu() data_indices = np.random.uniform(0, 12, [12, 1]).astype(np.int32) indices = torch.from_numpy(data_indices).to(torch.int32).npu() data_updates = np.random.uniform(1, 2, [12, 128]).astype(np.float16) updates = torch.from_numpy(data_updates).to(torch.float16).npu() torch_npu.npu_scatter_nd_update_(var, indices, updates)
- 图模式调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
import os import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig import torch.nn as nn import torch import numpy as np import numpy torch_npu.npu.set_compile_mode(jit_compile=True) os.environ["ENABLE_ACLNN"] = "false" class Network(nn.Module): def __init__(self): super(Network, self).__init__() def forward(self, var, indices, update): # 调用目标接口 res = torch_npu.npu_scatter_nd_update_(var, indices, update) return res npu_mode = Network() config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) npu_mode = torch.compile(npu_mode, fullgraph=True, backend=npu_backend, dynamic=False) dtype = np.float32 x = [33 ,5] indices = [33,25,1] update = [33,25,5] data_x = np.random.uniform(0, 1, x).astype(dtype) data_indices = np.random.uniform(0, 10, indices).astype(dtype) data_update = np.random.uniform(0, 1, update).astype(dtype) tensor_x = torch.from_numpy(data_x).to(torch.float16) tensor_indices = torch.from_numpy(data_indices).to(torch.int32) tensor_update = torch.from_numpy(data_update).to(torch.float16) # 传参 print(npu_mode(tensor_x.npu(), tensor_indices.npu(), tensor_update.npu()))
父主题: torch_npu