插入量化算子后,如何恢复模型训练参数
用户调用quantize_model接口在传入的图中插入量化相关的变量op列表(quant_add_ops),该列表中的变量值无法在模型训练参数文件中找到,故模型训练参数直接恢复会出现变量无法找到的错误,因此需要在模型参数恢复之前,将quant_add_ops列表中的变量值从恢复列表中剔除,具体的剔除方法如下:
- 影子变量恢复情况
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# 1.获取所有变量的{变量名:变量值}对应字典variables_dict variables_ema = tf.train.ExponentialMovingAverage(moving_average_decay) variables_dict = variables_ema.variables_to_restore() # 2.定义需要恢复的变量的{变量名:变量值}对应字典params_need_load params_need_load = dict() # 3.基于quant_add_ops从variables_dict中找寻需要恢复的变量 for key, value in variables_dict.items(): if value not in quant_add_ops: params_need_load[key] = value # 4.恢复变量 loader = tf.train.Saver(params_need_load) loader.restore(sess, FLAGS.checkpoint)
- 非影子变量恢复的情况
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# 1.获取所有变量的{变量名:变量值}对应字典variables_dict variables_global = tf.global_variables() variables_dict = dict() for var in variables_global: variables_dict[var.name[:-2]] = var # 2.定义需要恢复的变量的{变量名:变量值}对应字典params_need_load params_need_load = dict() # 3.基于quant_add_ops从variables_dict中找寻需要恢复的变量 for key, value in variables_dict.items(): if value not in quant_add_ops: params_need_load[key] = value # 4.恢复变量 loader = tf.train.Saver(params_need_load) loader.restore(sess, FLAGS.checkpoint)
父主题: FAQ