由于硬件约束,Atlas 200/300/500 推理产品、Atlas 推理系列产品、Atlas 训练系列产品不支持4选2结构化稀疏特性:使能后获取不到性能收益。
AMCT支持基于重训练的4选2结构化稀疏特性,稀疏示例请参见获取更多样例。该特性支持的层以及约束如下:
支持的层类型 |
约束 |
备注 |
---|---|---|
MatMul |
transpose_a=False 权重数据类型为 Float32, Float64 |
- |
Conv2D |
权重数据类型为 Float32, Float64 |
weight的输入来源不含有placeholder等可动态变化的节点,且weight的节点类型只能是const。 |
Conv2DBackpropInput |
- |
蓝色部分为用户实现,灰色部分为用户调用AMCT提供的API实现,用户在TensorFlow原始网络推理的代码中导入库,并在特定的位置调用相应API,即可实现稀疏功能。
1 2 |
import amct_tensorflow as amct amct.set_logging_level(print_level='info', save_level='info') |
推荐执行该步骤,以确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。
1
|
user_test_evaluate_model(evaluate_model, test_data) |
1
|
train_graph = user_load_train_graph() |
用户基于构造的训练模式的图结构(BN的is_training参数为True) ,调用create_prune_retrain_model接口(对应图1中的序号1),根据稀疏配置文件(对应图1中的序号2)对训练的图进行稀疏前的图结构修改。create_prune_retrain_model接口会在图结构中插入4选2结构化稀疏算子,达到推理时伪稀疏的效果。稀疏配置文件需要参见量化感知训练简易配置文件自行构造。
1 2 3 4 5 6 |
record_file = './tmp/record.txt' simple_cfg = './retrain.cfg' amct.create_prune_retrain_model(graph=train_graph, outputs=user_model_outputs, record_file=record_file, config_defination=simple_cfg) |
1 2 3 |
optimizer = tf.compat.v1.train.RMSPropOptimizer( ARGS.learning_rate, momentum=ARGS.momentum) train_op = optimizer.minimize(loss) |
1 2 3 4 5 |
with tf.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(outputs) #将训练后的参数保存为checkpoint文件 saver_save.save(sess, retrain_ckpt, global_step=0) |
1
|
test_graph = user_load_test_graph() |
1 2 3 4 5 6 |
record_file = './tmp/record.txt' simple_cfg = './retrain.cfg' amct.create_prune_retrain_model(graph=test_graph, outputs=user_model_outputs, record_file=record_file, config_defination=simple_cfg) |
1 2 3 4 5 6 7 8 9 10 11 |
variables_to_restore = tf.compat.v1.global_variables() saver_restore = tf.compat.v1.train.Saver(variables_to_restore) with tf.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) #恢复训练参数 saver_restore.restore(sess, retrain_ckpt) #固化pb模型 constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants( sess, eval_graph.as_graph_def(), [output.name[:-2] for output in outputs]) with tf.io.gfile.GFile(masked_pb_path, 'wb') as f: f.write(constant_graph.SerializeToString()) |
1 2 3 4 5 |
pruned_model_path = './result/user_model' amct.save_prune_retrain_model(pb_model=masked_pb_path, outputs=user_model_outputs, record_file=record_file, save_path=pruned_model_path) |
1 2 |
pruned_model = './results/user_model_pruned.pb' user_do_inference(pruned_model, test_data) |