torch_npu.npu_all_gather_base_mm

功能说明

TP切分场景下,实现allgather和matmul的融合,融合算子内部实现通信和计算流水并行。

使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。

接口原型

npu_all_gather_base_mm(Tensor input, Tensor x2, str hcom, int world_size, *, Tensor? bias=None, int gather_index=0, bool gather_output=True, int comm_turn=0) -> (Tensor, Tensor)

参数说明

输出说明

两个输出,均为Tensor类型:(Tensor, Tensor)

约束说明

支持的型号

Atlas A2 训练系列产品

调用示例