文档
注册

npu.distribute.shard_and_rebatch_dataset

函数原型

npu.distribute.shard_and_rebatch_dataset(dataset, global_bs)

功能说明

用于NPU分布式部署场景下,不同worker上数据集分片及batch大小调整。

参数说明

参数名

输入/输出

描述

dataset

输入

tensorflow的Dataset类型。

需要进行切分的数据集。

global_bs

输入

全局batch的大小。

返回值

返回一个2个元素的tuple对象,第一个元素为切分后的Dataset,第二个元素为每个worker应当处理的实际batch大小。

调用示例

import npu_device as npu
dataset, batch_size = npu.distribute.shard_and_rebatch_dataset(dataset, batch_size)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词