predict
函数功能
执行模型推理。
函数原型
predict(llm_req: Union[LLMReq, List[LLMReq], Tuple[LLMReq]], inputs: Any, **kwargs)
参数说明
参数名称 |
数据类型 |
取值说明 |
---|---|---|
llm_req |
Union[LLMReq, List[LLMReq], Tuple[LLMReq]] |
请求信息。
|
inputs |
Any |
模型输入, 会透传给ModelRunner.run_model方法。 如果使用默认的ModelRunner, inputs的类型需要为Union[Tuple[Tensor], List[Tensor]]。 |
kwargs |
Optional[Dict] |
可选参数,会透传给ModelRunner.run_model方法。 |
调用示例
from llm_datadist import LLMDataDist, LLMRole, ModelConfig, ModelRunner, LLMReq llm_datadist = LLMDataDist(LLMRole.DECODER, 0) model_config = ModelConfig() kv_nums = 80 model_config.kv_shapes = ["1,1024" for _ in range(kv_nums)] model_config.kv_dtypes = [DataType.DT_FLOAT for _ in range(kv_nums)] model_options = model_config.generate_options() class TestModelRunner(ModelRunner): def run_model(self, kv_cache: KvCache, input_tensors, **kwargs): return [] llm_model = llm_datadist.add_model(model_options, TestModelRunner()) llm_req = LLMReq() inputs = [Tensor(np.array([1]))] llm_model.predict(llm_req, inputs)
约束说明
无
父主题: LLMModel