torch_npu.distributed.reinit_process_group(group: optional[ProcessGroup] = None, rebuild_link: bool = True) -> None
重新构建processgroup集合通信域。
要确保是一个有效的device。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | import os import torch import torch.distributed as dist import multiprocessing as mp import torch_npu def _do_allreduce(rank, world_size): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29688' torch.npu.set_device(rank) dist.init_process_group(backend='hccl', world_size=world_size, rank=rank) # 重建group torch_npu.distributed.reinit_process_group() a = torch.ones(2,2,device=f"npu:{rank}") dist.all_reduce(a) def _multiprocess(world_size,f): ctx = mp.get_context('spawn') ps = [] for i in range(world_size): p = ctx.Process(target=f, args=(i,world_size)) p.start() for p in ps: p.join() if __name__ == '__main__': _multiprocess(4, _do_allreduce) |