算子提取脚本样例

import re
from collections import defaultdict
import argparse

def dump_file(ops, file_path):
    with open(file_path, 'w') as f:
        for op in ops:
            f.write(op)
            f.write('\n')

def parse_profiler(profiler_file):
    ops_shapes = defaultdict(set)
    ops_shapes_first_step = defaultdict(set)
    ops_shapes_other_steps = defaultdict(set)
    with open(profiler_file, 'r') as f:
        lines = f.readlines()
        step = 0
        for line in lines:
            if re.findall(r'^Name.*Input Shapes$', line.strip()):
                step += 1
                continue
            if step == 0:
                continue

            if -1 == line.find('[[') and -1 == line.find('[]'):
                continue

            line_fields = [field.strip() for field in line.strip().split('  ') if field != '']
            ops_shapes[line_fields[0]].add(line_fields[-1])
            if step == 1:
                ops_shapes_first_step[line_fields[0]].add(line_fields[-1])
            else:
                ops_shapes_other_steps[line_fields[0]].add(line_fields[-1])

    all_ops = [k for k, v in ops_shapes.items()]

    dynamic_ops = list()
    for op_name, shape_set in ops_shapes_other_steps.items():
        if op_name not in ops_shapes_first_step.keys():
            dynamic_ops.append(op_name)
        else:
            if len(shape_set - ops_shapes_first_step[op_name]) > 0:
                dynamic_ops.append(op_name)
    return all_ops, dynamic_ops

def extract_ops(profiler_file):
    all_ops, dynamic_ops = parse_profiler(profiler_file)

    print('all_ops:', all_ops)
    print('dynamic_ops', dynamic_ops)

    dump_file(all_ops, 'all_ops.txt')
    dump_file(dynamic_ops, 'dynamic_ops.txt')

if __name__ == '__main__':
    parser = argparse.ArgumentParser('extract ops')
    parser.add_argument('--profiler_file', default='', type=str, metavar='PATH')

    args = parser.parse_args()
    extract_ops(args.profiler_file)