torch.nn

API名称

是否支持

限制与说明

torch.nn.parameter.Parameter

  

torch.nn.parameter.UninitializedParameter

  

torch.nn.parameter.UninitializedBuffer

  

torch.nn.Module

  

torch.nn.Module.add_module

  

torch.nn.Module.apply

  

torch.nn.Module.bfloat16

  

torch.nn.Module.buffers

  

torch.nn.Module.children

  

torch.nn.Module.cpu

  

torch.nn.Module.cuda

  

torch.nn.Module.double

  

torch.nn.Module.eval

  

torch.nn.Module.extra_repr

  

torch.nn.Module.float

  

torch.nn.Module.forward

  

torch.nn.Module.get_buffer

  

torch.nn.Module.get_extra_state

  

torch.nn.Module.get_parameter

  

torch.nn.Module.get_submodule

  

torch.nn.Module.half

  

torch.nn.Module.load_state_dict

  

torch.nn.Module.modules

  

torch.nn.Module.named_buffers

  

torch.nn.Module.named_children

  

torch.nn.Module.named_modules

  

torch.nn.Module.named_parameters

  

torch.nn.Module.parameters

  

torch.nn.Module.register_backward_hook

  

torch.nn.Module.register_buffer

  

torch.nn.Module.register_forward_hook

  

torch.nn.Module.register_forward_pre_hook

  

torch.nn.Module.register_full_backward_hook

  

torch.nn.Module.register_module

  

torch.nn.Module.register_parameter

  

torch.nn.Module.requires_grad_

  

torch.nn.Module.set_extra_state

  

torch.nn.Module.share_memory

  

torch.nn.Module.state_dict

  

torch.nn.Module.to(*args, **kwargs)

  

torch.nn.Module.to(device=None, dtype=None, non_blocking=False)

  

torch.nn.Module.to(dtype, non_blocking=False)

  

torch.nn.Module.to(tensor, non_blocking=False)

  

torch.nn.Module.to(memory_format=torch.channels_last)

  

torch.nn.Module.to_empty

  

torch.nn.Module.train

  

torch.nn.Module.type

  

torch.nn.Module.xpu

  

torch.nn.Module.zero_grad

  

torch.nn.Sequential

  

torch.nn.Sequential.append

  

torch.nn.ModuleList

  

torch.nn.ModuleList.append

  

torch.nn.ModuleList.extend

  

torch.nn.ModuleList.insert

  

torch.nn.ModuleDict

  

torch.nn.ModuleDict.clear

  

torch.nn.ModuleDict.items

  

torch.nn.ModuleDict.keys

  

torch.nn.ModuleDict.pop

  

torch.nn.ModuleDict.update

  

torch.nn.ModuleDict.values

  

torch.nn.ParameterList

  

torch.nn.ParameterList.append

  

torch.nn.ParameterList.extend

  

torch.nn.ParameterDict

  

torch.nn.ParameterDict.clear

  

torch.nn.ParameterDict.copy

  

torch.nn.ParameterDict.fromkeys

  

torch.nn.ParameterDict.get

  

torch.nn.ParameterDict.items

  

torch.nn.ParameterDict.keys

  

torch.nn.ParameterDict.pop

  

torch.nn.ParameterDict.popitem

  

torch.nn.ParameterDict.setdefault

  

torch.nn.ParameterDict.update

  

torch.nn.ParameterDict.values

  

torch.nn.modules.module.register_module_forward_pre_hook

  

torch.nn.modules.module.register_module_forward_hook

  

torch.nn.modules.module.register_module_backward_hook

  

torch.nn.modules.module.register_module_full_backward_hook

  

torch.nn.Conv1d

  

torch.nn.Conv2d

  

torch.nn.Conv3d

  

torch.nn.ConvTranspose1d

只支持float16。

torch.nn.ConvTranspose2d

  

torch.nn.ConvTranspose3d

  

torch.nn.LazyConv1d

  

torch.nn.LazyConv2d

  

torch.nn.LazyConv3d

  

torch.nn.LazyConvTranspose1d

  

torch.nn.LazyConvTranspose2d

  

torch.nn.LazyConvTranspose3d

  

torch.nn.Unfold

  

torch.nn.Fold

  

torch.nn.MaxPool1d

  

torch.nn.MaxPool2d

  

torch.nn.MaxPool3d

  

torch.nn.MaxUnpool1d

  

torch.nn.MaxUnpool2d

  

torch.nn.MaxUnpool3d

  

torch.nn.AvgPool1d

  

torch.nn.AvgPool2d

  

