下载
中文
注册

beta)torch_npu.npu_swiglu

接口原型

torch_npu.npu_swiglu(Tensor self, int dim=-1) -> (Tensor)

功能描述

提供swiglu的激活函数。

公式如下:

outputs = swiglu\(x,dim = -1) = swish(A) * B = A * sigmoid(A) * B

  • “x”是输入Tensor。
  • “dim”是切分维度,默认为-1。
  • “A”“B”是x沿dim维度切分的Tensor。

参数说明

“dim”:Int类型,默认为-1。

输入说明(针对于class需要)

“x”:Tensor类型,shape支持1-8维,dtype支持FP32、FP16或BF16类型。

输出说明

输出为Tensor,计算公式的最终输出outputs。

支持的型号

  • Atlas A2 训练系列产品
  • Atlas 推理系列产品

调用示例

1
2
3
4
import torch
import torch_npu
input_tensor = torch.randn(2, 32, 6, 6).npu()
output = torch_npu.npu_swiglu(input_tensor, dim = -1)