torch_npu.npu_moe_gating_top_k_softmax
功能描述
MoE计算中,对gating的输出做Softmax计算,取topk操作。
接口原型
npu_moe_gating_top_k_softmax(Tensor x, Tensor? finished=None, int k=1) -> (Tensor, Tensor, Tensor)
参数说明
- x(aclTensor*,计算输入):待计算的输入,要求是一个2D/3D的Tensor,数据类型支持FLOAT16、BFLOAT16、FLOAT32,数据格式要求为ND。
- finished(aclTensor*,计算输入) :可选,要求是一个1D/2D的Tensor,数据类型支持BOOL,shape为gating_shape[:-1],数据格式要求为ND。
- k(int,计算输入):topk的k值,大小为0 <= k <= x的-1轴大小,k<=1024。
输出说明
- y(aclTensor*,计算输出):对x做softmax后取的topk值,要求是一个2D/3D的Tensor,数据类型与x需要保持一致,其非-1轴要求与x的对应轴大小一致,其-1轴要求其大小同k值。数据格式要求为ND。
- expert_idx(aclTensor*,计算输出):对x做softmax后取topk值的索引,即专家的序号。shape要求与y一致,数据类型支持int32,数据格式要求为ND。
- row_idx(aclTensor*,计算输出):指示每个位置对应的原始行位置,请参见调用示例,shape要求与y一致,数据类型支持int32,数据格式要求为ND。
约束说明
该融合算子仅在推理场景使用。
支持的型号
Atlas A2训练系列产品/Atlas 800I A2推理产品
父主题: torch_npu