下载
中文
注册

插入量化算子后,如何恢复模型训练参数

用户调用quantize_model接口在传入的图中插入量化相关的变量op列表(quant_add_ops),该列表中的变量值无法在模型训练参数文件中找到,故模型训练参数直接恢复会出现变量无法找到的错误,因此需要在模型参数恢复之前,将quant_add_ops列表中的变量值从恢复列表中剔除,具体的剔除方法如下:

  1. 影子变量恢复情况
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    # 1.获取所有变量的{变量名:变量值}对应字典variables_dict
    variables_ema = tf.train.ExponentialMovingAverage(moving_average_decay)
    variables_dict = variables_ema.variables_to_restore()
    
    # 2.定义需要恢复的变量的{变量名:变量值}对应字典params_need_load
    params_need_load = dict()
    
    # 3.基于quant_add_ops从variables_dict中找寻需要恢复的变量
    for key, value in variables_dict.items():
        if value not in quant_add_ops:
            params_need_load[key] = value
    
    # 4.恢复变量
    loader = tf.train.Saver(params_need_load)
    loader.restore(sess, FLAGS.checkpoint)
    
  2. 非影子变量恢复的情况
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    # 1.获取所有变量的{变量名:变量值}对应字典variables_dict
    variables_global = tf.global_variables()
    variables_dict = dict()
    for var in variables_global:
        variables_dict[var.name[:-2]] = var
    
    # 2.定义需要恢复的变量的{变量名:变量值}对应字典params_need_load
    params_need_load = dict()
    
    # 3.基于quant_add_ops从variables_dict中找寻需要恢复的变量
    for key, value in variables_dict.items():
        if value not in quant_add_ops:
           params_need_load[key] = value
    
    # 4.恢复变量
    loader = tf.train.Saver(params_need_load)
    loader.restore(sess, FLAGS.checkpoint)