文档
注册

基础稀疏

AMCT目前主要支持基于重训练的通道稀疏模型压缩特性,稀疏示例请参见获取更多样例>resnet101,支持通道稀疏的层以及约束如下:

表1 通道稀疏支持的层以及约束

优化方式

支持的层类型

约束

通道稀疏

torch.nn.Linear:全连接层

复用层(共用weight和bias参数)不支持稀疏。

torch.nn.Conv2d:卷积层

复用层(共用weight和bias参数)不支持稀疏。

depthwise只能被动稀疏(groups=in_channels),不能主动稀疏。

接口调用流程

通道稀疏功能接口调用流程如图1所示。

图1 通道稀疏接口调用流程
蓝色部分为用户实现,灰色部分为用户调用AMCT提供的API实现:
  1. 用户首先构造PyTorch的原始模型,调用create_prune_retrain_model接口对原始模型进行修改,在图结构中插入通道稀疏mask算子,修改后的模型参数量被裁剪。
  2. 对修改后的模型进行训练,直至精度满足要求;如果训练过程中断,则可基于原始模型和记录稀疏信息的文件,重新调用restore_prune_retrain_model接口稀疏模型,继续进行量化感知的训练,直至精度满足要求。
  3. 根据最终的重训练好的通道稀疏模型,生成满足精度要求的pth文件;或者调用save_prune_retrain_model接口,生成最终ONNX仿真模型以及部署模型。

调用示例

  1. 如下示例标有“由用户补充处理”的步骤,需要用户根据自己的模型和数据集进行补充处理,示例中仅为示例代码。
  2. 调用AMCT的部分,函数入参可以根据实际情况进行调整。稀疏基于用户的训练过程,请确保已经有基于PyTorch环境进行训练的脚本,并且训练后的精度正常。
  1. 导入AMCT包,并通过安装后处理中的环境变量设置日志级别。
    import amct_pytorch as amct
  2. (可选,由用户补充处理)建议使用原始待稀疏的模型和测试集,在PyTorch环境下推理,验证环境、推理脚本是否正常。

    推荐执行该步骤,请确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。

    ori_model.load()
    # 测试模型
    user_test_model(ori_model, test_data, test_iterations)
  3. 调用AMCT,执行带稀疏算子的训练流程。
    1. 对原始模型进行修改,在图结构中插入通道稀疏mask算子。
      实现该步骤前,应先恢复训练好的参数,如2中的ori_model.load()。
      simple_cfg = './retrain.cfg'
      record_file = './tmp/record.txt'
      prune_retrain_model = amct.create_prune_retrain_model(model=ori_model,
                                      input_data=ori_model_input_data,
                                      config_defination=simple_cfg,
                                      record_file=record_file)
    2. (由用户补充处理)使用修改后的图,创建反向梯度,在训练集上做训练。
      1. 使用修改后的图,创建反向梯度。
        该步骤需要在稀疏模型后执行。
        optimizer = user_create_optimizer(prune_retrain_model)
      2. 从训练好的checkpoint恢复模型,并训练模型。

        注意:从训练好的checkpoint恢复模型参数后再训练。

        quant_pth = './ckpt/user_model'
        user_train_model(optimizer, prune_retrain_model, train_data)
    3. (可选)如果调用save_prune_retrain_model接口,则需要参考该步骤,如果保存为pth文件则不需要。
      保存模型,实现通道稀疏。
      prune_retrain_model = amct.save_prune_retrain_model(
           model=pruned_retrain_model,
           save_path=save_path,
           input_data=input_data)
  4. (可选)如果调用save_prune_retrain_model接口,则需要参考该步骤,如果保存为pth文件则不需要。

    (由用户补充处理)基于ONNX Runtime的环境,使用通道稀疏后模型(prune_retrain_model)在测试集(test_data)上做推理,测试量化后仿真模型的精度。

    使用稀疏后仿真模型精度与2中的原始精度做对比,可以观察通道稀疏对精度的影响。

    prune_retrain_model = './results/user_model_fake_prune_model.onnx'
    user_do_inference_onnx(prune_retrain_model, test_data, test_iterations)

如果训练过程中断,需要从ckpt中恢复数据,继续训练,则调用流程为:

  1. 导入AMCT包,并通过安装后处理中的环境变量设置日志级别。
    import amct_pytorch as amct
  1. 准备原始模型。
    ori_model= user_create_model()
  2. 调用AMCT,恢复量化训练流程。
    1. 修改模型,在图结构中插入通道稀疏mask算子,保存为新的prune_model。
      model = ori_model
      input_data = ori_model_input_data
      record_file = './tmp/record.txt'
      config_defination = './prune_cfg.cfg'
      save_pth_path = /your/path/to/save/tmp.pth
      model.load_state_dict(torch.load(state_dict_path))
      prune_retrain_model = amct.restore_prune_retrain_model(model=ori_model,
                                                             input_data=ori_model_input_data,
                                                             record_file=record_file,
                                                             config_defination='./prune_cfg.cfg',
                                                             save_pth_path=/your/path/to/save/tmp.pth,
                                                             'state_dict')
    2. (由用户补充处理)使用修改后的模型,恢复断点,创建反向梯度,在训练集上做训练。
      1. 从稀疏后训练中断的checkpoint恢复模型参数。
        quant_pth = './ckpt/user_prune_model'
        user_train_model(optimizer, prune_retrain_model, train_data)
      2. 使用修改后的图,创建反向梯度。
        该步骤需要在恢复模型参数后执行。
        optimizer = user_create_optimizer(prune_retrain_model)
      3. 从训练好的checkpoint恢复模型,并训练模型。

        注意:从训练好的checkpoint恢复模型参数后再训练。

        user_train_model(optimizer, prune_retrain_model, train_data)
    3. (可选)如果调用save_prune_retrain_model接口,则需要参考该步骤,如果保存为pth文件则不需要。
      保存模型,实现通道稀疏。
      prune_retrain_model = amct.save_prune_retrain_model(
           model=pruned_retrain_model,
           save_path=save_path,
           input_data=input_data)
  3. (可选)如果调用save_prune_retrain_model接口,则需要参考该步骤,如果保存为pth文件则不需要

    (由用户补充处理)基于ONNX Runtime的环境,使用通道稀疏后模型(prune_retrain_model)在测试集(test_data)上做推理,测试量化后仿真模型的精度。

    使用稀疏后仿真模型精度与2中的原始精度做对比,可以观察通道稀疏对精度的影响。

    prune_retrain_model = './results/user_model_fake_prune_model.onnx'
    user_do_inference_onnx(prune_retrain_model, test_data, test_iterations)

后续处理

如果稀疏后输出的模型为pth格式,则需要参考该章节,如果调用save_prune_retrain_model接口,则不需要。

由于输出的pth模型无法直接用于推理,需要用户自行将pth模型转成ONNX网络模型,或者调用save_prune_retrain_model接口保存为最终ONNX仿真模型以及部署模型,然后才能使用ATC工具进行模型转换。调用save_prune_retrain_model接口的调用示例如下:

prune_retrain_model = amct.10.6.3-save_prune_retrain_model(
     model=pruned_retrain_model,
     save_path=save_path,
     input_data=input_data)
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词