下载
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助
昇腾小AI

torch_npu.npu_moe_gating_top_k_softmax

功能描述

MoE计算中,对gating的输出做Softmax计算,取topk操作。

接口原型

npu_moe_gating_top_k_softmax(Tensor x, Tensor? finished=None, int k=1) -> (Tensor, Tensor, Tensor)

参数说明

  • x(aclTensor*,计算输入):待计算的输入,要求是一个2D/3D的Tensor,数据类型支持FLOAT16、BFLOAT16、FLOAT32,数据格式要求为ND。
  • finished(aclTensor*,计算输入) :可选,要求是一个1D/2D的Tensor,数据类型支持BOOL,shape为gating_shape[:-1],数据格式要求为ND。
  • k(int,计算输入):topk的k值,大小为0 <= k <= x的-1轴大小,k<=1024。

输出说明

  • y(aclTensor*,计算输出):对x做softmax后取的topk值,要求是一个2D/3D的Tensor,数据类型与x需要保持一致,其非-1轴要求与x的对应轴大小一致,其-1轴要求其大小同k值。数据格式要求为ND。
  • expert_idx(aclTensor*,计算输出):对x做softmax后取topk值的索引,即专家的序号。shape要求与y一致,数据类型支持int32,数据格式要求为ND。
  • row_idx(aclTensor*,计算输出):指示每个位置对应的原始行位置,请参见调用示例,shape要求与y一致,数据类型支持int32,数据格式要求为ND。

约束说明

该融合算子仅在推理场景使用。

支持的型号

Atlas A2训练系列产品/Atlas 800I A2推理产品

调用示例

import torch
import torch_npu
x = torch.rand((3, 3), dtype=torch.float32).to("npu")
finished = torch.randint(2, size=(3,), dtype=torch.bool).to("npu")
y, expert_idx, row_idx = torch_npu.npu_moe_gating_top_k_softmax(x, finished, k=2)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词