import torch
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
torch_npu.npu.set_compile_mode(jit_compile=True)
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
device=torch.device(f'npu:0')
torch_npu.npu.set_device(device)
class MoeInitRoutingModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, row_idx, expert_idx, active_num):
expanded_x, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(x, row_idx, expert_idx, active_num=active_num)
return expanded_x, expanded_row_idx, expanded_expert_idx
x = torch.tensor([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2],[0.3, 0.3, 0.3, 0.3]], dtype=torch.float32).to("npu")
row_idx = torch.tensor([[0, 3], [1, 4], [2, 5]], dtype=torch.int32).to("npu")
expert_idx = torch.tensor([[1, 2], [0, 1], [0, 2]], dtype=torch.int32).to("npu")
active_num = 3
moe_init_routing_model = MoeInitRoutingModel().npu()
moe_init_routing_model = torch.compile(moe_init_routing_model, backend=npu_backend, dynamic=True)
expanded_x, expanded_row_idx, expanded_expert_idx = moe_init_routing_model(x, row_idx, expert_idx, active_num=active_num)
print(expanded_x)
print(expanded_row_idx)
print(expanded_expert_idx)