文档
注册
评分
提单
论坛
小AI

torch_npu. scatter_update_

功能描述

将tensor updates中的值按指定的轴axis和索引indices更新tensor data中的值,并将结果保存到输出tensor,data本身的数据被改变。

接口原型

scatter_update_(Tensor(a!) data, Tensor indices, Tensor updates, int axis) -> Tensor(a!)

参数说明

  • data:Device侧的Tensor类型,计算输入;数据类型支持INT8、FLOAT16、FLOAT32、BFLOAT16类型;data只支持2-8维,且维度大小需要与updates一致;支持非连续的tensor;数据格式支持ND;不支持空Tensor。
  • indices:Device侧的Tensor类型,计算输入;数据类型支持INT32、INT64;目前仅支持一维跟二维;支持非连续的tensor;数据格式支持ND;不支持空Tensor。
  • updates:Device侧的Tensor类型,计算输入;数据类型支持INT8、FLOAT16、FLOAT32、BFLOAT16类型;updates的维度大小需要与data一致;支持非连续的tensor;数据格式支持ND;不支持空Tensor。
  • axis(int64_t, 计算输入):用来scatter的维度,数据类型为INT64。

输出说明

out:Device侧的Tensor类型,计算输出,复用输入地址;数据类型支持INT8、FLOAT16、FLOAT32、BFLOAT16类型;out只支持2-8维,且维度大小需要与data一致;支持非连续的tensor;数据格式支持ND;不支持空Tensor。

约束说明

  • BFLOAT16数据类型仅支持如下产品型号:

    Atlas A2 训练系列产品

  • data与updates的秩一致。
  • 不支持索引越界,索引越界不校验。

支持的PyTorch版本

  • PyTorch 2.3
  • PyTorch 2.2
  • PyTorch 2.1
  • PyTorch 1.11.0

支持的型号

  • Atlas 训练系列产品
  • Atlas A2 训练系列产品

调用示例

import torch
import torch_npu
import numpy as np
data = torch.tensor([[[[1,1,1,1,1,1,1,1],[2,2,2,2,2,2,2,2]]]], dtype=torch.float32).npu()
indices = torch.tensor ([1],dtype=torch.int64).npu()
updates = torch.tensor([[[[3,3,3,3,3,3,3,3]]]] , dtype=torch.float32).npu()
out = torch_npu.scatter_update_(data, indices, updates, axis=-2)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词