下载
中文
注册

torch_npu.npu_mm_reduce_scatter_base

功能描述

TP切分场景下,实现matmul和reduce_scatter的融合,融合算子内部实现计算和通信流水并行。

使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。

接口原型

npu_mm_reduce_scatter_base(Tensor self, Tensor x2, str hcom, int world_size, *, str reduce_op='sum', Tensor? bias=None, int comm_turn=0) -> Tensor

参数说明

  • self:Device侧的Tensor类型,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND,输入shape支持2维。
  • x2:Device侧的Tensor类型,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND,数据类型需要和self保持一致,输入shape维度和self保持一致。
  • hcom:Host侧的String类型,通信域handle名,通过get_hccl_comm_name接口获取。
  • world_size:Host侧的int类型,通信域内的rank总数,仅支持为2、4、8。
  • *:代表其之前的变量是位置相关,按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
  • reduce_op:Host侧的String类型,reduce操作类型,当前仅支持'sum',默认值:'sum'。
  • bias:Device侧的Tensor类型,可选输入,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND格式。数据类型需要和self保持一致。bias仅支持一维,且维度大小与output的第1维大小相同。当前版本暂不支持bias输入。
  • comm_turn:Host侧的int类型,表示rank间通信切分粒度,默认值:0,表示默认的切分方式。当前版本仅支持输入0。

输出说明

Tensor类型,数据类型和self保持一致,shape维度和self保持一致。

约束说明

  • 输入self、x2必须是2维,分别为(m, k),(k, n),轴满足matmul算子入参要求,m轴需要整除world_size,k轴相等,且k轴取值范围为[256, 65535)。
  • x1不支持输入转置后的tensor,x2转置后输入,需要满足shape的第一维大小与x1的最后一维相同,满足matmul的计算条件。

支持的PyTorch版本

  • PyTorch 2.1
  • PyTorch 1.11.0

支持的型号

Atlas A2 训练系列产品

调用示例

import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
def run_mm_reduce_scatter_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
    torch_npu.npu.set_device(rank)
    init_method = 'tcp://' + master_ip + ':' + master_port
    dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
    from torch.distributed.distributed_c10d import _get_default_group
    default_pg = _get_default_group()
    if torch.__version__ > '2.0.1':
        hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
    else:
        hcomm_info = default_pg.get_hccl_comm_name(rank)

    input_ = torch.randn(x1_shape, dtype=dtype).npu()
    weight = torch.randn(x2_shape, dtype=dtype).npu()
    output = torch_npu.npu_mm_reduce_scatter_base(input_, weight, hcomm_info, world_size)

if __name__ == "__main__":
    worksize = 8
    master_ip = '127.0.0.1'
    master_port = '50001'
    x1_shape = [128, 512]
    x2_shape = [512, 64]
    dtype = torch.float16

    mp.spawn(run_mm_reduce_scatter_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)