通算融合算子
相比于一般的计算或搬运类算子,通算融合算子是融合集合通信任务和计算任务的算子,在算子执行过程中,计算和通信任务可以部分流水并行,以便提升性能。典型的应用场景如Matmul计算+集合通信。通算融合类算子的实现,请参考并使用Hccl高阶API。关于更多集合通信的内容和相关概念请参考HCCL概述。
通算融合算子的开发过程与一般算子相同,但请注意,当前通算融合算子暂不支持Kernel直调,暂不支持入图(GE图)开发。
算子原型定义
相比于一般算子,通算融合算子在实现算子原型定义时,有如下约束:
- 必须定义一个表示算子通信域名称的属性。通信域是集合通信执行的上下文,管理对应的通信实体(例如一个NPU就是一个通信实体)和通信所需的资源。
- 必须通过MC2接口注册该算子为通算融合算子,并通过HcclGroup接口配置该算子的通信域名称。
以算子属性"group"为例,"group"为该算子的通信域名称,在算子原型定义中,调用方式如下:
this->Attr("group").AttrType(REQUIRED).String(); // "group"为通算融合算子的属性,表示通信域名称 ... this->MC2().HcclGroup("group"); // 将"group"配置为该算子的通信域
调用示例
以AllGatherMatmul自定义算子为例,算子原型定义如下。
namespace ops {
class AllGatherMatmulCustom : public OpDef {
public:
explicit AllGatherMatmulCustom(const char *name) : OpDef(name) {
this->Input("x1")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("x2")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND})
.IgnoreContiguous();
this->Input("bias")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("y")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("gather_out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("group").AttrType(REQUIRED).String();
this->Attr("isTransA").AttrType(OPTIONAL).Bool(false);
this->Attr("isTransB").AttrType(OPTIONAL).Bool(false);
this->Attr("gatherIndex").AttrType(OPTIONAL).Int(0);
this->Attr("commTurn").AttrType(OPTIONAL).Int(0);
this->Attr("rank_size").AttrType(OPTIONAL).Int(8);
this->Attr("is_gather_out").AttrType(OPTIONAL).Bool(true);
this->AICore().SetTiling(optiling::AllGatherMatmulCustomTilingFunc);
this->AICore().AddConfig("ascendxxx"); // ascendxxx请修改为对应的昇腾AI处理器型号。
this->MC2().HcclGroup("group");
}
};
父主题: 专题