下载
中文
注册

torch_npu.npu_moe_finalize_routing

功能描述

  • 算子功能:MoE计算中,最后处理合并MoE FFN的输出结果。
  • 计算公式:

接口原型

1
npu_moe_finalize_routing(Tensor expanded_permuted_rows, Tensor? skip1, Tensor? skip2, Tensor? bias, Tensor? scales, Tensor expanded_src_to_dst_row, Tensor? export_for_source_row, int? drop_pad_mode=0) -> Tensor

参数说明

  • expanded_permuted_rows:必选参数,经过专家处理过的结果,要求是一个2D的Tensor,数据类型支持FLOAT16、BFLOAT16、FLOAT32,数据格式要求为ND。shape支持(NUM_ROWS * K, H),NUM_ROWS为行数,K为从总的专家E中选出K个专家,H为列数。
  • skip1:可选参数,求和的输入参数1,要求是一个2D的Tensor,数据类型要求与expanded_permuted_rows一致 ,shape要求与输出out的shape一致。
  • skip2:可选参数,求和的输入参数2,要求是一个2D的Tensor,数据类型要求与expanded_permuted_rows一致 ,shape要求与输出out的shape一致。skip2参数为None时,skip1参数必须也为None。
  • bias:可选参数,专家的偏差,要求是一个2D的Tensor,数据类型要求与expanded_permuted_rows一致。shape支持(E,H),E为总的专家个数,H为列数。
  • scales:可选参数,专家的权重,要求是一个2D的Tensor,数据类型要求与expanded_permuted_rows一致,shape支持(NUM_ROWS,K)。
  • expanded_src_to_dst_row: 必选参数,保存每个专家处理结果的索引,要求是一个1D的Tensor,数据类型支持INT32。shape支持(NUM_ROWS * K),NUM_ROWS为行数,K为从总的专家E中选出K个专家,drop_pad_mode参数为0时,Tensor中的值取值范围是[0, NUM_ROWS * K-1]。
  • export_for_source_row: 可选参数,每行处理的专家号,要求是一个2D的Tensor,数据类型支持INT32。shape支持(NUM_ROWS,K),NUM_ROWS为行数,K为从总的专家E中选出K个专家。
  • drop_pad_mode:可选参数,表示是否支持丢弃模式,取值范围为0,默认值为0。

输出说明

out:Device侧的Tensor类型,最后处理合并MoE FFN的输出结果。

约束说明

  • 该接口仅在推理场景下使用。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。

支持的型号

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

调用示例

  • 单算子模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    import torch
    import torch_npu
    
    expert_num = 16
    token_len = 10
    top_k = 4
    num_rows = 50
    device =torch.device('npu')
    dtype = torch.float32
    expanded_permuted_rows = torch.randn((num_rows * top_k, token_len), device=device, dtype=dtype)
    skip1 = torch.randn((num_rows, token_len), device=device, dtype=dtype)
    skip2_optional = torch.randn((num_rows, token_len), device=device, dtype=dtype)
    bias = torch.randn((expert_num, token_len), device=device, dtype=dtype)
    scales = torch.randn((num_rows, top_k), device=device, dtype=dtype)
    expert_for_source_row = torch.randint(low=0, high=expert_num, size=(num_rows, top_k), device=device, dtype=torch.int32)
    expanded_src_to_dst_row = torch.randint(low=0, high=num_rows * top_k, size=(num_rows * top_k,), device=device, dtype=torch.int32)
    drop_pad_mode = 0
    output = torch_npu.npu_moe_finalize_routing(expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row, drop_pad_mode)
    
  • 图模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    import torch
    import torch.nn as nn
    import torch_npu
    import torchair as tng
    from torchair.configs.compiler_config import CompilerConfig
    config = CompilerConfig()
    npu_backend = tng.get_npu_backend(compiler_config=config)
    class GMMModel(nn.Module):
        def __init__(self):
            super().__init__()
        
        def forward(self, expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row, drop_pad_mode):
            return torch_npu.npu_moe_finalize_routing(expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row, drop_pad_mode)
    def main():
        expert_num = 16
        token_len = 10
        top_k = 4
        num_rows = 50
        device =torch.device('npu')
        dtype = torch.float32
        expanded_permuted_rows = torch.randn((num_rows * top_k, token_len), device=device, dtype=dtype)
        skip1 = torch.randn((num_rows, token_len), device=device, dtype=dtype)
        skip2_optional = torch.randn((num_rows, token_len), device=device, dtype=dtype)
        bias = torch.randn((expert_num, token_len), device=device, dtype=dtype)
        scales = torch.randn((num_rows, top_k), device=device, dtype=dtype)
        expert_for_source_row = torch.randint(low=0, high=expert_num, size=(num_rows, top_k), device=device, dtype=torch.int32)
        expanded_src_to_dst_row = torch.randint(low=0, high=num_rows * top_k, size=(num_rows * top_k,), device=device, dtype=torch.int32)
        drop_pad_mode = 0
        model = GMMModel().npu()
        model = torch.compile(model, backend=npu_backend, dynamic=False)
        custom_output = model(expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row, drop_pad_mode)
    if __name__ == '__main__':
        main()