下载
EN
注册

vLLM 0.4.2 版本昇腾框架适配说明

vllm_npu_0.4.2版本的目录结构如下所示:
vllm_npu
|-- __init__.py
|-- attention
|   |-- __init__.py
|   |-- backends.py
|   |-- selector.py
|-- config.py
|-- core
|   |-- __init__.py
|-- engine
|   |-- __init__.py
|   |-- ascend_engine.py
|   |-- async_ascend_engine.py
|-- executor
|   |-- __init__.py
|   |-- ascend_executor.py
|   |-- ascend_ray_executor.py
|   |-- ray_utils.py
|-- model_executor
|   |-- __init__.py
|   |-- ascend_model_loader.py
|   |-- layers
|   |   |-- __init__.py
|   |   |-- ascend_sampler.py
|   |-- models
|       |-- __init__.py
|       |-- ascend
|           |-- __init__.py
|           |-- mindie_llm_wrapper.py
|-- npu_adaptor.py
|-- usage
|   |-- __init__.py
|   |-- usage_lib.py
|-- utils.py
|-- worker
    |-- __init__.py
    |-- ascend_model_runner.py
    |-- ascend_worker.py
    |-- cache_engine.py

vllm_npu_0.4.2版本中重写了attention、engine、executor、model_executor、usage、worker六个模块,与vLLM原生框架中的同名模块一一对应进行热替换适配。

  • attention模块:

    重写了vLLM框架中的AttentionBackend类,对昇腾环境下对接MindIE LLM所需要的Attention计算数据以及KV Cache的shape等关键信息进行了定义,并在该模块的初始化文件中对原框架的get_atten_backend函数进行了热替换,从而使得框架在昇腾环境下会运行该模块中的attention后端类。

  • engine模块:

    该模块主要重写了vLLM引擎的from_engine_args函数,vLLM 0.4.2版本中引擎会通过该类方法进行实例化,并在其中根据运行环境信息选择对应的不同executor;这里加入了新的判断逻辑分支,当检测到运行在昇腾环境下时,引擎会选择vllm_npu补丁包中定义的AscendExecutor和RayAscendExecutor类分别去进行单卡推理和多卡推理。另外,这里对vllm原生框架的离线同步推理引擎LLMEngine和在线异步推理引擎AsyncLLMEngine的from_engine_args函数分别进行了重写替换,替换操作发生在该模块的初始化文件中。

  • executor模块:

    该模块中主要实现了四个executor类,其中AscendExecutor和AscendExecutorAsync用于单卡环境的同步和异步调用模式下的推理,RayAscendExecutor和RayAscendExecutorAsync用于多卡ray分布式环境的同步和异步调用模式下的推理。

    此外,在ray_utils.py中对initialize_ray_cluster函数进行了重写,主要是因为在昇腾的npu环境下ray无法自动识别到npu的数量,因此需要手动显示指定。

  • model_executor模块:
    • 该模块为实际对接MindIE LLM模型推理与后处理的位置,其中包括layers模块和models模块,分别对应后处理和模型推理。

      在models模块中,编写了MindIELlmWrapper类,在该类中会对MindIE LLM提供的GeneratorTorch统一接口进行实例化操作,并从vLLM原生框架的数据结构中拆解出MindIE LLM所需要的模型推理参数,从而传给统一接口调用模型推理服务;另外,在进行warmup操作时使用的fake data构造操作也在该类中实现。实现代码如下所示:

      class MindIELlmWrapper(nn.Module):
          def __init__(self, mindie_model_config, linear_method=None, lora_config=None):
              super(MindIELlmWrapper, self).__init__()
              
              self.mindie_model_config = mindie_model_config
              self.rank = mindie_model_config['rank']
              self.local_rank = mindie_model_config['local_rank']
              self.npu_id = self.local_rank
              self.world_size = mindie_model_config['world_size']
              self.mindie_model = None
              self.sampler = None
          def forward(
                  self,
                  input_ids: torch.Tensor,
                  positions: torch.Tensor,
                  kv_caches: List[KVCache],
                  attn_metadata: AttentionMetadata,
          ) -> torch.Tensor:
              is_prompt = attn_metadata.num_prefill_tokens > 0
              
              if kv_caches[0][0] is None:
                  kv_caches, block_tables, slots = self.create_dummy_kv_cache(attn_metadata, input_ids)
              else:
                  if is_prompt:
                      block_tables = torch.tensor([0], dtype=torch.int32, device="npu")
                  else:
                      block_tables = attn_metadata.decode_metadata.block_tables
                  slots = attn_metadata.slot_mapping
              if is_prompt:
                  input_lengths = attn_metadata.prefill_metadata.seq_lens_tensor.to(torch.int32)
                  max_seq_len = int(attn_metadata.prefill_metadata.seq_lens_tensor.max())
                  lm_head_indices = (attn_metadata.prefill_metadata.seq_lens_tensor.cumsum(dim=-1) - 1).to(torch.int64)
              else:
                  input_lengths = attn_metadata.decode_metadata.seq_lens_tensor
                  max_seq_len = attn_metadata.decode_metadata.max_seq_len
                  lm_head_indices = None
              
              logits = self.mindie_model.forward_tensor(input_ids, positions, is_prompt, kv_caches, block_tables, slots,
                                      input_lengths, max_seq_len, lm_head_indices)
              return logits
          def compute_logits(self, hidden_states: torch.Tensor,
                             sampling_metadata: SamplingMetadata) -> torch.Tensor:
              return hidden_states
          def sample(
              self,
              logits: torch.Tensor,
              sampling_metadata: SamplingMetadata,
          ) -> Optional[SamplerOutput]:
              # hidden_states is logits
              next_tokens = self.sampler(logits, sampling_metadata)
              return next_tokens
          def load_weights(self,
                           model_name_or_path: str,
                           cache_dir: Optional[str] = None,
                           load_format: str = "auto",
                           revision: Optional[str] = None):
              if load_format not in ['auto', 'safetensors', 'pt']:
                  raise ValueError('load-format support [safetensors, pt]')
              self.weight_dtype = torch.get_default_dtype()
              torch.set_default_dtype(torch.float32)
              self.mindie_model = GeneratorTorch(self.mindie_model_config)
              self.sampler = AscendSampler(self.mindie_model)
              torch.set_default_dtype(self.weight_dtype)
          # when warmup, create dummy kvcache, block_tables, slot_mapping
          def create_dummy_kv_cache(self, attn_metadata, input_ids):
              dummy_block_num = 1
              dummy_block_size = 128
              model_runner = self.mindie_model.model_wrapper.model_runner
              kv_cache = [
                  (
                      torch.empty(
                          (dummy_block_num, dummy_block_size, model_runner.num_kv_heads, model_runner.head_size),
                          dtype=self.weight_dtype,
                          device="npu",
                      ),
                      torch.empty(
                          (dummy_block_num, dummy_block_size, model_runner.num_kv_heads, model_runner.head_size),
                          dtype=self.weight_dtype,
                          device="npu",
                      ),
                  )
                  for _ in range(model_runner.num_layers)
              ]
              max_s = max(attn_metadata.prefill_metadata.seq_lens_tensor)
              max_need_block = math.ceil(max_s / dummy_block_size)
              batch_size = len(attn_metadata.prefill_metadata.seq_lens_tensor)
              block_tables = torch.zeros(batch_size, max_need_block, dtype=int, device="npu")
              slot = [i for i in range(dummy_block_size)]
              slots = []
              warm_up_len = len(input_ids)
              while warm_up_len > 0:
                  if warm_up_len > dummy_block_size:
                      slots.extend(slot)
                      warm_up_len -= dummy_block_size
                  else:
                      slots.extend(slot[:warm_up_len])
                      warm_up_len = 0
              slots = torch.tensor(slots, dtype=torch.long, device="npu")
              return kv_cache, block_tables, slots
    • 在layers模块中,编写实现了AscendSampler类,进行vLLM原生框架的数据结构与模型仓底层数据结构之间的对接工作,实现代码如下所示:
      class AscendSampler(nn.Module):
          def __init__(self, mindie_model):
              super().__init__()
              self.mindie_model = mindie_model
              self.include_gpu_probs_tensor = False
          def forward(
              self,
              logits: torch.Tensor,
              sampling_metadata: SamplingMetadata,
          ) -> Optional[SamplerOutput]:
              _, vocab_size = logits.shape
              mindie_sampling_data, mindie_sampling_param = self.construct_data(sampling_metadata, vocab_size)
              probs = torch.softmax(logits, dim=-1, dtype=torch.float)
              logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
              next_tokens = self.mindie_model.sample(
                  logits, 
                  sampling_data=mindie_sampling_data, 
                  sampling_param=mindie_sampling_param,
              )
              
              sample_results, maybe_sampled_tokens_tensor = recover_data(
                  sampling_metadata=sampling_metadata, 
                  sampled_tokens=next_tokens, 
                  logprobs=logprobs, 
                  include_gpu_probs_tensor=self.include_gpu_probs_tensor,
              )
              if self.include_gpu_probs_tensor:
                  if maybe_sampled_tokens_tensor is None:
                      raise RuntimeError("maybe_sampled_tokens_tensor is None")
                  on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
              else:
                  on_device_tensors = None
              # Get the logprobs query results.
              prompt_logprobs, sample_logprobs = _get_logprobs(
                  logprobs, sampling_metadata, sample_results)
              return _build_sampler_output(sample_results,
                                           sampling_metadata,
                                           prompt_logprobs,
                                           sample_logprobs,
                                           on_device_tensors=on_device_tensors)
          def construct_data(
              self,
              sampling_metadata: SamplingMetadata,
              vocab_size: int,
          ) -> Tuple[SamplingData, SamplingParam]:
              all_input_tokens: List[List[int]] = []
              prompt_tokens: List[List[int]] = []
              output_tokens: List[List[int]] = []
              top_ks: List[int] = []
              temperatures: List[float] = []
              top_ps: List[float] = []
              min_ps: List[float] = []
              presence_penalties: List[float] = []
              frequency_penalties: List[float] = []
              repetition_penalties: List[float] = []
              sampling_seeds: List[int] = []
              sample_indices: List[int] = []
              do_samples: List[bool] = []  # To Do
              do_penalties = False
              do_top_p_top_k = False
              do_min_p = False
              greedy_flag = False
              
              if sampling_metadata.seq_groups is None:
                  raise RuntimeError("sampling_metadata.seq_group is None, no data received.")
              for seq_group in sampling_metadata.seq_groups:
                  do_samples.append(seq_group.do_sample)
                  seq_ids = seq_group.seq_ids
                  sampling_params = seq_group.sampling_params
                  temperature = sampling_params.temperature
                  p = sampling_params.presence_penalty
                  f = sampling_params.frequency_penalty
                  r = sampling_params.repetition_penalty
                  top_p = sampling_params.top_p
                  min_p = sampling_params.min_p
                  is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
                  seed = sampling_params.seed
                  if seed is None:
                      if is_greedy:
                          seed = 0
                      else:
                          lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
                          seed = random.randint(lo, hi)
                  if is_greedy:
                      greedy_flag = True
                  # k should not be greater than the vocab size.
                  top_k = min(sampling_params.top_k, vocab_size)
                  top_k = vocab_size if top_k == -1 else top_k
                  if temperature < _SAMPLING_EPS:
                      temperature = 1.0
                  if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
                                             or top_k != vocab_size):
                      do_top_p_top_k = True
                  if not do_min_p and min_p > _SAMPLING_EPS:
                      do_min_p = True
                  if not do_penalties:
                      if abs(p) >= _SAMPLING_EPS:
                          do_penalties = True
                      elif abs(f) >= _SAMPLING_EPS:
                          do_penalties = True
                      elif abs(r - 1.0) >= _SAMPLING_EPS:
                          do_penalties = True
                  is_prompt = seq_group.is_prompt
                  if (seq_group.is_prompt
                          and sampling_params.prompt_logprobs is not None):
                      # For tokens in the prompt that we only need to get
                      # their logprobs
                      query_len = seq_group.query_len
                      if query_len is None:
                          raise RuntimeError("query_len is None")
                      prefill_len = len(seq_group.prompt_logprob_indices)
                      temperatures += [temperature] * prefill_len
                      sampling_seeds += [seed] * prefill_len
                      top_ps += [top_p] * prefill_len
                      top_ks += [top_k] * prefill_len
                      min_ps += [min_p] * prefill_len
                      presence_penalties += [0] * prefill_len
                      frequency_penalties += [0] * prefill_len
                      repetition_penalties += [1] * prefill_len
                      prompt_tokens.extend([] for _ in range(prefill_len))
                      output_tokens.extend([] for _ in range(prefill_len))
                      all_input_tokens.extend([] for _ in range(prefill_len))
                  if seq_group.do_sample:
                      sample_lens = len(seq_group.sample_indices)
                      if sample_lens != len(seq_ids):
                          raise ValueError("sample_lens != len(seq_ids)")
                      for seq_id in seq_ids:
                          seq_data = seq_group.seq_data[seq_id]
                          prompt_tokens.append(seq_data.prompt_token_ids)
                          output_tokens.append(seq_data.output_token_ids)
                          all_input_tokens.append(seq_data.prompt_token_ids + seq_data.output_token_ids)
                      temperatures += [temperature] * len(seq_ids)
                      sampling_seeds += [seed] * len(seq_ids)
                      top_ps += [top_p] * len(seq_ids)
                      top_ks += [top_k] * len(seq_ids)
                      min_ps += [min_p] * len(seq_ids)
                      presence_penalties += [p] * len(seq_ids)
                      frequency_penalties += [f] * len(seq_ids)
                      repetition_penalties += [r] * len(seq_ids)
              repetition_penalties = np.array(repetition_penalties, dtype=np.float32)
              frequency_penalties = np.array(frequency_penalties, dtype=np.float32)
              presence_penalties = np.array(presence_penalties, dtype=np.float32)
              temperatures = np.array(temperatures, dtype=np.float32)
              top_ks = np.array(top_ks, dtype=np.int32)
              top_ps = np.array(top_ps, dtype=np.float32)
              sampling_seeds = np.array(sampling_seeds)
              do_samples = np.array(do_samples)
              max_tokens_len = max([len(tokens) for tokens in all_input_tokens], default=0)
              padded_all_input_tokens = [
                  tokens + [vocab_size] * (max_tokens_len - len(tokens))
                  for tokens in all_input_tokens
              ]
              padded_all_input_tokens = np.array(padded_all_input_tokens, dtype=np.int32)
              output_max_len = max([len(tokens) for tokens in output_tokens], default=0)
              padded_output_tokens = [
                  tokens + [vocab_size] * (output_max_len - len(tokens))
                  for tokens in output_tokens
              ]
              padded_output_tokens = np.array(padded_output_tokens, dtype=np.int32)
              all_input_ids_tensor = _to_tensor(
                  padded_all_input_tokens, 
                  torch.int32
              ) if padded_all_input_tokens is not None else None
              output_ids_tensor = _to_tensor(
                  padded_output_tokens, 
                  torch.int32
              ) if padded_output_tokens is not None else None
              mindie_sampling_data = SamplingData(
                  all_input_ids=all_input_ids_tensor, 
                  output_ids=output_ids_tensor
              )
              if greedy_flag:
                  mindie_sampling_param = None
              else:
                  mindie_sampling_param = SamplingParam.from_numpy(
                      repetition_penalty=repetition_penalties,
                      frequency_penalty=frequency_penalties,
                      presence_penalty=presence_penalties,
                      temperature=temperatures,
                      top_k=top_ks,
                      top_p=top_ps,
                      seed=sampling_seeds,
                      do_sample=do_samples,
                      to_tensor=_to_tensor,
                  )
              return (mindie_sampling_data, mindie_sampling_param)
      def recover_data(
          sampling_metadata: SamplingMetadata,
          sampled_tokens: np.ndarray,
          logprobs: torch.Tensor,
          include_gpu_probs_tensor: bool,
      ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
          categorized_seq_group_ids: Dict[SamplingType,
                                          List[int]] = {t: []
                                                        for t in SamplingType}
          categorized_sample_indices = sampling_metadata.categorized_sample_indices
          for i, seq_group in enumerate(sampling_metadata.seq_groups):
              sampling_params = seq_group.sampling_params
              sampling_type = sampling_params.sampling_type
              categorized_seq_group_ids[sampling_type].append(i)
          sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
          sample_metadata = {}
          # Create output tensor for sampled token ids.
          if include_gpu_probs_tensor:
              sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
                                                     1,
                                                     dtype=torch.long,
                                                     device=logprobs.device)
          else:
              sampled_token_ids_tensor = None
          for sampling_type in SamplingType:
              sample_indices = categorized_sample_indices[sampling_type][:, 0]
              num_tokens = len(sample_indices)
              if num_tokens == 0:
                  continue
              seq_group_id = categorized_seq_group_ids[sampling_type]
              seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
              sample_metadata[sampling_type] = (seq_group_id, seq_groups)
          for sampling_type in SamplingType:
              if sampling_type not in sample_metadata:
                  continue
              (seq_group_id, seq_groups) = sample_metadata[sampling_type]
              if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, SamplingType.RANDOM_SEED):
                  sample_results = normal_wrap(seq_groups, sampled_tokens)
              elif sampling_type == SamplingType.BEAM:
                  sample_results = beam_wrap(seq_groups, sampled_tokens)
              sample_results_dict.update(zip(seq_group_id, sample_results))
          sample_results = [
              sample_results_dict.get(i, ([], []))
              for i in range(len(sampling_metadata.seq_groups))
          ]
          return sample_results, sampled_token_ids_tensor
      def normal_wrap(
          selected_seq_groups: List[SequenceGroupToSample],
          samples: np.ndarray,
      ):
          samples = samples.tolist()
          sample_idx = 0
          results: SampleResultType = []
          for seq_group in selected_seq_groups:
              if not seq_group.do_sample:
                  results.append(([], []))
                  continue
              seq_ids = seq_group.seq_ids
              num_parent_seqs = len(seq_ids)
              parent_ids = list(range(num_parent_seqs))
              next_token_ids = [samples[sample_idx]]
              results.append((next_token_ids, parent_ids))
              sample_idx += num_parent_seqs
          return results

      另外,该模块中重写了vLLM框架的get_model和get_architecture_class_name函数,从而将MindIELlmWrapper类引入到vLLM框架中。

  • usage模块:

    对vLLM框架中该模块里的UsageMessage类的_report_usage_once成员函数进行了重写,修改了其中的torch.cuda.get_device_properties函数的使用方式,该函数目前在昇腾环境上的使用方式和GPU环境上有所差异。

  • worker模块:

    实现了AscendWorker类,以供executor模块中的executor类进行调用;实现了AscendModelRunner类,在AscendWorker中进行调用。

    替换原生框架中CacheEngine的_allocate_kv_cache函数,主要是对生成kv_cache的数据格式进行了修改,从Torch.tensor修改为Tuple[torch.Tensor, torch.Tensor]。

    AscendModelRunner类继承自原生框架中ModelRunner类,主要是为了对原生的load_model,execute_model和profile_run函数进行重写:vLLM新版本中执行模型调用时分为了先调用模型生成hidden_states,再使用一个process处理hidden_statesd得到logits,再进行最后的sample操作得到结果;而在MindIE模型仓中前两步操作是通过模型调用一步完成的,因此在这里进行了修改;profile_run函数的修改主要是为了构造warmup时使用的fake data。

