文档
注册

Estimator模式下使能训练迭代循环下沉

自动迁移场景

  1. 在迁移后的脚本中查找“npu_run_config_init”,找到运行配置参数(例如示例中的“run_config”),在运行配置函数中传入session_config参数,并在session_config参数中添加“iterations_per_loop”配置。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    session_config = tf.ConfigProto(allow_soft_placement=True)
    custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add()
    custom_op.name = 'NpuOptimizer'
    custom_op.parameter_map["enable_data_pre_proc"].b = True # getnext算子下沉是迭代循环下沉的必要条件
    custom_op.parameter_map["iterations_per_loop"].i = 10
    
    run_config = tf.estimator.RunConfig(
        train_distribute=distribution_strategy,
        session_config=session_config,       # 运行配置参数中添加session_config配置
        save_checkpoints_secs=60*60*24)
    
    classifier = tf.estimator.Estimator(
        model_fn=model_function, model_dir=flags_obj.model_dir, config=npu_run_config_init(run_config=run_config))
    
  2. 增加“SetIterationsVarHook”:
    1
    2
    3
    4
    5
    train_hooks = hooks_helper.get_train_hooks(
        flags_obj.hooks,
        model_dir=flags_obj.model_dir,
        batch_size=flags_obj.batch_size)
    train_hooks.append(SetIterationsVarHook(10))
    
  3. 在train_op中增加“IterationOp”:
    1
    2
    train_op = opt.apply_gradients( grad_var_list, global_step = global_step )
    train_op = tf.group(train_op, name="IterationOp")   #该name设置到梯度更新返回的op
    

手工迁移场景

Estimator模式下,通过NPURunConfig中的iterations_per_loop参数配置,配置方法如下:

1
2
3
4
from npu_bridge.npu_init import *

session_config=tf.ConfigProto(allow_soft_placement=True)
config = NPURunConfig(session_config=session_config, iterations_per_loop=10)  

同时需要使能getnext算子下沉,getnext算子下沉是迭代循环下沉的必要条件。Estimator模式下getnext算子默认下沉,即enable_data_pre_proc默认为True,可不用手工配置。

检查iterations_per_loop生效

开启“训练迭代循环下沉”功能后,可通过查看Host侧INFO日志中是否存在关键字“Insert op success”来判断iterations_per_loop是否生效。

可通过如下命令设置Host侧日志级别为INFO,INFO日志的默认输出路径为“$HOME/ascend/log/run/plog/”。

export ASCEND_GLOBAL_LOG_LEVEL=1
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词