torch_npu.npu_group_norm_silu
功能描述
计算输入self的组归一化结果out、均值meanOut、标准差的倒数rstdOut、以及silu的输出。
接口原型
torch_npu.npu_group_norm_silu(Tensor self, Tensor weight, Tensor bias, int group, float eps) -> (Tensor, Tensor, Tensor)
参数说明
- self:Device侧的Tensor类型,必选输入,源数据张量,数据类型支持FLOAT16、FLOAT、BFLOAT16,维度需大于一维,数据格式支持ND,支持非连续的Tensor。
- weight:Device侧的Tensor类型,必选输入,索引张量,数据类型支持FLOAT16、FLOAT、BFLOAT16,维度为1且元素数量需与输入self的第1维度保持相同,数据格式支持ND,支持非连续的Tensor。
- bias:Device侧的Tensor类型,必选输入,更新数据张量,数据类型支持FLOAT16、FLOAT、BFLOAT16,维度为1元素数量需与输入self的第1维度保持相同,数据格式支持ND,支持非连续的Tensor。
- group:Host侧的int类型,必选输入,表示将输入self的第1维度分为group组。
- eps:Host侧的float类型,可选参数,数值稳定性而加到分母上的值,若保持精度,则eps需大于0。
输出说明
- out:Device侧的Tensor类型,计算输出,数据类型支持FLOAT16、FLOAT、BFLOAT16,数据类型和shape与self相同,支持ND,支持非连续的Tensor。
- meanOut:Device侧的Tensor类型,计算输出,数据类型支持FLOAT16、FLOAT、BFLOAT16,数据类型与self相同,shape为(N, group)支持ND,支持非连续的Tensor。
- rstdOut:Device侧的Tensor类型,计算输出,数据类型支持FLOAT16、FLOAT、BFLOAT16,数据类型与self相同,shape为(N, group)。
约束说明
- BFLOAT16数据类型仅在Atlas A2 训练系列产品支持。
- self、weight、bias、out、meanOut、rstdOut数据类型必须支持的范围之内。
- out、meanOut、rstdOut的数据类型与self相同;weight、bias与self可以不同。
- self第1维度能整除group。
- out的shape与self相同。
- meanOut与rstdOut的shape为(N, group),其中N为self第0维度值。
- weight与bias的数据类型必须保持一致,且数据类型的精度不能低于self的数据类型。
支持的PyTorch版本
- PyTorch 2.3
- PyTorch 2.2
- PyTorch 2.1
- PyTorch 1.11.0
支持的型号
- Atlas A2 训练系列产品
- Atlas 推理系列产品
调用示例
import torch import numpy as np import torch_npu dtype = np.float32 shape_x = [24,320,48,48] num_groups = 32 shape_c = [320] eps = 0.00001 x_npu=torch.randn(shape_x,dtype=torch.float32).npu() gamma_npu=torch.randn(shape_c,dtype=torch.float32).npu() beta_npu=torch.randn(shape_c,dtype=torch.float32).npu() out_npu, mean_npu, rstd_out = torch_npu.npu_group_norm_silu(x_npu, gamma_npu, beta_npu, group=num_groups, eps=eps) x_npu=torch.randn(shape_x,dtype=torch.bfloat16).npu() gamma_npu=torch.randn(shape_c,dtype=torch.bfloat16).npu() beta_npu=torch.randn(shape_c,dtype=torch.bfloat16).npu() out_npu, mean_npu, rstd_out = torch_npu.npu_group_norm_silu(x_npu, gamma_npu, beta_npu, group=num_groups, eps=eps) x_npu=torch.randn(shape_x,dtype=torch.float16).npu() gamma_npu=torch.randn(shape_c,dtype=torch.float16).npu() beta_npu=torch.randn(shape_c,dtype=torch.float16).npu() out_npu, mean_npu, rstd_out = torch_npu.npu_group_norm_silu(x_npu, gamma_npu, beta_npu, group=num_groups, eps=eps)
父主题: torch_npu