文档
注册

npu.distribute.broadcast

函数原型

npu.distribute.broadcast(values, root_rank, fusion=2, fusion_id=0, group="hccl_world_group")

功能说明

用于NPU分布式部署场景下,worker间的变量同步。

参数说明

参数名

输入/输出

描述

values

输入

单个tensorflow的Variable或者Variable的集合。

tensor支持的数据类型为int8, int32, float16, float32, int64, uint64。

root_rank

输入

int类型。

作为root节点的rank_id,该id是group内的rank id。默认为0。

fusion

输入

int类型。

broadcast算子融合标识。

  • 0:不融合,该broadcast算子不和其他broadcast算子融合。
  • 2:按照相同fusion_id进行融合。默认为2。
  • 其他值非法。

fusion_id

输入

broadcast算子的融合id。

对相同fusion_id的broadcast算子进行融合。

group

输入

String类型,最大长度为128字节,含结束符。

group名称,可以为用户自定义group或者"hccl_world_group"。

返回值

无。

调用示例

将0卡上的变量广播到其他卡:

# rank_id = 0  rank_size = 8
import npu_device as npu
x = tf.Variable(tf.random.normal(shape=()))
print("before broadcast", x)
npu.distribute.broadcast(x, root_rank=0)
print("after_broadcast", x)

广播前:

广播后:

搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词