torch

若API未标明支持情况,则代表该API的支持情况待验证。

API名称

是否支持

限制与说明

torch.SymInt

  

torch.SymFloat

  

torch.SymBool

  

torch.Tag

  

torch.is_tensor

  

torch.is_storage

  

torch.is_complex

  

torch.is_conj

  

torch.is_floating_point

  

torch.is_nonzero

  

torch.set_default_dtype

  

torch.get_default_dtype

  

torch.set_default_device

  

torch.set_default_tensor_type

  

torch.numel

  

torch.set_printoptions

  

torch.set_flush_denormal

  

torch.tensor

  

torch.sparse_coo_tensor

  

torch.sparse_csr_tensor

  

torch.sparse_csc_tensor

  

torch.sparse_bsr_tensor

  

torch.sparse_bsc_tensor

  

torch.asarray

  

torch.as_tensor

  

torch.as_strided

     

torch.from_numpy

  

torch.from_dlpack

  

torch.frombuffer

  

torch.zeros

  

torch.zeros_like

  

torch.ones

  

torch.ones_like

  

torch.arange

  

torch.range

  

torch.linspace

  

torch.logspace

  

torch.eye

  

torch.empty

  

torch.empty_like

  

torch.empty_strided

  

torch.full

  

torch.full_like

  

torch.quantize_per_tensor

  

torch.quantize_per_channel

  

torch.dequantize

     

torch.dequantize

     

torch.complex

  

torch.polar

  

torch.heaviside

  

torch.adjoint

  

torch.argwhere

  

torch.cat

  

torch.concat

不支持float64,不支持8D输入。

torch.concatenate

  

torch.conj

  

torch.chunk

  

torch.dsplit

只支持float16,float32,float64

torch.column_stack

  

torch.dstack

不支持float64

torch.gather

  

torch.hsplit

  

torch.hstack

不支持float64

torch.index_add

  

torch.index_copy

  

torch.index_reduce

  

torch.index_select

  

torch.masked_select

  

torch.movedim

  

torch.moveaxis

  

torch.narrow

  

torch.narrow_copy

  

torch.nonzero

  

torch.permute

     

torch.reshape

  

torch.row_stack

不支持float64

torch.select

  

torch.scatter

  

torch.diagonal_scatter

  

torch.select_scatter

  

torch.slice_scatter

  

torch.scatter_add

  

torch.scatter_reduce

  

torch.split

     

torch.squeeze

  

torch.stack

  

torch.swapaxes

  

torch.swapdims

  

torch.t

  

torch.take

  

torch.take_along_dim

不支持float64,bool

torch.tensor_split

  

torch.tile

  

torch.transpose

     

torch.unbind

     

torch.unsqueeze

     

torch.vsplit

只支持float16,float32,float64

torch.vstack

  

torch.where

  

torch.where

  

torch.Generator

  

torch.Generator.get_state

  

torch.Generator.initial_seed

  

torch.Generator.manual_seed

  

torch.Generator.seed

  

torch.Generator.set_state

  

torch.seed

  

torch.manual_seed

  

torch.initial_seed

  

torch.get_rng_state

  

torch.set_rng_state

  

torch.bernoulli

  

torch.multinomial

  

torch.normal

  

torch.normal

  

torch.normal

  

torch.normal

  

torch.poisson

可以走CPU实现

torch.rand

  

torch.rand_like

不支持int64

torch.randint

  

torch.randint_like

  

torch.randn

  

torch.randn_like

  

torch.randperm

  

torch.quasirandom.SobolEngine

  

torch.quasirandom.SobolEngine.draw

     

torch.quasirandom.SobolEngine.draw_base2

     

torch.quasirandom.SobolEngine.fast_forward

     

torch.quasirandom.SobolEngine.reset

     

torch.save

  

torch.load

     

torch.get_num_threads

只支持CPU,GPU/NPU不支持

torch.set_num_threads

只支持CPU,GPU/NPU不支持

torch.get_num_interop_threads

  

torch.set_num_interop_threads

只支持CPU,GPU/NPU不支持

torch.no_grad

  

torch.enable_grad

  

torch.set_grad_enabled

  

torch.is_grad_enabled

  

torch.inference_mode

  

torch.is_inference_mode_enabled

  

torch.abs

  

torch.absolute

  

torch.acos

不支持int32

torch.arccos

不支持int64

torch.acosh

不支持int64

torch.arccosh

不支持int64

torch.add

  

torch.addcdiv

  

torch.addcmul

支持fp16,fp32,fp64,int8,int32,int64,uint8,bool,bf16。

torch.angle

     

torch.asin

  

torch.arcsin

不支持int64

torch.asinh

  

torch.arcsinh

不支持int64

torch.atan

  

torch.arctan

不支持int64

torch.atanh

不支持int64

torch.arctanh

  

torch.atan2

  

torch.arctan2

  

torch.bitwise_not

  

torch.bitwise_and

  

torch.bitwise_or

  

torch.bitwise_xor

  

torch.bitwise_left_shift

  

torch.bitwise_right_shift

  

torch.ceil

  

torch.clamp

  

