用户调用quantize_model接口在传入的图中插入量化相关的变量op列表(quant_add_ops),该列表中的变量值无法在模型训练参数文件中找到,故模型训练参数直接恢复会出现变量无法找到的错误,因此需要在模型参数恢复之前,将quant_add_ops列表中的变量值从恢复列表中剔除,具体的剔除方法如下:
- 影子变量恢复情况
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 | # 1.获取所有变量的{变量名:变量值}对应字典variables_dict
variables_ema = tf.train.ExponentialMovingAverage(moving_average_decay)
variables_dict = variables_ema.variables_to_restore()
# 2.定义需要恢复的变量的{变量名:变量值}对应字典params_need_load
params_need_load = dict()
# 3.基于quant_add_ops从variables_dict中找寻需要恢复的变量
for key, value in variables_dict.items():
if value not in quant_add_ops:
params_need_load[key] = value
# 4.恢复变量
loader = tf.train.Saver(params_need_load)
loader.restore(sess, FLAGS.checkpoint)
|
- 非影子变量恢复的情况
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 | # 1.获取所有变量的{变量名:变量值}对应字典variables_dict
variables_global = tf.global_variables()
variables_dict = dict()
for var in variables_global:
variables_dict[var.name[:-2]] = var
# 2.定义需要恢复的变量的{变量名:变量值}对应字典params_need_load
params_need_load = dict()
# 3.基于quant_add_ops从variables_dict中找寻需要恢复的变量
for key, value in variables_dict.items():
if value not in quant_add_ops:
params_need_load[key] = value
# 4.恢复变量
loader = tf.train.Saver(params_need_load)
loader.restore(sess, FLAGS.checkpoint)
|