torch.nn

若API未标明支持情况则代表该API暂无验证结果,待验证后更新表格。

API名称

是否支持

限制与说明

torch.nn.parameter.Parameter

支持fp32

torch.nn.parameter.UninitializedParameter

  

torch.nn.parameter.UninitializedBuffer

  

torch.nn.Module

支持fp32

torch.nn.Module.add_module

支持fp32

torch.nn.Module.apply

支持fp32

torch.nn.Module.bfloat16

支持bf16,fp32,int64

torch.nn.Module.buffers

  

torch.nn.Module.children

支持fp32

torch.nn.Module.cpu

支持fp32

torch.nn.Module.cuda

支持fp32

torch.nn.Module.double

  

torch.nn.Module.eval

支持fp32,int64

torch.nn.Module.extra_repr

支持fp32

torch.nn.Module.float

支持fp16,fp32,int64

torch.nn.Module.forward

支持fp32

torch.nn.Module.get_buffer

  

torch.nn.Module.get_extra_state

  

torch.nn.Module.get_parameter

支持fp32

torch.nn.Module.get_submodule

支持fp32

torch.nn.Module.half

支持fp16,fp32,int64

torch.nn.Module.load_state_dict

支持fp32

torch.nn.Module.modules

支持fp32

torch.nn.Module.named_buffers

  

torch.nn.Module.named_children

支持fp32

torch.nn.Module.named_modules

支持fp32

torch.nn.Module.named_parameters

  

torch.nn.Module.parameters

  

torch.nn.Module.register_backward_hook

支持fp32

torch.nn.Module.register_buffer

支持fp32

torch.nn.Module.register_forward_hook

支持fp32

torch.nn.Module.register_forward_pre_hook

支持fp32

torch.nn.Module.register_full_backward_hook

支持fp32

torch.nn.Module.register_module

支持fp32

torch.nn.Module.register_parameter

  

torch.nn.Module.requires_grad_

  

torch.nn.Module.set_extra_state

  

torch.nn.Module.share_memory

  

torch.nn.Module.state_dict

支持fp32

torch.nn.Module.to(*args, **kwargs)

支持fp32,int64

torch.nn.Module.to(device=None, dtype=None, non_blocking=False)

支持fp32,int64

torch.nn.Module.to(dtype, non_blocking=False)

支持fp32,int64

torch.nn.Module.to(tensor, non_blocking=False)

支持fp32,int64

torch.nn.Module.to(memory_format=torch.channels_last)

  

torch.nn.Module.to_empty

支持fp32

torch.nn.Module.train

支持fp32

torch.nn.Module.type

支持fp16,fp32,int64

torch.nn.Module.xpu

  

torch.nn.Module.zero_grad

支持fp32

torch.nn.Sequential

支持fp32

torch.nn.Sequential.append

支持fp32

torch.nn.ModuleList

支持fp32

torch.nn.ModuleList.append

支持fp32

torch.nn.ModuleList.extend

支持fp32

torch.nn.ModuleList.insert

支持fp32

torch.nn.ModuleDict

支持fp32

torch.nn.ModuleDict.clear

支持fp32

torch.nn.ModuleDict.items

支持fp32

torch.nn.ModuleDict.keys

支持fp32

torch.nn.ModuleDict.pop

支持fp32

torch.nn.ModuleDict.update

支持fp32

torch.nn.ModuleDict.values

支持fp32

torch.nn.ParameterList

支持fp32

torch.nn.ParameterList.append

支持fp32

torch.nn.ParameterList.extend

支持fp32

torch.nn.ParameterDict

支持fp32

torch.nn.ParameterDict.clear

支持fp32

torch.nn.ParameterDict.copy

支持fp32

torch.nn.ParameterDict.fromkeys

支持fp32

torch.nn.ParameterDict.get

支持fp32

torch.nn.ParameterDict.items

支持fp32

torch.nn.ParameterDict.keys

支持fp32

torch.nn.ParameterDict.pop

支持fp32

torch.nn.ParameterDict.popitem

支持fp32

torch.nn.ParameterDict.setdefault

支持fp32

torch.nn.ParameterDict.update

支持fp32

torch.nn.ParameterDict.values

支持fp32

torch.nn.modules.module.register_module_forward_pre_hook

支持fp32

torch.nn.modules.module.register_module_forward_hook

