torch.nn.functional

若API“是否支持”“是”“限制与说明”为空,说明此API和原生API支持度保持一致。

API名称

是否支持

限制与说明

torch.nn.functional.conv1d

支持fp16,fp32

torch.nn.functional.conv2d

支持bf16,fp16,fp32

torch.nn.functional.conv3d

  

torch.nn.functional.conv_transpose1d

支持fp32

torch.nn.functional.conv_transpose2d

支持fp16,fp32

torch.nn.functional.conv_transpose3d

  

torch.nn.functional.unfold

支持bf16,fp16,fp32

torch.nn.functional.fold

支持fp16,fp32

torch.nn.functional.avg_pool1d

支持fp16,fp32

torch.nn.functional.avg_pool2d

支持fp16,fp32

torch.nn.functional.avg_pool3d

  

torch.nn.functional.max_pool1d

  

torch.nn.functional.max_pool3d

  

torch.nn.functional.max_unpool1d

支持fp16,fp32,fp64,uint8,int8,int32,int64

torch.nn.functional.max_unpool2d

支持fp16,fp32,fp64,uint8,int8,int32,int64

torch.nn.functional.max_unpool3d

  

torch.nn.functional.lp_pool1d

支持fp16,fp32

torch.nn.functional.lp_pool2d

支持fp16,fp32

torch.nn.functional.adaptive_max_pool1d

  

torch.nn.functional.adaptive_max_pool2d

  

torch.nn.functional.adaptive_avg_pool1d

支持fp16,fp32

torch.nn.functional.adaptive_avg_pool2d

支持fp16,fp32

torch.nn.functional.adaptive_avg_pool3d

支持bf16,fp16,fp32

torch.nn.functional.scaled_dot_product_attention

说明:

在当前版本中,sdpa(scaled_dot_product_attention)接口仅作为一项试用特性,此功能在后续版本中可能会有所调整或改进。请用户在使用过程中关注后续版本的迭代。

当requires_grad=true,支持bf16、fp16和fp32;当requires_grad=false,支持bf16、fp16。所有参数输入均符合以下约束

约束

此API仅Atlas A2 训练系列产品时支持。

  • 当requires_grad=true,所有参数输入符合规格:
    • 输入query、key、value的N:batch size,当前只支持[N,head_num, S(L), E(Ev)]的排布方式,取值范围1~2K
    • 输入query的head num和key/value的head num必须成比例关系,即Nq/Nkv必须是非0整数,取值范围1~256
    • 输入query的L:Target sequence length,取值范围1~512K
    • 输入key、value的S:Source sequence length,取值范围1~512K
    • 输入query、key、value的E:Embedding dimension of the query and key,取值范围1~512
    • 输入value的Ev:Embedding dimension of the value,必须与E相等
    • 输入attn_mask:当前支持[N, 1, L, S]、[N, head_num, L, S]、[1, 1, L, S]、[L, S]共4种排布方式
    • 在使能is_causal计算时,attn_mask必须为None;不使能is_causal时,若attn_mask输入有效数据,输入数据类型必须是Bool类型
    • 输入query、key、value的数据类型bf16、fp16、fp32
  • 当requires_grad=false,所有参数输入均符合以下约束:
    • 输入query、key、value的N:batch size,当前只支持[N,head_num, S(L), E(Ev)]的排布方式,取值范围1~ 65536
    • 输入query的head num和key/value的head num必须成比例关系,即Nq/Nkv必须是非0整数,取值范围1~64
    • 输入query、key、value的E:Embedding dimension of the query and key,取值范围1~512,且为16的整数倍
    • 输入value的Ev:Embedding dimension of the value,必须与E相等
    • 输入attn_mask:当前支持[N, 1, L, S]、[N, head_num, L, S]、[1, 1, L, S]、[L, S]共4种排布方式
    • 在使能is_causal计算时,attn_mask必须为None;不使能is_causal时,若attn_mask输入有效数据,输入数据类型必须是Bool类型
    • 输入query、key、value的数据类型bf16、fp16
  • 与原接口除了规格限制之外差异点:
    • NPU用DSA硬件实现,算法在DSA引擎固化存在跟GPU算法实现差异,导致dropout功能和GPU结果不一致
    • 当前接口支持输入query的head num和key/value的head num不等长,而原生PyTorch接口不支持

torch.nn.functional.threshold

支持fp16,fp32,uint8,int8,int16,int32,int64

torch.nn.functional.threshold_

支持fp16,fp32,uint8,int8,int16,int32,int64

torch.nn.functional.relu

支持bf16,fp16,fp32,uint8,int8,int32,int64

torch.nn.functional.relu_

支持bf16,fp16,fp32,uint8,int8,int32,int64

torch.nn.functional.hardtanh

支持fp16,fp32,int8,int16,int32,int64

torch.nn.functional.hardtanh_

支持fp16,fp32,int8,int16,int32,int64

torch.nn.functional.hardswish

支持fp16,fp32

torch.nn.functional.relu6

支持fp16,fp32,uint8,int8,int16,int32,int64

torch.nn.functional.elu

支持fp16,fp32

torch.nn.functional.elu_

支持fp16,fp32

torch.nn.functional.selu

支持fp16,fp32

fp16的反向场景下对比GPU存在精度累加误差,可以通过如下方式进行规避:

