torch.nn

若API未标明支持情况,则代表该API的支持情况待验证。

API名称

是否支持

限制与说明

torch.nn.parameter.Parameter

支持fp32

torch.nn.parameter.UninitializedParameter

  

torch.nn.parameter.UninitializedBuffer

  

torch.nn.Module

支持fp32

torch.nn.Module.add_module

支持fp32

torch.nn.Module.apply

支持fp32

torch.nn.Module.bfloat16

支持bf16,fp32,int64

torch.nn.Module.buffers

  

torch.nn.Module.children

支持fp32

torch.nn.Module.cpu

支持fp32

torch.nn.Module.cuda

支持fp32

torch.nn.Module.double

  

torch.nn.Module.eval

支持fp32,int64

torch.nn.Module.extra_repr

支持fp32

torch.nn.Module.float

支持fp16,fp32,int64

torch.nn.Module.forward

支持fp32

torch.nn.Module.get_buffer

  

torch.nn.Module.get_extra_state

  

torch.nn.Module.get_parameter

支持fp32

torch.nn.Module.get_submodule

支持fp32

torch.nn.Module.half

支持fp16,fp32,int64

torch.nn.Module.ipu

  

torch.nn.Module.load_state_dict

支持fp32

torch.nn.Module.modules

支持fp32

torch.nn.Module.named_buffers

  

torch.nn.Module.named_children

支持fp32

torch.nn.Module.named_modules

支持fp32

torch.nn.Module.named_parameters

  

torch.nn.Module.parameters

  

torch.nn.Module.register_backward_hook

支持fp32

torch.nn.Module.register_buffer

支持fp32

torch.nn.Module.register_forward_hook

支持fp32

torch.nn.Module.register_forward_pre_hook

支持fp32

torch.nn.Module.register_full_backward_hook

支持fp32

torch.nn.Module.register_full_backward_pre_hook

  

torch.nn.Module.register_load_state_dict_post_hook

  

torch.nn.Module.register_module

支持fp32

torch.nn.Module.register_parameter

  

torch.nn.Module.register_state_dict_pre_hook

  

torch.nn.Module.requires_grad_

  

torch.nn.Module.set_extra_state

  

torch.nn.Module.share_memory

     

torch.nn.Module.state_dict

支持fp32

torch.nn.Module.to

支持fp32,int64

torch.nn.Module.to_empty

支持fp32

torch.nn.Module.train

支持fp32

torch.nn.Module.type

支持fp16,fp32,int64

torch.nn.Module.xpu

  

torch.nn.Module.zero_grad

支持fp32

torch.nn.Sequential

支持fp32

torch.nn.Sequential.append

支持fp32

torch.nn.ModuleList

支持fp32

torch.nn.ModuleList.append

支持fp32

torch.nn.ModuleList.extend

支持fp32

torch.nn.ModuleList.insert

支持fp32

torch.nn.ModuleDict

支持fp32

torch.nn.ModuleDict.clear

支持fp32

torch.nn.ModuleDict.items

支持fp32

torch.nn.ModuleDict.keys

支持fp32

torch.nn.ModuleDict.pop

支持fp32

torch.nn.ModuleDict.update

支持fp32

torch.nn.ModuleDict.values

支持fp32

torch.nn.ParameterList

支持fp32

torch.nn.ParameterList.append

支持fp32

torch.nn.ParameterList.extend

支持fp32

torch.nn.ParameterDict

支持fp32

torch.nn.ParameterDict.clear

支持fp32

torch.nn.ParameterDict.copy

支持fp32

torch.nn.ParameterDict.fromkeys

支持fp32

torch.nn.ParameterDict.get

支持fp32

torch.nn.ParameterDict.items

支持fp32

torch.nn.ParameterDict.keys

支持fp32

torch.nn.ParameterDict.pop

支持fp32

torch.nn.ParameterDict.popitem

支持fp32

torch.nn.ParameterDict.setdefault

支持fp32

torch.nn.ParameterDict.update

支持fp32

torch.nn.ParameterDict.values

支持fp32

