下载
中文
注册

Tiling Data

功能说明

AI Cpu启动下发通信任务前,需获取固定的通信配置。在算子实现中,由Tiling组装通信配置项,通过配置固定参数和固定参数顺序的Tiling Data,将通信配置信息在调用AI Cpu通信接口时传递给AI Cpu。

参数说明

表1 参数说明

参数名

描述

preparePosition

设置服务端组装任务的方式,参数取值如下:

  • 0:AI Cpu一次性获取任务并下发;AI Cpu通信任务由host下发时设置为0。
  • 1:AI Cpu与AI Core通过通信任务机制实现消息传递和任务下发;由AI Core侧通过消息通知时设置为1,即算子中使用Hccl时设置为1。

sendOff

预留参数,无需配置。Host下发通信任务使用。

recvOff

预留参数,无需配置。Host下发通信任务使用。

tailSendOff

预留参数,无需配置。Host下发通信任务使用。

tailRecvOff

预留参数,无需配置。Host下发通信任务使用。

sendCnt

预留参数,无需配置。Host下发通信任务使用。

recvCnt

预留参数,无需配置。Host下发通信任务使用。

tailSendCnt

预留参数,无需配置。Host下发通信任务使用。

tailRecvCnt

预留参数,无需配置。Host下发通信任务使用。

totalCnt

预留参数,无需配置。Host下发通信任务使用。

turnNum

预留参数,无需配置。Host下发通信任务使用。

tailNum

预留参数,无需配置。Host下发通信任务使用。

stride

预留参数,无需配置。Host下发通信任务使用。

workspaceOff

预留参数,无需配置。Host下发通信任务使用。

notifyOff

预留参数,无需配置。Host下发通信任务使用。

notifyBeginCnt

预留参数,无需配置。Host下发通信任务使用。

notifyEndCnt

预留参数,无需配置。Host下发通信任务使用。

useBufferType

预留参数,无需配置。Host下发通信任务使用。

funID

预留参数,无需配置。Host下发通信任务使用。

dataType

预留参数,无需配置。Host下发通信任务使用。

groupNum

预留参数,无需配置。Host下发通信任务使用。

reuseMode

预留参数,无需配置。Host下发通信任务使用。

commType

预留参数,无需配置。Host下发通信任务使用。

reduceOp

预留参数,无需配置。Host下发通信任务使用。

commOrder

预留参数,无需配置。Host下发通信任务使用。

waitPolicy

预留参数,无需配置。Host下发通信任务使用。

rspPolicy

预留参数,无需配置。Host下发通信任务使用。

exitPolicy

预留参数,无需配置。Host下发通信任务使用。

commAlg

设置具体通信算法,参数取值如下:

  • 1:FullMesh算法,即NPU之间的全连接,任意两个NPU之间可以直接进行数据收发。详细的算法内容可参见集合通信用户指南>集合通信算法。

taskType

预留参数,无需配置。Host下发通信任务使用。

debugMode

预留参数,无需配置。Host下发通信任务使用。

stepSize

预留参数,无需配置。Host下发通信任务使用。

sendArgIndex

算子中第一个输入的参数索引。

recvArgIndex

算子中第一个输出的参数索引。

commOutArgIndex

通信任务的输出在算子原型中的参数索引。

hasCommOut

预留参数,无需配置。Host下发通信任务使用。

reserve

保留字段。

reserve2

保留字段。

支持的型号

Atlas A2训练系列产品/Atlas 800I A2推理产品

注意事项

AI Cpu需获取固定数据结构的通信配置,算子注册Tiling Data时保持该结构的一致性。

调用示例

以自定义算子AllGatherMatmulCustom为例,如下为该算子的算子原型,"gather_out"为通信任务AllGather的输出。