将正向调用的torch.nn.functional.selu替换成torch.ops.aten.elu,例如:将torch.nn.functional.selu(input_x)替换为torch.ops.aten.elu(input_x, 1.6732632423543772848170429916717, 1.0507009873554804934193349852946)

torch.nn.functional.celu

支持fp16,fp32

torch.nn.functional.leaky_relu

支持bf16,fp16,fp32,fp64

torch.nn.functional.leaky_relu_

支持fp16,fp32,fp64

torch.nn.functional.prelu

支持fp16,fp32

torch.nn.functional.rrelu

支持bf16,fp16,fp32

torch.nn.functional.glu

支持fp16,fp32

torch.nn.functional.gelu

支持bf16,fp16,fp32

torch.nn.functional.logsigmoid

支持fp16,fp32

torch.nn.functional.hardshrink

支持fp16,fp32

torch.nn.functional.softsign

支持bf16,fp16,fp32,uint8,int8,int16,int32,int64

torch.nn.functional.softplus

支持fp16,fp32

torch.nn.functional.softmax

支持bf16,fp16,fp32

torch.nn.functional.softshrink

支持fp16,fp32

torch.nn.functional.gumbel_softmax

  

torch.nn.functional.log_softmax

支持bf16,fp16,fp32

torch.nn.functional.tanh

支持fp16,fp32,uint8,int8,int16,int32,int64,bool

torch.nn.functional.sigmoid

支持fp16,fp32,uint8,int8,int16,int32,int64,bool,complex64,complex128

torch.nn.functional.hardsigmoid

支持fp16,fp32

torch.nn.functional.silu

支持fp16,fp32

torch.nn.functional.mish

支持fp16,fp32

torch.nn.functional.batch_norm

支持fp16,fp32

torch.nn.functional.group_norm

支持fp16,fp32

该API仅支持2维及以上的输入input

torch.nn.functional.layer_norm

支持bf16,fp16,fp32

torch.nn.functional.normalize

支持bf16,fp16,fp32

torch.nn.functional.linear

支持fp16,fp32

torch.nn.functional.bilinear

支持fp16,fp32

torch.nn.functional.dropout

支持bf16,fp16,fp32

torch.nn.functional.alpha_dropout

支持bf16,fp16,fp32,uint8,int8,int16,int32,int64,bool,complex64,complex128

torch.nn.functional.feature_alpha_dropout

支持bf16,fp16,fp32,uint8,int8,int16,int32,int64,bool,complex64,complex128

torch.nn.functional.dropout2d

支持bf16,fp16,fp32

torch.nn.functional.embedding

支持int32,int64

torch.nn.functional.embedding_bag

  

torch.nn.functional.one_hot

支持int32,int64

torch.nn.functional.cosine_similarity

支持fp16,fp32

torch.nn.functional.pdist

  

torch.nn.functional.binary_cross_entropy

支持fp16,fp32

torch.nn.functional.binary_cross_entropy_with_logits

支持fp16,fp32

torch.nn.functional.poisson_nll_loss

支持bf16,fp16,fp32,int64

torch.nn.functional.cross_entropy

支持fp16,fp32

torch.nn.functional.ctc_loss

支持fp32

目标序列的长度不支持0,即属性target_lengths的取值不能包含0

torch.nn.functional.gaussian_nll_loss

支持bf16,fp16,fp32,uint8,int8,int16,int32,int64

torch.nn.functional.kl_div

支持bf16,fp16,fp32

当前log_target参数仅支持False

当前target不支持求导

torch.nn.functional.l1_loss

支持fp16,fp32

torch.nn.functional.mse_loss

支持fp16,fp32

torch.nn.functional.margin_ranking_loss

支持bf16,fp16,fp32

torch.nn.functional.multilabel_margin_loss

支持fp16,fp32

输入tensor的元素个数不能超过10万

torch.nn.functional.multilabel_soft_margin_loss

  

torch.nn.functional.nll_loss

支持fp32

torch.nn.functional.smooth_l1_loss

支持fp16,fp32

torch.nn.functional.soft_margin_loss

支持bf16,fp16,fp32,不支持double,complex64,complex128数据类型

torch.nn.functional.triplet_margin_loss

  

torch.nn.functional.triplet_margin_with_distance_loss

  

torch.nn.functional.pixel_shuffle

支持bf16,fp16,fp32,fp64,uint8,int8,int16,int32,int64,bool

torch.nn.functional.pixel_unshuffle

支持bf16,fp16,fp32,fp64,uint8,int8,int16,int32,int64,bool

torch.nn.functional.pad

属性mode为constant时,支持bf16,fp16,fp32,fp64,uint8,int8,int16,int32,int64,bool

属性mode非constant时,支持fp16,fp32,fp64

在输入x为六维以上时可能会出现性能下降问题

torch.nn.functional.interpolate

支持fp16,fp32

支持nearest,linear,bilinear,bicubic,trilinear, area

不支持scale_factors

torch.nn.functional.upsample

支持fp16,fp32,fp64

只支持mode = nearest

torch.nn.functional.upsample_nearest

支持fp16,fp32,fp64

只支持3-5维

torch.nn.functional.upsample_bilinear

支持fp16,fp32

torch.nn.functional.grid_sample

  

torch.nn.functional.affine_grid

支持fp16,fp32

torch.nn.parallel.data_parallel