集合通信算子入图
功能简介
集合通信算子入图能够从整图层面进行通信与计算并行优化,而原生PyTorch 2.1版本暂不支持集合通信算子入图,从PyTorch 2.2版本及更高版本开始陆续支持相关算子入图。
原生PyTorch的dynamo trace入图仅支持算子和python内置函数入图,普通python函数会被dynamo拆解,将算子记录到图中。早期PyTorch提供的集合通信能力以API形式呈现,无法直接将其转换入图,为此PyTorch社区新增了以torch.ops.c10_functional为前缀的系列集合通信算子解决该问题。当用户使用torch.ops.c10_functional系列算子时,需要手动修改网络模型脚本中的集合通信部分,并且该部分通常被Deepspeed等框架封装,因此社区在dynamo内部将集合通信API转换为torch.ops.c10_functional系列算子,实现了早期集合通信API方式调用脚本也能trace入图。然而,社区在dynamo内部将集合通信API转换为torch.ops.c10_functional系列算子的行为,在PyTorch 2.3版本才逐渐补齐相关能力,PyTorch 2.1版本处于不可用状态。
针对PyTorch 2.1版本中不支持入图的部分集合通信算子,TorchAir提供了补丁方案,封装了一系列NPU通信算子,以实现集合通信算子入图。
使用方法
假设训练/推理PyTorch脚本能在PyTorch 2.1版本中以Eager模式正常运行,用户无需修改脚本,直接导入补丁包(Monkey-Patch)可以完成一系列通信算子入图,具体方法如下:
分布式场景下,由于Deepspeed框架对原生allreduce API的封装使其无法入图,TorchAir可同时Patch Deepspeed框架中allreduce封装函数,以解决入图问题。
- 方案1(推荐):
在脚本中导入patch_for_hcom包,并在图执行之前调用patch_for_hcom函数。该函数封装了一系列NPU通信算子,用于替换PyTorch原生集合通信API,具体参见表1。
1 2 3 4 5 6
# 1.导入patch_for_hcom包 import torch_npu import torchair from torchair import patch_for_hcom # 2.在图执行之前调用patch方法 patch_for_hcom()
表1 集合通信算子列表 PyTorch集合通信API
自定义的NPU通信算子
torch.distributed.all_reduce
torch.ops.c10d_functional.all_reduce
说明:本算子是由PyTorch框架定义,其余通信算子均由TorchAir定义。
torch.distributed.all_gather
torch.ops.npu_define.allgather
torch.distributed.all_to_all
torch.ops.npu_define.all_to_all
torch.distributed.broadcast
torch.ops.npu_define.broadcast
torch.distributed.all_to_all_single
torch.ops.npu_define.all_to_all_single
torch.distributed.allgather_in_tensor
torch.ops.npu_define.allgather_in_tensor
- 方案2:
本方案在Ascend Extension for PyTorch的6.0.RC3版本及后续版本不再演进,可能会废弃,建议优先使用方案1(推荐)。
在脚本中导入patch_for_hcom_allreduce包,该包将PyTorch原生的torch.distributed.allreduce API替换成torch.distributed._functional_collectives.all_reduce算子,以实现入图。
1 2 3
# 导入补丁包 import torch_npu import torchair.ge_concrete_graph.ge_converter.experimental patch_for_hcom_allreduce