下载
中文
注册

predict

函数功能

执行模型推理。

函数原型

predict(llm_req: Union[LLMReq, List[LLMReq], Tuple[LLMReq]], inputs: Any, **kwargs)

参数说明

参数名称

数据类型

取值说明

llm_req

Union[LLMReq, List[LLMReq], Tuple[LLMReq]]

请求信息。

  • 对于Decoder,其个数需要与模型的batch_size相等, 空闲的位置需要用req_id = UINT64_MAX的占位。
  • 对于Prompt,最后一个有效Req之后的无效Req可以省略。

inputs

Any

模型输入, 会透传给ModelRunner.run_model方法。

如果使用默认的ModelRunner, inputs的类型需要为Union[Tuple[Tensor], List[Tensor]]。

kwargs

Optional[Dict]

可选参数,会透传给ModelRunner.run_model方法。

调用示例

from llm_datadist import LLMDataDist, LLMRole, ModelConfig, ModelRunner, LLMReq
llm_datadist = LLMDataDist(LLMRole.DECODER, 0)
model_config = ModelConfig()
kv_nums = 80
model_config.kv_shapes = ["1,1024" for _ in range(kv_nums)]
model_config.kv_dtypes = [DataType.DT_FLOAT for _ in range(kv_nums)]
model_options = model_config.generate_options()
class TestModelRunner(ModelRunner):    
   def run_model(self, kv_cache: KvCache, input_tensors, **kwargs):        
       return []
llm_model = llm_datadist.add_model(model_options, TestModelRunner())
llm_req = LLMReq()
inputs = [Tensor(np.array([1]))]
llm_model.predict(llm_req, inputs)

返回值

正常情况下,返回LLMModel关联的ModelRunnerrun_model输出。

异常情况会抛出LLMException。

约束说明