下载
中文
注册

do_merge_lookup

功能描述

该接口用于自动改图模式下,对多次查询的表进行lookup合并操作。

在模型中,此函数在Optimizer.compute_gradients()中利用patch执行,确保train时拥有正确的梯度和计算图;eval时在改图阶段执行。

函数原型

from mx_rec.graph.merge_lookup import do_merge_lookup

参数说明

参数名

类型

可选/必选

说明

is_train

bool

必选

当前是否为训练模式。

  • True:训练(train)模式。
  • False:评估(eval)模式。

使用示例

例如,train模式,全部的梯度计算都使用tf.gradients,则需要主动调用do_merge_lookup。

from mx_rec.graph.merge_lookup import do_merge_lookup
do_merge_lookup(is_train=True)
sparse_grads = tf.gradients(loss, sparse_variables)
grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)]
optimizer.apply_gradients(grads_and_vars)