torch.nn.functional

API名称

是否支持

限制与说明

torch.nn.functional.conv1d

  

torch.nn.functional.conv2d

  

torch.nn.functional.conv3d

  

torch.nn.functional.conv_transpose1d

只支持float16。

torch.nn.functional.conv_transpose2d

  

torch.nn.functional.conv_transpose3d

  

torch.nn.functional.unfold

  

torch.nn.functional.fold

  

torch.nn.functional.avg_pool1d

  

torch.nn.functional.avg_pool2d

  

torch.nn.functional.avg_pool3d

  

torch.nn.functional.max_pool1d

  

torch.nn.functional.max_pool2d

  

torch.nn.functional.max_pool3d

  

torch.nn.functional.max_unpool1d

  

torch.nn.functional.max_unpool2d

  

torch.nn.functional.max_unpool3d

  

torch.nn.functional.lp_pool1d

  

torch.nn.functional.lp_pool2d

  

torch.nn.functional.adaptive_max_pool1d

  

torch.nn.functional.adaptive_max_pool2d

输入shape要能够整除输出shape。

torch.nn.functional.adaptive_max_pool3d

  

torch.nn.functional.adaptive_avg_pool1d

  

torch.nn.functional.adaptive_avg_pool2d

  

torch.nn.functional.adaptive_avg_pool3d

仅支持D=1,H=1,W=1场景。

torch.nn.functional.fractional_max_pool2d

  

torch.nn.functional.fractional_max_pool3d

  

torch.nn.functional.threshold

  

torch.nn.functional.threshold_

  

torch.nn.functional.relu

  

torch.nn.functional.relu_

  

torch.nn.functional.hardtanh

  

torch.nn.functional.hardtanh_

  

torch.nn.functional.hardswish

  

torch.nn.functional.relu6

  

torch.nn.functional.elu

  

torch.nn.functional.elu_

  

torch.nn.functional.selu

  

torch.nn.functional.celu

  

torch.nn.functional.leaky_relu

  

torch.nn.functional.leaky_relu_

  

torch.nn.functional.prelu

  

torch.nn.functional.rrelu

  

torch.nn.functional.rrelu_

  

torch.nn.functional.glu

  

torch.nn.functional.gelu

  

torch.nn.functional.logsigmoid

  

torch.nn.functional.hardshrink

  

torch.nn.functional.tanhshrink

  

torch.nn.functional.softsign

  

torch.nn.functional.softplus

  

torch.nn.functional.softmin

  

torch.nn.functional.softmax

  

torch.nn.functional.softshrink

  

torch.nn.functional.gumbel_softmax

可以走CPU实现。

torch.nn.functional.log_softmax

  

torch.nn.functional.tanh

  

torch.nn.functional.sigmoid

  

torch.nn.functional.hardsigmoid

  

torch.nn.functional.silu

不支持double数据类型。

torch.nn.functional.mish

  

torch.nn.functional.batch_norm

  

torch.nn.functional.group_norm

  

torch.nn.functional.instance_norm

  

torch.nn.functional.layer_norm

  

torch.nn.functional.local_response_norm

  

torch.nn.functional.normalize

  

torch.nn.functional.linear

  

torch.nn.functional.bilinear

Atlas A2 训练系列产品:不支持int32输入类型。

torch.nn.functional.dropout

  

torch.nn.functional.alpha_dropout

  

torch.nn.functional.feature_alpha_dropout

  

torch.nn.functional.dropout2d

  

torch.nn.functional.dropout3d

  

torch.nn.functional.embedding

  

torch.nn.functional.embedding_bag

  

torch.nn.functional.one_hot

  

torch.nn.functional.pairwise_distance

  

torch.nn.functional.cosine_similarity

  

torch.nn.functional.pdist

  

torch.nn.functional.binary_cross_entropy

  

torch.nn.functional.binary_cross_entropy_with_logits

  

torch.nn.functional.poisson_nll_loss

不支持int64。

torch.nn.functional.cosine_embedding_loss

不支持int64。

torch.nn.functional.cross_entropy

  

torch.nn.functional.ctc_loss

不支持float64。

torch.nn.functional.gaussian_nll_loss

  

torch.nn.functional.hinge_embedding_loss

  

torch.nn.functional.kl_div

  

torch.nn.functional.l1_loss

  

torch.nn.functional.mse_loss

  

torch.nn.functional.margin_ranking_loss

  

torch.nn.functional.multilabel_margin_loss

  

torch.nn.functional.multilabel_soft_margin_loss

  

torch.nn.functional.multi_margin_loss

可以走CPU实现。

torch.nn.functional.nll_loss

  

torch.nn.functional.huber_loss

  

torch.nn.functional.smooth_l1_loss

  

torch.nn.functional.soft_margin_loss

  

torch.nn.functional.triplet_margin_loss

  

torch.nn.functional.triplet_margin_with_distance_loss

  

torch.nn.functional.pixel_shuffle

不支持8D输入。

torch.nn.functional.pixel_unshuffle

  

torch.nn.functional.pad

在输入x为六维以上时可能会出现性能下降问题。

torch.nn.functional.interpolate

不支持recompute_scale_factor=True。当设置jit_complile=false时,可能会存在精度问题。

torch.nn.functional.upsample

Atlas 训练系列产品:是

Atlas A2 训练系列产品:否

Atlas 训练系列产品

  • 不能输入较大shape(数据量百万级别)。
  • 在torch.nn.Upsample(size=size, mode='bicubic',align_corners=False)中的size为(1,*)时,torch.nn.functional.upsample中的align_corners参数需要设置为True。
  • 建议使用float32类型,否则会影响计算性能。
  • align_corners参数为False时,输入和输出的向量shape需相同,否侧输出结果有误。

torch.nn.functional.upsample_nearest

  

torch.nn.functional.upsample_bilinear

  

torch.nn.functional.grid_sample

  

torch.nn.functional.affine_grid

  

torch.nn.parallel.data_parallel