文档
注册

add_model

函数功能

添加模型。

函数原型

add_model(model_options: Dict[str, str], model_runner: Optional[ModelRunner] = None)

参数说明

参数名称

数据类型

取值说明

model_options

options: Dict[str, str]

模型配置项。

传入options通过ModelConfig来生成,必填字段:kv_shapes和kv_dtypes。

model_runner

Optional[ModelRunner]

ModelRunner的实现类。

调用示例

from llm_datadist import LLMDataDist, LLMRole, ModelConfig, ModelRunner
llm_datadist = LLMDataDist(LLMRole.DECODER, 0)
model_config = ModelConfig()
kv_nums = 80
model_config.kv_shapes = ["1,1024" for _ in range(kv_nums)]
model_config.kv_dtypes = [DataType.DT_FLOAT for _ in range(kv_nums)]
model_options = model_config.generate_options()

class TestModelRunner(ModelRunner):    
   def run_model(self, kv_cache: KvCache, input_tensors, **kwargs):        
       return []
llm_model = llm_datadist.add_model(model_options, TestModelRunner())

返回值

正常情况下返回新创建的LLMModel。

参数错误可能抛出TypeError或RuntimeError。

约束说明

搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词