下载
中文
注册

npu.distribute.broadcast

函数原型

npu.distribute.broadcast(values, root_rank, fusion=2, fusion_id=0, group="hccl_world_group")

功能说明

用于NPU分布式部署场景下,worker间的变量同步。

参数说明

参数名

输入/输出

描述

values

输入

单个TensorFlow的Variable或者Variable的集合。

针对Atlas 训练系列产品,tensor支持的数据类型为int8、int32、float16、float32、int64、uint64。

针对Atlas 推理系列产品,tensor支持的数据类型为int8、int16、int32、float16、float32、int64、uint64。

针对Atlas A2 训练系列产品,tensor支持的数据类型为int8、int32、float16、float32、int64、uint64、bfloat16。

root_rank

输入

int类型。

作为root节点的rank_id,该id是group内的rank id。默认为0。

fusion

输入

int类型。

broadcast算子融合标识。

  • 0:不融合,该broadcast算子不和其他broadcast算子融合。
  • 2:按照相同fusion_id进行融合。默认为2。
  • 其他值非法。

fusion_id

输入

broadcast算子的融合id。

对相同fusion_id的broadcast算子进行融合。

group

输入

String类型,最大长度为128字节,含结束符。

group名称,可以为用户自定义group或者"hccl_world_group"。

返回值

无。

调用示例

将0卡上的变量广播到其他卡:

1
2
3
4
5
6
# rank_id = 0  rank_size = 8
import npu_device as npu
x = tf.Variable(tf.random.normal(shape=()))
print("before broadcast", x)
npu.distribute.broadcast(x, root_rank=0)
print("after_broadcast", x)

广播前:

广播后: