下载
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助
昇腾小AI

torch_npu.npu_moe_init_routing

功能描述

接口原型

npu_moe_init_routing(Tensor x, Tensor row_idx, Tensor expert_idx, int active_num) -> (Tensor, Tensor, Tensor)

参数说明

  • x :Device侧的Tensor类型,必选输入,MOE的输入即token特征输入,要求为一个2D的Tensor,shape为 (NUM_ROWS, H)。数据类型支持FLOAT16、BFLOAT16、FLOAT32,数据格式要求为ND。shape大小需要小于2^24。
  • row_idx:Device侧的Tensor类型,必选输入,指示每个位置对应的原始行位置,shape要求与expert_idx一致。数据类型支持INT32,数据格式要求为ND。
  • expert_idx: Device侧的Tensor类型,必选输入,torch_npu.npu_moe_gating_top_k_softmax的输出每一行特征对应的K个处理专家,要求是一个2D的shape (NUM_ROWS, K),数据类型支持int32,数据格式要求为ND。
  • active_num:Host侧的int类型,表示总的最大处理row数,输出expanded_x只有这么多行是有效的。

输出说明

  • expanded_x:Device侧的Tensor类型,根据expert_idx进行扩展过的特征,要求是一个2D的Tensor,shape (min(NUM_ROWS, activeNum) * k, H)。数据类型同x,数据格式要求为ND。
  • expanded_row_idx:Device侧的Tensor类型,expanded_x和x的映射关系, 要求是一个1D的Tensor,Shape为(NUM_ROWS*K, ),数据类型支持INT32,数据格式要求为ND。
  • expanded_expert_idx:Device侧的Tensor类型,输出expert_idx排序后的结果。

约束说明

该融合算子仅在推理场景使用。

支持的型号

Atlas A2训练系列产品/Atlas 800I A2推理产品

调用示例

import torch
import torch_npu
x = torch.tensor([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2],[0.3, 0.3, 0.3, 0.3]], dtype=torch.float32).to("npu")
row_idx = torch.tensor([[0, 3], [1, 4], [2, 5]], dtype=torch.int32).to("npu")
expert_idx = torch.tensor([[1, 2], [0, 1], [0, 2]], dtype=torch.int32).to("npu")
active_num = 3
expanded_x, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(x, row_idx, expert_idx, active_num)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词