NPUEstimatorSpec构造函数
功能说明
NPUEstimatorSpec类的构造函数,NPUEstimatorSpec类继承了TensorFlow的EstimatorSpec类,可以调用基类的原生接口,定义具体的模型对象。
EstimatorSpec是model_fn的返回数据结构,包含了mode,predictions,loss,train_op和export_outputs等字段,传递到Estimator。直接使用EstimatorSpec无法满足训练的功能,定义NPUEstimatorSpec,代替EstimatorSpec的功能。
函数原型
def __new__(cls, mode, predictions=None, loss=None, train_op=None, eval_metric_ops=None, export_outputs=None, training_chief_hooks=None, training_hooks=None, scaffold=None, evaluation_hooks=None, prediction_hooks=None, host_call=None)
参数说明
参数名 |
输入/输出 |
描述 |
---|---|---|
NPUEstimatorSpec继承EstimatorSpec的参数: |
||
mode |
输入 |
模式,指明当前是在训练、验证、还是推理。
|
predictions |
输入 |
推理的输出Tensor,当mode为ModeKeys.PREDICT时必须指定该参数。 |
loss |
输入 |
训练的损失。 |
train_op |
输入 |
训练算子。 |
eval_metric_ops |
输入 |
度量结果的字典(按照Tensor名称)。 字典值可以是以下之一:
|
export_outputs |
输入 |
用于模型保存,描述了导出到SavedModel的输出格式。 |
training_chief_hooks |
输入 |
训练执行时主节点的SessionRunHooks集合。 |
training_hooks |
输入 |
训练执行时的SessionRunHooks集合。 |
scaffold |
输入 |
定义scaffold(提供定制saver、init_op、summary_op、global_step的能力) |
evaluation_hooks |
输入 |
验证执行时的SessionRunHooks集合 |
prediction_hook |
输入 |
推理执行时的SessionRunHooks集合 |
NPUEstimatorSpec新增参数: |
||
host_call |
输入 |
捕捉Summary信息,将每个step的信息传回Host侧查看。 host_call是一个function和一个tensor的列表或字典组成的元组,用于返回tensor列表。 host_call目前适用于train()和evaluate()。 |
返回值
返回NPUEstimatorSpec类对象。
调用示例
1 2 3 4 |
from npu_bridge.npu_init import * ... host_call = (_host_call_fn, [global_step, loss]) return NPUEstimatorSpec(mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op, host_call=host_call) |