SwiGlu

算子基础信息

表1 算子信息

算子名称

SwiGlu

torch_npu api接口

torch_npu. npu_swiglu(x,dim)

支持的torch_npu版本

2.1.0

支持的芯片类型

Atlas A2 训练系列产品 Atlas A3 训练系列产品

支持的数据类型

float16,bfloat16,float

算子IR及torch_npu接口参数

算子IR:
1
2
3
4
5
6
REG_OP(SwiGlu)
.INPUT(x, "T")
    .OUTPUT(y, "T")
    .DATATYPE(T, TensorType({DT_BF16, DT_FLOAT16, DT_FLOAT}))
    .ATTR(dim, Int, -1)
    .OP_END_FACTORY_REG(SwiGlu)

torch_npu接口:
torch_npu.npu_swiglu(Tensor self, int dim=-1) -> (Tensor)
参数说明:
  • x:Tensor类型,shape支持1-8维。
  • dim:int类型,默认为-1。

输出说明:

输出为Tensor,计算公式的最终输出y。

模型中替换代码及算子计算逻辑

图1 计算流程图

算子替换的模型中小算子

使用限制

当前仅支持 Atlas A2 训练系列产品

已支持模型典型case