Keras迁移
Keras简介
Keras和Estimator类似,都属于TensorFlow高阶API,提供了方便的构图功能,并对训练、评估、验证、导出提供了方便的接口。
TensorFlow 2.6版本中继续支持Keras API,如果需要沿用在TensorFlow1.x版本中的用法,则可以通过compat.v1模块调用,调用方式如下:
1 | tf.compat.v1.Session |
使用Keras API进行训练脚本开发的一般步骤为:
- 数据预处理。
- 模型搭建。
- 模型编译。
- 模型训练。
当前仅支持通过Tensorflow的Keras API编写的训练脚本,而不支持原生Keras API。
下面介绍如何迁移此类Keras训练脚本,以便在昇腾AI处理器上进行训练。
头文件增加
对于以下步骤中涉及修改的python文件,新增以下头文件引用,用于导入NPU相关库。
1 2 3 | import npu_device from npu_device.compat.v1.npu_init import * npu_device.compat.enable_v1() |
迁移点说明
如果您是Keras训练脚本,由于Keras迁移到Ascend平台运行时,某些功能受限支持,例如不支持动态学习率等,因此不建议用户在Ascend平台上迁移Keras开发的网络脚本。如需在Ascend平台运行Keras脚本,您需要对脚本进行如下修改:
创建一个TensorFlow会话并且注册Keras,并增加相关配置项以便在昇腾AI处理器执行训练。同时在训练结束时,关闭会话。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | import tensorflow as tf import tensorflow.keras as keras from tensorflow.keras import backend as K from npu_device.compat.v1.npu_init import * sess_config = tf.compat.v1.ConfigProto() custom_op = sess_config.graph_options.rewrite_options.custom_optimizers.add() custom_op.name = "NpuOptimizer" sess_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF sess_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF sess = tf.compat.v1.Session(config=sess_config) K.set_session(sess) #数据预处理... #模型搭建... #模型编译... #模型训练... sess.close() |
父主题: compat.v1脚本手工迁移