torch.clip

  

torch.conj_physical

     

torch.copysign

  

torch.cos

  

torch.cosh

  

torch.deg2rad

  

torch.div

  

torch.divide

  

torch.digamma

可以走CPU实现

torch.erf

  

torch.erfc

  

torch.erfinv

  

torch.exp

  

torch.exp2

     

torch.expm1

  

torch.fake_quantize_per_channel_affine

  

torch.fake_quantize_per_tensor_affine

  

torch.fix

  

torch.float_power

不支持float64

torch.floor

  

torch.floor_divide

  

torch.fmod

  

torch.frac

  

torch.frexp

  

torch.gradient

不支持8D输入,不支持bool,float64

torch.imag

  

torch.ldexp

在int64带out场景下,out场景也必须是int64类型。

torch.lerp

  

torch.lgamma

可以走CPU实现

torch.log

  

torch.log10

不支持bool, int8, int16,int32,int64,float64

torch.log1p

  

torch.log2

不支持bool, int8, int16,int32,int64,float64

torch.logaddexp

不支持double数据类型。

torch.logaddexp2

不支持double数据类型。

torch.logical_and

  

torch.logical_not

  

torch.logical_or

只支持bool

torch.logical_xor

     

torch.logit

可以走CPU实现

torch.hypot

  

torch.i0

  

torch.igamma

  

torch.igammac

  

torch.mul

  

torch.multiply

不支持bool

torch.mvlgamma

可以走CPU实现

torch.nan_to_num

  

torch.neg

  

torch.negative

  

torch.nextafter

  

torch.polygamma

  

torch.positive

  

torch.pow

不支持int64

torch.pow

不支持int64

torch.quantized_batch_norm

  

torch.quantized_max_pool1d

  

torch.quantized_max_pool2d

  

torch.rad2deg

不支持int64

torch.real

  

torch.reciprocal

不支持uint8,int8,int16,int32,int64,bool

torch.remainder

不支持int8,unit8,bool

torch.round

  

torch.rsqrt

不支持uint8,int8,int16,int32,int64,bool

torch.sigmoid

  

torch.sign

  

torch.sgn

  

torch.signbit

  

torch.sin

不支持uint8,int8,int16,int32,int64,bool

torch.sinc

  

torch.sinh

不支持uint8,int8,int16,int32,int64,bool

torch.softmax

  

torch.sqrt

不支持uint8,int8,int16,int32,int64,bool

torch.square

bool类型不支持out参数

torch.sub

不支持uint8

torch.subtract

不支持bool

torch.tan

不支持int8,int16,int64,bool

torch.tanh

  

torch.true_divide

不支持out参数

torch.trunc

只支持float16,float32

torch.xlogy

  

torch.argmax

  

torch.argmax

  

torch.argmin

  

torch.amax

  

torch.amin

  

torch.aminmax

  

torch.all

  

torch.all

  

torch.any

  

torch.any

  

torch.max

  

torch.max

  

torch.max

  

torch.min

  

torch.min

  

torch.min

  

torch.dist

  

torch.logsumexp

  

torch.mean

  

torch.mean

  

torch.nanmean

  

torch.median

只支持float16,float32

torch.median

只支持float16,float32

torch.nanmedian

可以走CPU实现

torch.nanmedian

可以走CPU实现

torch.mode

可以走CPU实现

torch.norm

  

torch.nansum

可以走CPU实现

torch.nansum

可以走CPU实现

torch.prod

  

torch.prod

  

torch.quantile

  

torch.nanquantile

  

torch.std

  

torch.std_mean

  

torch.sum

  

torch.sum

  

torch.unique

  

torch.unique_consecutive

传参时必须使用关键字,否则精度不达标。return_inverse=return_inverse,return_counts=return_counts,dim=dim。

torch.var

  

torch.var_mean

  

torch.count_nonzero

     

torch.allclose

  

torch.argsort

  

torch.eq

  

torch.equal

  

torch.ge

  

torch.greater_equal

  

torch.gt

  

torch.greater

  

torch.isclose

  

torch.isfinite

  

torch.isin

  

torch.isinf

  

torch.isposinf

  

torch.isneginf

  

torch.isnan

Atlas 训练系列产品:是

Atlas A2 训练系列产品:否

  

torch.isreal

  

torch.kthvalue

  

torch.le

  

torch.less_equal

  

torch.lt

  

torch.less

  

torch.maximum

  

torch.minimum

  

torch.fmax

  

torch.fmin

     

torch.ne

  

torch.not_equal

  

torch.sort

  

torch.topk

  

torch.msort

不支持float64,bool

torch.stft

  

torch.istft

  

torch.bartlett_window

  

torch.blackman_window

  

torch.hamming_window

  

torch.hann_window

  

torch.kaiser_window

  

torch.atleast_1d

  

torch.atleast_2d

  

torch.atleast_3d

  

torch.bincount

  

torch.block_diag

  

torch.broadcast_tensors

  

torch.broadcast_to

  

torch.broadcast_shapes

  

torch.bucketize

可以走CPU实现

torch.cartesian_prod

  