torch.nn.AvgPool3d

  

torch.nn.FractionalMaxPool2d

  

torch.nn.FractionalMaxPool3d

  

torch.nn.LPPool1d

  

torch.nn.LPPool2d

  

torch.nn.AdaptiveMaxPool1d

  

torch.nn.AdaptiveMaxPool2d

  

torch.nn.AdaptiveMaxPool3d

可以走CPU实现。

torch.nn.AdaptiveAvgPool1d

  

torch.nn.AdaptiveAvgPool2d

  

torch.nn.AdaptiveAvgPool3d

  

torch.nn.ReflectionPad1d

  

torch.nn.ReflectionPad2d

  

torch.nn.ReflectionPad3d

  

torch.nn.ReplicationPad1d

  

torch.nn.ReplicationPad2d

  

torch.nn.ReplicationPad3d

可以走CPU实现。

torch.nn.ZeroPad2d

  

torch.nn.ConstantPad1d

在输入x为六维以上时可能会出现性能下降问题。

torch.nn.ConstantPad2d

在输入x为六维以上时可能会出现性能下降问题。

torch.nn.ConstantPad3d

在输入x为六维以上时可能会出现性能下降问题。

torch.nn.ELU

  

torch.nn.Hardshrink

  

torch.nn.Hardsigmoid

  

torch.nn.Hardtanh

  

torch.nn.Hardswish

  

torch.nn.LeakyReLU

  

torch.nn.LogSigmoid

  

torch.nn.MultiheadAttention

  

torch.nn.MultiheadAttention.forward

  

torch.nn.PReLU

  

torch.nn.ReLU

  

torch.nn.ReLU6

  

torch.nn.RReLU

  

torch.nn.SELU

  

torch.nn.CELU

  

torch.nn.GELU

  

torch.nn.Sigmoid

  

torch.nn.SiLU

  

torch.nn.Mish

  

torch.nn.Softplus

  

torch.nn.Softshrink

  

torch.nn.Softsign

  

torch.nn.Tanh

  

torch.nn.Tanhshrink

  

torch.nn.Threshold

  

torch.nn.GLU

  

torch.nn.Softmin

  

torch.nn.Softmax

  

torch.nn.Softmax2d

  

torch.nn.LogSoftmax

  

torch.nn.AdaptiveLogSoftmaxWithLoss

  

torch.nn.AdaptiveLogSoftmaxWithLoss.log_prob

  

torch.nn.AdaptiveLogSoftmaxWithLoss.predict

  

torch.nn.BatchNorm1d

  

torch.nn.BatchNorm2d

  

torch.nn.BatchNorm3d

  

torch.nn.LazyBatchNorm1d

  

torch.nn.LazyBatchNorm2d

  

torch.nn.LazyBatchNorm3d

  

torch.nn.GroupNorm

  

torch.nn.SyncBatchNorm

  

torch.nn.SyncBatchNorm.convert_sync_batchnorm

  

torch.nn.InstanceNorm1d

  

torch.nn.InstanceNorm2d

  

torch.nn.InstanceNorm3d

不支持4D输入。

torch.nn.LazyInstanceNorm1d

  

torch.nn.LazyInstanceNorm2d

  

torch.nn.LazyInstanceNorm3d

  

torch.nn.LayerNorm

  

torch.nn.LocalResponseNorm

  

torch.nn.RNNBase

  

torch.nn.RNNBase.flatten_parameters

  

torch.nn.RNN

  

torch.nn.LSTM

  

torch.nn.GRU

  

torch.nn.RNNCell

mode不支持LSTM GRU。

torch.nn.LSTMCell

  

torch.nn.GRUCell

  

torch.nn.Transformer

layer_norm_eps5, batch_first, norm_first, device, dtype参数未实现。

torch.nn.Transformer.forward

layer_norm_eps5, batch_first, norm_first, device, dtype参数未实现。

torch.nn.Transformer.generate_square_subsequent_mask

layer_norm_eps5, batch_first, norm_first, device, dtype参数未实现。

torch.nn.TransformerEncoder

layer_norm_eps5, batch_first, norm_first, device, dtype参数未实现。

torch.nn.TransformerEncoder.forward

layer_norm_eps5, batch_first, norm_first, device, dtype参数未实现。

torch.nn.TransformerDecoder

layer_norm_eps5, batch_first, norm_first, device, dtype参数未实现。

torch.nn.TransformerDecoder.forward

layer_norm_eps5, batch_first, norm_first, device, dtype参数未实现。

torch.nn.TransformerEncoderLayer

