代码示例
该代码示例使用默认的全局通信域进行通信。
假设代码文件命名为hccl_test.py。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | import tensorflow as tf import sys import os import numpy as np import time import argparse from npu_bridge.npu_init import * def tensor_type(list1, type): tensor1=[] tensor1 = tf.Variable(list1, dtype=tf.int64) return tensor1 def numpy_type(type): input_type = np.int64 return input_type def hccl_operator(rank_id, root_rank, rank_size, group, dtype, data): tensors={} # allreduce list_1=['sum','max','min','prod'] for i in range(len(list_1)): exec('list_1=["sum","max","min","prod"]') exec('element_list'+str(i)+'=[1 for i in range(data)]') exec('tensor_'+str(i)+'= tensor_type(element_list'+str(i)+', dtype)') exec('tensor_tmp'+str(i)+'= tf.add(tensor_'+str(i)+', rank_id + 1)') exec('new_tensor'+str(i)+'= tf.reshape(tensor_tmp'+str(i)+', [rank_size, -1])') exec('tensors[\'allreduce_'+list_1[i]+'\'] = hccl_ops.allreduce(new_tensor'+str(i)+','+'\"'+list_1[i]+'\"'+', group=group)') # broadcast exec('list_test = np.ones((1,data))') exec('tensor_test = tensor_type(list_test, dtype)') exec('tensor_z = tf.add(tensor_test, rank_id + 1)') exec('new_tensor10 = tf.reshape(tensor_z, [rank_size, -1])') exec('test_list1=[new_tensor10]') exec('tensors[\'broadcast\'] = hccl_ops.broadcast(test_list1, root_rank, group=group)') # allgather exec('tensors[\'gather_tensor\'] = hccl_ops.allgather(new_tensor'+str(1)+', rank_size, group=group)') # reducescatter for i in range(len(list_1)): exec('list_1=["sum","max","min","prod"]') exec('element_list'+str(i+5)+'=[1 for i in range(data)]') exec('tensor_'+str(i+5)+'= tensor_type(element_list'+str(i+5)+', dtype)') exec('tensor_tmp'+str(i+5)+'= tf.add(tensor_'+str(i+5)+', rank_id + 1)') exec('new_tensor'+str(i+5)+'= tf.reshape(tensor_tmp'+str(i+5)+', [rank_size, -1])') exec('tensors[\'reducescatter_'+list_1[i]+'\'] = hccl_ops.reduce_scatter(new_tensor'+str(i+5)+','+'\"'+list_1[i]+'\"'+', '+'rank_size, group=group)') # reduce for i in range(len(list_1)): exec('list_1=["sum","max","min","prod"]') exec('element_list'+str(i+10)+'=[1 for i in range(data)]') exec('tensor_'+str(i+10)+'= tensor_type(element_list'+str(i+10)+', dtype)') exec('tensor_tmp'+str(i+10)+'= tf.add(tensor_'+str(i+10)+', rank_id + 1)') exec('new_tensor'+str(i+10)+'= tf.reshape(tensor_tmp'+str(i+10)+', [rank_size, -1])') exec('tensors[\'reduce_'+list_1[i]+'\'] = hccl_ops.reduce(new_tensor'+str(i+10)+','+'\"'+list_1[i]+'\"'+', '+'root_rank, group=group)') input_type = numpy_type(dtype) data1_shape = data*rank_size + (rank_size-1)*rank_size data1_ = np.arange(1,data1_shape+1).astype(input_type) check_data_shape = (data + rank_id) * rank_size check_data_ = np.arange(1,check_data_shape+1).astype(input_type) send_data = tf.Variable(data1_) check_data = tf.Variable(check_data_) send_counts_list = [data+i for i in range(rank_size)] send_counts = tf.constant(send_counts_list,dtype=tf.int64) send_displacements = tf.constant([rank_id*(data+i) for i in range(rank_size)],dtype=tf.int64) # 静态shape recv_counts和recv_displacements必须使用tf.constant recv_counts = tf.constant([rank_id+data for _ in range(rank_size)],dtype=tf.int64) recv_displacements = tf.constant([(rank_id+data)*i for i in range(rank_size)],dtype=tf.int64) all_to_all_v = hccl_ops.all_to_all_v(send_data,send_counts,send_displacements,recv_counts,recv_displacements,group=group) tensors['alltoallv_tensor'] = all_to_all_v tensors['check_tensors'] = check_data return tensors def main(): config = {} hccl_session_config = tf.ConfigProto() custom_op = hccl_session_config.graph_options.rewrite_options.custom_optimizers.add() custom_op.name = "NpuOptimizer" custom_op.parameter_map["use_off_line"].b = True npu_init = npu_ops.initialize_system() npu_shutdown = npu_ops.shutdown_system() with tf.Session(config=hccl_session_config) as sess: # 进行集合通信初始化 sess.run(npu_init) # 获取group内rank数量 config['rank_size'] = get_rank_size() # 获取device在group中对应的rank序号 config['rank_id'] = get_rank_id() try: # 下发集合通信算子 tensors = hccl_operator(config['rank_id'], 0, config['rank_size'], "hccl_world_group", "float32", 1024) # tf框架全局变量初始化 init_var = tf.global_variables_initializer() sess.run(init_var) # 执行训练,此处仅为示例 v = sess.run(tensors) tf.logging.info(v) except Exception as e: print('ERROR : %s' % e) print('train fail') else: print('train success') # 关闭session sess.run(npu_shutdown) if __name__ == '__main__': # 开启日志记录 tf.logging.set_verbosity(tf.logging.INFO) # 执行main函数 main() |
父主题: 样例代码