[
    {
        "op": "AllGatherMatmulCustom",
        "input_dsec": [
            {
                "name": "x1",
                "param_type": "required",
                "format": [
                    "ND"
                ],
                "type": [
                    "float16",
                    "bfloat16"
                ]
            },
            {
                "name": "x2",
                "param_type": "required",
                "format": [
                    "ND"
                ],
                "type": [
                    "float16",
                    "bfloat16"
                ]
            },
            {
                "name": "bias",
                "param_type": "optional",
                "format": [
                    "ND"
                ],
                "type": [
                    "float16",
                    "bfloat16"
                ]
            }
        ],
        "output_desc":[
            {
                "name": "y",
                "param_type": "required",
                "format": [
                    "ND"
                ],
                "type": [
                    "float16",
                    "bfloat16"
                ]
            },
            {
                "name": "gather_out",
                "param_type": "required",
                "format": [
                    "ND"
                ],
                "type": [
                    "float16",
                    "bfloat16"
                ]
            }
        ],
        "attr": [
            {
                "name": "group",
                "dtype": "string",
                "default_value":"",
                "param_type":"required"
            },
            {
                "name": "is_trans_a",
                "dtype": "bool",
                "default_value":false,
                "param_type":"optional"
            },
            {
                "name": "is_trans_b",
                "dtype": "bool",
                "default_value":false,
                "param_type":"optional"
            },
            {
                "name": "gather_index",
                "dtype": "int",
                "default_value":0,
                "param_type":"optional"
            },
            {
                "name": "comm_turn",
                "dtype": "int",
                "default_value":0,
                "param_type":"optional"
            },
            {
                "name": "rank_size",
                "dtype": "int",
                "default_value":0,
                "param_type":"optional"
            },
            {
                "name": "is_gather_out",
                "dtype": "bool",
                "default_value":true,
                "param_type":"optional"
            }
        ]
    }
]
// 声明Mc2Msg结构
BEGIN_TILING_DATA_DEF(Mc2Msg)
    TILING_DATA_FIELD_DEF(uint32_t, preparePosition);
    TILING_DATA_FIELD_DEF(uint32_t, sendOff); 
    TILING_DATA_FIELD_DEF(uint32_t, recvOff);
    TILING_DATA_FIELD_DEF(uint32_t, tailSendOff);
    TILING_DATA_FIELD_DEF(uint32_t, tailRecvOff);
    TILING_DATA_FIELD_DEF(uint64_t, sendCnt);
    TILING_DATA_FIELD_DEF(uint32_t, recvCnt);
    TILING_DATA_FIELD_DEF(uint32_t, tailSendCnt);
    TILING_DATA_FIELD_DEF(uint32_t, tailRecvCnt);
    TILING_DATA_FIELD_DEF(uint32_t, totalCnt);
    TILING_DATA_FIELD_DEF(uint32_t, turnNum);
    TILING_DATA_FIELD_DEF(uint32_t, tailNum);
    TILING_DATA_FIELD_DEF(uint32_t, stride);
    TILING_DATA_FIELD_DEF(uint32_t, workspaceOff);
    TILING_DATA_FIELD_DEF(uint32_t, notifyOff);
    TILING_DATA_FIELD_DEF(uint16_t, notifyBeginCnt);
    TILING_DATA_FIELD_DEF(uint16_t, notifyEndCnt);
    TILING_DATA_FIELD_DEF(uint8_t, useBufferType);
    TILING_DATA_FIELD_DEF(uint8_t, funID);
    TILING_DATA_FIELD_DEF(uint8_t, dataType);
    TILING_DATA_FIELD_DEF(uint8_t, groupNum);
    TILING_DATA_FIELD_DEF(uint8_t, reuseMode);
    TILING_DATA_FIELD_DEF(uint8_t, commType);
    TILING_DATA_FIELD_DEF(uint8_t, reduceOp);
    TILING_DATA_FIELD_DEF(uint8_t, commOrder);
    TILING_DATA_FIELD_DEF(uint8_t, waitPolicy);
    TILING_DATA_FIELD_DEF(uint8_t, rspPolicy);
    TILING_DATA_FIELD_DEF(uint8_t, exitPolicy);
    TILING_DATA_FIELD_DEF(uint8_t, commAlg);
    TILING_DATA_FIELD_DEF(uint8_t, taskType);
    TILING_DATA_FIELD_DEF(uint8_t, debugMode);
    TILING_DATA_FIELD_DEF(uint8_t, stepSize);
    TILING_DATA_FIELD_DEF(uint8_t, sendArgIndex);
    TILING_DATA_FIELD_DEF(uint8_t, recvArgIndex);
    TILING_DATA_FIELD_DEF(uint8_t, commOutArgIndex);
    TILING_DATA_FIELD_DEF(uint8_t, hasCommOut);
    TILING_DATA_FIELD_DEF(uint8_t, reserve);
    TILING_DATA_FIELD_DEF(uint32_t, reserve2);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(Mc2MsgOp, Mc2Msg)

BEGIN_TILING_DATA_DEF(AllGatherMatmulCustomTilingData)
    TILING_DATA_FIELD_DEF_STRUCT(Mc2Msg, msg);
END_TILING_DATA_DEF;

// 配置Mc2Msg 
AllGatherMatmulCustomTilingData tiling;
tiling.msg.set_preparePosition(1);
tiling.msg.set_commAlg(1);
tiling.msg.set_sendArgIndex(0);    // 设置算子原型中第一个输入数据的参数索引
tiling.msg.set_recvArgIndex (3);   // 设置算子原型中第一个输出数据的参数索引
tiling.msg.set_commOutArgIndex(4); // 设置算子原型中通信输出数据的参数索引