支持fp32

torch.nn.modules.module.register_module_backward_hook

支持fp32

torch.nn.modules.module.register_module_full_backward_hook

支持fp32

torch.nn.Conv1d

支持fp32,fp16

torch.nn.Conv2d

支持bf16,fp16,fp32

torch.nn.Conv3d

     

torch.nn.ConvTranspose1d

支持fp32

torch.nn.ConvTranspose2d

支持fp16,fp32

torch.nn.ConvTranspose3d

     

torch.nn.LazyConv1d

支持fp32,fp16

torch.nn.LazyConv2d

支持fp32,fp16

torch.nn.LazyConv3d

  

torch.nn.LazyConvTranspose1d

支持fp16

torch.nn.LazyConvTranspose2d

支持fp16,fp32

torch.nn.LazyConvTranspose3d

     

torch.nn.Unfold

支持bf16,fp16,fp32

torch.nn.Fold

     

torch.nn.MaxPool1d

     

torch.nn.MaxPool2d

     

torch.nn.MaxPool3d

     

torch.nn.MaxUnpool1d

支持fp16,fp32

torch.nn.MaxUnpool2d

支持fp16,fp32

torch.nn.MaxUnpool3d

     

torch.nn.AvgPool1d

支持fp16,fp32

torch.nn.AvgPool2d

支持fp32,fp16

torch.nn.AvgPool3d

     

torch.nn.FractionalMaxPool2d

  

torch.nn.FractionalMaxPool3d

  

torch.nn.LPPool1d

支持fp32,fp16,uint8,int8,int16,int32,int64,bool

torch.nn.LPPool2d

支持fp32,fp16,int16,int32,int64,bool

torch.nn.AdaptiveMaxPool1d

     

torch.nn.AdaptiveMaxPool2d

     

torch.nn.AdaptiveMaxPool3d

支持fp32,fp64

torch.nn.AdaptiveAvgPool1d

支持fp32,fp16

torch.nn.AdaptiveAvgPool2d

支持fp32,fp16

torch.nn.AdaptiveAvgPool3d

     

torch.nn.ReflectionPad1d

支持fp32,fp16

torch.nn.ReflectionPad2d

支持fp32,fp16

torch.nn.ReflectionPad3d

  

torch.nn.ReplicationPad1d

支持fp32,fp16,complex64,complex128

torch.nn.ReplicationPad2d

支持fp32,fp16,complex64,complex128

torch.nn.ReplicationPad3d

可以走CPU实现。

torch.nn.ZeroPad2d

支持bf16,fp32,fp16,uint8,int8,int16,int32,int64,bool,complex64,complex128

torch.nn.ConstantPad1d

支持int8,bool

在输入x为六维以上时可能会出现性能下降问题

torch.nn.ConstantPad2d

支持fp32,fp16,uint8,int8,int16,int32,int64,complex64,complex128

在输入x为六维以上时可能会出现性能下降问题

torch.nn.ConstantPad3d

支持fp32,fp16,uint8,int8,int16,int32,int64,bool,complex64,complex128

在输入x为六维以上时可能会出现性能下降问题

torch.nn.ELU

支持fp32,fp16

torch.nn.Hardshrink

支持fp32,fp16

torch.nn.Hardsigmoid

支持fp16,fp32,int32

torch.nn.Hardtanh

支持fp32,fp16,int8,int16,int32,int64

torch.nn.Hardswish

支持fp16,fp32

torch.nn.LeakyReLU

支持bf16,fp16,fp32,fp64

torch.nn.LogSigmoid

支持fp32,fp16

torch.nn.MultiheadAttention

     

torch.nn.MultiheadAttention.forward

     

torch.nn.PReLU

支持fp32

torch.nn.ReLU

支持bf16,fp16,fp32,int8,uint8,int32,int64

torch.nn.ReLU6

支持fp32,fp16,uint8,int8,int16,int32,int64

torch.nn.RReLU

     

torch.nn.SELU

支持fp16,fp32,int8,int32,fp64,uint8,int16,int64,bool

torch.nn.CELU

支持fp32,fp16

torch.nn.GELU

支持bf16,fp32,fp16

torch.nn.Sigmoid

支持bf16,fp32,fp16,uint8,int8,int16,int32,int64,bool,complex64,complex128

torch.nn.SiLU

