HcclCommInitRootInfo初始化方式
该样例支持单机N卡的组网,N需要小于等于8。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
#include <iostream> #include <vector> #include <memory> #include <thread> #include <chrono> #include "hccl/hccl.h" #include "hccl/hccl_types.h" #include "mpi.h" #define ACLCHECK(ret) do {\ if(ret != ACL_SUCCESS)\ {\ printf("acl interface return err %s:%d, retcode: %d \n", __FILE__, __LINE__, ret);\ return ret;\ }\ } while(0)\ #define HCCLCHECK(ret) do {\ if(ret != HCCL_SUCCESS)\ {\ printf("hccl interface return errreturn err %s:%d, retcode: %d \n", __FILE__, __LINE__, ret);\ return ret;\ }\ } while(0) struct ThreadContext { HcclComm comm; int32_t device; }; int Sample(void *arg) { ThreadContext* ctx = (ThreadContext *)arg; void* host_buf = nullptr; void* send_buff = nullptr; void* recv_buff = nullptr; uint64_t count = 1; int malloc_kSize = count * sizeof(float); aclrtEvent start_event, end_event; aclrtStream stream; ACLCHECK(aclrtCreateStream(&stream)); ACLCHECK(aclrtCreateEvent(&start_event)); ACLCHECK(aclrtCreateEvent(&end_event)); //申请集合通信操作的内存 ACLCHECK(aclrtMalloc((void**)&send_buff, malloc_kSize, ACL_MEM_MALLOC_HUGE_FIRST)); ACLCHECK(aclrtMalloc((void**)&recv_buff, malloc_kSize, ACL_MEM_MALLOC_HUGE_FIRST)); //初始化输入内存 ACLCHECK(aclrtMallocHost((void**)&host_buf, malloc_kSize)); ACLCHECK(aclrtMemcpy((void*)send_buff, malloc_kSize, (void*)host_buf, malloc_kSize, ACL_MEMCPY_HOST_TO_DEVICE)); //执行集合通信操作 HCCLCHECK(HcclAllReduce((void *)send_buff, (void*)recv_buff, count, HCCL_DATA_TYPE_FP32, HCCL_REDUCE_SUM, ctx->comm, stream)); //等待stream中集合通信任务执行完成 ACLCHECK(aclrtSynchronizeStream(stream)); if (ctx->device < 8) { void* resultBuff; ACLCHECK(aclrtMallocHost((void**)&resultBuff, malloc_kSize)); ACLCHECK(aclrtMemcpy((void*)resultBuff, malloc_kSize, (void*)recv_buff, malloc_kSize, ACL_MEMCPY_DEVICE_TO_HOST)); float* tmpResBuff = static_cast<float*>(resultBuff); for (uint32_t i = 0; i < count; ++i) { std::cout << "rankId:" << ctx->device << ",i" << i << " " << tmpResBuff[i] << std::endl; } ACLCHECK(aclrtFreeHost(resultBuff)); } ACLCHECK(aclrtFree(send_buff)); ACLCHECK(aclrtFree(recv_buff)); ACLCHECK(aclrtFreeHost(host_buf)); //销毁任务流 ACLCHECK(aclrtDestroyStream(stream)); ACLCHECK(aclrtDestroyEvent(start_event)); ACLCHECK(aclrtDestroyEvent(end_event)); return 0; } int main(int argc, char*argv[]) { MPI_Init(&argc, &argv); int procSize = 0; int procRank = 0; // 获取当前进程在所属进程组的编号 MPI_Comm_size(MPI_COMM_WORLD, &procSize); MPI_Comm_rank(MPI_COMM_WORLD, &procRank); int devId = procRank; int devCount = procSize; //设备资源初始化 ACLCHECK(aclInit(NULL)); // 指定集合通信操作使用的设备 ACLCHECK(aclrtSetDevice(devId)); // 在 rootRank 获取 rootInfo HcclRootInfo rootInfo; int32_t rootRank = 0; if(devId == rootRank) { HCCLCHECK(HcclGetRootInfo(&rootInfo)); } // 将root_info广播到通信域内的其他rank MPI_Bcast(&rootInfo, HCCL_ROOT_INFO_BYTES, MPI_CHAR, rootRank, MPI_COMM_WORLD); MPI_Barrier(MPI_COMM_WORLD); // 初始化集合通信域 HcclComm hcclComm; HCCLCHECK(HcclCommInitRootInfo(devCount, &rootInfo, devId, &hcclComm)); // 创建任务stream struct ThreadContext args; args.comm = hcclComm; args.device = devId; Sample((void *)&args); //销毁集合通信域 HCCLCHECK(HcclCommDestroy(hcclComm)); //重置设备 ACLCHECK(aclrtResetDevice(devId)); //设备去初始化 ACLCHECK(aclFinalize()); return 0; } |