下载
中文
注册

TGI昇腾框架适配说明

以TGI v2.0.4版本为例,适配的部分在server/text_generation_server包中,主要适配部分包括:MindIE LLM切入、Batch结构修改、MindIE LLM推理和后处理的调用、Cache管理在NPU上的配置。

  1. MindIE LLM切入:将原TGI框架从调用GPU模型切到MindIE LLM

    适配文件:server/text_generation_server/models/__init__.py

    # 1.引入tginpu上的适配包,tgi_npu 
    try:
         import torch_npu
         from tgi_npu import MindModel, VlmMindModel
         npu_module_imported = True
     except (ImportError, NotImplementedError):
         npu_module_imported = False
     
    # 2.多模态模型支持Qwen_VL模型;文本模型支持范围同MindIE LLM
     if npu_module_imported and torch.npu.is_available():
         if model_type == QWEN_VL:
             return VlmMindModel(model_id)
         else:
             return MindModel(model_id)
  2. 添加文本模型类MindModel:

    MindModel类继承自TGI框架的FlashCausalLM类,主要适配部分为初始化MindIE LLM模型,并获得模型、Tokenizer及模型相关的信息。

    class MindModel(FlashCausalLM):
         def __init__(
                 self,
                 model_id: str
         ):
             logger.warning("Initialize mindie-llm model.")
             rank = int(os.getenv("RANK", "0"))
             world_size = int(os.getenv("WORLD_SIZE", "1"))
             model_config = {
                 'backend_type': BackendType.ATB,
                 'rank': rank,
                 'world_size': world_size,
                 'model_id': model_id,
                 'num_threads': 8,
                 'local_rank': rank,
                 'npu_device_id': rank
             }
    # 1. 初始化mindie llm模型,获得的self.model_runner包含模型信息
            self.model_runner = GeneratorTorch(model_config)
             super(MindModel, self).__init__(
                 model=self.model_runner.model_wrapper.model_runner.model,
                 tokenizer=self.model_runner.tokenizer,
                 num_layers=self.model_runner.model_info.num_layers,
                 num_kv_heads=self.model_runner.model_info.num_kv_heads,
                 head_size=self.model_runner.model_info.head_size,
                 dtype=self.model_runner.model_info.dtype,
                 device=self.model_runner.model_info.device,
                 rank=self.model_runner.rank,
                 world_size=self.model_runner.world_size,
             )
             logger.warning("MindModel from tgi_npu initialized.")
    
    # 2. 为每个batch请求生成token
    # 改用MindModel初始化生成的self.model_runnerforward_tensor推理接口(链接)进行推理,输入为MindFlashCausalLMBatch类(链接)中相应字段
    self.model_runner.forward_tensor(
         input_ids=input_ids,
         position_ids=position_ids,
         is_prefill=cu_seqlen_prefill is not None,
         kv_cache=kv_cache,
         block_tables=block_tables,
         slots=slots,
         input_lengths=input_lengths,
         max_seq_len=max_s,
         lm_head_indices=lm_head_indices,
     )
     
    # 3. 后处理部分,输入参数适配MindIELLMHeterogeneousNextTokenChooser类中的采样方法
    (
         next_input_ids,
         next_token_logprobs,
         logprobs,
         accepted_ids,
         speculative_ids,
     ) = batch.next_token_chooser(
         batch.all_input_ids_tensor[:, : batch.max_seqlen],
         next_token_logits,
         speculate,
         batch.speculative_ids,
         speculative_logits,
     )
  3. (可选)添加多模态模型类VlmMindModel:

    多模态模型类继承自MindModel,在初始化部分额外设置了Tokenize方法,该方法由MindIE LLM侧提供,专用于多模态的编码。

    class VlmMindModel(MindModel):
        def __init__(
                self,
                model_id: str
        ):
            logger.warning("Initialize mindie-llm model for vlm.")
            # 使用父类初始化方法
            super(VlmMindModel, self).__init__(model_id)
            # 额外设置tokenize, 该方法是由MindIE-LLM提供的专用于多模态编码的方法
            self.tokenize = TokenizerWrapper(model_id).tokenize
            logger.warning("VlmMindModel from tgi_npu initialized.")
        @property
        def batch_type(self) -> Type[VlmMindFlashCausalLMBatch]:
            # 返回多模态Batch类型, 用于在server端接收gRPC请求后转换为多模态Batch
            return VlmMindFlashCausalLMBatch
  4. 添加文本请求的Batch类MindFlashCausalLMBatch:

    创建MindFlashCausalLMBatch类作为MindIE LLM的文本请求Batch,继承自原TGI框架的FlashCausalLMBatch类。

    @dataclass
     class MindFlashCausalLMBatch(FlashCausalLMBatch):
    # 1. 在基类FlashCausalLMBatch上增加两个字段,next_token_chooser(MindIE LLM的后处理类)、all_input_ids_tensor为(
    # 给next_token_chooser进行sampling
         next_token_chooser: MindIELLMHeterogeneousNextTokenChooser
         all_input_ids_tensor: torch.Tensor
     
    def from_tokenized(
             cls,
             pb: generate_pb2.Batch,
             tokenizer: PreTrainedTokenizerBase,
             batch_tokenized_inputs,
             dtype: torch.dtype,
             device: torch.device,
     ) -> "MindFlashCausalLMBatch":
     
    …
    # 2.构建后处理类
         next_token_chooser = MindIELLMHeterogeneousNextTokenChooser.from_pb(
             pb=next_token_chooser_parameters, dtype=dtype, device=device
         )
    
    # 3. 构建传给后处理类的input id tensor
         # Padded all_input_ids_tensor
         all_input_ids_tensor = np.zeros(
             (len(all_input_ids), max_length), dtype=np.int64
         )
         for i, input_ids in enumerate(all_input_ids):
             all_input_ids_tensor[i, : len(input_ids)] = input_ids
    
         all_input_ids_tensor = torch.tensor(
             all_input_ids_tensor, dtype=torch.int64, device=device
         )
  5. (可选)添加多模态请求的Batch类VlmMindFlashCausalLMBatch:

    Prefill和Warmup阶段的Batch是从Web Router通过grpc协议传给text_generation_server,再由text_generation_server反序列化成可供MindIE LLM推理的Batch。而文本和多模态Batch的处理方式不同,多模态支持的模型需要区别引入。

    适配文件:server/text_generation_server/server.py

    # 增加新引入的VlmMindFlashCausalLMBatch
    from tgi_npu.vlm_mind_models import VlmMindFlashCausalLMBatch
    VLM_BATCH_TYPES = {VlmMindFlashCausalLMBatch}

    创建VlmMindFlashCausalLMBatch类作为MindIE LLM的多模态请求Batch,继承自NPU适配后的MindFlashCausalLMBatch类。主要是针对入参进行参数拆解组装,并使用Tokenzie对入参进行编码。

    def split(string) -> List[Dict[str, str]]:
        parts = []
        cursor = 0
        for pattern in IMAGES.finditer(string):
            start = pattern.start()
            if start != cursor:
                parts.append({"text": string[cursor:start]})
            parts.append({"image": pattern.group(1)})
            cursor = pattern.end()
        if cursor != len(string):
            parts.append({"text": string[cursor:]})
        return parts
    @dataclass
    class VlmMindFlashCausalLMBatch(MindFlashCausalLMBatch):
        @classmethod
        def batch_tokenized_inputs(cls, requests, tokenize):
            inputs = []
            for r in requests:
                splits = split(r.inputs)
                single_input = tokenize(splits).tolist()
                inputs.append(single_input)
            return inputs
        @classmethod
        def from_pb_processor(
                cls,
                pb: generate_pb2.Batch,
                tokenizer: PreTrainedTokenizerBase,
                tokenize,
                dtype: torch.dtype,
                device: torch.device,
        ) -> "VlmMindFlashCausalLMBatch":
            batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenize)
            batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
            return batch
    
  6. 增加后处理类MindIELLMHeterogeneousNextTokenChooser:

    后处理类根据Batch中每个Request的采样参数,为每个Batch构造一个后处理对象。Sampling为从推理生成的logits中采样出token(s)。

    from mindie_llm.text_generator.utils.sampling_metadata import SamplingData, SamplingParam
    from mindie_llm.text_generator.utils.config import SamplerConfig
    from mindie_llm.text_generator.samplers.sampler import Sampler
     from mindie_llm.modeling.backend_type import BackendType
     
    class MindIELLMHeterogeneousNextTokenChooser:
     
    # 1. 初始化函数,从grpc接收的后采样参数构造mindie sampler
    curr_rank = int(os.getenv("RANK", "0"))
     sample_method = Sampler(SamplerConfig(rank=curr_rank, backend_type=BackendType.ATB, npu_id=curr_rank))
    self.tensor_wrapper = TensorWrapper(BackendType.ATB, device)
     wrapper_dict = {WRAPPER_KEY: TensorWrapper(BackendType.ATB, device)}
    self.sample_params = SamplingParam.from_numpy(
         repetition_penalty=np.array(repetition_penalty, dtype=np.float16),
         presence_penalty=None,
         frequency_penalty=np.array(frequency_penalty, dtype=np.float16),
         temperature=np.array(temperature, dtype=np.float16),
         top_k=np.array(top_k),
         top_p=np.array(top_p),
         seed=np.array(seeds).astype(np.int32),
         do_sample=np.array(do_sample),
         **wrapper_dict
     )
    
    self.choice = sample_method
    self.dtype = dtype
    self.device = device
    self.seeds = seeds
     self.do_sample = self.sample_params.do_sample_meta.do_sample_array
     
    # 2. 根据输入token id及推理的logits(scores),采样出下一个tokenidnext_ids
    def __call__(self,
                  input_ids: torch.Tensor,
                  scores: torch.Tensor,
                  speculate: int,
                  speculated_ids: Optional[torch.Tensor] = None,
                  speculative_scores: Optional[torch.Tensor] = None,
                  ):
         batch_size = scores.shape[0]
         speculate_size = 1
         scores = scores.view(batch_size, speculate_size, -1)
    
         input_ids_int32 = input_ids.to(torch.int32)
         sample_data = SamplingData(all_input_ids=input_ids_int32, output_ids=input_ids_int32)
         next_ids = torch.zeros((batch_size, speculate_size), device=scores.device, dtype=torch.long)
         for j in range(speculate_size):
             _scores = scores[:, j]
    
             batch_logits, _next_ids = self.choice(batch_logits=_scores, batch_sampling_data=sample_data,
                                                   batch_sampling_params=self.sample_params)
             scores[:, j] = _scores
             next_ids[:, j] = torch.from_numpy(_next_ids)
         next_ids = next_ids.view(batch_size * speculate_size)
         allscores = scores.view(batch_size * speculate_size, -1)
         alllogprobs = torch.log_softmax(allscores, -1)
    
         accepted_ids = torch.ones_like(next_ids)
         logprobs = alllogprobs
    
         next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
    
         speculative_ids = None
         return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
  7. Cache管理:

    Warmup阶段的目的是计算出TGI NPU能预留出的最大KV Cache,以此计算出最大Batch Token数量。

    • 需要适配的部分是获取Warmup Batch计算时峰值NPU显存占用、NPU卡最大显存、以及计算出装满剩余显存的最大Token数量,峰值时Token数量的计算。

      类:MindModel

      方法:warmup(用下面方式修改基类FlashCausalLM的Warmup中获取GPU显存)

      # NPU卡总显存
      total_gpu_memory = torch_npu.npu.get_device_properties(self.device).total_memory
       
      # 峰值显存
       peak_memory = torch_npu.npu.max_memory_allocated()
      # 剩余显存
       total_free_memory = total_gpu_memory - peak_memory
       logger.warning(f">>>>total_free_memory {total_free_memory}, total_gpu_memory {total_gpu_memory}, "
                      f"MEMORY_FRACTION {MEMORY_FRACTION}")
       # 剩余可用显存
      free_memory = max(0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory)
      # 总共可支持的KV Block
       num_blocks = (
           # Leave 5% for some wiggle room
               int((free_memory * 0.95) // total_cache_size)
               # Add batch.blocks as we allocated it above, so it is included in the peak memory.
               + cache_manager.num_blocks
       )
      #总共可支持的KV slottotal batch tokens),BLOCK_SIZE建议128
      return int(num_blocks * BLOCK_SIZE)
    • 为Batch中每个request分配KV Cache由CacheManager类负责,与GPU上的CacheManager类相比,进行如下适配:

      每个BLOCK_SIZE(即一个BLOCK中SLOT数量)建议设置为128。

      NPU卡上,910是ND数据排布, 310是NZ排布,KV Cache分配需要考虑。

      适配文件:server\text_generation_server\models\cache_manager.py

      类:CacheManager

      方法:__init__

      from tgi_npu.info import NPUSocInfo
      # 1. 建议修改为128
      BLOCK_SIZE: int = 128
      # Will be set in warmup
      CACHE_MANAGER: Optional["CacheManager"] = None
      
       class CacheManager:
           def __init__(
                   self,
                   num_blocks: int,
                   num_layers: int,
                   num_heads: int,
                   head_size: int,
                   repeat_slots: bool,
                   dtype: torch.dtype,
                   device: torch.device,
           ):
               self.block_size = BLOCK_SIZE
               self.num_blocks = num_blocks
               self.repeat_slots = repeat_slots
               soc_info = NPUSocInfo()
               # 2.根据npu卡设置kvcache中tensor的数据排布格式,for 910 ND, 310 NZ
               self.need_nz = soc_info.need_nz
               if self.need_nz:
                   self.kv_cache = [
                       (
                           torch.empty(
                               (num_blocks, num_heads * head_size // 16, self.block_size, 16),
                               dtype=dtype,
                               device=device,
                           ),
                           torch.empty(
                               (num_blocks, num_heads * head_size // 16, self.block_size, 16),
                               dtype=dtype,
                               device=device,
                           ),
                       )
                       for _ in range(num_layers)
                   ]
               else:
                   self.kv_cache = [
                       (
                           torch.empty(
                               (num_blocks, self.block_size, num_heads, head_size),
                               dtype=dtype,
                               device=device,
                           ),
                           torch.empty(
                               (num_blocks, self.block_size, num_heads, head_size),
                               dtype=dtype,
                               device=device,
                           ),
                       )
                       for _ in range(num_layers)
                   ]
               self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
               self.slots = torch.arange(
                   0, num_blocks * self.block_size, dtype=torch.int64
               ).view(num_blocks, self.block_size)
    • NPU卡信息判断:
      @dataclass
      class NPUSocInfo:
           soc_name: str = ""
           soc_version: int = -1
           need_nz: bool = False
               self.soc_version = torch_npu._C._npu_get_soc_version()
               if self.soc_version in (100, 101, 102, 103, 104, 200, 201, 202, 203):
                   self.need_nz = True