支持bf16,fp16,fp32

torch.nn.Mish

支持fp32,fp16

torch.nn.Softplus

支持fp16,fp32

torch.nn.Softshrink

支持bf16,fp16,fp32

torch.nn.Softsign

支持bf16,fp32,fp16,uint8,int8,int16,int32,int64

torch.nn.Tanh

支持bf16,fp16,fp32,bool

torch.nn.Tanhshrink

支持fp32,fp16,uint8,int8,int16,int32,int64

torch.nn.Threshold

支持fp16,fp32,int8,uint8,int16,int32,int64

torch.nn.GLU

支持fp32,fp16

torch.nn.Softmin

支持bf16,fp32,fp16

torch.nn.Softmax

支持bf16,fp32,fp16

torch.nn.Softmax2d

支持bf16,fp32,fp16

torch.nn.LogSoftmax

支持bf16,fp32,fp16

torch.nn.AdaptiveLogSoftmaxWithLoss

  

torch.nn.AdaptiveLogSoftmaxWithLoss.log_prob

  

torch.nn.AdaptiveLogSoftmaxWithLoss.predict

  

torch.nn.BatchNorm1d

支持fp32,fp16

torch.nn.BatchNorm2d

支持fp32,fp16

torch.nn.BatchNorm3d

支持fp32,fp16

torch.nn.LazyBatchNorm1d

     

torch.nn.LazyBatchNorm2d

     

torch.nn.LazyBatchNorm3d

     

torch.nn.GroupNorm

  • 支持fp32
  • 不支持jit_compile=True的场景
  • 该API仅支持2维及以上的输入input。 该API反向不支持输入input不为4维,或输入num_groups非32整除,或C轴维度非(10 * num_groups)整除的场景

torch.nn.SyncBatchNorm

支持fp16,fp32

torch.nn.SyncBatchNorm.convert_sync_batchnorm

  

torch.nn.InstanceNorm1d

     

torch.nn.InstanceNorm2d

     

torch.nn.InstanceNorm3d

     

torch.nn.LazyInstanceNorm1d

     

torch.nn.LazyInstanceNorm2d

     

torch.nn.LazyInstanceNorm3d

     

torch.nn.LayerNorm

支持bf16,fp16,fp32

torch.nn.LocalResponseNorm

     

torch.nn.RNNBase

     

torch.nn.RNNBase.flatten_parameters

     

torch.nn.RNN

     

torch.nn.LSTM

支持fp32

不支持proj_size参数

不支持dropout参数

入参input不支持2维

torch.nn.GRU

  

torch.nn.RNNCell

     

torch.nn.LSTMCell

接口暂不支持jit_compile=False,需要在该模式下使用时请将"DynamicGRUV2"添加至"NPU_FUZZY_COMPILE_BLACKLIST"选项内,具体操作可参考添加二进制黑名单示例

torch.nn.GRUCell

支持fp32,fp16

torch.nn.Transformer

支持fp32,fp16

torch.nn.Transformer.forward

     

torch.nn.Transformer.generate_square_subsequent_mask

     

torch.nn.TransformerEncoder

     

torch.nn.TransformerEncoder.forward

     

torch.nn.TransformerDecoder

     

torch.nn.TransformerDecoder.forward

     

torch.nn.TransformerEncoderLayer

     

torch.nn.TransformerEncoderLayer.forward

     

torch.nn.TransformerDecoderLayer

     

torch.nn.TransformerDecoderLayer.forward

     

torch.nn.Identity

支持fp32

torch.nn.Linear

支持fp32,fp16

torch.nn.Bilinear

支持fp32,fp16

torch.nn.LazyLinear

支持fp16,fp32

torch.nn.Dropout

支持bf16,fp16,fp32

torch.nn.Dropout2d

支持fp16,fp32,int64,bool

torch.nn.Dropout3d

     

torch.nn.AlphaDropout

支持fp32,fp16

torch.nn.FeatureAlphaDropout

支持fp32,fp16

torch.nn.Embedding

支持int32,int64

属性max_norm不支持nan,仅支持非负值

torch.nn.Embedding.from_pretrained

支持fp64

torch.nn.EmbeddingBag

支持int32,int64

torch.nn.EmbeddingBag.forward

支持int64

torch.nn.EmbeddingBag.from_pretrained

支持int64

