torch_npu.npu_group_norm_silu
功能描述
计算输入input的组归一化结果out、均值meanOut、标准差的倒数rstdOut、以及silu的输出。
接口原型
torch_npu.npu_group_norm_silu(Tensor input, Tensor weight, Tensor bias, int group, float eps) -> (Tensor, Tensor, Tensor)
参数说明
- input:Device侧的Tensor类型,必选输入,源数据张量,维度需大于一维,数据格式支持ND,支持非连续的Tensor。
Atlas 推理系列产品 :数据类型支持FLOAT16、FLOAT。Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、FLOAT、BFLOAT16。
- weight:Device侧的Tensor类型,必选输入,索引张量,维度为1且元素数量需与输入input的第1维度保持相同,数据格式支持ND,支持非连续的Tensor。
Atlas 推理系列产品 :数据类型支持FLOAT16、FLOAT。Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、FLOAT、BFLOAT16。
- bias:Device侧的Tensor类型,必选输入,更新数据张量,维度为1元素数量需与输入input的第1维度保持相同,数据格式支持ND,支持非连续的Tensor。
Atlas 推理系列产品 :数据类型支持FLOAT16、FLOAT。Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、FLOAT、BFLOAT16。
- group:Host侧的int类型,必选输入,表示将输入input的第1维度分为group组。
- eps:Host侧的float类型,可选参数,数值稳定性而加到分母上的值,若保持精度,则eps需大于0。
输出说明
- out:Device侧的Tensor类型,计算输出,数据类型和shape与input相同,支持ND,支持非连续的Tensor。
Atlas 推理系列产品 :数据类型支持FLOAT16、FLOAT。Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、FLOAT、BFLOAT16。
- meanOut:Device侧的Tensor类型,计算输出,数据类型与input相同,shape为(N, group)支持ND,支持非连续的Tensor。
Atlas 推理系列产品 :数据类型支持FLOAT16、FLOAT。Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、FLOAT、BFLOAT16。
- rstdOut:Device侧的Tensor类型,计算输出,数据类型与input相同,shape为(N, group)。
Atlas 推理系列产品 :数据类型支持FLOAT16、FLOAT。Atlas A2 训练系列产品/Atlas 800I A2 推理产品 :数据类型支持FLOAT16、FLOAT、BFLOAT16。
约束说明
- input、weight、bias、out、meanOut、rstdOut数据类型必须支持的范围之内。
- out、meanOut、rstdOut的数据类型与input相同;weight、bias与input可以不同。
- input第1维度能整除group。
- out的shape与input相同。
- meanOut与rstdOut的shape为(N, group),其中N为input第0维度值。
- weight与bias的数据类型必须保持一致,且数据类型的精度不能低于input的数据类型。
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品 Atlas 推理系列产品
调用示例
单算子调用:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | import torch import numpy as np import torch_npu dtype = np.float32 shape_x = [24,320,48,48] num_groups = 32 shape_c = [320] eps = 0.00001 x_npu=torch.randn(shape_x,dtype=torch.float32).npu() gamma_npu=torch.randn(shape_c,dtype=torch.float32).npu() beta_npu=torch.randn(shape_c,dtype=torch.float32).npu() out_npu, mean_npu, rstd_out = torch_npu.npu_group_norm_silu(x_npu, gamma_npu, beta_npu, group=num_groups, eps=eps) x_npu=torch.randn(shape_x,dtype=torch.bfloat16).npu() gamma_npu=torch.randn(shape_c,dtype=torch.bfloat16).npu() beta_npu=torch.randn(shape_c,dtype=torch.bfloat16).npu() out_npu, mean_npu, rstd_out = torch_npu.npu_group_norm_silu(x_npu, gamma_npu, beta_npu, group=num_groups, eps=eps) x_npu=torch.randn(shape_x,dtype=torch.float16).npu() gamma_npu=torch.randn(shape_c,dtype=torch.float16).npu() beta_npu=torch.randn(shape_c,dtype=torch.float16).npu() out_npu, mean_npu, rstd_out = torch_npu.npu_group_norm_silu(x_npu, gamma_npu, beta_npu, group=num_groups, eps=eps) |
父主题: torch_npu