SetUserDefInfo
功能说明
使能模板参数MatmulCallBack(自定义回调函数)时,设置算子tiling地址,用于回调函数使用,该接口仅需调用一次。
函数原型
1 | __aicore__ inline void SetUserDefInfo(const uint64_t tilingPtr) |
参数说明
参数名 |
输入/输出 |
描述 |
---|---|---|
tilingPtr |
输入 |
设置的算子tiling地址。 |
返回值
无
支持的型号
Atlas A2训练系列产品/Atlas 800I A2推理产品
注意事项
若回调函数中需要使用tilingPtr参数时,必须调用此接口;若回调函数不使用tilingPtr参数,无需调用此接口。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | //用户自定义回调函数 void DataCopyOut(const __gm__ void *gm, const LocalTensor<int8_t> &co1Local, const void *dataCopyOutParams, const uint64_t tilingPtr, const uint64_t dataPtr); void CopyA1(const LocalTensor<int8_t> &aMatrix, const __gm__ void *gm, int row, int col, int useM, int useK, const uint64_t tilingPtr, const uint64_t dataPtr); void CopyB1(const LocalTensor<int8_t> &bMatrix, const __gm__ void *gm, int row, int col, int useK, int useN, const uint64_t tilingPtr, const uint64_t dataPtr); typedef matmul::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half> aType; typedef matmul::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half> bType; typedef matmul::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float> cType; typedef matmul::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float> biasType; Matmul<aType, bType, cType, biasType, CFG_NORM, MatmulCallBackFunc<DataCopyOut, CopyA1, CopyB1>> mm; REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm, &tiling); uint64_t tilingPtr = reinterpret_cast<uint64_t>(tiling); mm.SetUserDefInfo(tilingPtr); mm.SetTensorA(gmA); mm.SetTensorB(gmB); mm.IterateAll(); |
父主题: Matmul