torch_npu.npu_moe_compute_expert_tokens
接口原型
npu_moe_compute_expert_tokens(Tensor sorted_expert_for_source_row, int num_expert) -> Tensor
参数说明
- sorted_expert_for_source_row:必选参数,经过专家处理过的结果,要求是一个1D的Tensor,数据类型支持INT32,数据格式要求为ND。shape大小需要小于2147483647。
- num_expert:必选参数,总专家数。
输出说明
expertTokens:Device侧的aclTensor,公式中的输出,要求的是一个1D的Tensor,数据类型与sorted_expert_for_source_row保持一致。
约束说明
- 该接口仅在推理场景下使用。
- 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
支持的型号
调用示例
- 单算子模式调用
1 2 3 4 5
import torch import torch_npu sorted_experts = torch.tensor([3,3,4,5,6,7], dtype=torch.int32) num_experts = 5 output = torch_npu.npu_moe_compute_expert_tokens(sorted_experts.npu(), num_experts)
- 图模式调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
import torch import torch.nn as nn import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class GMMModel(nn.Module): def __init__(self): super().__init__() def forward(self, sorted_experts, num_experts): return torch_npu.npu_moe_compute_expert_tokens(sorted_experts, num_experts) def main(): sorted_experts = torch.tensor([3,3,4,5,6,7], dtype=torch.int32) num_experts = 5 model = GMMModel().npu() model = torch.compile(model, backend=npu_backend, dynamic=False) custom_output = model(sorted_experts, num_experts) if __name__ == '__main__': main()
父主题: torch_npu