v1版本TilingData
功能说明
AI CPU启动下发通信任务前,需获取固定的通信配置Mc2Msg。在算子实现中,由Tiling组装通信配置项,通过配置固定参数和固定参数顺序的Tiling Data,将通信配置信息在调用AI CPU通信接口时传递给AI CPU。
参数说明
参数名 |
描述 |
---|---|
preparePosition |
设置服务端组装任务的方式,用户需要在Tiling中显示赋值,uint32_t类型,当前支持的取值如下: 1:AI CPU与AI Core通过通信任务机制实现消息传递和任务下发;由AI Core侧通过消息通知时设置为1,即算子中使用Hccl时设置为1。 |
sendOff |
预留参数,不可配置。 |
recvOff |
预留参数,不可配置。 |
tailSendOff |
预留参数,不可配置。 |
tailRecvOff |
预留参数,不可配置。 |
sendCnt |
预留参数,不可配置。 |
recvCnt |
预留参数,不可配置。 |
tailSendCnt |
预留参数,不可配置。 |
tailRecvCnt |
预留参数,不可配置。 |
totalCnt |
预留参数,不可配置。 |
turnNum |
预留参数,不可配置。 |
tailNum |
预留参数,不可配置。 |
stride |
预留参数,不可配置。 |
workspaceOff |
预留参数,不可配置。 |
notifyOff |
预留参数,不可配置。 |
notifyBeginCnt |
预留参数,不可配置。 |
notifyEndCnt |
预留参数,不可配置。 |
useBufferType |
设置通信算法获取输入数据的位置,uint8_t类型,参数取值如下:
|
funID |
预留参数,不可配置。 |
dataType |
预留参数,不可配置。 |
groupNum |
预留参数,不可配置。 |
reuseMode |
预留参数,不可配置。 |
commType |
预留参数,不可配置。 |
reduceOp |
预留参数,不可配置。 |
commOrder |
预留参数,不可配置。 |
waitPolicy |
预留参数,不可配置。 |
rspPolicy |
预留参数,不可配置。 |
exitPolicy |
预留参数,不可配置。 |
commAlg |
设置具体通信算法,用户需要在Tiling中显示赋值,uint8_t类型,当前支持的取值如下: 1:FullMesh算法,即NPU之间的全连接,任意两个NPU之间可以直接进行数据收发。详细的算法内容可参见集合通信算法。 |
taskType |
预留参数,不可配置。 |
debugMode |
预留参数,不可配置。 |
stepSize |
预留参数,不可配置。 |
sendArgIndex |
预留参数,不可配置。 |
recvArgIndex |
预留参数,不可配置。 |
commOutArgIndex |
预留参数,不可配置。 |
hasCommOut |
本卡的通信算法的计算结果是否输出到recvBuf(目的数据buffer地址)。仅AllGather算法与AlltoAll算法支持配置该参数。uint8_t类型,参数取值如下:
|
reserve |
保留字段。 |
reserve2 |
保留字段。 |
支持的型号
Atlas A2训练系列产品/Atlas 800I A2推理产品
注意事项
- 算子的Tiling Data结构需要按顺序完整包含Mc2Msg参数。
- AI CPU需获取固定数据结构的通信配置,算子注册Tiling Data时保持该结构的一致性。
调用示例
以自定义算子AllGatherMatmulCustom为例,如下为该算子的算子原型,"gather_out"为通信任务AllGather的输出。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
[ { "op": "AllGatherMatmulCustom", "input_desc": [ { "name": "x1", "param_type": "required", "format": [ "ND", "ND" ], "type": [ "float16", "bfloat16" ] }, { "name": "x2", "param_type": "required", "format": [ "ND", "ND" ], "type": [ "float16", "bfloat16" ] }, { "name": "bias", "param_type": "optional", "format": [ "ND", "ND" ], "type": [ "float16", "bfloat16" ] } ], "output_desc":[ { "name": "y", "param_type": "required", "format": [ "ND", "ND" ], "type": [ "float16", "bfloat16" ] }, { "name": "gather_out", "param_type": "required", "format": [ "ND", "ND" ], "type": [ "float16", "bfloat16" ] } ], "attr": [ { "name": "group", "type": "string", "default_value":"", "param_type":"required" }, { "name": "rank_size", "type": "int", "default_value":0, "param_type":"optional" }, { "name": "is_gather_out", "type": "bool", "default_value":true, "param_type":"optional" } ] } ] |
算子的Tiling Data结构需要按顺序完整包含Mc2Msg参数,如下为算子Tiling Data代码示例。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
// 声明Mc2Msg结构 BEGIN_TILING_DATA_DEF(Mc2Msg) TILING_DATA_FIELD_DEF(uint32_t, preparePosition); TILING_DATA_FIELD_DEF(uint32_t, sendOff); TILING_DATA_FIELD_DEF(uint32_t, recvOff); TILING_DATA_FIELD_DEF(uint32_t, tailSendOff); TILING_DATA_FIELD_DEF(uint32_t, tailRecvOff); TILING_DATA_FIELD_DEF(uint64_t, sendCnt); TILING_DATA_FIELD_DEF(uint32_t, recvCnt); TILING_DATA_FIELD_DEF(uint32_t, tailSendCnt); TILING_DATA_FIELD_DEF(uint32_t, tailRecvCnt); TILING_DATA_FIELD_DEF(uint32_t, totalCnt); TILING_DATA_FIELD_DEF(uint32_t, turnNum); TILING_DATA_FIELD_DEF(uint32_t, tailNum); TILING_DATA_FIELD_DEF(uint32_t, stride); TILING_DATA_FIELD_DEF(uint32_t, workspaceOff); TILING_DATA_FIELD_DEF(uint32_t, notifyOff); TILING_DATA_FIELD_DEF(uint16_t, notifyBeginCnt); TILING_DATA_FIELD_DEF(uint16_t, notifyEndCnt); TILING_DATA_FIELD_DEF(uint8_t, useBufferType); TILING_DATA_FIELD_DEF(uint8_t, funID); TILING_DATA_FIELD_DEF(uint8_t, dataType); TILING_DATA_FIELD_DEF(uint8_t, groupNum); TILING_DATA_FIELD_DEF(uint8_t, reuseMode); TILING_DATA_FIELD_DEF(uint8_t, commType); TILING_DATA_FIELD_DEF(uint8_t, reduceOp); TILING_DATA_FIELD_DEF(uint8_t, commOrder); TILING_DATA_FIELD_DEF(uint8_t, waitPolicy); TILING_DATA_FIELD_DEF(uint8_t, rspPolicy); TILING_DATA_FIELD_DEF(uint8_t, exitPolicy); TILING_DATA_FIELD_DEF(uint8_t, commAlg); TILING_DATA_FIELD_DEF(uint8_t, taskType); TILING_DATA_FIELD_DEF(uint8_t, debugMode); TILING_DATA_FIELD_DEF(uint8_t, stepSize); TILING_DATA_FIELD_DEF(uint8_t, sendArgIndex); TILING_DATA_FIELD_DEF(uint8_t, recvArgIndex); TILING_DATA_FIELD_DEF(uint8_t, commOutArgIndex); TILING_DATA_FIELD_DEF(uint8_t, hasCommOut); TILING_DATA_FIELD_DEF(uint8_t, reserve); TILING_DATA_FIELD_DEF(uint32_t, reserve2); END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(Mc2MsgOp, Mc2Msg) BEGIN_TILING_DATA_DEF(AllGatherMatmulCustomTilingData) TILING_DATA_FIELD_DEF_STRUCT(Mc2Msg, msg); END_TILING_DATA_DEF; |
1 2 3 4 5 6 |
// 配置Mc2Msg AllGatherMatmulCustomTilingData tiling; tiling.msg.set_preparePosition(1); tiling.msg.set_commAlg(1); tiling.msg.set_useBufferType(1); tiling.msg.set_hasCommOut(1); |