torch.nn.modules.module.register_module_forward_pre_hook

支持fp32

torch.nn.modules.module.register_module_forward_hook

支持fp32

torch.nn.modules.module.register_module_backward_hook

支持fp32

torch.nn.modules.module.register_module_full_backward_hook

支持fp32

torch.nn.Conv1d

支持fp32,fp16

torch.nn.Conv2d

支持fp16,fp32

torch.nn.Conv3d

  

torch.nn.ConvTranspose1d

支持fp32

torch.nn.ConvTranspose2d

支持fp16,fp32

torch.nn.ConvTranspose3d

  

torch.nn.LazyConv1d

支持fp32,fp16

torch.nn.LazyConv2d

支持fp32,fp16

torch.nn.LazyConv3d

  

torch.nn.LazyConvTranspose1d

支持fp16

torch.nn.LazyConvTranspose2d

支持fp16,fp32

torch.nn.LazyConvTranspose3d

     

torch.nn.Unfold

     

torch.nn.Fold

     

torch.nn.MaxPool1d

     

torch.nn.MaxPool2d

     

torch.nn.MaxPool3d

     

torch.nn.MaxUnpool1d

支持fp16,fp32

torch.nn.MaxUnpool2d

支持fp16,fp32

torch.nn.MaxUnpool3d

     

torch.nn.AvgPool1d

  

torch.nn.AvgPool2d

  

torch.nn.AvgPool3d

  

torch.nn.FractionalMaxPool2d

  

torch.nn.FractionalMaxPool3d

  

torch.nn.LPPool1d

  

torch.nn.LPPool2d

  

torch.nn.AdaptiveMaxPool1d

     

torch.nn.AdaptiveMaxPool2d

     

torch.nn.AdaptiveMaxPool3d

     

torch.nn.AdaptiveAvgPool1d

支持fp32,fp16

torch.nn.AdaptiveAvgPool2d

支持fp32,fp16

torch.nn.AdaptiveAvgPool3d

     

torch.nn.ReflectionPad1d

支持fp32,fp16

torch.nn.ReflectionPad2d

支持fp32,fp16

torch.nn.ReflectionPad3d

  

torch.nn.ReplicationPad1d

支持fp32,fp16,complex64,complex128

torch.nn.ReplicationPad2d

支持fp32,fp16,complex64,complex128

torch.nn.ReplicationPad3d

可以走CPU实现

torch.nn.ZeroPad2d

支持bf16,fp32,fp16,uint8,int8,int16,int32,int64,bool,complex64,complex128

torch.nn.ConstantPad1d

支持int8,bool

在输入x为六维以上时可能会出现性能下降问题

torch.nn.ConstantPad2d

支持fp32,fp16,uint8,int8,int16,int32,int64,complex64,complex128

在输入x为六维以上时可能会出现性能下降问题

torch.nn.ConstantPad3d

  

torch.nn.ELU

支持fp32,fp16

torch.nn.Hardshrink

支持fp32,fp16

torch.nn.Hardsigmoid

支持fp16,fp32,int32

torch.nn.Hardtanh

支持fp32,fp16,int8,int16,int32,int64

torch.nn.Hardswish

支持fp16,fp32

torch.nn.LeakyReLU

支持fp16,fp32,fp64

torch.nn.LogSigmoid

支持fp32,fp16

torch.nn.MultiheadAttention

  

torch.nn.MultiheadAttention.forward

  

torch.nn.MultiheadAttention.merge_masks

  

torch.nn.PReLU

支持fp32

torch.nn.ReLU

支持bf16,fp16,fp32,int8,uint8,int32,int64

torch.nn.ReLU6

支持fp32,fp16,uint8,int8,int16,int32,int64

torch.nn.RReLU

     

torch.nn.SELU

支持fp16,fp32,int8,int32,fp64,uint8,int16,int64,bool

torch.nn.CELU

支持fp32,fp16

torch.nn.GELU

支持bf16,fp32,fp16

torch.nn.Sigmoid

支持bf16,fp32,fp16,uint8,int8,int16,int32,int64,bool,complex64,complex128

