torch_npu.npu_moe_finalize_routing
接口原型
1 | npu_moe_finalize_routing(Tensor expanded_permuted_rows, Tensor? skip1, Tensor? skip2, Tensor? bias, Tensor? scales, Tensor expanded_src_to_dst_row, Tensor? export_for_source_row, int? drop_pad_mode=0) -> Tensor |
参数说明
- expanded_permuted_rows:必选参数,经过专家处理过的结果,要求是一个2D的Tensor,数据类型支持FLOAT16、BFLOAT16、FLOAT32,数据格式要求为ND。shape支持(NUM_ROWS * K, H),NUM_ROWS为行数,K为从总的专家E中选出K个专家,H为列数。
- skip1:可选参数,求和的输入参数1,要求是一个2D的Tensor,数据类型要求与expanded_permuted_rows一致 ,shape要求与输出out的shape一致。
- skip2:可选参数,求和的输入参数2,要求是一个2D的Tensor,数据类型要求与expanded_permuted_rows一致 ,shape要求与输出out的shape一致。skip2参数为None时,skip1参数必须也为None。
- bias:可选参数,专家的偏差,要求是一个2D的Tensor,数据类型要求与expanded_permuted_rows一致。shape支持(E,H),E为总的专家个数,H为列数。
- scales:可选参数,专家的权重,要求是一个2D的Tensor,数据类型要求与expanded_permuted_rows一致,shape支持(NUM_ROWS,K)。
- expanded_src_to_dst_row: 必选参数,保存每个专家处理结果的索引,要求是一个1D的Tensor,数据类型支持INT32。shape支持(NUM_ROWS * K),NUM_ROWS为行数,K为从总的专家E中选出K个专家,drop_pad_mode参数为0时,Tensor中的值取值范围是[0, NUM_ROWS * K-1]。
- export_for_source_row: 可选参数,每行处理的专家号,要求是一个2D的Tensor,数据类型支持INT32。shape支持(NUM_ROWS,K),NUM_ROWS为行数,K为从总的专家E中选出K个专家。
- drop_pad_mode:可选参数,表示是否支持丢弃模式,取值范围为0,默认值为0。
输出说明
out:Device侧的Tensor类型,最后处理合并MoE FFN的输出结果。
约束说明
- 该接口仅在推理场景下使用。
- 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
支持的型号
调用示例
- 单算子模式调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
import torch import torch_npu expert_num = 16 token_len = 10 top_k = 4 num_rows = 50 device =torch.device('npu') dtype = torch.float32 expanded_permuted_rows = torch.randn((num_rows * top_k, token_len), device=device, dtype=dtype) skip1 = torch.randn((num_rows, token_len), device=device, dtype=dtype) skip2_optional = torch.randn((num_rows, token_len), device=device, dtype=dtype) bias = torch.randn((expert_num, token_len), device=device, dtype=dtype) scales = torch.randn((num_rows, top_k), device=device, dtype=dtype) expert_for_source_row = torch.randint(low=0, high=expert_num, size=(num_rows, top_k), device=device, dtype=torch.int32) expanded_src_to_dst_row = torch.randint(low=0, high=num_rows * top_k, size=(num_rows * top_k,), device=device, dtype=torch.int32) drop_pad_mode = 0 output = torch_npu.npu_moe_finalize_routing(expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row, drop_pad_mode)
- 图模式调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
import torch import torch.nn as nn import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class GMMModel(nn.Module): def __init__(self): super().__init__() def forward(self, expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row, drop_pad_mode): return torch_npu.npu_moe_finalize_routing(expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row, drop_pad_mode) def main(): expert_num = 16 token_len = 10 top_k = 4 num_rows = 50 device =torch.device('npu') dtype = torch.float32 expanded_permuted_rows = torch.randn((num_rows * top_k, token_len), device=device, dtype=dtype) skip1 = torch.randn((num_rows, token_len), device=device, dtype=dtype) skip2_optional = torch.randn((num_rows, token_len), device=device, dtype=dtype) bias = torch.randn((expert_num, token_len), device=device, dtype=dtype) scales = torch.randn((num_rows, top_k), device=device, dtype=dtype) expert_for_source_row = torch.randint(low=0, high=expert_num, size=(num_rows, top_k), device=device, dtype=torch.int32) expanded_src_to_dst_row = torch.randint(low=0, high=num_rows * top_k, size=(num_rows * top_k,), device=device, dtype=torch.int32) drop_pad_mode = 0 model = GMMModel().npu() model = torch.compile(model, backend=npu_backend, dynamic=False) custom_output = model(expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row, drop_pad_mode) if __name__ == '__main__': main()
父主题: torch_npu