集合通信算子入图
功能简介
集合通信算子入图能够从整图层面进行通信与计算并行优化,原生PyTorch 2.1版本暂不支持集合通信算子入图,从PyTorch 2.2及更高版本开始陆续支持。
原生PyTorch的dynamo trace入图仅支持算子和python内置函数入图,普通python函数会被dynamo拆解,将算子记录到图中。早期PyTorch提供的集合通信能力以API形式呈现,无法直接将其转换入图,PyTorch社区新增了一系列torch.ops.c10_functional前缀的集合通信算子解决该问题。当用户使用torch.ops.c10_functional算子时,需手动修改模型脚本中的集合通信部分,该部分通常被Deepspeed等框架封装,因此社区在dynamo内部将集合通信API转换为torch.ops.c10_functional算子,实现了早期集合通信API调用脚本也能trace入图。然而,社区在dynamo内部将集合通信API转换为torch.ops.c10_functional算子的行为,在PyTorch 2.3版本才逐渐补齐相关能力,PyTorch 2.1版本处于不可用状态。
针对PyTorch 2.1版本不支持入图的集合通信算子,TorchAir提供了补丁方案(patch),封装了一系列NPU通信算子,实现集合通信算子入图。
使用方法
假设训练/推理脚本在PyTorch 2.1版本能以Eager模式正常运行,用户无需修改脚本,直接导入补丁包(Monkey-Patch)即可完成通信算子入图,方案如下:
分布式场景下,由于Deepspeed框架对原生allreduce API的封装使其无法入图,TorchAir可同时Patch Deepspeed框架中allreduce封装函数,以解决入图问题。
- 方案1(推荐):
脚本中导入patch_for_hcom包,并在图执行之前调用patch_for_hcom函数。该函数封装了一系列NPU通信算子,用于替换PyTorch原生集合通信API,具体参见表1。
1 2 3 4 5 6
import torch_npu import torchair # 1.导入patch_for_hcom包 from torchair import patch_for_hcom # 2.在图执行之前任意位置调用patch方法 patch_for_hcom()
表1 集合通信算子列表 PyTorch集合通信API
自定义的NPU通信算子
torch.distributed.all_reduce
torch.ops.c10d_functional.all_reduce
说明:本算子入图能力由PyTorch社区提供,其余通信算子入图能力均由TorchAir提供。
torch.distributed.all_gather
torch.ops.npu_define.allgather
torch.distributed.all_to_all
torch.ops.npu_define.all_to_all
torch.distributed.broadcast
torch.ops.npu_define.broadcast
torch.distributed.all_to_all_single
torch.ops.npu_define.all_to_all_single
torch.distributed.all_gather_into_tensor
torch.ops.npu_define.allgather_in_tensor
- 方案2:
本方案在Ascend Extension for PyTorch的6.0.RC3版本及后续版本不再演进,可能会废弃,建议优先使用方案1(推荐)。
脚本中导入patch_for_hcom_allreduce包,该包将PyTorch原生的torch.distributed.allreduce API替换成torch.distributed._functional_collectives.all_reduce算子,以实现入图。
1 2 3
# 导入补丁包 import torch_npu import torchair.ge_concrete_graph.ge_converter.experimental patch_for_hcom_allreduce