torch.nn.CosineSimilarity

     

torch.nn.PairwiseDistance

     

torch.nn.L1Loss

支持fp32,fp16,int64,

torch.nn.MSELoss

支持fp32,fp16

torch.nn.CrossEntropyLoss

支持fp32,fp16

torch.nn.CTCLoss

支持fp32

不支持log_probs 2D输入

torch.nn.NLLLoss

支持fp32,fp16

torch.nn.PoissonNLLLoss

支持bf16,fp32,fp16,int64

torch.nn.GaussianNLLLoss

支持bf16,fp32,fp16,int16,int32,int64

torch.nn.KLDivLoss

支持bf16,fp16,fp32

当前log_target参数仅支持False

torch.nn.BCELoss

支持fp32,fp16

torch.nn.BCEWithLogitsLoss

支持fp32,fp16,int16,int32,int64,bool

torch.nn.MarginRankingLoss

支持bf16,fp32,fp16,int8,int16,int32,int64

torch.nn.HingeEmbeddingLoss

支持bf16,fp32,fp16,uint8,int8,int16,int32,int64

torch.nn.MultiLabelMarginLoss

     

torch.nn.HuberLoss

  

torch.nn.SmoothL1Loss

支持fp32,fp16

torch.nn.SoftMarginLoss

     

torch.nn.MultiLabelSoftMarginLoss

支持fp16,fp32

torch.nn.CosineEmbeddingLoss

     

torch.nn.MultiMarginLoss

  

torch.nn.TripletMarginLoss

支持fp16,fp32,int8,uint8,int16,int32,int64

torch.nn.TripletMarginWithDistanceLoss

支持bf16,fp32,fp16

torch.nn.PixelShuffle

支持bf16,fp16,fp32,fp64,uint8,int8,int16,int32,int64,bool

torch.nn.PixelUnshuffle

支持fp16,fp32,fp64,int8,uint8,int16,int32,int64,bool

torch.nn.Upsample

支持fp16,fp32,fp64

只支持mode = nearest

torch.nn.UpsamplingNearest2d

支持fp32,fp16,uint8

torch.nn.UpsamplingBilinear2d

     

torch.nn.ChannelShuffle

支持bf16,fp32,fp16,uint8,int8,int16,int32,int64,bool,complex64,complex128

torch.nn.DataParallel

  

torch.nn.parallel.DistributedDataParallel

  

torch.nn.parallel.DistributedDataParallel.join

  

torch.nn.parallel.DistributedDataParallel.join_hook

  

torch.nn.parallel.DistributedDataParallel.no_sync

  

torch.nn.parallel.DistributedDataParallel.register_comm_hook

  

torch.nn.utils.clip_grad_norm_

     

torch.nn.utils.clip_grad_value_

支持bf16,fp32,fp16

torch.nn.utils.parameters_to_vector

     

torch.nn.utils.vector_to_parameters

支持bf16,fp16,fp32,fp64,complex64

torch.nn.utils.prune.BasePruningMethod

  

torch.nn.utils.prune.BasePruningMethod.apply

  

torch.nn.utils.prune.BasePruningMethod.apply_mask

支持fp32

torch.nn.utils.prune.BasePruningMethod.abstract compute_mask

  

torch.nn.utils.prune.BasePruningMethod.prune

支持fp32

torch.nn.utils.prune.BasePruningMethod.remove

支持fp32

torch.nn.utils.prune.PruningContainer

  

torch.nn.utils.prune.PruningContainer.add_pruning_method

  

torch.nn.utils.prune.PruningContainer.apply

  

torch.nn.utils.prune.PruningContainer.apply_mask

  

torch.nn.utils.prune.PruningContainer.compute_mask

支持fp32

torch.nn.utils.prune.PruningContainer.prune

支持fp32

torch.nn.utils.prune.PruningContainer.remove

支持fp32

torch.nn.utils.prune.Identity

支持fp32

torch.nn.utils.prune.Identity.apply

支持fp32

torch.nn.utils.prune.Identity.apply_mask

支持fp32

torch.nn.utils.prune.Identity.prune

支持fp32

torch.nn.utils.prune.Identity.remove

支持fp32

torch.nn.utils.prune.RandomUnstructured

支持fp32

torch.nn.utils.prune.RandomUnstructured.apply

支持fp32

