下载
中文
注册

beta)torch_npu.npu_dtype_cast

该接口计划废弃,可以使用torch.to接口进行替换。

接口原型

torch_npu.npu_dtype_cast(input, dtype) -> Tensor

功能描述

执行张量数据类型(dtype)转换。支持FakeTensor模式。

参数说明

  • input (Tensor) - 输入张量。
  • dtype (torch.dtype) - 返回张量的目标数据类型。

支持的型号

  • Atlas 训练系列产品
  • Atlas A2 训练系列产品
  • Atlas A3 训练系列产品
  • Atlas 推理系列产品

调用示例

示例一:

1
2
>>> torch_npu.npu_dtype_cast(torch.tensor([0, 0.5, -1.]).npu(), dtype=torch.int)
tensor([ 0,  0, -1], device='npu:0', dtype=torch.int32)

示例二:

1
2
3
4
5
6
7
8
//FakeTensor模式
>>> from torch._subclasses.fake_tensor import FakeTensorMode
>>> with FakeTensorMode():
...     x = torch.rand(2, dtype=torch.float32).npu()
...     res = torch_npu.npu_dtype_cast(x, torch.float16)
...
>>> res
FakeTensor(..., device='npu:0', size=(2,), dtype=torch.float16)