下载
中文
注册

集合通信算子入图

功能简介

集合通信算子入图能够从整图层面进行通信与计算并行优化,原生PyTorch 2.1版本暂不支持集合通信算子入图,从PyTorch 2.2及更高版本开始陆续支持。

原生PyTorch的dynamo trace入图仅支持算子和python内置函数入图,普通python函数会被dynamo拆解,将算子记录到图中。早期PyTorch提供的集合通信能力以API形式呈现,无法直接将其转换入图,PyTorch社区新增了一系列torch.ops.c10_functional前缀的集合通信算子解决该问题。当用户使用torch.ops.c10_functional算子时,需手动修改模型脚本中的集合通信部分,该部分通常被Deepspeed等框架封装,因此社区在dynamo内部将集合通信API转换为torch.ops.c10_functional算子,实现了早期集合通信API调用脚本也能trace入图。然而,社区在dynamo内部将集合通信API转换为torch.ops.c10_functional算子的行为,在PyTorch 2.3版本才逐渐补齐相关能力,PyTorch 2.1版本处于不可用状态。

针对PyTorch 2.1版本不支持入图的集合通信算子,TorchAir提供了补丁方案(patch),封装了一系列NPU通信算子,实现集合通信算子入图。

使用方法

假设训练/推理脚本在PyTorch 2.1版本能以Eager模式正常运行,用户无需修改脚本,直接导入补丁包(Monkey-Patch)即可完成通信算子入图,方案如下:

分布式场景下,由于Deepspeed框架对原生allreduce API的封装使其无法入图,TorchAir可同时Patch Deepspeed框架中allreduce封装函数,以解决入图问题。

  • 方案1(推荐)

    脚本中导入patch_for_hcom包,并在图执行之前调用patch_for_hcom函数。该函数封装了一系列NPU通信算子,用于替换PyTorch原生集合通信API,具体参见表1

    1
    2
    3
    4
    5
    6
    import torch_npu
    import torchair
    # 1.导入patch_for_hcom包
    from torchair import patch_for_hcom
    # 2.在图执行之前任意位置调用patch方法
    patch_for_hcom()
    
    表1 集合通信算子列表

    PyTorch集合通信API

    自定义的NPU通信算子

    torch.distributed.all_reduce

    torch.ops.c10d_functional.all_reduce

    说明:

    本算子入图能力由PyTorch社区提供,其余通信算子入图能力均由TorchAir提供。

    torch.distributed.all_gather

    torch.ops.npu_define.allgather

    torch.distributed.all_to_all

    torch.ops.npu_define.all_to_all

    torch.distributed.broadcast

    torch.ops.npu_define.broadcast

    torch.distributed.all_to_all_single

    torch.ops.npu_define.all_to_all_single

    torch.distributed.all_gather_into_tensor

    torch.ops.npu_define.allgather_in_tensor

  • 方案2:

    本方案在Ascend Extension for PyTorch6.0.RC3版本及后续版本不再演进,可能会废弃,建议优先使用方案1(推荐)

    脚本中导入patch_for_hcom_allreduce包,该包将PyTorch原生的torch.distributed.allreduce API替换成torch.distributed._functional_collectives.all_reduce算子,以实现入图。

    1
    2
    3
    # 导入补丁包
    import torch_npu 
    import torchair.ge_concrete_graph.ge_converter.experimental patch_for_hcom_allreduce