torch.nn.utils.prune.RandomUnstructured.apply_mask

支持fp32

torch.nn.utils.prune.RandomUnstructured.prune

支持fp32

torch.nn.utils.prune.RandomUnstructured.remove

  

torch.nn.utils.prune.L1Unstructured

支持fp32

torch.nn.utils.prune.L1Unstructured.apply

支持fp32

torch.nn.utils.prune.L1Unstructured.apply_mask

支持fp32

torch.nn.utils.prune.L1Unstructured.prune

支持fp32

torch.nn.utils.prune.L1Unstructured.remove

支持fp32

torch.nn.utils.prune.RandomStructured

支持fp32

torch.nn.utils.prune.RandomStructured.apply

支持fp32

torch.nn.utils.prune.RandomStructured.apply_mask

支持fp32

torch.nn.utils.prune.RandomStructured.compute_mask

支持fp32

torch.nn.utils.prune.RandomStructured.prune

  

torch.nn.utils.prune.RandomStructured.remove

  

torch.nn.utils.prune.LnStructured

支持fp32

torch.nn.utils.prune.LnStructured.apply

支持fp32

torch.nn.utils.prune.LnStructured.apply_mask

支持fp32

torch.nn.utils.prune.LnStructured.compute_mask

支持fp32

torch.nn.utils.prune.LnStructured.prune

支持fp32

torch.nn.utils.prune.LnStructured.remove

支持fp32

torch.nn.utils.prune.CustomFromMask

支持int64

torch.nn.utils.prune.CustomFromMask.apply

支持int64

torch.nn.utils.prune.CustomFromMask.apply_mask

  

torch.nn.utils.prune.CustomFromMask.prune

  

torch.nn.utils.prune.CustomFromMask.remove

  

torch.nn.utils.prune.Identity

支持fp32

torch.nn.utils.prune.Identity.apply

支持fp32

torch.nn.utils.prune.Identity.apply_mask

支持fp32

torch.nn.utils.prune.Identity.prune

支持fp32

torch.nn.utils.prune.Identity.remove

支持fp32

torch.nn.utils.prune.random_unstructured

  

torch.nn.utils.prune.l1_unstructured

  

torch.nn.utils.prune.random_structured

  

torch.nn.utils.prune.ln_structured

  

torch.nn.utils.prune.global_unstructured

  

torch.nn.utils.prune.custom_from_mask

支持int64

torch.nn.utils.prune.remove

  

torch.nn.utils.prune.is_pruned

  

torch.nn.utils.weight_norm

  

torch.nn.utils.remove_weight_norm

     

torch.nn.utils.spectral_norm

  

torch.nn.utils.remove_spectral_norm

     

torch.nn.utils.skip_init

  

torch.nn.utils.parametrizations.orthogonal

  

torch.nn.utils.parametrizations.spectral_norm

  

torch.nn.utils.parametrize.register_parametrization

  

torch.nn.utils.parametrize.remove_parametrizations

  

torch.nn.utils.parametrize.cached

  

torch.nn.utils.parametrize.is_parametrized

  

torch.nn.utils.parametrize.ParametrizationList

  

torch.nn.utils.parametrize.ParametrizationList.right_inverse

支持fp32

torch.nn.utils.rnn.PackedSequence

支持fp32,int64

torch.nn.utils.rnn.PackedSequence.count

支持fp32

torch.nn.utils.rnn.PackedSequence.index

支持fp32

torch.nn.utils.rnn.PackedSequence.is_pinned

  

torch.nn.utils.rnn.PackedSequence.to

支持fp32,int64

torch.nn.utils.rnn.pack_padded_sequence

     

torch.nn.utils.rnn.pad_packed_sequence

     

torch.nn.utils.rnn.pad_sequence

支持fp32,fp16

torch.nn.utils.rnn.pack_sequence

     

torch.nn.Flatten

支持bf16,fp32,fp16,uint8,int8,int16,int32,int64,bool,complex64,complex128

torch.nn.Unflatten

支持fp16,fp32,fp64,int8,int16,uint8,int32,int64,bool

torch.nn.modules.lazy.LazyModuleMixin

支持fp32

torch.nn.modules.lazy.LazyModuleMixin.has_uninitialized_params

支持fp32

torch.nn.modules.lazy.LazyModuleMixin.initialize_parameters

支持fp32