torch_npu.npu_mm_all_reduce_base

功能描述

TP切分场景下,实现mm和all_reduce的融合,融合算子内部实现计算和通信流水并行。

使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。

接口原型

npu_mm_all_reduce_base(Tensor x1, Tensor x2, str hcom, *, str reduce_op='sum', Tensor? bias=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? x3=None, Tensor? dequant_scale=None Tensor? pertoken_scale=None, Tensor? comm_quant_scale_1=None, Tensor? comm_quant_scale_2=None, int comm_turn=0, int antiquant_group_size=0) -> Tensor

参数说明

输出说明

Tensor类型,数据类型非量化场景以及伪量化场景与x1保持一致,全量化场景输出数据类型为FLOAT16或BFLOAT16。shape第0维度和x1的0维保持一致,若x1为2维,shape第1维度和x2的1维保持一致,若x1为3维,shape第1维度和x1的1维保持一致,shape第2维度和x2的1维保持一致。

约束说明

支持的型号

调用示例