torch.nn.SiLU

支持bf16,fp16,fp32

torch.nn.Mish

     

torch.nn.Softplus

支持fp16,fp32

torch.nn.Softshrink

支持bf16,fp16,fp32

torch.nn.Softsign

支持bf16,fp32,fp16,uint8,int8,int16,int32,int64

torch.nn.Tanh

支持bf16,fp16,fp32,bool

torch.nn.Tanhshrink

支持fp32,fp16,uint8,int8,int16,int32,int64

torch.nn.Threshold

支持fp16,fp32,int8,uint8,int16,int32,int64

torch.nn.GLU

支持fp32,fp16

torch.nn.Softmin

支持bf16,fp32,fp16

torch.nn.Softmax

支持bf16,fp32,fp16

torch.nn.Softmax2d

支持bf16,fp32,fp16

torch.nn.LogSoftmax

支持bf16,fp32,fp16

torch.nn.AdaptiveLogSoftmaxWithLoss

  

torch.nn.AdaptiveLogSoftmaxWithLoss.log_prob

  

torch.nn.AdaptiveLogSoftmaxWithLoss.predict

  

torch.nn.BatchNorm1d

支持fp32,fp16

torch.nn.BatchNorm2d

支持fp32,fp16

torch.nn.BatchNorm3d

支持fp32,fp16

torch.nn.LazyBatchNorm1d

     

torch.nn.LazyBatchNorm2d

     

torch.nn.LazyBatchNorm3d

     

torch.nn.GroupNorm

支持fp32

torch.nn.SyncBatchNorm

支持fp16,fp32

torch.nn.SyncBatchNorm.convert_sync_batchnorm

  

torch.nn.InstanceNorm1d

  

torch.nn.InstanceNorm2d

  

torch.nn.InstanceNorm3d

  

torch.nn.LazyInstanceNorm1d

     

torch.nn.LazyInstanceNorm2d

     

torch.nn.LazyInstanceNorm3d

     

torch.nn.LayerNorm

支持fp32,fp16

torch.nn.LocalResponseNorm

  

torch.nn.RNNBase

     

torch.nn.RNNBase.flatten_parameters

     

torch.nn.RNN

  

torch.nn.LSTM

  

torch.nn.GRU

  

torch.nn.RNNCell

     

torch.nn.LSTMCell

  

torch.nn.GRUCell

支持fp32,fp16

torch.nn.Transformer

支持fp32,fp16

torch.nn.Transformer.forward

     

torch.nn.Transformer.generate_square_subsequent_mask

  

torch.nn.TransformerEncoder

  

torch.nn.TransformerEncoder.forward

  

torch.nn.TransformerDecoder

     

torch.nn.TransformerDecoder.forward

     

torch.nn.TransformerEncoderLayer

  

torch.nn.TransformerEncoderLayer.forward

  

torch.nn.TransformerDecoderLayer

  

torch.nn.TransformerDecoderLayer.forward

  

torch.nn.Identity

支持fp32

torch.nn.Linear

支持fp32,fp16

torch.nn.Bilinear

支持fp32,fp16

torch.nn.LazyLinear

支持fp16,fp32

torch.nn.Dropout

     

torch.nn.Dropout1d

  

torch.nn.Dropout2d

支持fp16,fp32,int64,bool

torch.nn.Dropout3d

     

torch.nn.AlphaDropout

支持fp32,fp16

torch.nn.FeatureAlphaDropout

支持fp32,fp16

torch.nn.Embedding

支持int32,int64

属性max_norm仅支持非负值

torch.nn.Embedding.from_pretrained

支持fp64

torch.nn.EmbeddingBag

支持int32,int64

torch.nn.EmbeddingBag.forward

支持int64

torch.nn.EmbeddingBag.from_pretrained

支持int64

torch.nn.CosineSimilarity

     

torch.nn.PairwiseDistance

     

torch.nn.L1Loss

支持fp32,fp16,int64,

torch.nn.MSELoss

支持fp32,fp16

