下载
中文
注册

torch_npu.npu_scatter_nd_update_

功能描述

将updates中的值按指定的索引indices更新self中的值,并将结果保存到输出tensor,self中的数据被改变。

接口原型

1
torch_npu.npu_scatter_nd_update_(Tensor(a!) self, Tensor indices, Tensor updates) -> Tensor(a!)

参数说明

  • self:Device侧的Tensor类型,必选输入,源数据张量,数据格式支持ND,支持非连续的Tensor,数据类型需要与updates一致,维数只能是1~8维。
    • Atlas 推理系列加速卡产品:数据类型支持FLOAT32、FLOAT16、BOOL。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FFLOAT32、FLOAT16、BOOL、BFLOAT16、INT64。
    • Atlas A3 训练系列产品:数据类型支持FLOAT32、FLOAT16、BOOL、BFLOAT16、INT64。
  • indices:Device侧的Tensor类型,必选输入,索引张量,数据类型支持INT32、INT64,数据格式支持ND,支持非连续的Tensor,indices中的索引数据不支持越界。
  • updates:Device侧的Tensor类型,必选输入,更新数据张量,数据格式支持ND,支持非连续的Tensor,数据类型需要与self一致。
    • Atlas 推理系列加速卡产品:数据类型支持FLOAT32、FLOAT16、BOOL。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品:数据类型支持FFLOAT32、FLOAT16、BOOL、BFLOAT16、INT64。
    • Atlas A3 训练系列产品:数据类型支持FLOAT32、FLOAT16、BOOL、BFLOAT16、INT64。

输出说明

返回被更新后的self。

约束说明

  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
  • indices至少是2维,其最后1维的大小不能超过self的维度大小。
  • 假设indices最后1维的大小是a,则updates的shape等于indices除最后1维外的shape加上self除前a维外的shape。举例:self的shape是(4, 5, 6),indices的shape是(3, 2),则updates的shape必须是(3, 6)。

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品
  • Atlas 训练系列产品
  • Atlas 推理系列产品

调用示例

  • 单算子模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    import torch
    import torch_npu
    import numpy as np
    
    data_var = np.random.uniform(0, 1, [24, 128]).astype(np.float16)
    var = torch.from_numpy(data_var).to(torch.float16).npu()
    
    data_indices = np.random.uniform(0, 12, [12, 1]).astype(np.int32)
    indices = torch.from_numpy(data_indices).to(torch.int32).npu()
    
    data_updates = np.random.uniform(1, 2, [12, 128]).astype(np.float16)
    updates = torch.from_numpy(data_updates).to(torch.float16).npu()
    
    torch_npu.npu_scatter_nd_update_(var, indices, updates)
    
  • 图模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    import os
    import torch_npu
    import torchair as tng
    from torchair.configs.compiler_config import CompilerConfig
    import torch.nn as nn
    import torch
    import numpy as np
    import numpy
    torch_npu.npu.set_compile_mode(jit_compile=True)
    
    os.environ["ENABLE_ACLNN"] = "false"
    class Network(nn.Module):
        def __init__(self):
            super(Network, self).__init__()
    
        def forward(self, var, indices, update):
            # 调用目标接口
            res = torch_npu.npu_scatter_nd_update_(var, indices, update)
            return res
    		
    npu_mode = Network()
    config = CompilerConfig()
    npu_backend = tng.get_npu_backend(compiler_config=config)
    npu_mode = torch.compile(npu_mode, fullgraph=True, backend=npu_backend, dynamic=False)
    
    dtype = np.float32
    x = [33 ,5]
    indices = [33,25,1]
    update = [33,25,5]
    
    data_x = np.random.uniform(0, 1, x).astype(dtype)
    data_indices = np.random.uniform(0, 10, indices).astype(dtype)
    data_update = np.random.uniform(0, 1, update).astype(dtype)
    
    tensor_x = torch.from_numpy(data_x).to(torch.float16)
    tensor_indices = torch.from_numpy(data_indices).to(torch.int32)
    tensor_update = torch.from_numpy(data_update).to(torch.float16)
    
    # 传参
    print(npu_mode(tensor_x.npu(), tensor_indices.npu(), tensor_update.npu()))