TGI昇腾框架适配说明
以TGI v2.0.4版本为例,适配的部分在server/text_generation_server包中,主要适配部分包括:MindIE LLM切入、Batch结构修改、MindIE LLM推理和后处理的调用、Cache管理在NPU上的配置。
- MindIE LLM切入:将原TGI框架从调用GPU模型切到MindIE LLM
适配文件:server/text_generation_server/models/__init__.py
# 1.引入tgi在npu上的适配包,tgi_npu try: import torch_npu from tgi_npu import MindModel, VlmMindModel npu_module_imported = True except (ImportError, NotImplementedError): npu_module_imported = False # 2.多模态模型支持Qwen_VL模型;文本模型支持范围同MindIE LLM if npu_module_imported and torch.npu.is_available(): if model_type == QWEN_VL: return VlmMindModel(model_id) else: return MindModel(model_id)
- 添加文本模型类MindModel:
MindModel类继承自TGI框架的FlashCausalLM类,主要适配部分为初始化MindIE LLM模型,并获得模型、Tokenizer及模型相关的信息。
class MindModel(FlashCausalLM): def __init__( self, model_id: str ): logger.warning("Initialize mindie-llm model.") rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) model_config = { 'backend_type': BackendType.ATB, 'rank': rank, 'world_size': world_size, 'model_id': model_id, 'num_threads': 8, 'local_rank': rank, 'npu_device_id': rank } # 1. 初始化mindie llm模型,获得的self.model_runner包含模型信息 self.model_runner = GeneratorTorch(model_config) super(MindModel, self).__init__( model=self.model_runner.model_wrapper.model_runner.model, tokenizer=self.model_runner.tokenizer, num_layers=self.model_runner.model_info.num_layers, num_kv_heads=self.model_runner.model_info.num_kv_heads, head_size=self.model_runner.model_info.head_size, dtype=self.model_runner.model_info.dtype, device=self.model_runner.model_info.device, rank=self.model_runner.rank, world_size=self.model_runner.world_size, ) logger.warning("MindModel from tgi_npu initialized.") # 2. 为每个batch请求生成token # 改用MindModel初始化生成的self.model_runner的forward_tensor推理接口(链接)进行推理,输入为MindFlashCausalLMBatch类(链接)中相应字段 self.model_runner.forward_tensor( input_ids=input_ids, position_ids=position_ids, is_prefill=cu_seqlen_prefill is not None, kv_cache=kv_cache, block_tables=block_tables, slots=slots, input_lengths=input_lengths, max_seq_len=max_s, lm_head_indices=lm_head_indices, ) # 3. 后处理部分,输入参数适配MindIELLMHeterogeneousNextTokenChooser类中的采样方法 ( next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids, ) = batch.next_token_chooser( batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, speculate, batch.speculative_ids, speculative_logits, )
- (可选)添加多模态模型类VlmMindModel:
多模态模型类继承自MindModel,在初始化部分额外设置了Tokenize方法,该方法由MindIE LLM侧提供,专用于多模态的编码。
class VlmMindModel(MindModel): def __init__( self, model_id: str ): logger.warning("Initialize mindie-llm model for vlm.") # 使用父类初始化方法 super(VlmMindModel, self).__init__(model_id) # 额外设置tokenize, 该方法是由MindIE-LLM提供的专用于多模态编码的方法 self.tokenize = TokenizerWrapper(model_id).tokenize logger.warning("VlmMindModel from tgi_npu initialized.") @property def batch_type(self) -> Type[VlmMindFlashCausalLMBatch]: # 返回多模态Batch类型, 用于在server端接收gRPC请求后转换为多模态Batch return VlmMindFlashCausalLMBatch
- 添加文本请求的Batch类MindFlashCausalLMBatch:
创建MindFlashCausalLMBatch类作为MindIE LLM的文本请求Batch,继承自原TGI框架的FlashCausalLMBatch类。
@dataclass class MindFlashCausalLMBatch(FlashCausalLMBatch): # 1. 在基类FlashCausalLMBatch上增加两个字段,next_token_chooser(MindIE LLM的后处理类)、all_input_ids_tensor为( # 给next_token_chooser进行sampling) next_token_chooser: MindIELLMHeterogeneousNextTokenChooser all_input_ids_tensor: torch.Tensor def from_tokenized( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, batch_tokenized_inputs, dtype: torch.dtype, device: torch.device, ) -> "MindFlashCausalLMBatch": … # 2.构建后处理类 next_token_chooser = MindIELLMHeterogeneousNextTokenChooser.from_pb( pb=next_token_chooser_parameters, dtype=dtype, device=device ) # 3. 构建传给后处理类的input id tensor # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( (len(all_input_ids), max_length), dtype=np.int64 ) for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids all_input_ids_tensor = torch.tensor( all_input_ids_tensor, dtype=torch.int64, device=device )
- (可选)添加多模态请求的Batch类VlmMindFlashCausalLMBatch:
Prefill和Warmup阶段的Batch是从Web Router通过grpc协议传给text_generation_server,再由text_generation_server反序列化成可供MindIE LLM推理的Batch。而文本和多模态Batch的处理方式不同,多模态支持的模型需要区别引入。
适配文件:server/text_generation_server/server.py
# 增加新引入的VlmMindFlashCausalLMBatch from tgi_npu.vlm_mind_models import VlmMindFlashCausalLMBatch VLM_BATCH_TYPES = {VlmMindFlashCausalLMBatch}
创建VlmMindFlashCausalLMBatch类作为MindIE LLM的多模态请求Batch,继承自NPU适配后的MindFlashCausalLMBatch类。主要是针对入参进行参数拆解组装,并使用Tokenzie对入参进行编码。
def split(string) -> List[Dict[str, str]]: parts = [] cursor = 0 for pattern in IMAGES.finditer(string): start = pattern.start() if start != cursor: parts.append({"text": string[cursor:start]}) parts.append({"image": pattern.group(1)}) cursor = pattern.end() if cursor != len(string): parts.append({"text": string[cursor:]}) return parts @dataclass class VlmMindFlashCausalLMBatch(MindFlashCausalLMBatch): @classmethod def batch_tokenized_inputs(cls, requests, tokenize): inputs = [] for r in requests: splits = split(r.inputs) single_input = tokenize(splits).tolist() inputs.append(single_input) return inputs @classmethod def from_pb_processor( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, tokenize, dtype: torch.dtype, device: torch.device, ) -> "VlmMindFlashCausalLMBatch": batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenize) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) return batch
- 增加后处理类MindIELLMHeterogeneousNextTokenChooser:
后处理类根据Batch中每个Request的采样参数,为每个Batch构造一个后处理对象。Sampling为从推理生成的logits中采样出token(s)。
from mindie_llm.text_generator.utils.sampling_metadata import SamplingData, SamplingParam from mindie_llm.text_generator.utils.config import SamplerConfig from mindie_llm.text_generator.samplers.sampler import Sampler from mindie_llm.modeling.backend_type import BackendType class MindIELLMHeterogeneousNextTokenChooser: # 1. 初始化函数,从grpc接收的后采样参数构造mindie sampler curr_rank = int(os.getenv("RANK", "0")) sample_method = Sampler(SamplerConfig(rank=curr_rank, backend_type=BackendType.ATB, npu_id=curr_rank)) self.tensor_wrapper = TensorWrapper(BackendType.ATB, device) wrapper_dict = {WRAPPER_KEY: TensorWrapper(BackendType.ATB, device)} self.sample_params = SamplingParam.from_numpy( repetition_penalty=np.array(repetition_penalty, dtype=np.float16), presence_penalty=None, frequency_penalty=np.array(frequency_penalty, dtype=np.float16), temperature=np.array(temperature, dtype=np.float16), top_k=np.array(top_k), top_p=np.array(top_p), seed=np.array(seeds).astype(np.int32), do_sample=np.array(do_sample), **wrapper_dict ) self.choice = sample_method self.dtype = dtype self.device = device self.seeds = seeds self.do_sample = self.sample_params.do_sample_meta.do_sample_array # 2. 根据输入token id及推理的logits(scores),采样出下一个tokenid(next_ids) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, ): batch_size = scores.shape[0] speculate_size = 1 scores = scores.view(batch_size, speculate_size, -1) input_ids_int32 = input_ids.to(torch.int32) sample_data = SamplingData(all_input_ids=input_ids_int32, output_ids=input_ids_int32) next_ids = torch.zeros((batch_size, speculate_size), device=scores.device, dtype=torch.long) for j in range(speculate_size): _scores = scores[:, j] batch_logits, _next_ids = self.choice(batch_logits=_scores, batch_sampling_data=sample_data, batch_sampling_params=self.sample_params) scores[:, j] = _scores next_ids[:, j] = torch.from_numpy(_next_ids) next_ids = next_ids.view(batch_size * speculate_size) allscores = scores.view(batch_size * speculate_size, -1) alllogprobs = torch.log_softmax(allscores, -1) accepted_ids = torch.ones_like(next_ids) logprobs = alllogprobs next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) speculative_ids = None return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
- Cache管理:
Warmup阶段的目的是计算出TGI NPU能预留出的最大KV Cache,以此计算出最大Batch Token数量。
- 需要适配的部分是获取Warmup Batch计算时峰值NPU显存占用、NPU卡最大显存、以及计算出装满剩余显存的最大Token数量,峰值时Token数量的计算。
方法:warmup(用下面方式修改基类FlashCausalLM的Warmup中获取GPU显存)
# NPU卡总显存 total_gpu_memory = torch_npu.npu.get_device_properties(self.device).total_memory # 峰值显存 peak_memory = torch_npu.npu.max_memory_allocated() # 剩余显存 total_free_memory = total_gpu_memory - peak_memory logger.warning(f">>>>total_free_memory {total_free_memory}, total_gpu_memory {total_gpu_memory}, " f"MEMORY_FRACTION {MEMORY_FRACTION}") # 剩余可用显存 free_memory = max(0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory) # 总共可支持的KV Block num_blocks = ( # Leave 5% for some wiggle room int((free_memory * 0.95) // total_cache_size) # Add batch.blocks as we allocated it above, so it is included in the peak memory. + cache_manager.num_blocks ) #总共可支持的KV slot(total batch tokens),BLOCK_SIZE建议128 return int(num_blocks * BLOCK_SIZE)
- 为Batch中每个request分配KV Cache由CacheManager类负责,与GPU上的CacheManager类相比,进行如下适配:
每个BLOCK_SIZE(即一个BLOCK中SLOT数量)建议设置为128。
NPU卡上,910是ND数据排布, 310是NZ排布,KV Cache分配需要考虑。
适配文件:server\text_generation_server\models\cache_manager.py
类:CacheManager
方法:__init__
from tgi_npu.info import NPUSocInfo # 1. 建议修改为128 BLOCK_SIZE: int = 128 # Will be set in warmup CACHE_MANAGER: Optional["CacheManager"] = None class CacheManager: def __init__( self, num_blocks: int, num_layers: int, num_heads: int, head_size: int, repeat_slots: bool, dtype: torch.dtype, device: torch.device, ): self.block_size = BLOCK_SIZE self.num_blocks = num_blocks self.repeat_slots = repeat_slots soc_info = NPUSocInfo() # 2.根据npu卡设置kvcache中tensor的数据排布格式,for 910 ND, 310 NZ self.need_nz = soc_info.need_nz if self.need_nz: self.kv_cache = [ ( torch.empty( (num_blocks, num_heads * head_size // 16, self.block_size, 16), dtype=dtype, device=device, ), torch.empty( (num_blocks, num_heads * head_size // 16, self.block_size, 16), dtype=dtype, device=device, ), ) for _ in range(num_layers) ] else: self.kv_cache = [ ( torch.empty( (num_blocks, self.block_size, num_heads, head_size), dtype=dtype, device=device, ), torch.empty( (num_blocks, self.block_size, num_heads, head_size), dtype=dtype, device=device, ), ) for _ in range(num_layers) ] self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") self.slots = torch.arange( 0, num_blocks * self.block_size, dtype=torch.int64 ).view(num_blocks, self.block_size)
- NPU卡信息判断:
@dataclass class NPUSocInfo: soc_name: str = "" soc_version: int = -1 need_nz: bool = False self.soc_version = torch_npu._C._npu_get_soc_version() if self.soc_version in (100, 101, 102, 103, 104, 200, 201, 202, 203): self.need_nz = True
- 需要适配的部分是获取Warmup Batch计算时峰值NPU显存占用、NPU卡最大显存、以及计算出装满剩余显存的最大Token数量,峰值时Token数量的计算。