下载
中文
注册

torch_npu.npu_moe_compute_expert_tokens

功能描述

  • 算子功能:MoE计算中,通过二分查找的方式查找每个专家处理的最后一行的位置。
  • 计算公式:

接口原型

npu_moe_compute_expert_tokens(Tensor sorted_expert_for_source_row, int num_expert) -> Tensor

参数说明

  • sorted_expert_for_source_row:必选参数,经过专家处理过的结果,要求是一个1D的Tensor,数据类型支持INT32,数据格式要求为ND。shape大小需要小于2147483647。
  • num_expert:必选参数,总专家数。

输出说明

expertTokens:Device侧的aclTensor,公式中的输出,要求的是一个1D的Tensor,数据类型与sorted_expert_for_source_row保持一致。

约束说明

  • 该接口仅在推理场景下使用。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。

支持的型号

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

调用示例

  • 单算子模式调用
    1
    2
    3
    4
    5
    import torch
    import torch_npu
    sorted_experts = torch.tensor([3,3,4,5,6,7], dtype=torch.int32)
    num_experts = 5
    output = torch_npu.npu_moe_compute_expert_tokens(sorted_experts.npu(), num_experts)
    
  • 图模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    import torch
    import torch.nn as nn
    import torch_npu
    import torchair as tng
    from torchair.configs.compiler_config import CompilerConfig
    config = CompilerConfig()
    npu_backend = tng.get_npu_backend(compiler_config=config)
    class GMMModel(nn.Module):
        def __init__(self):
            super().__init__()
        
        def forward(self, sorted_experts, num_experts):
            return torch_npu.npu_moe_compute_expert_tokens(sorted_experts, num_experts)
    def main():
        sorted_experts = torch.tensor([3,3,4,5,6,7], dtype=torch.int32)
        num_experts = 5
        model = GMMModel().npu()
        model = torch.compile(model, backend=npu_backend, dynamic=False)
        custom_output = model(sorted_experts, num_experts)
    if __name__ == '__main__':
        main()