除了上述的六个模块的适配外,还有一些主模块外的Python文件里的函数需要热替换,包括config.py中的DeviceConfig类中引入了NPU作为device_type,在utils.py文件中引入了is_ascend()函数用于检测当前运行环境是否为昇腾环境。最后,在npu_adaptor.py,对vLLM原框架中导入的一些昇腾环境下不具备的包(例如预编译的cuda算子、triton等)进行了屏蔽操作。

  • 多模态模型Qwen-VL支持:

    为了适配多模态模型Qwen-VL的推理,需要对vLLM框架中离线推理接口类LLM的generate函数进行修改,以及对AscendModelRunner中构造warmup假数据的部分进行修改,并且自己定义了Qwen-VL的tokenizer需要的多模态输入数据格式,同时定义了一个新的MindIETokenizer类来对接MindIE LLM的前处理,具体的适配细节如下:

    首先需要在vllm_npu包里新建entrypoints和transformers_utils两个模块和一个sequence.py文件,entrypoints模块下新建__init__.py和llm.py两个文件,transformers_utils模块下新建__init__.py和mindie_tokenizer.py两个文件。

    1. 在vllm_npu/sequence.py文件中添加如下新类用于定义Qwen-VL的多模态数据。
      from typing import Dict, List
      class MultiModalData:
          """Multi modal input for MindIE LLM.
          Args:
              data_list: List of input data in the format of dict. For example:
              [
                  {"image": url_of_image1},
                  {"image": url_of_image1},
                  {"text": input_prompt1},
                  {"text": input_prompt1}
              ]
          """
          def __init__(self, data_list: List[Dict[str, str]]):
              self.data_list = data_list
    2. 在vllm_npu/transformers_utils/mindie_tokenizer.py文件中引入了新类MindIETokenizer来对接MindIE LLM的前处理。
      from typing import List, Optional
      from transformers import PreTrainedTokenizer
      from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
          BaseTokenizerGroup)
      from vllm.transformers_utils.tokenizer import get_lora_tokenizer, get_lora_tokenizer_async
      from vllm.lora.request import LoRARequest
      from atb_llm.runner.tokenizer_wrapper import TokenizerWrapper
      class MindIETokenizer(BaseTokenizerGroup):
          """A group of tokenizers that can be used for LoRA adapters."""
          def __init__(self, model_path: str):
              self.tokenizer_wrapper = TokenizerWrapper(model_path)
              self.tokenizer = self.tokenizer_wrapper.tokenizer
              self.lora_tokenizers = None
              self.enable_lora = False
          def ping(self) -> bool:
              """Check if the tokenizer group is alive."""
              return True
          def get_max_input_len(self,
                                lora_request: Optional[LoRARequest] = None
                                ) -> Optional[int]:
              """Get the maximum input length for the LoRA request."""
              return 0
          def encode(self,
                     prompt: str,
                     request_id: Optional[str] = None,
                     lora_request: Optional[LoRARequest] = None) -> List[int]:
              tokenizer = self.get_lora_tokenizer(lora_request)
              return tokenizer.encode(prompt)
          async def encode_async(
                  self,
                  prompt: str,
                  request_id: Optional[str] = None,
                  lora_request: Optional[LoRARequest] = None) -> List[int]:
              tokenizer = await self.get_lora_tokenizer_async(lora_request)
              return tokenizer.encode(prompt)
          def get_lora_tokenizer(
                  self,
                  lora_request: Optional[LoRARequest] = None
          ) -> "PreTrainedTokenizer":
              if not lora_request or not self.enable_lora:
                  return self.tokenizer
              if lora_request.lora_int_id not in self.lora_tokenizers:
                  tokenizer = (get_lora_tokenizer(
                      lora_request, **self.tokenizer_config) or self.tokenizer)
                  self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
                  return tokenizer
              else:
                  return self.lora_tokenizers.get(lora_request.lora_int_id)
          async def get_lora_tokenizer_async(
                  self,
                  lora_request: Optional[LoRARequest] = None
          ) -> "PreTrainedTokenizer":
              if not lora_request or not self.enable_lora:
                  return self.tokenizer
              if lora_request.lora_int_id not in self.lora_tokenizers:
                  tokenizer = (await get_lora_tokenizer_async(
                      lora_request, **self.tokenizer_config) or self.tokenizer)
                  self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
                  return tokenizer
              else:
                  return self.lora_tokenizers.get(lora_request.lora_int_id)
    3. 在vllm_npu/entrypoints/llm.py文件里重新定义generate函数,引入多模态数据的输入接口。
      import torch
      from typing import List, Optional, Union
      from vllm.lora.request import LoRARequest
      from vllm.outputs import RequestOutput
      from vllm.sampling_params import SamplingParams
      from vllm.transformers_utils.detokenizer import Detokenizer
      from vllm_npu.transformers_utils import MindIETokenizer
      from vllm_npu.sequence import MultiModalData
      
      def generate(
          self,
          prompts: Optional[Union[str, List[str]]] = None,
          sampling_params: Optional[Union[SamplingParams,
                                          List[SamplingParams]]] = None,
          prompt_token_ids: Optional[List[List[int]]] = None,
          use_tqdm: bool = True,
          lora_request: Optional[LoRARequest] = None,
          multi_modal_data_list: Optional[List[MultiModalData]] = None,
      ) -> List[RequestOutput]:
          """Generates the completions for the input prompts.
          Args:
              prompts: A list of prompts to generate completions for.
              sampling_params: The sampling parameters for text generation. If
                  None, we use the default sampling parameters. 
                  When it is a single value, it is applied to every prompt. 
                  When it is a list, the list must have the same length as the 
                  prompts and it is paired one by one with the prompt.
              prompt_token_ids: A list of token IDs for the prompts. If None, we
                  use the tokenizer to convert the prompts to token IDs.
              use_tqdm: Whether to use tqdm to display the progress bar.
              lora_request: LoRA request to use for generation, if any.
              multi_modal_data_list: List of Multi modal data. Each element in this list 
              is organized in the form of List[Dict[str, str]]
          Returns:
              A list of `RequestOutput` objects containing the generated
              completions in the same order as the input prompts.
          """
          if prompts is None and prompt_token_ids is None and multi_modal_data_list is None:
              raise ValueError("Either prompts or prompt_token_ids of multi_modal_data must be "
                                  "provided.")
          if self.llm_engine.model_config.skip_tokenizer_init \
              and prompts is not None:
              raise ValueError("prompts must be None if skip_tokenizer_init "
                                  "is True")
          if isinstance(prompts, str):
              # Convert a single prompt to a list.
              prompts = [prompts]
          if (prompts is not None and prompt_token_ids is not None
                  and len(prompts) != len(prompt_token_ids)):
              raise ValueError("The lengths of prompts and prompt_token_ids "
                                  "must be the same.")
          if multi_modal_data_list and self.llm_engine.tokenizer is None:
              self.llm_engine.tokenizer = MindIETokenizer(self.llm_engine.model_config.model)
              self.llm_engine.detokenizer = Detokenizer(self.llm_engine.tokenizer)
              self.llm_engine.output_processor.detokenizer = self.llm_engine.detokenizer
          if prompts is not None:
              num_requests = len(prompts)
          elif multi_modal_data_list is not None:
              num_requests = len(multi_modal_data_list)
          else:
              assert prompt_token_ids is not None
              num_requests = len(prompt_token_ids)
          if sampling_params is None:
              # Use default sampling params.
              sampling_params = SamplingParams()
          elif isinstance(sampling_params,
                          list) and len(sampling_params) != num_requests:
              raise ValueError("The lengths of prompts and sampling_params "
                                  "must be the same.")
          # Add requests to the engine.
          for i in range(num_requests):
              prompt = prompts[i] if prompts is not None else None
              token_ids = None if prompt_token_ids is None else prompt_token_ids[
                  i]
              if multi_modal_data_list:
                  token_ids = self.llm_engine.tokenizer.tokenizer_wrapper.tokenize(
                      multi_modal_data_list[i].data_list).tolist()
                  # print(token_ids)
              self._add_request(
                  prompt,
                  sampling_params[i]
                  if isinstance(sampling_params, list) else sampling_params,
                  token_ids,
                  lora_request=lora_request,
                  # Get ith image while maintaining the batch dim.
                  multi_modal_data=None,
              )
          outputs = self._run_engine(use_tqdm)
          for output in outputs:
              token_ids = output.outputs[0].token_ids
              token_ids_tensor = torch.tensor(token_ids, dtype=torch.int64)
              mindie_generated_text = self.llm_engine.tokenizer.tokenizer.decode(token_ids_tensor, False)
              output.outputs[0].text = mindie_generated_text
          return outputs
    4. 修改vllm_npu/worker/ascend_model_runner.py文件,将其中从vLLM框架里导入的_prepare_fake_inputs去掉,重新定义该函数如下。
      def _prepare_fake_inputs(
              seq_len: int, model_config: ModelConfig):
          """Prepare fake inputs for profile run."""
          if getattr(model_config.hf_config, "visual", None) is not None:
              img_start_id = model_config.hf_config.visual["image_start_id"]
              img_end_id = img_start_id + 1
              img_patch_id = img_start_id + 2
              fake_img_token_ids = [24669, 220, 16, 25, 151857, 120, 121] + \
                  [img_patch_id] * 254 + [img_end_id, 198]
              img_token_nums = len(fake_img_token_ids)
              if seq_len < img_token_nums:
                  raise ValueError(f"The number of max_model_len/max_num_seqs is smaller than the img_token_nums({img_token_nums}) of Qwen-VL.")
              prompt_tokens = fake_img_token_ids + \
                  [0] * (seq_len - img_token_nums)
          else:
              prompt_tokens = [0] * seq_len
          fake_image_input = None
          return SequenceData(prompt_tokens), fake_image_input

      另外,调用该函数部分的代码更改为。

      seq_data, fake_multi_modal_input = _prepare_fake_inputs(
                      seq_len, self.model_config)
    5. 在vllm_npu/entrypoints/__init__.py中进行generate函数的替换。
      from vllm_npu.entrypoints.llm import generate
      import vllm.entrypoints.llm as vllm_entry_llm
      vllm_entry_llm.LLM.generate = generate

      在vllm_npu/transformers_utils/__init__.py中进行新类的导入。

      from .mindie_tokenizer import MindIETokenizer
    6. 在vllm_npu/__init__.py文件中添加新模块的导入。
      import vllm_npu.transformers_utils
      import vllm_npu.entrypoints
      import vllm.sequence as v_sequence
      
      v_sequence.MultiModalData = MultiModalData

最后,总结对接MindIE LLM的关键部件修改的代码量,如表1所示。

表1 对接MindIE LLM的关键部件修改的代码量

重要组件

代码量

作用

MindIELlmWrapper

~120行

对接MindIE-LLM模型调用接口。

AscendSampler.py

~300行

对接MindIE-LLM后处理接口。

AscendModelRunner

~170行

原生框架中关键组件ModelRunner中的关键函数execute_model,load_model和profile_run的适配,在这里将模型导入部分替换为导入MindIELlmWrapper。