动态shape固定

场景解析

当前PyTorch在执行算子时会检查shape和attribute,如遇到未编译的则会进行编译,导致运行速度下降。对动态shape场景会反复编译算子导致性能较差。

调优思路

可通过脚本获取host侧日志,解析得到重编译算子信息并固定shape。

优化方法

算子过程中的调试方法可参考下表。

表1 基础命令参考

命令

功能说明

ls -l kernel_meta/ | wc -l

查看当前编译的算子个数。

watch "ls -l kernel_meta/ | wc -l"

持续观察下算子个数是否随epoch增长来判断是否动态shape。

ls -lt kernel_meta/ | head -5

查看最新编译的5个算子,通过查看算子名称,可大致推断动态shape所在的地方。

动态shape算子提取步骤如下:

  1. host日志提取。

    添加每个iter的标志位。
    # Example
    for idx, (img, target) in enumerate(dataset_loader):
        # 添加标志位
        print('To compile op|=====iter%d'%idx)
        ...

  2. 在拉起训练命令前加上nohup,其余不变。示例命令如下。

    nohup python3 main.py /home/data/resnet50/imagenet   --batch-size 128 \       # 训练批次大小
                                                   --lr 0.1 \               # 学习率
                                                   --epochs 90 \            # 训练迭代轮数
                                                   --arch resnet50 \        # 模型架构
                                                   --world-size 1 \
                                                   --rank 0 \         
                                                   --workers 40 \           # 加载数据进程数
                                                   --momentum 0.9 \         # 动量  
                                                   --weight-decay 1e-4      # 权重衰减
                                                   --gpu 0                  # device号, 这里参数名称仍为gpu, 但迁移完成后实际训练设备已在代码中定义为npu

    推荐执行5-10个step,等待运行完成。

  3. 解析得到recompile_op.txt。

    1. 获取标志位信息。
      cat nohup.out | grep "To compile op|=====iter" > recompile_op.log
    2. 获取重编译算子列表。脚本recompiled_op.py样例见样例参考
      python3  recompiled_op.py
    3. 重点关注recompile_op_list.txt中第三个step以后的compile的算子,视为动态shape引起的重编译算子。
    4. 固定动态shape参考固定动态shape范例

样例参考

recompiled_op.py样例参考:

#用于导出动态shape算子
import os
import sys
import argparse


def func(log_path, split_flag):
    """
    :param log_path: where log_path addr is.
    :return:
    """
    recompile_flag = 'To compile op: '
    output_list = [[]]

    with open(log_path, 'r')as f:
        log_list = f.read().split('\n')
        for log in log_list:
            log = log.strip()
            if split_flag in log:
                output_list.append([])
            elif recompile_flag in log:
                op_name = log.split(recompile_flag)[1]
                if op_name not in output_list[-1]:
                    output_list[-1].append(op_name)

    with open('recompile_op_list.txt', 'w')as f:
        for idx, output in enumerate(output_list):
            f.write('iter: %d' % idx + '\n')
            f.write(','.join(output) + '\n')


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='trans the log')
    parser.add_argument('--log_path', default="./recompile_op.log",
                        help="input the dir name, trans the current dir with default")
    parser.add_argument('--split_flag', default='=====iter',
                        help="flag for split epochs")
    args = parser.parse_args()
    func(args.log_path, args.split_flag)