torch.cdist

  

torch.clone

  

torch.combinations

r不能大于8

torch.corrcoef

不支持float64

torch.cov

不支持bool,float16,float64

torch.cross

  

torch.cummax

  

torch.cummin

  

torch.cumprod

  

torch.cumsum

  

torch.diag

仅支持diagonal=0场景

torch.diag_embed

不支持复数

torch.diagflat

  

torch.diagonal

仅支持diagonal=0场景

torch.diff

  

torch.einsum

  

torch.flatten

  

torch.flip

  

torch.fliplr

  

torch.flipud

  

torch.kron

不支持5维度及以上输入

torch.rot90

  

torch.gcd

     

torch.histc

  

torch.histogram

  

torch.histogramdd

  

torch.meshgrid

  

torch.lcm

  

torch.logcumsumexp

  

torch.ravel

  

torch.renorm

不支持fp64数据类型输入。

torch.repeat_interleave

  

torch.repeat_interleave

  

torch.roll

  

torch.searchsorted

不支持float64,bool

torch.tensordot

  

torch.trace

支持fp16,fp32,fp64,int8,int16,int32,int64,uint8,complex64,complex128

torch.tril

  

torch.tril_indices

  

torch.triu

  

torch.triu_indices

  

torch.unflatten

  

torch.vander

  

torch.view_as_real

  

torch.view_as_complex

  

torch.resolve_conj

  

torch.resolve_neg

  

torch.addbmm

  

torch.addmm

  

torch.addmv

  

torch.addr

不支持int32

torch.baddbmm

  

torch.bmm

  

torch.chain_matmul

  

torch.cholesky

  

torch.cholesky_inverse

  

torch.cholesky_solve

  

torch.dot

  

torch.geqrf

  

torch.ger

  

torch.inner

只支持float16,float32

torch.inverse

  

torch.det

     

torch.logdet

  

torch.slogdet

  

torch.lu

  

torch.lu_solve

  

torch.lu_unpack

  

torch.matmul

  

torch.matrix_power

  

torch.matrix_exp

  

torch.mm

只支持float16,float32

torch.mv

  

torch.orgqr

  

torch.ormqr

  

torch.outer

  

torch.pinverse

  

torch.qr

  

torch.svd

  

torch.svd_lowrank

可以走CPU实现

torch.pca_lowrank

可以走CPU实现

torch.lobpcg

  

torch.trapz

  

torch.trapezoid

  

torch.cumulative_trapezoid

  

torch.triangular_solve

  

torch.vdot

可以走CPU实现

torch._foreach_abs

  

torch._foreach_abs_

  

torch._foreach_acos

  

torch._foreach_acos_

  

torch._foreach_asin

  

torch._foreach_asin_

  

torch._foreach_atan

  

torch._foreach_atan_

  

torch._foreach_ceil

  

torch._foreach_ceil_

  

torch._foreach_cos

  

torch._foreach_cos_

  

torch._foreach_cosh

  

torch._foreach_cosh_

  

torch._foreach_erf

  

torch._foreach_erf_

  

torch._foreach_erfc

  

torch._foreach_erfc_

  

torch._foreach_exp

  

torch._foreach_exp_

  

torch._foreach_expm1

  

torch._foreach_expm1_

  

torch._foreach_floor

  

torch._foreach_floor_

  

torch._foreach_log

  

torch._foreach_log_

  

torch._foreach_log10

  

torch._foreach_log10_

  

torch._foreach_log1p

  

torch._foreach_log1p_

  

torch._foreach_log2

  

torch._foreach_log2_

  

torch._foreach_neg

  

torch._foreach_neg_

  

torch._foreach_tan

  

torch._foreach_tan_

  

torch._foreach_sin

  

torch._foreach_sin_

  

torch._foreach_sinh

  

torch._foreach_sinh_

  

torch._foreach_round

  

torch._foreach_round_

  

torch._foreach_sqrt

  

torch._foreach_sqrt_

  

torch._foreach_lgamma

  

torch._foreach_lgamma_

  

torch._foreach_frac

  

torch._foreach_frac_

  

torch._foreach_reciprocal

  

torch._foreach_reciprocal_

  

torch._foreach_sigmoid

  

torch._foreach_sigmoid_

  

torch._foreach_trunc

  

torch._foreach_trunc_

  

torch._foreach_zero_

  

torch.compiled_with_cxx11_abi

  

torch.result_type

  

torch.can_cast

  

torch.promote_types

  

torch.use_deterministic_algorithms

  

torch.are_deterministic_algorithms_enabled

  

torch.is_deterministic_algorithms_warn_only_enabled

  

torch.set_deterministic_debug_mode

  

torch.get_deterministic_debug_mode

  

torch.set_float32_matmul_precision

  

torch.get_float32_matmul_precision

  

torch.set_warn_always

  

torch.is_warn_always_enabled

  

torch.vmap

  

torch._assert

  

torch.sym_float

  

torch.sym_int

  

torch.sym_max

  

torch.sym_min

  

torch.sym_not

  

torch.compile

  

torch.autograd.set_multithreading_enabled