(beta)torch_npu.npu_swiglu
接口原型
torch_npu.npu_swiglu(Tensor self, int dim=-1) -> (Tensor)
功能描述
提供swiglu的激活函数。
公式如下:
outputs = swiglu\(x,dim = -1) = swish(A) * B = A * sigmoid(A) * B
- “x”是输入Tensor。
- “dim”是切分维度,默认为-1。
- “A”和“B”是x沿dim维度切分的Tensor。
参数说明
“dim”:Int类型,默认为-1。
输入说明(针对于class需要)
“x”:Tensor类型,shape支持1-8维,dtype支持FP32、FP16或BF16类型。
输出说明
输出为Tensor,计算公式的最终输出outputs。
支持的型号
Atlas A2 训练系列产品
调用示例
import torch import torch_npu input_tensor = torch.randn(2, 32, 6, 6).npu() output = torch_npu.npu_swiglu(input_tensor, dim = -1)
父主题: torch_npu