torch.nn.CrossEntropyLoss

支持fp32,fp16

torch.nn.CTCLoss

支持fp32

torch.nn.NLLLoss

支持fp32,fp16

torch.nn.PoissonNLLLoss

支持bf16,fp32,fp16,int64

torch.nn.GaussianNLLLoss

支持bf16,fp32,fp16,int16,int32,int64

torch.nn.KLDivLoss

支持fp16,fp32,fp64

torch.nn.BCELoss

支持fp32,fp16

torch.nn.BCEWithLogitsLoss

支持fp32,fp16,int16,int32,int64,bool

torch.nn.MarginRankingLoss

支持bf16,fp32,fp16,int8,int16,int32,int64

torch.nn.HingeEmbeddingLoss

支持bf16,fp32,fp16,uint8,int8,int16,int32,int64

torch.nn.MultiLabelMarginLoss

     

torch.nn.HuberLoss

  

torch.nn.SmoothL1Loss

支持fp32,fp16

torch.nn.SoftMarginLoss

     

torch.nn.MultiLabelSoftMarginLoss

支持fp16,fp32

torch.nn.CosineEmbeddingLoss

     

torch.nn.MultiMarginLoss

  

torch.nn.TripletMarginLoss

  

torch.nn.TripletMarginWithDistanceLoss

支持bf16,fp32,fp16

torch.nn.PixelShuffle

支持fp16,fp32,fp64,int8,uint8,int16,int32,int64,bool

torch.nn.PixelUnshuffle

支持fp16,fp32,fp64,int8,uint8,int16,int32,int64,bool

torch.nn.Upsample

支持fp16,fp32,fp64

只支持mode = nearest

torch.nn.UpsamplingNearest2d

支持fp32,fp16,uint8

torch.nn.UpsamplingBilinear2d

     

torch.nn.ChannelShuffle

支持bf16,fp32,fp16,uint8,int8,int16,int32,int64,bool,complex64,complex128

torch.nn.DataParallel

  

torch.nn.parallel.DistributedDataParallel

  

torch.nn.parallel.DistributedDataParallel.join

  

torch.nn.parallel.DistributedDataParallel.join_hook

  

torch.nn.parallel.DistributedDataParallel.no_sync

  

torch.nn.parallel.DistributedDataParallel.register_comm_hook

  

torch.nn.utils.clip_grad_norm_

  

torch.nn.utils.clip_grad_value_

支持bf16,fp32,fp16

torch.nn.utils.parameters_to_vector

     

torch.nn.utils.vector_to_parameters

支持bf16,fp16,fp32,fp64,complex64

torch.nn.utils.prune.BasePruningMethod

  

torch.nn.utils.prune.BasePruningMethod.apply

  

torch.nn.utils.prune.BasePruningMethod.apply_mask

支持fp32

torch.nn.utils.prune.BasePruningMethod.prune

支持fp32

torch.nn.utils.prune.BasePruningMethod.remove

支持fp32

torch.nn.utils.prune.PruningContainer

  

torch.nn.utils.prune.PruningContainer.add_pruning_method

  

torch.nn.utils.prune.PruningContainer.apply

     

torch.nn.utils.prune.PruningContainer.apply_mask

  

torch.nn.utils.prune.PruningContainer.compute_mask

支持fp32

torch.nn.utils.prune.PruningContainer.prune

支持fp32

torch.nn.utils.prune.PruningContainer.remove

支持fp32

torch.nn.utils.prune.Identity

支持fp32

torch.nn.utils.prune.Identity.apply

支持fp32

torch.nn.utils.prune.Identity.apply_mask

支持fp32

torch.nn.utils.prune.Identity.prune

支持fp32

torch.nn.utils.prune.Identity.remove

支持fp32

torch.nn.utils.prune.RandomUnstructured

支持fp32

torch.nn.utils.prune.RandomUnstructured.apply

支持fp32

torch.nn.utils.prune.RandomUnstructured.apply_mask

支持fp32

torch.nn.utils.prune.RandomUnstructured.prune

支持fp32

