torch_npu. scatter_update
功能描述
将tensor updates中的值按指定的轴axis和索引indices更新tensor data中的值,并将结果保存到输出tensor,data本身的数据不变。
接口原型
scatter_update(Tensor data, Tensor indices, Tensor updates, int axis) -> Tensor
参数说明
- data:Device侧的Tensor类型,计算输入;数据类型支持INT8、FLOAT16、FLOAT32、BFLOAT16类型;data只支持2-8维,且维度大小需要与updates一致;支持非连续的tensor;数据格式支持ND;不支持空Tensor。
- indices:Device侧的Tensor类型,计算输入;数据类型支持INT32、INT64;目前仅支持一维跟二维;支持非连续的tensor;数据格式支持ND;不支持空Tensor。
- updates:Device侧的Tensor类型,计算输入,数据类型支持INT8、FLOAT16、FLOAT32、BFLOAT16类型;updates的维度大小需要与data一致;支持非连续的tensor;数据格式支持ND;不支持空Tensor。
- axis(int64_t, 计算输入):用来scatter的维度,数据类型为INT64。
输出说明
out:Device侧的Tensor类型,计算输出;数据类型支持INT8、FLOAT16、FLOAT32、BFLOAT16类型;out只支持2-8维,且维度大小需要与data一致;支持非连续的tensor;数据格式支持ND;不支持空Tensor。
支持的PyTorch版本
- PyTorch 2.3
- PyTorch 2.2
- PyTorch 2.1
- PyTorch 1.11.0
支持的型号
- Atlas 训练系列产品
- Atlas A2 训练系列产品
调用示例
import torch import torch_npu import numpy as np data = torch.tensor([[[[1,1,1,1,1,1,1,1],[2,2,2,2,2,2,2,2]]]], dtype=torch.float32).npu() indices = torch.tensor ([1],dtype=torch.int64).npu() updates = torch.tensor([[[[3,3,3,3,3,3,3,3]]]] , dtype=torch.float32).npu() out = torch_npu.scatter_update(data, indices, updates, axis=-2)
父主题: torch_npu