文档
注册

自动改图

推荐场景下客户希望零代码改动迁移至昇腾并应用mxRec框架,自动改图通过修改TensorFlow计算图的方式,使得客户训练脚本无需创建特征类(Feature Spec),无需调用嵌入“read emb key”算子的函数即可开始模型训练。

使用流程介绍

本节介绍如何使用自动改图模式进行训练,整体操作流程请参见图1

图1 自动改图使用流程

正常训练流程一般包含处理数据、创建稀疏表、查表、开始训练,使用自动改图模式和正常训练流程一致,仅需在查表接口设置“modify_graph”参数为“True”,并且在开始训练之前需要调用自动改图的接口。其中查表和调用自动改图接口操作为关键步骤。

关键步骤介绍

  1. 查询稀疏特征表。

    调用sparse_lookup接口,设置参数modify_graph = True表示在查表时采用自动改图模式,该参数默认为“False”

  2. 调用自动改图接口。

    自动改图接口为modify_graph_and_start_emb_cache,但区分使用TF.Session训练模式使用NPUEstimator训练模式

    1. 使用TF.Session训练模式。

      在TF.Session训练模式下,需要显示调用modify_graph_and_start_emb_cache接口,同时sess.run(iterator.initializer)也需修改为自动改图的数据集初始化接口sess.run(get_initializer(True))或者sess.run(get_initializer(False)),前者用于训练(train)、后者用于评估(eval)。

    2. 使用NPUEstimator训练模式。

      在NPUEstimator模式下,需要在NPUEstimator的多个模式(train、predict、train_and_evaluate)中添加自动改图的GraphModifierHook,如当前为训练(train),则在训练的钩子(Hook)中添加GraphModifierHook,即可完成自动改图模式的训练。

示例代码

  1. 查询稀疏特征表。
    from mx_rec.core.embedding import sparse_lookup
    
    embedding = sparse_lookup(hash_table, feature, send_count, dim=None, is_train=is_train,
                              access_and_evict_config=access_and_evict_config,
                              name=hash_table.table_name + "_lookup", modify_graph=True, batch=batch)
  2. 调用改图接口。
    1. 使用TF.Session训练模式。
      from mx_rec.util.initialize import get_initializer
      from mx_rec.graph.modifier import modify_graph_and_start_emb_cache
      
      if MODIFY_GRAPH_FLAG:
          logging.info("start to modifying graph")
          modify_graph_and_start_emb_cache(dump_graph=True)
      else:
          start_asc_pipeline()
      
      # train
      with tf.compat.v1.Session(config=sess_config()) as sess:
          if MODIFY_GRAPH_FLAG:
              sess.run(get_initializer(True))
          else:
              sess.run(train_iterator.initializer)
      
      # eval
      def evaluate():
          if MODIFY_GRAPH_FLAG:
              sess.run(get_initializer(False))
          else:
              sess.run(eval_iterator.initializer)
    2. 使用NPUEstimator训练模式。
      from mx_rec.graph.modifier import GraphModifierHook
      
      est.train(input_fn=lambda: input_fn(), hooks=[GraphModifierHook()])   #est为创建的NPUEstimator对象
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词