torch.nn.utils.prune.RandomUnstructured.remove

  

torch.nn.utils.prune.L1Unstructured

支持fp32

torch.nn.utils.prune.L1Unstructured.apply

支持fp32

torch.nn.utils.prune.L1Unstructured.apply_mask

支持fp32

torch.nn.utils.prune.L1Unstructured.prune

支持fp32

torch.nn.utils.prune.L1Unstructured.remove

支持fp32

torch.nn.utils.prune.RandomStructured

支持fp32

torch.nn.utils.prune.RandomStructured.apply

支持fp32

torch.nn.utils.prune.RandomStructured.apply_mask

支持fp32

torch.nn.utils.prune.RandomStructured.compute_mask

支持fp32

torch.nn.utils.prune.RandomStructured.prune

  

torch.nn.utils.prune.RandomStructured.remove

  

torch.nn.utils.prune.LnStructured

支持fp32

torch.nn.utils.prune.LnStructured.apply

支持fp32

torch.nn.utils.prune.LnStructured.apply_mask

支持fp32

torch.nn.utils.prune.LnStructured.compute_mask

支持fp32

torch.nn.utils.prune.LnStructured.prune

支持fp32

torch.nn.utils.prune.LnStructured.remove

支持fp32

torch.nn.utils.prune.CustomFromMask

支持int64

torch.nn.utils.prune.CustomFromMask.apply

支持int64

torch.nn.utils.prune.CustomFromMask.apply_mask

  

torch.nn.utils.prune.CustomFromMask.prune

  

torch.nn.utils.prune.CustomFromMask.remove

  

torch.nn.utils.prune.identity

支持fp32

torch.nn.utils.prune.random_unstructured

  

torch.nn.utils.prune.l1_unstructured

  

torch.nn.utils.prune.random_structured

  

torch.nn.utils.prune.ln_structured

  

torch.nn.utils.prune.global_unstructured

  

torch.nn.utils.prune.custom_from_mask

支持int64

torch.nn.utils.prune.remove

  

torch.nn.utils.prune.is_pruned

  

torch.nn.utils.weight_norm

     

torch.nn.utils.remove_weight_norm

     

torch.nn.utils.spectral_norm

  

torch.nn.utils.remove_spectral_norm

     

torch.nn.utils.skip_init

  

torch.nn.utils.parametrizations.orthogonal

  

torch.nn.utils.parametrizations.spectral_norm

  

torch.nn.utils.parametrize.register_parametrization

  

torch.nn.utils.parametrize.remove_parametrizations

  

torch.nn.utils.parametrize.cached

  

torch.nn.utils.parametrize.is_parametrized

  

torch.nn.utils.parametrize.ParametrizationList

  

torch.nn.utils.parametrize.ParametrizationList.right_inverse

支持fp32

torch.nn.utils.stateless.functional_call

  

torch.nn.utils.rnn.PackedSequence

支持fp32,int64

torch.nn.utils.rnn.PackedSequence.count

支持fp32

torch.nn.utils.rnn.PackedSequence.index

支持fp32

torch.nn.utils.rnn.PackedSequence.is_pinned

  

torch.nn.utils.rnn.PackedSequence.to

支持fp32,int64

torch.nn.utils.rnn.pack_padded_sequence

     

torch.nn.utils.rnn.pad_packed_sequence

     

torch.nn.utils.rnn.pad_sequence

支持fp32,fp16

torch.nn.utils.rnn.pack_sequence

     

torch.nn.utils.rnn.unpack_sequence

  

torch.nn.utils.rnn.unpad_sequence

  

torch.nn.Flatten

支持bf16,fp32,fp16,uint8,int8,int16,int32,int64,bool,complex64,complex128

torch.nn.Unflatten

支持fp16,fp32,fp64,int8,int16,uint8,int32,int64,bool

torch.nn.modules.lazy.LazyModuleMixin

支持fp32

torch.nn.modules.lazy.LazyModuleMixin.has_uninitialized_params

支持fp32

torch.nn.modules.lazy.LazyModuleMixin.initialize_parameters