文档
注册

框架接口调用

PyTorch框架调用为例,通过加速库包调用PARunner可以直接进行PagedAttention优化后的模型推理,具体样例可参考examples/run_pa.py:
# 1.环境设置
torch.npu.set_compile_mode(jit_compile=False)  # 使能二进制优化,消除动态shape的编译问题
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))

# 2.初始化PARunner
runner_config = {
    'rank': rank,                                                      # 并行推理参数rank
    'local_rank': local_rank,                                          # 并行推理参数local_rank
    'world_size': world_size,                                          # 并行推理参数world_size
    'max_prefill_tokens': -1,                                          # 最大输入序列长度,-1则为max_batch_size*(max_input_length+max_output_length)
    'block_size': 128,                                                 # PA block大小
    'model_path': weight_dir,                                          # 模型权重路径
    'is_bf16': False,                                                  # 模型dtype,True为bfloat16,False为float16
    'max_position_embeddings': 1024,                                   # 最大序列长度,可配置为max_input_length+max_output_length
    'max_batch_size': 8,                                               # 最大batch size
    'use_refactor': True,
    'max_input_length': 512,                                           # 最大输入序列长度
    'max_output_length': 512                                           # 最大输出序列长度
}
pa_runner = PARunner(**runner_config)

# 模型warmup,用于给kv cache预申请内存
pa_runner.warm_up()

# 执行推理
input_texts = ["What's deep learning?"]
generate_texts, token_nums, ete_time = pa_runner.infer(
                                    input_texts=input_texts,           # 输入文本
                                    batch_size=1,                      # batch大小
                                    max_output_length=512,             # 最大输出序列长度
                                    ignore_eos=False,                  # 是否忽略EOS
                                    input_ids=None                     # 可直接传入tokeninze后的输入id,为None则使用input_texts
)

# 打印推理结果
for i, generate_text in enumerate(generate_texts):
    length = len(args.input_ids) if args.input_ids else len(args.input_texts)
    inputs = args.input_ids if args.input_ids else args.input_texts
    if i < length:
        print_log(rank, logger.info, f'Question[{i}]: {inputs[i]}')
    print_log(rank, logger.info, f'Answer[{i}]: {generate_text}')
    print_log(rank, logger.info, f'Generate[{i}] token num: {token_nums[i]}')
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词