layer_norm_eps5, batch_first, norm_first, device, dtype参数未实现。

torch.nn.TransformerEncoderLayer.forward

layer_norm_eps5, batch_first, norm_first, device, dtype参数未实现。

torch.nn.TransformerDecoderLayer

layer_norm_eps5, batch_first, norm_first, device, dtype参数未实现。

torch.nn.TransformerDecoderLayer.forward

layer_norm_eps5, batch_first, norm_first, device, dtype参数未实现。

torch.nn.Identity

  

torch.nn.Linear

  

torch.nn.Bilinear

  

torch.nn.LazyLinear

  

torch.nn.Dropout

  

torch.nn.Dropout2d

  

torch.nn.Dropout3d

  

torch.nn.AlphaDropout

  

torch.nn.FeatureAlphaDropout

  

torch.nn.Embedding

  

torch.nn.Embedding.from_pretrained

  

torch.nn.EmbeddingBag

  

torch.nn.EmbeddingBag.forward

  

torch.nn.EmbeddingBag.from_pretrained

  

torch.nn.CosineSimilarity

  

torch.nn.PairwiseDistance

  

torch.nn.L1Loss

  

torch.nn.MSELoss

  

torch.nn.CrossEntropyLoss

  

torch.nn.CTCLoss

不支持float64。

torch.nn.NLLLoss

  

torch.nn.PoissonNLLLoss

  

torch.nn.GaussianNLLLoss

  

torch.nn.KLDivLoss

  

torch.nn.BCELoss

  

torch.nn.BCEWithLogitsLoss

  

torch.nn.MarginRankingLoss

  

torch.nn.HingeEmbeddingLoss

  

torch.nn.MultiLabelMarginLoss

  

torch.nn.HuberLoss

  

torch.nn.SmoothL1Loss

  

torch.nn.SoftMarginLoss

  

torch.nn.MultiLabelSoftMarginLoss

  

torch.nn.CosineEmbeddingLoss

不支持int64。

torch.nn.MultiMarginLoss

  

torch.nn.TripletMarginLoss

  

torch.nn.TripletMarginWithDistanceLoss

  

torch.nn.PixelShuffle

不支持8D输入。

torch.nn.PixelUnshuffle

  

torch.nn.Upsample

  

torch.nn.UpsamplingNearest2d

  

torch.nn.UpsamplingBilinear2d

  

torch.nn.ChannelShuffle

  

torch.nn.DataParallel

  

torch.nn.parallel.DistributedDataParallel

  

torch.nn.parallel.DistributedDataParallel.join

  

torch.nn.parallel.DistributedDataParallel.join_hook

  

torch.nn.parallel.DistributedDataParallel.no_sync

  

torch.nn.parallel.DistributedDataParallel.register_comm_hook

  

torch.nn.utils.clip_grad_norm_

  

torch.nn.utils.clip_grad_value_

  

torch.nn.utils.parameters_to_vector

  

torch.nn.utils.vector_to_parameters

  

torch.nn.utils.prune.BasePruningMethod

  

torch.nn.utils.prune.BasePruningMethod.apply

  

torch.nn.utils.prune.BasePruningMethod.apply_mask

  

torch.nn.utils.prune.BasePruningMethod.abstract compute_mask

  

torch.nn.utils.prune.BasePruningMethod.prune

  

torch.nn.utils.prune.BasePruningMethod.remove

  

torch.nn.utils.prune.PruningContainer

  

torch.nn.utils.prune.PruningContainer.add_pruning_method

  

torch.nn.utils.prune.PruningContainer.apply

  

torch.nn.utils.prune.PruningContainer.apply_mask

  

torch.nn.utils.prune.PruningContainer.compute_mask

  

torch.nn.utils.prune.PruningContainer.prune

  

torch.nn.utils.prune.PruningContainer.remove

  

torch.nn.utils.prune.Identity

  

torch.nn.utils.prune.Identity.apply

  

torch.nn.utils.prune.Identity.apply_mask

  

torch.nn.utils.prune.Identity.prune

  

torch.nn.utils.prune.Identity.remove

  

torch.nn.utils.prune.RandomUnstructured

  

torch.nn.utils.prune.RandomUnstructured.apply

  

torch.nn.utils.prune.RandomUnstructured.apply_mask

  

torch.nn.utils.prune.RandomUnstructured.prune

  

torch.nn.utils.prune.RandomUnstructured.remove

  

torch.nn.utils.prune.L1Unstructured

  

torch.nn.utils.prune.L1Unstructured.apply

  

