算子名称 |
SwiGlu |
---|---|
torch_npu api接口 |
torch_npu. npu_swiglu(x,dim) |
支持的torch_npu版本 |
2.1.0 |
支持的芯片类型 |
|
支持的数据类型 |
float16,bfloat16,float |
1 2 3 4 5 6 |
REG_OP(SwiGlu) .INPUT(x, "T") .OUTPUT(y, "T") .DATATYPE(T, TensorType({DT_BF16, DT_FLOAT16, DT_FLOAT})) .ATTR(dim, Int, -1) .OP_END_FACTORY_REG(SwiGlu) |
torch_npu.npu_swiglu(Tensor self, int dim=-1) -> (Tensor)
输出说明:
输出为Tensor,计算公式的最终输出y。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
class SwiGlu(torch.nn.Module): def __init__(self, dim: = -1): """ Initialize the SwiGlu. Args: dim (int): The dimension of the input tensor. dim(int, optional): The splitting dimension of input tensor. Default is -1. Attributes: dim(int): The splitting dimension of input tensor. """ super().__init__() self.dim= dim def _swiglu(self, x): """ Apply the SwiGlu to the input tensor. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The normalized tensor. """ x = torch.chunk(x, 2, -1) return torch.nn.functional.silu(x[0])*x[1] def forward(self, x): """ Forward pass through the SwiGlu. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying SwiGlu. """ output = self._swiglu(x) return output |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch_npu class SwiGlu(torch.nn.Module): def __init__(self, dim: = -1): """ Initialize the SwiGlu. Args: dim (int): The dimension of the input tensor. dim(int, optional): The splitting dimension of input tensor. Default is -1. Attributes: dim(int): The splitting dimension of input tensor """ super().__init__() self.dim= dim def forward(self, x): """ Forward pass through the SwiGlu. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying SwiGlu. """ dim = -1 return torch_npu.npu_swiglu(x, dim = dim) |
当前仅支持