文档
注册

模型稀疏加速训练

模型训练往往存在大量冗余计算,稀疏加速算法基于网络扩增训练的思想,结合参数继承方法,对模型中的部分算子实现权重的稀疏化。分别实现了宽度、深度层面的网络扩增算法,以应对不同场景,支撑模型训练阶段加速。

模型稀疏加速训练前需要参考环境准备完成开发环境部署、Python环境变量、所需框架及训练服务器环境变量配置

目前支持对PyTorch框架下包括但不限于表1中的模型稀疏加速训练。

表1 已验证模型列表

类型

模型

自然语言处理

BertBase Chinese

语音识别

espnet-conformer

图像分类

ResNet50

vgg16

swin_tiny

目标检测

Yolov5

Yolox

宽度扩增模型稀疏训练加速(以ResNet50模型为例)

  1. 用户需自行准备模型、训练脚本和数据集,本样例以PyTorch框架的ResNet50和数据集ImageNet为例。
  2. 编辑模型的训练脚本pytorch_resnet50_apex.py文件,导入稀疏加速训练的接口,新增代码参见如下示例。
    from modelslim.pytorch import sparse
  3. (可选)新增如下代码,调整日志输出等级。启动训练任务后,将打屏显示调试的日志信息。
    from modelslim import set_logger_level
    set_logger_level("info")        #根据实际情况配置
  4. 当前训练脚本包含初始化模型、优化器optimizer和训练数据集,用户可以参考训练脚本pytorch_resnet50_apex.py自定义实现。
    在初始化模型、优化器optimizer和训练数据集后,新增如下代码调用稀疏化训练接口,将模型包装为稀疏化训练模型,请参考sparse_model_width进行配置。
    epochs_each_stage = [10, 20, 30]     #定义稀疏化每个阶段的epoch数量
    model = sparse.sparse_model_width(
            model, optimizer, steps_per_epoch=len(train_loader), epochs_each_stage=epochs_each_stage
    )

    其中epochs_each_stage = [10, 20, 30]表示稀疏化分3个阶段:

    • 第1个阶段,从原模型裁剪为1/4的初始模型开始训练10个epoch。
    • 第2个阶段将初始模型扩增2倍,训练20个epoch。
    • 第3个阶段初始模型扩增为4倍,恢复为原模型大小训练,训练30个epoch。

    若原训练脚本中已使用torch.nn.parallel.DistributedDataParallel封装了模型,稀疏加速接口sparse.sparse_model_width()的参数 model需为非ddp模式的模型。

  5. 启动训练任务,根据单卡或多卡调用不同的执行脚本,并指定data_path为数据集路径。
    • 单卡训练时,执行命令启动训练任务。
      bash ./test/train_full_1p.sh --data_path=./datasets/imagenet  #请根据实际情况配置数据集路径
    • 多卡训练时,执行命令启动训练任务。以下示例为8卡训练,请根据实际情况替换启动脚本。
      bash ./test/train_full_8p.sh --data_path=./datasets/imagenet   #请根据实际情况配置数据集路径

深度扩增模型稀疏训练加速(以Swin-Transformer模型为例)

  1. 用户需自行准备模型、训练脚本和数据集,本样例以PyTorch框架的Swin-Transformer和数据集ImageNet为例。
  2. 编辑模型的训练脚本main.py文件,导入稀疏加速训练的接口,新增代码参见如下示例。
    from modelslim.pytorch import sparse
  3. (可选)新增如下代码,调整日志输出等级。启动训练任务后,将打屏显示调试的日志信息。
    from modelslim import set_logger_level
    set_logger_level("info")        #根据实际情况配置
  4. 当前训练脚本包含初始化模型、优化器optimizer和训练数据集,用户可以参考训练脚本main.py自定义实现。
    在初始化模型、优化器optimizer和训练数据集后,新增如下代码,调用稀疏化训练接口,将模型包装为稀疏化训练模型,请参考sparse_model_depth进行配置。
    epochs_each_stage = [10, 20, -1]     #定义稀疏化每个阶段的epoch数量
    model = sparse.sparse_model_depth(model, optimizer, steps_per_epoch=len(data_loader_train), epochs_each_stage=epochs_each_stage)

    其中epochs_each_stage = [10, 20, -1]表示稀疏化分3个阶段:

    • 第1个阶段,从原模型裁剪为1/4的初始模型开始训练10个epoch。
    • 第2个阶段将初始模型扩增2倍,训练20个epoch。
    • 第3个阶段epoch 数量“-1”表示训练直到总的 epoch 结束,初始模型扩增为4倍,恢复为原模型大小训练。

    若原训练脚本中已使用torch.nn.parallel.DistributedDataParallel封装了模型,稀疏加速接口sparse.sparse_model_depth()的参数 model应为非ddp模式的模型。

  5. 启动训练任务,根据单卡或多卡调用不同的执行脚本,并指定data_path为数据集路径。
    • 单卡训练时,执行命令启动训练任务。
      bash ./test/train_full_1p.sh --data_path=./datasets/imagenet  #请根据实际情况配置数据集路径
    • 多卡训练时,执行命令启动训练任务。以下示例为8卡训练,请根据实际情况替换启动脚本。
      bash ./test/train_full_8p.sh --data_path=./datasets/imagenet   #请根据实际情况配置数据集路径

huggingface库稀疏训练加速

  1. 用户需自行准备模型、训练脚本和数据集,本样例以PyTorch框架的Bert-base-Chinese为例。

    访问Bert_Chinese获取训练代码,参考对应readme准备数据集zhwiki-latest-pages-articles.txt,并获取配置模型和分词文件bert-base-chinese

  2. 编辑模型的训练脚本run_mlm.py,导入稀疏加速训练的接口,新增代码参见如下示例。
    from modelslim.pytorch import sparse

    稀疏加速训练API接口请参考sparse_huggingface_trainer_depth

  3. (可选)新增如下代码,调整日志输出等级。启动训练任务后,将打屏显示调试的日志信息。
    from modelslim import set_logger_level
    set_logger_level("info")        #根据实际情况配置
  4. 新增如下代码,将原本的trainer包装为稀疏化训练模型,请参考sparse_huggingface_trainer_depth进行配置。
    from transformers import AutoConfig
    model_config = AutoConfig.from_pretrained("bert-base-chinese/config.json")  # 模型配置类,一般已创建好
    trainer = sparse.sparse_huggingface_trainer_depth(trainer=trainer, model_config=model_config, epochs_each_stage=[2, 2, 2])
  5. 启动训练任务,根据单卡或多卡调用不同的执行脚本。
    • 单卡训练时,执行命令启动训练任务。
      bash test/train_full_1p.sh --data_path=./datasets/zhwiki-latest-pages-articles.txt --batch_size=32 --model_size=base   #请根据实际情况配置数据集的路径
    • 多卡训练,执行命令启动训练任务。以下示例为8卡训练时,请根据实际情况替换启动脚本。
      bash test/train_full_8p.sh --data_path=./datasets/zhwiki-latest-pages-articles.txt --batch_size=32 --model_size=base
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词