torch.nn.utils.prune.L1Unstructured.apply_mask

  

torch.nn.utils.prune.L1Unstructured.prune

  

torch.nn.utils.prune.L1Unstructured.remove

  

torch.nn.utils.prune.RandomStructured

  

torch.nn.utils.prune.RandomStructured.apply

  

torch.nn.utils.prune.RandomStructured.apply_mask

  

torch.nn.utils.prune.RandomStructured.compute_mask

  

torch.nn.utils.prune.RandomStructured.prune

  

torch.nn.utils.prune.RandomStructured.remove

  

torch.nn.utils.prune.LnStructured

  

torch.nn.utils.prune.LnStructured.apply

  

torch.nn.utils.prune.LnStructured.apply_mask

  

torch.nn.utils.prune.LnStructured.compute_mask

  

torch.nn.utils.prune.LnStructured.prune

  

torch.nn.utils.prune.LnStructured.remove

  

torch.nn.utils.prune.CustomFromMask

  

torch.nn.utils.prune.CustomFromMask.apply

  

torch.nn.utils.prune.CustomFromMask.apply_mask

  

torch.nn.utils.prune.CustomFromMask.prune

  

torch.nn.utils.prune.CustomFromMask.remove

  

torch.nn.utils.prune.Identity

  

torch.nn.utils.prune.Identity.apply

  

torch.nn.utils.prune.Identity.apply_mask

  

torch.nn.utils.prune.Identity.prune

  

torch.nn.utils.prune.Identity.remove

  

torch.nn.utils.prune.random_unstructured

  

torch.nn.utils.prune.l1_unstructured

  

torch.nn.utils.prune.random_structured

  

torch.nn.utils.prune.ln_structured

  

torch.nn.utils.prune.global_unstructured

  

torch.nn.utils.prune.custom_from_mask

  

torch.nn.utils.prune.remove

  

torch.nn.utils.prune.is_pruned

  

torch.nn.utils.weight_norm

  

torch.nn.utils.remove_weight_norm

  

torch.nn.utils.spectral_norm

  

torch.nn.utils.remove_spectral_norm

  

torch.nn.utils.skip_init

  

torch.nn.utils.parametrizations.orthogonal

  

torch.nn.utils.parametrizations.spectral_norm

  

torch.nn.utils.parametrize.register_parametrization

  

torch.nn.utils.parametrize.remove_parametrizations

  

torch.nn.utils.parametrize.cached

  

torch.nn.utils.parametrize.is_parametrized

  

torch.nn.utils.parametrize.ParametrizationList

  

torch.nn.utils.parametrize.ParametrizationList.right_inverse

  

torch.nn.utils.rnn.PackedSequence

  

torch.nn.utils.rnn.PackedSequence.count

  

torch.nn.utils.rnn.PackedSequence.index

  

torch.nn.utils.rnn.PackedSequence.is_pinned

  

torch.nn.utils.rnn.PackedSequence.to

  

torch.nn.utils.rnn.pack_padded_sequence

该接口在部分情况下无法通过 .npu() 转移至 NPU 设备上。此时有以下两种处理方案:

  • 使用 .to("npu") 代替 .npu() 进行转移。
  • 将输入 tensor 转移至 NPU 设备上,执行时将自动在设备上运行该接口。

torch.nn.utils.rnn.pad_packed_sequence

该接口在部分情况下无法通过 .npu() 转移至 NPU 设备上。此时有以下两种处理方案:

  • 使用 .to("npu") 代替 .npu() 进行转移。
  • 将输入 tensor 转移至 NPU 设备上,执行时将自动在设备上运行该接口。

torch.nn.utils.rnn.pad_sequence

该接口在部分情况下无法通过 .npu() 转移至 NPU 设备上。此时有以下两种处理方案:

  • 使用 .to("npu") 代替 .npu() 进行转移。
  • 将输入 tensor 转移至 NPU 设备上,执行时将自动在设备上运行该接口。

torch.nn.utils.rnn.pack_sequence

该接口在部分情况下无法通过 .npu() 转移至 NPU 设备上。此时有以下两种处理方案:

  • 使用 .to("npu") 代替 .npu() 进行转移。
  • 将输入 tensor 转移至 NPU 设备上,执行时将自动在设备上运行该接口。

torch.nn.Flatten

  

torch.nn.Unflatten

  

torch.nn.modules.lazy.LazyModuleMixin

  

torch.nn.modules.lazy.LazyModuleMixin.has_uninitialized_params

  

torch.nn.modules.lazy.LazyModuleMixin.initialize_parameters