npu.distribute.broadcast
功能说明
用于NPU分布式部署场景下,worker间的变量同步。
函数原型
npu.distribute.broadcast(values, root_rank, fusion=2, fusion_id=0, group="hccl_world_group")
参数说明
参数名 |
输入/输出 |
描述 |
---|---|---|
values |
输入 |
单个TensorFlow的Variable或者Variable的集合。 针对Atlas 训练系列产品,tensor支持的数据类型为int8、int32、float16、float32、int64、uint64。 针对Atlas 推理系列产品,tensor支持的数据类型为int8、int16、int32、float16、float32、int64、uint64。 针对Atlas A2 训练系列产品,tensor支持的数据类型为int8、int32、float16、float32、int64、uint64、bfloat16。 |
root_rank |
输入 |
int类型。 作为root节点的rank_id,该id是group内的rank id。默认为0。 |
fusion |
输入 |
int类型。 broadcast算子融合标识。
|
fusion_id |
输入 |
broadcast算子的融合id。 对相同fusion_id的broadcast算子进行融合。 |
group |
输入 |
String类型,最大长度为128字节,含结束符。 group名称,可以为用户自定义group或者"hccl_world_group"。 |
返回值
无。
调用示例
将0卡上的变量广播到其他卡:
1 2 3 4 5 6 |
# rank_id = 0 rank_size = 8 import npu_device as npu x = tf.Variable(tf.random.normal(shape=())) print("before broadcast", x) npu.distribute.broadcast(x, root_rank=0) print("after_broadcast", x) |
广播前:
广播后: