torch_npu.npu_mm_reduce_scatter_base

功能描述

TP切分场景下,实现matmul和reduce_scatter的融合,融合算子内部实现计算和通信流水并行。

使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。

接口原型

npu_mm_reduce_scatter_base(Tensor input, Tensor x2, str hcom, int world_size, *, str reduce_op='sum', Tensor? bias=None, int comm_turn=0) -> Tensor

参数说明

输出说明

Tensor类型,数据类型和input保持一致,shape维度和input保持一致。

约束说明

支持的型号

Atlas A2 训练系列产品

调用示例