下载
中文
注册

TGI v2.0.4 版本参考适配代码

文件目录结构如下:
Tgi-MindIE
 |______cover
        |______models
               |________init__.py
        |______cli.py
        |______server.py
 |______tgi_npu
        |______init__.py
        |____cache_manager.py
        |____info.py
        |____mind_model.py
        |____token_mindie.py
        |____vlm_mind_models.py
 |______pyproject.toml
 |______install.sh
 |______README.md

各源文件的含义和作用如下表所示:

源文件

含义及作用

cover/models/__init__.py

替换原仓中server/text_generation_server/models/__init__.py文件,将推理模型引导至MindIE LLM

cover/cli.py

替换原仓中server/text_generation_server/cli.py文件,添加tgi_npu模块的日志打印过滤。

cover/server.py

替换原仓中server/text_generation_server/server.py文件,添加tgi_npu支持多模态模型入口功能。

tgi_npu/__init__.py

针对NPU硬件环境进行必要的初始化。

tgi_npu/cache_manager.py

KV Cache管理器,主要针对NPU进行KV Cache初始化。

tgi_npu/info.py

NPU信息。

tgi_npu/mind_model.py

定义了推理模型入口类MindModel以及对应的数据通信格式MindFlashCasualLMBatch,分别继承自原仓的FlashCasualLM以及FlashCasualLMBatch。在MindModel中,generate_token方法沿用了原版大部分代码,并结合MindIE LLM调用过程进行了修改。其中,Forward方法改为调用MindIE LLM提供的forward_tensor方法。warmup 结合NPU访存特点进行修改。

tgi_npu/token_mindie.py

后采样代码。

tgi_npu/vlm_mind_models.py

定义了多模态模型入口类VlmMindModel以及对应的数据通信格式。VlmMindFlashCasualLMBatch,分别继承自MindModel以及MindFlashCasualLMBatch。在VlmMindFlashCasualLMBatch中,batch_tokenized_inputs方法中使用了MindIE LLM模块提供的Tokenize方法,将输入编码为符合多模态模型输入要求的格式。

pyproject.toml

适配安装包配置文件。

样例代码:

  • Tgi-MindIE/install.sh
    #!/usr/bin/env bash
    # install-origin
    if [ -d "./tgi_origin" ]; then 
        echo "./tgi_origin directory has already exist!"
        exit 1
    fi
    
    git clone -b v2.0.4 https://github.com/huggingface/text-generation-inference.git tgi_origin
    
    cp cover/cli.py tgi_origin/server/text_generation_server/
    cp cover/models/__init__.py tgi_origin/server/text_generation_server/models
    sed -i "s/requires_padding, 16, window_size/requires_padding, 128, window_size/g" tgi_origin/router/src/infer.rs
    sed -i "s/prefill_logprobs: true/prefill_logprobs: false/g" tgi_origin/router/client/src/client.rs
    sed -i "s/bnb, accelerate, quantize, peft, outlines/accelerate, quantize, peft, outlines/g" tgi_origin/server/Makefile
    
    cd tgi_origin && make install-server && make install-router && make install-launcher
    
    cd .. && pip install -e .
    
  • Tgi-MindIE/cover/models/__init__.py
    # This file was copied from project[huggingface][text-generation-inference]
    
    from typing import Optional
    
    import torch
    from loguru import logger
    
    from text_generation_server.utils.speculate import get_speculate, set_speculate
    from text_generation_server.models.model import Model
    
    # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
    # in PyTorch 1.12 and later.
    torch.backends.cuda.matmul.allow_tf32 = True
    
    # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
    torch.backends.cudnn.allow_tf32 = True
    
    # Disable gradients
    torch.set_grad_enabled(False)
    
    __all__ = [
        "Model",
        "get_model",
    ]
    
    
    def get_model(
        model_id: str,
        revision: Optional[str],
        sharded: bool,
        quantize: Optional[str],
        speculate: Optional[int],
        dtype: Optional[str],
        trust_remote_code: bool,
    ) -> Model:
        if speculate is not None:
            logger.warning("Speculate Decoding is not supported now!")
        set_speculate(0)
    
        try:
            import torch_npu
            from tgi_npu import MindModel
            npu_module_imported = True
        except (ImportError, NotImplementedError) as excp:
            npu_module_imported = False
            logger.error(f"Error catched: {str(excp)}")
    
        if npu_module_imported and torch.npu.is_available():
            return MindModel(model_id)
        else:
            logger.error("NPU enviroment error!!!!!!!!!!!!")
            raise ValueError("NPU enviroment error!!!!!!!!!!!!")
  • Tgi-MindIE/cover/cli.py
    import os
    import sys
    from pathlib import Path
    from typing import Optional
    from enum import Enum
    import typer
    from loguru import logger
    from huggingface_hub import hf_hub_download
    app = typer.Typer()
    MODEL_SUFFIX = ".safetensors"
    CONFIG_FILENAME = "config.json"
    
    
    class Quantization(str, Enum):
        bitsandbytes = "bitsandbytes"
        bitsandbytes_nf4 = "bitsandbytes-nf4"
        bitsandbytes_fp4 = "bitsandbytes-fp4"
        gptq = "gptq"
        awq = "awq"
        eetq = "eetq"
        fp8 = "fp8"
    
    
    class Dtype(str, Enum):
        float16 = "float16"
        bloat16 = "bfloat16"
    
    
    @app.command()
    def serve(
            model_id: str,
            revision: Optional[str] = None,
            sharded: bool = False,
            quantize: Optional[Quantization] = None,
            speculate: Optional[int] = None,
            dtype: Optional[Dtype] = None,
            trust_remote_code: bool = False,
            uds_path: Path = "/tmp/text-generation-server",
            logger_level: str = "INFO",
            json_output: bool = False,
            otlp_endpoint: Optional[str] = None,
    ):
        if sharded:
            assert (
                    os.getenv("RANK", None) is not None
            ), "RANK must be set when sharded is True"
            assert (
                    os.getenv("WORLD_SIZE", None) is not None
            ), "WORLD_SIZE must be set when sharded is True"
            assert (
                    os.getenv("MASTER_ADDR", None) is not None
            ), "MASTER_ADDR must be set when sharded is True"
            assert (
                    os.getenv("MASTER_PORT", None) is not None
            ), "MASTER_PORT must be set when sharded is True"
    
        # Remove default handler
        logger.remove()
        logger.add(
            sys.stdout,
            format="{message}",
            filter="text_generation_server",
            level=logger_level,
            serialize=json_output,
            backtrace=True,
            diagnose=False,
        )
        logger.add(
            sys.stdout,
            format="{message}",
            filter="tgi_npu",
            level=logger_level,
            serialize=json_output,
            backtrace=True,
            diagnose=False,
        )
  • Tgi-MindIE/cover/server.py
    import asyncio
    import os
    import torch
    import torch_npu
    import time
    import signal
    from grpc import aio
    from loguru import logger
    from grpc_reflection.v1alpha import reflection
    from pathlib import Path
    from typing import List, Optional
    from text_generation_server.cache import Cache
    from text_generation_server.interceptor import ExceptionInterceptor
    from text_generation_server.models import Model, get_model
    try:
        from tgi_npu.vlm_mind_models import VlmMindFlashCausalLMBatch
        VLM_BATCH_TYPES = {VlmMindFlashCausalLMBatch}
    except (ImportError, NotImplementedError):
        # These imports can fail on CPU/Non flash.
        VLM_BATCH_TYPES = set()
    from text_generation_server.pb import generate_pb2_grpc, generate_pb2
    from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
    from text_generation_server.models.globals import set_model_id
    soc_version = torch_npu._C._npu_get_soc_version()
    if soc_version not in [104,220,221,222,223,224]:
        logger.info("Some ops do not support in this soc !")
        option = {"NPU_FUZZY_COMPILE_BLACKLIST": "ReduceNansum"}
        torch.npu.set_option(option)
    else:
        option = {"NPU_FUZZY_COMPILE_BLACKLIST": "GatherElements"}
    torch.npu.set_option(option)
    class SignalHandler:
        KEEP_PROCESSING = True
    
        def __init__(self):
            signal.signal(signal.SIGINT, self.exit_gracefully)
            signal.signal(signal.SIGTERM, self.exit_gracefully)
        def exit_gracefully(self, signum, frame):
            print(f"Exiting gracefully: Signal {signum}")
            self.KEEP_PROCESSING = False
    
    
    class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
        def __init__(
            self,
            model: Model,
            cache: Cache,
            quantize: Optional[str],
            server_urls: List[str],
        ):
            self.cache = cache
            self.model = model
            self.quantize = quantize
            self.server_urls = server_urls
            # For some reason, inference_mode does not work well with GLOO which we use on CPU
            if model.device.type == "cuda" or model.device.type == "npu":
                # Force inference mode for the lifetime of TextGenerationService
                self._inference_mode_raii_guard = torch._C._InferenceMode(True)
            self.step = 0
    
        async def Info(self, request, context):
            return self.model.info
        async def Health(self, request, context):
            if self.model.device.type == "cuda":
                torch.zeros((2, 2)).cuda()
            return generate_pb2.HealthResponse()
        async def ServiceDiscovery(self, request, context):
            return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
        async def ClearCache(self, request, context):
            if request.HasField("id"):
                self.cache.delete(request.id)
            else:
                self.cache.clear()
            return generate_pb2.ClearCacheResponse()
        async def FilterBatch(self, request, context):
            batch = self.cache.pop(request.batch_id)
            if batch is None:
                raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
            filtered_batch = batch.filter(request.request_ids)
            self.cache.set(filtered_batch)
            return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
        async def Warmup(self, request, context):
            for i, r in enumerate(request.batch.requests):
                r.parameters.typical_p = 1.0
                r.prefill_logprobs = False
            if self.quantize == "gptq":
                try:
                    # When using GPTQ, Exllama kernels need some global kernels
                    # For which we have the finale shapes only after the model has loaded
                    # This will allocate those buffers.
                    from text_generation_server.layers.gptq import (
                        create_exllama_buffers,
                        set_device,
                    )
                    set_device(self.model.device)
                    create_exllama_buffers(request.max_prefill_tokens)
                except ImportError:
                    pass
    
            if (
                self.model.batch_type in VLM_BATCH_TYPES
            ):  # Hack, i would rather use kwargs in the `from_pb` call
                for i, r in enumerate(request.batch.requests):
                    r.inputs = r.inputs.split('!')[0]
                batch = self.model.batch_type.from_pb_processor(
                    request.batch,
                    self.model.tokenizer,
                    self.model.tokenize,
                    self.model.dtype,
                    self.model.device,
                )
            else:
                batch = self.model.batch_type.from_pb(
                    request.batch, self.model.tokenizer, self.model.dtype, self.model.device
                )
            max_supported_total_tokens = self.model.warmup(batch)
            return generate_pb2.WarmupResponse(
                max_supported_total_tokens=max_supported_total_tokens
            )
        async def Prefill(self, request, context):
            start = time.time_ns()
            if (
                self.model.batch_type in VLM_BATCH_TYPES
            ):  # Hack, i would rather use kwargs in the `from_pb` call
                batch = self.model.batch_type.from_pb_processor(
                    request.batch,
                    self.model.tokenizer,
                    self.model.tokenize,
                    self.model.dtype,
                    self.model.device,
                )
            else:
                batch = self.model.batch_type.from_pb(
                    request.batch, self.model.tokenizer, self.model.dtype, self.model.device
                )
            generations, next_batch, timings = self.model.generate_token(batch)
            self.cache.set(next_batch)
            return generate_pb2.PrefillResponse(
                generations=[generation.to_pb() for generation in generations],
                batch=next_batch.to_pb() if next_batch else None,
                forward_ns=timings[0],
                decode_ns=timings[1],
                total_ns=time.time_ns() - start,
            )
  • Tgi-MindIE/tgi_npu/__init__.py
    #!/usr/bin/env python3
    # coding=utf-8
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    
    
    import torch
    import torch_npu
    from loguru import logger
    
    from tgi_npu.mind_models import MindModel
    
    
    def init():
        torch._C._InferenceMode(True)
        soc_version = torch_npu._C._npu_get_soc_version()
        if soc_version not in [104, 220, 221, 222, 223, 224]:
            logger.info("Some op does not support for this soc!")
            option = {"NPU_FUZZY_COMPILE_BLACKLIST": "ReduceNansum"}
        else:
            option = {"NPU_FUZZY_COMPILE_BLACKLIST": "GatherElements"}
        try:
            torch.npu.set_option(option)
            logger.warning("Finish init for NPU device!")
        except Exception as e:
            logger.error(f"Failed to init for NPU device: {e}!")
    
    
    init()
  • Tgi-MindIE/tgi_npu/cache_manager.py
    # Part of codes in this file was copied from project[huggingface][text-generation-inference]
    
    import math
    from typing import Optional, List, Tuple
    import gc
    import torch
    from tgi_npu.info import NPUSocInfo
    
    BLOCK_SIZE: int = 128
    # Will be set in warmup
    CACHE_MANAGER: Optional["CacheManager"] = None
    
    
    class CacheManager:
        def __init__(
                self,
                num_blocks: int,
                num_layers: int,
                num_heads: int,
                head_size: int,
                repeat_slots: bool,
                dtype: torch.dtype,
                device: torch.device,
        ):
            self.block_size = BLOCK_SIZE
            self.num_blocks = num_blocks
            self.repeat_slots = repeat_slots
            # for NZ/ND data format display
            self.need_nz = NPUSocInfo().need_nz
    
            if self.need_nz:
                self.kv_cache = [
                    (
                        torch.empty(
                            (num_blocks, num_heads * head_size // 16, self.block_size, 16),
                            dtype=dtype,
                            device=device,
                        ),
                        torch.empty(
                            (num_blocks, num_heads * head_size // 16, self.block_size, 16),
                            dtype=dtype,
                            device=device,
                        ),
                    )
                    for _ in range(num_layers)
                ]
            else:
                self.kv_cache = [
                    (
                        torch.empty(
                            (num_blocks, self.block_size, num_heads, head_size),
                            dtype=dtype,
                            device=device,
                        ),
                        torch.empty(
                            (num_blocks, self.block_size, num_heads, head_size),
                            dtype=dtype,
                            device=device,
                        ),
                    )
                    for _ in range(num_layers)
                ]
            self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
            self.slots = torch.arange(
                0, num_blocks * self.block_size, dtype=torch.int64
            ).view(num_blocks, self.block_size)
    
        def __repr__(self):
            return (f"CacheManager: "
                    f"num_blocks={self.num_blocks},"
                    f"block_size={self.block_size},"
                    f"free_block_mask={self.free_block_mask},"
                    f"slots={self.slots},"
                    f"k_cache shape={self.kv_cache[0][0].shape},"
                    f"v_cache shape={self.kv_cache[0][1].shape}")
    
        def allocate(
                self,
                needed_blocks_slots: List[Tuple[int, int]],
                blocks: int,
                max_blocks: int,
                device: torch.device,
        ):
            # Get free blocks indices by finding values in mask that are not set to 0
            free_block_indices = self.free_block_mask.nonzero()
            if blocks > len(free_block_indices):
                raise RuntimeError(
                    f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
                )
    
            # Slice by the number of required blocks
            block_indices = free_block_indices[:blocks]
            block_indices = block_indices.flatten()
    
            # Padded block tables
            block_tables_tensor = torch.zeros(
                (len(needed_blocks_slots), max_blocks), dtype=torch.int32
            )
    
            # Allocate paged attention blocks
            cumulative_blocks = 0
            slots = []
            block_tables = []
            for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
                # Get allocated blocks for this sequence
                allocated_blocks = block_indices[
                                   cumulative_blocks: cumulative_blocks + needed_blocks
                                   ]
                # Get slots for the allocated blocks
                all_slots = self.slots[allocated_blocks].flatten()
    
                # Repeat slots in the case of context sliding window
                if needed_slots > len(all_slots) and self.repeat_slots:
                    repeats = math.ceil(needed_slots / len(all_slots))
                    all_slots = all_slots.repeat(repeats)
    
                allocated_slots = all_slots[:needed_slots]
    
                slots.append(allocated_slots)
                block_tables.append(allocated_blocks.tolist())
                block_tables_tensor[i, :needed_blocks] = allocated_blocks
                cumulative_blocks += needed_blocks
    
            block_tables = block_tables
            block_tables_tensor = block_tables_tensor.to(device)
            slots = torch.concat(slots).to(device)
    
            # Allocate the required number of blocks by setting the mask to 0
            self.free_block_mask[block_indices] = 0
    
            return block_tables, block_tables_tensor, slots
    
        def free(self, block_indices: Optional[List[int]]):
            if block_indices is not None and block_indices:
                # Reset mask
                self.free_block_mask[block_indices] = 1
    
    
    def set_cache_manager(
            num_blocks: int,
            num_layers: int,
            num_heads: int,
            head_size: int,
            repeat_slots: bool,
            dtype: torch.dtype,
            device: torch.device,
    ) -> CacheManager:
        global CACHE_MANAGER
        if CACHE_MANAGER is not None:
            del CACHE_MANAGER
            torch.npu.empty_cache()
            gc.collect()
    
        CACHE_MANAGER = CacheManager(
            num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
        )
        return CACHE_MANAGER
    
    
    def get_cache_manager() -> CacheManager:
        global CACHE_MANAGER
        if CACHE_MANAGER is None:
            raise RuntimeError("cache manager was not initialized")
    
        return CACHE_MANAGER
  • Tgi-MindIE/tgi_npu/info.py
    from dataclasses import dataclass
    import torch_npu
    
    @dataclass
    class NPUSocInfo:
         soc_name: str = ""
         soc_version: int = -1
         need_nz: bool = False
         def __post_init__(self):
             self.soc_version = torch_npu._C._npu_get_soc_version()
             if self.soc_version in (100, 101, 102, 103, 104, 200, 201, 202, 203):
                 self.need_nz = True
  • Tgi-MindIE/tgi_npu/mind_model.py
    # Part of codes in this file was copied from project[huggingface][text-generation-inference]
    
    
    import math
    import time
    import os
    from typing import Optional, Tuple, List, Type
    from dataclasses import dataclass
    
    import torch_npu
    import torch
    from loguru import logger
    from opentelemetry import trace
    import numpy as np
    
    from text_generation_server.models.flash_causal_lm import FlashCausalLM, FlashCausalLMBatch
    from text_generation_server.models.types import (
        Batch,
        Tokens,
        Generation,
        GeneratedText
    )
    
    from text_generation_server.utils import StoppingCriteria
    from text_generation_server.pb import generate_pb2
    from text_generation_server.utils.speculate import get_speculate
    from text_generation_server.utils.dist import RANK, MEMORY_FRACTION
    from text_generation_server.utils.tokens import batch_top_tokens
    from text_generation_server.models import cache_manager as tgi_cache_manager
    
    from transformers import PreTrainedTokenizerBase
    from mindie_llm.text_generator.adapter.generator_torch import GeneratorTorch
    from mindie_llm.modeling.backend_type import BackendType
    
    from tgi_npu.tokens_mindie import MindIELLMHeterogeneousNextTokenChooser
    
    from tgi_npu.cache_manager import (
        BLOCK_SIZE,
        get_cache_manager,
        set_cache_manager,
    )
    
    tracer = trace.get_tracer(__name__)
    
    
    @dataclass
    class MindFlashCausalLMBatch(FlashCausalLMBatch):
        next_token_chooser: MindIELLMHeterogeneousNextTokenChooser
        all_input_ids_tensor: torch.Tensor
    
        def __repr__(self):
            return (f"MindFlashCausalLMBatch: batch_id={self.batch_id},"
                    f"requests_idx_mapping={self.requests_idx_mapping},"
                    f"input_ids={self.input_ids},"
                    f"position_ids={self.position_ids},"
                    f"cu_seqlen_prefill={self.cu_seqlen_prefill},"
                    f"start_slots={self.start_slots},"
                    f"slot_indices={self.slot_indices},"
                    f"needed_blocks_slots={self.needed_blocks_slots},"
                    f"block_tables={self.block_tables},"
                    f"block_tables_tensor={self.block_tables_tensor},"
                    f"slots={self.slots},"
                    f"max_seqlen={self.max_seqlen},"
                    f"prefill_head_indices={self.prefill_head_indices},"
                    f"prefill_next_token_indices={self.prefill_next_token_indices},"
                    f"prefill_cu_outlens={self.prefill_cu_outlens},"
                    f"input_lengths={self.input_lengths},"
                    f"input_lengths_tensor={self.input_lengths_tensor},"
                    f"prefix_offsets={self.prefix_offsets},"
                    f"read_offsets={self.read_offsets},"
                    f"all_input_ids_tensor={self.all_input_ids_tensor},"
                    f"next_token_chooser={self.next_token_chooser},"
                    f"stopping_criterias={self.stopping_criterias},"
                    f"blocks={self.blocks},"
                    f"max_blocks={self.max_blocks}")
    
        @classmethod
        def from_pb(
                cls,
                pb: generate_pb2.Batch,
                tokenizer: PreTrainedTokenizerBase,
                dtype: torch.dtype,
                device: torch.device
        ) -> "MindFlashCausalLMBatch":
            batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
            return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
    
        @classmethod
        def from_tokenized(
                cls,
                pb: generate_pb2.Batch,
                tokenizer: PreTrainedTokenizerBase,
                batch_tokenized_inputs,
                dtype: torch.dtype,
                device: torch.device,
        ) -> "MindFlashCausalLMBatch":
            position_ids = []
            cu_seqlen_prefill = [0]
            needed_blocks_slots = []
            start_slots = []
            slot_indices = []
    
            input_lengths = []
            prefix_offsets = []
            read_offsets = []
            all_input_ids = []
            requests_idx_mapping = {}
    
            all_prefill_logprobs = True
            no_prefill_logprobs = True
            prefill_head_indices = []
            prefill_next_token_indices = []
            prefill_cu_outlens = [0]
    
            next_token_chooser_parameters = []
            stopping_criterias = []
            top_n_tokens = []
    
            # Cumulative length
            cumulative_length = 0
            cumulative_max_length = 0
            prefill_out_cumulative_length = 0
    
            blocks = 0
            max_seqlen = 0
            max_length = 0
            max_blocks = 0
    
            # Parse batch
            for i, (r, tokenized_input) in enumerate(
                    zip(pb.requests, batch_tokenized_inputs)
            ):
                # request id -> idx in list mapping
                requests_idx_mapping[r.id] = i
    
                tokenized_input = tokenized_input[-r.truncate:]
                if (
                        tokenized_input[0] == tokenizer.bos_token_id
                        and tokenized_input[1] == tokenizer.bos_token_id
                ):
                    tokenized_input = tokenized_input[1:]
    
                input_length = len(tokenized_input)
                input_lengths.append(input_length)
    
                prefix_offsets.append(input_length - 5)
                read_offsets.append(input_length)
    
                all_input_ids.append(tokenized_input)
    
                # Position ids
                request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
                position_ids.append(request_position_ids)
    
                # Add cumulative lengths of all previous inputs
                cu_seqlen_prefill.append(cumulative_length + input_length)
    
                next_token_chooser_parameters.append(r.parameters)
    
                stopping_criteria = StoppingCriteria.from_pb(
                    r.stopping_parameters, tokenizer
                )
                max_new_tokens = stopping_criteria.max_new_tokens
                stopping_criterias.append(stopping_criteria)
                top_n_tokens.append(r.top_n_tokens)
    
                # Paged attention
                # Remove one as the first token des not have a past
                speculative_length = get_speculate()
                speculative_length = 0 if speculative_length is None else speculative_length
                total_tokens = input_length + max_new_tokens - 1 + speculative_length
                needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
                blocks += needed_blocks
                needed_blocks_slots.append((needed_blocks, total_tokens))
                start_slots.append(cumulative_max_length)
    
                request_slot_indices = torch.arange(
                    cumulative_max_length,
                    cumulative_max_length + input_length,
                    dtype=torch.int64,
                )
                slot_indices.append(request_slot_indices)
    
                all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
                no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
    
                if r.prefill_logprobs:
                    prefill_head_indices.append(request_position_ids + cumulative_length)
                    prefill_next_token_indices.append(
                        prefill_out_cumulative_length + input_length - 1
                    )
                    prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                    prefill_out_cumulative_length += input_length
                else:
                    prefill_head_indices.append(
                        torch.tensor(
                            [cumulative_length + input_length - 1], dtype=torch.int64
                        )
                    )
                    prefill_next_token_indices.append(prefill_out_cumulative_length)
                    prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                    prefill_out_cumulative_length += 1
    
                # Update
                cumulative_length += input_length
                cumulative_max_length += total_tokens
                max_seqlen = max(max_seqlen, input_length)
                max_blocks = max(max_blocks, needed_blocks)
                max_length = max(
                    max_length, input_length + max_new_tokens + speculative_length
                )
    
            next_token_chooser = MindIELLMHeterogeneousNextTokenChooser.from_pb(
                pb=next_token_chooser_parameters, dtype=dtype, device=device
            )
    
            start_slots = torch.tensor(start_slots, dtype=torch.int64)
    
            # Padded all_input_ids_tensor
            all_input_ids_tensor = np.zeros(
                (len(all_input_ids), max_length), dtype=np.int64
            )
            for i, input_ids in enumerate(all_input_ids):
                all_input_ids_tensor[i, : len(input_ids)] = input_ids
    
            all_input_ids_tensor = torch.tensor(
                all_input_ids_tensor, dtype=torch.int64, device=device
            )
    
            if len(pb.requests) > 1:
                input_ids = np.concatenate(all_input_ids, dtype=np.int64)
                position_ids = torch.cat(position_ids)
                slot_indices = torch.cat(slot_indices)
            else:
                input_ids = all_input_ids[0]
                position_ids = position_ids[0]
                slot_indices = slot_indices[0]
    
            cu_seqlen_prefill = torch.tensor(
                cu_seqlen_prefill, device=device, dtype=torch.int64
            )
            position_ids = position_ids.to(device)
            slot_indices = slot_indices.to(device)
            input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
            input_lengths_tensor = torch.tensor(
                input_lengths, dtype=torch.int64, device=device
            )
    
            if all_prefill_logprobs:
                prefill_head_indices = None
                prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
            elif no_prefill_logprobs:
                prefill_head_indices = cu_seqlen_prefill[1:] - 1
                prefill_next_token_indices = None
            else:
                prefill_head_indices = torch.tensor(
                    torch.cat(prefill_head_indices), dtype=torch.int64, device=device
                )
                prefill_next_token_indices = torch.tensor(
                    prefill_next_token_indices, dtype=torch.int64, device=device
                )
            top_n_tokens_tensor = torch.tensor(
                top_n_tokens, device=device, dtype=torch.int64
            )
    
            return cls(
                batch_id=pb.id,
                requests=pb.requests,
                requests_idx_mapping=requests_idx_mapping,
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=cu_seqlen_prefill,
                start_slots=start_slots,
                slot_indices=slot_indices,
                needed_blocks_slots=needed_blocks_slots,
                block_tables=None,
                block_tables_tensor=None,
                slots=None,
                max_seqlen=max_seqlen,
                prefill_head_indices=prefill_head_indices,
                prefill_next_token_indices=prefill_next_token_indices,
                prefill_cu_outlens=prefill_cu_outlens,
                input_lengths=input_lengths,
                input_lengths_tensor=input_lengths_tensor,
                prefix_offsets=prefix_offsets,
                read_offsets=read_offsets,
                all_input_ids=all_input_ids,
                all_input_ids_tensor=all_input_ids_tensor,
                next_token_chooser=next_token_chooser,
                stopping_criterias=stopping_criterias,
                top_n_tokens=top_n_tokens,
                top_n_tokens_tensor=top_n_tokens_tensor,
                blocks=blocks,
                max_blocks=max_blocks,
                speculative_ids=None,
            )
    
        @classmethod
        @tracer.start_as_current_span("concatenate")
        def concatenate(cls, batches: List["MindFlashCausalLMBatch"]) -> "MindFlashCausalLMBatch":
            # Batch attributes
            requests = []
            requests_idx_mapping = {}
    
            blocks = 0
            total_batch_size = 0
            total_slots = 0
            max_blocks = 0
            max_length = 0
            max_seqlen = 0
            for b in batches:
                total_batch_size += len(b)
                total_slots += len(b.slots)
                blocks += b.blocks
                speculative_length = (
                    b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
                )
                max_blocks = max(max_blocks, b.max_blocks)
                max_seqlen = max(max_seqlen, b.max_seqlen)
                max_length = max(
                    max_length,
                    max(
                        input_length
                        + stopping_criteria.max_new_tokens
                        + speculative_length
                        - stopping_criteria.current_tokens
                        for input_length, stopping_criteria in zip(
                            b.input_lengths, b.stopping_criterias
                        )
                    ),
                )
    
            input_ids = batches[0].input_ids.new_empty(total_batch_size)
            position_ids = batches[0].position_ids.new_empty(total_batch_size)
            slots = batches[0].slots.new_empty(total_slots)
            slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
            input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
                total_batch_size
            )
            block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
                (total_batch_size, max_blocks)
            )
    
            all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
                (total_batch_size, max_length)
            )
            top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
                total_batch_size,
            )
    
            start_slots = []
            block_tables = []
            all_input_ids = []
    
            input_lengths = []
            prefix_offsets = []
            read_offsets = []
    
            next_token_chooser_parameters = []
            fsm_grammar_states = []
            stopping_criterias = []
            top_n_tokens = []
    
            # Cumulative length
            cumulative_batch_size = 0
            cumulative_slots = 0
    
            for i, batch in enumerate(batches):
                requests.extend(batch.requests)
    
                if i == 0:
                    requests_idx_mapping = batch.requests_idx_mapping
                else:
                    # We need to offset the mapping for each batch by the cumulative batch size
                    for k, v in batch.requests_idx_mapping.items():
                        requests_idx_mapping[k] = v + cumulative_batch_size
    
                start_index = cumulative_batch_size
                end_index = cumulative_batch_size + len(batch)
                slots_start_index = cumulative_slots
                slots_end_index = cumulative_slots + len(batch.slots)
    
                # Copy tensors (GPU)
                input_ids[start_index:end_index] = batch.input_ids
                position_ids[start_index:end_index] = batch.position_ids
                slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
                input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
                top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
                slots[slots_start_index:slots_end_index] = batch.slots
    
                all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
                ] = batch.all_input_ids_tensor[:, :max_length]
    
                block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
                ] = batch.block_tables_tensor[:, :max_blocks]
    
                start_slots.append(batch.start_slots + cumulative_slots)
    
                block_tables.extend(batch.block_tables)
                all_input_ids.extend(batch.all_input_ids)
    
                input_lengths.extend(batch.input_lengths)
                prefix_offsets.extend(batch.prefix_offsets)
                read_offsets.extend(batch.read_offsets)
    
                next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
                fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
                stopping_criterias.extend(batch.stopping_criterias)
    
                top_n_tokens.extend(batch.top_n_tokens)
    
                # Update
                cumulative_batch_size += len(batch)
                cumulative_slots += len(batch.slots)
    
            start_slots = torch.concat(start_slots)
    
            next_token_chooser = MindIELLMHeterogeneousNextTokenChooser.from_pb(
                pb=next_token_chooser_parameters,
                dtype=batches[0].next_token_chooser.dtype,
                device=batches[0].next_token_chooser.device
            )
    
            speculative_ids = (
                torch.cat([b.speculative_ids for b in batches], dim=0)
                if batches[0].speculative_ids is not None
                else None
            )
    
            # Needed to avoid dropping blocks when the batches will go out of scope
            for b in batches:
                b.block_tables = None
                del b
    
            return cls(
                batch_id=batches[0].batch_id,
                requests=requests,
                requests_idx_mapping=requests_idx_mapping,
                input_ids=input_ids,
                position_ids=position_ids,
                cu_seqlen_prefill=None,
                start_slots=start_slots,
                slot_indices=slot_indices,
                needed_blocks_slots=None,
                block_tables=block_tables,
                block_tables_tensor=block_tables_tensor,
                slots=slots,
                max_seqlen=max_seqlen,
                prefill_head_indices=None,
                prefill_next_token_indices=None,
                prefill_cu_outlens=None,
                input_lengths=input_lengths,
                input_lengths_tensor=input_lengths_tensor,
                prefix_offsets=prefix_offsets,
                read_offsets=read_offsets,
                all_input_ids=all_input_ids,
                all_input_ids_tensor=all_input_ids_tensor,
                next_token_chooser=next_token_chooser,
                stopping_criterias=stopping_criterias,
                top_n_tokens=top_n_tokens,
                top_n_tokens_tensor=top_n_tokens_tensor,
                blocks=blocks,
                max_blocks=max_blocks,
                speculative_ids=speculative_ids,
            )
    
        def to_pb(self) -> generate_pb2.CachedBatch:
            return generate_pb2.CachedBatch(
                id=self.batch_id,
                request_ids=[r.id for r in self.requests],
                size=len(self),
                max_tokens=self.blocks * BLOCK_SIZE,
            )
    
    class MindModel(FlashCausalLM):
        def __init__(
                self,
                model_id: str
        ):
            logger.warning("Initialize mindie-llm model.")
            rank = int(os.getenv("RANK", "0"))
            world_size = int(os.getenv("WORLD_SIZE", "1"))
            model_config = {
                'backend_type': BackendType.ATB,
                'rank': rank,
                'world_size': world_size,
                'model_id': model_id,
                'num_threads': 8,
                'local_rank': rank,
                'npu_device_id': rank
            }
            self.model_runner = GeneratorTorch(model_config)
            super(MindModel, self).__init__(
                model=self.model_runner.model_wrapper.model_runner.model,
                tokenizer=self.model_runner.tokenizer,
                num_layers=self.model_runner.model_info.num_layers,
                num_kv_heads=self.model_runner.model_info.num_kv_heads,
                head_size=self.model_runner.model_info.head_size,
                dtype=self.model_runner.model_info.dtype,
                device=self.model_runner.model_info.device,
                rank=self.model_runner.rank,
                world_size=self.model_runner.world_size,
            )
            logger.warning("MindModel from tgi_npu initialized.")
    
        def __del__(self):
            del self.model_runner.model_wrapper
    
        @property
        def batch_type(self) -> Type[MindFlashCausalLMBatch]:
            return MindFlashCausalLMBatch
    
        def forward(
                self, batch: MindFlashCausalLMBatch
        ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
            """Assume return logits, speculative_logits"""
            input_ids = batch.input_ids
            position_ids = batch.position_ids
            cu_seqlen_prefill = batch.cu_seqlen_prefill
            kv_cache = get_cache_manager().kv_cache
            block_tables = batch.block_tables_tensor
            slots = batch.slots[batch.slot_indices]
            input_lengths = batch.input_lengths_tensor
            max_s = batch.max_seqlen
            lm_head_indices = batch.prefill_head_indices
            speculative_logits = None
    
            return self.model_runner.forward_tensor(
                input_ids=input_ids,
                position_ids=position_ids,
                is_prefill=cu_seqlen_prefill is not None,
                kv_cache=kv_cache,
                block_tables=block_tables,
                slots=slots,
                input_lengths=input_lengths,
                max_seq_len=max_s,
                lm_head_indices=lm_head_indices,
            ), speculative_logits
    
        @tracer.start_as_current_span("generate_token")
        def generate_token(
                self, batch: MindFlashCausalLMBatch
        ) -> Tuple[List[Generation], Optional[MindFlashCausalLMBatch], Tuple[int, int]]:
            start = time.time_ns()
            prefill = batch.cu_seqlen_prefill is not None
            prefill_logprobs = batch.prefill_next_token_indices is not None
            # check if need slots
            if batch.needed_blocks_slots:
                # Allocate blocks to this batch
                block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                    batch.needed_blocks_slots,
                    batch.blocks,
                    batch.max_blocks,
                    batch.input_ids.device,
                )
                batch.needed_blocks_slots = None
                batch.block_tables = block_tables
                batch.block_tables_tensor = block_tables_tensor
                batch.slots = slots
    
            try:
                out, speculative_logits = self.forward(batch)
            except Exception as e:
                del batch
                raise e
    
            if prefill:
                next_token_logits = (
                    out[batch.prefill_next_token_indices] if prefill_logprobs else out
                )
                if speculative_logits is not None:
                    speculative_logits = (
                        speculative_logits[batch.prefill_next_token_indices]
                        if prefill_logprobs
                        else speculative_logits
                    )
            else:
                logger.debug(f"Decode batch size {batch.input_ids.shape[0]}")
                next_token_logits = out
    
            speculate = get_speculate()
    
            request_ids = [req.id for req in batch.requests]
            (
                next_input_ids,
                next_token_logprobs,
                logprobs,
                accepted_ids,
                speculative_ids,
            ) = batch.next_token_chooser(
                request_ids,
                prefill,
                batch.all_input_ids_tensor[:, : batch.max_seqlen],
                next_token_logits,
                speculate
            )
    
            batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
                batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
            )
    
            if prefill:
                if len(batch) > 1 and prefill_logprobs:
                    # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                    # When batch == 1, we will just use the batch.input_ids values directly
                    prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
    
                next_position_ids = batch.position_ids.new_empty(len(batch))
                batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
                # We do not need cu_seqlen_prefill anymore
                batch.cu_seqlen_prefill = None
            else:
                prefill_logprobs = None
                next_position_ids = batch.position_ids
    
            # Cumulative length
            cumulative_length = 0
    
            # Results
            generations: List[Generation] = []
            stopped = True
    
            # Zipped iterator
            iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
    
            # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
            # one, we need to first do a GPU <-> CPU sync
            # It is faster if we delay this sync for the maximum amount of time
    
            # For each member of the batch
            index = 0
            for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
                # Indexing metadata
                start_index = cumulative_length
                end_index = cumulative_length + input_length
    
                if prefill:
                    # Indexing metadata
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]
                    out_length = out_end_index - out_start_index
    
                    # Initialize position_ids
                    # In decode, we do not need this as we can just increment position ids
                    next_position_ids[i] = batch.position_ids[end_index - 1]
    
                    # Used to gather prefill logprobs
                    # Copy batch.input_ids to prefill_token_indices
                    if prefill_logprobs:
                        if len(batch) > 1:
                            prefill_tokens_indices[out_start_index: out_end_index - 1] = (
                                batch.input_ids[start_index + 1: start_index + out_length]
                            )
                        else:
                            # Set prefill_tokens_indices to the correct slice
                            prefill_tokens_indices = batch.input_ids[
                                                     start_index + 1: start_index + out_length
                                                     ]
    
                for _ in range(n_accepted_ids):
                    index += 1
    
                cumulative_length += input_length
    
            batch.all_input_ids_tensor.scatter_(1, batch.input_lengths_tensor.view(batch.input_lengths_tensor.shape[0], 1),
                                                next_input_ids.view(next_input_ids.shape[0], 1))
            # Update values
            batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
            batch.speculative_ids = speculative_ids
            batch.position_ids = next_position_ids + accepted_ids
            batch.input_lengths_tensor += accepted_ids
            batch.slot_indices += accepted_ids
    
            if prefill and prefill_logprobs:
                # Get prefill logprobs
                prefill_logprobs_tensor = torch.log_softmax(out, -1)
                prefill_logprobs = torch.gather(
                    prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
                )
                # GPU <-> CPU sync
                prefill_logprobs = prefill_logprobs.view(-1).tolist()
    
            # GPU <-> CPU sync
            next_token_logprobs = next_token_logprobs.tolist()
            next_token_ids = next_input_ids.tolist()
            accepted_ids = accepted_ids.tolist()
            start_decode = time.time_ns()
    
            # Zipped iterator
            iterator = zip(
                batch.requests,
                batch.input_lengths,
                batch.prefix_offsets,
                batch.read_offsets,
                batch.stopping_criterias,
                batch.all_input_ids,
                batch.next_token_chooser.do_sample,
                batch.next_token_chooser.seeds,
                batch.top_n_tokens,
                accepted_ids,
                batch_top_token_ids,
                batch_top_token_logprobs,
            )
    
            # For each member of the batch
            index = 0
            for i, (
                    request,
                    input_length,
                    prefix_offset,
                    read_offset,
                    stopping_criteria,
                    all_input_ids,
                    do_sample,
                    seed,
                    top_n_tokens,
                    n_accepted_ids,
                    top_token_ids,
                    top_token_logprobs,
            ) in enumerate(iterator):
                # Append next token to all tokens
                next_token_texts = []
                left = 0
    
                if n_accepted_ids > 1:
                    if RANK == 0:
                        logger.debug(f"Speculated ids {n_accepted_ids - 1}")
    
                current_stopped = False
                for j in range(index, index + n_accepted_ids):
                    # Generated token
                    next_token_id = next_token_ids[j]
                    all_input_ids.append(next_token_id)
                    # Generated token
                    next_token_text, prefix_offset, read_offset = self.decode_token(
                        all_input_ids,
                        prefix_offset,
                        read_offset,
                    )
    
                    next_token_texts.append(next_token_text)
    
                    stop, reason = stopping_criteria(
                        next_token_id,
                        next_token_text,
                    )
    
                    if stop:
                        left = index + n_accepted_ids - j - 1
                        current_stopped = True
                        break
                    else:
                        current_stopped = False
                stopped = stopped and current_stopped
    
                _next_token_ids = next_token_ids[index: index + n_accepted_ids - left]
                _next_token_logprobs = next_token_logprobs[
                                       index: index + n_accepted_ids - left
                                       ]
                index += n_accepted_ids
    
                # Shard generations
                # All generations will be appended in the rust sharded client
                if i % self.world_size == self.rank:
                    if stop:
                        # Decode generated tokens
                        output_text, _, _ = self.decode_token(
                            all_input_ids,
                            prefix_offset=len(all_input_ids)
                                          - stopping_criteria.current_tokens
                                          - 1,
                            read_offset=len(all_input_ids)
                                        - stopping_criteria.current_tokens,
                            skip_special_tokens=True,
                        )
                        generated_text = GeneratedText(
                            output_text,
                            stopping_criteria.current_tokens,
                            reason,
                            seed if do_sample else None,
                        )
                    else:
                        generated_text = None
    
                    # Prefill
                    if prefill and request.prefill_logprobs:
                        out_start_index = batch.prefill_cu_outlens[i]
                        out_end_index = batch.prefill_cu_outlens[i + 1]
    
                        # Remove generated token to only have prefill and add nan for first prompt token
                        request_prefill_logprobs = [float("nan")] + prefill_logprobs[
                                                                    out_start_index: out_end_index - 1
                                                                    ]
                        prefill_token_ids = all_input_ids[:-1]
                        prefill_texts = self.tokenizer.batch_decode(
                            prefill_token_ids,
                            clean_up_tokenization_spaces=False,
                            skip_special_tokens=False,
                        )
    
                        prefill_tokens = Tokens(
                            prefill_token_ids,
                            request_prefill_logprobs,
                            prefill_texts,
                            is_special=[],
                        )
                    else:
                        prefill_tokens = None
    
                    if top_n_tokens > 0:
                        all_top_tokens = []
                        for top_token_ids, top_token_logprobs in zip(
                                top_token_ids, top_token_logprobs
                        ):
                            toptoken_texts = self.tokenizer.batch_decode(
                                top_token_ids,
                                clean_up_tokenization_spaces=False,
                                skip_special_tokens=False,
                            )
                            special_toptokens = [
                                token_id in self.all_special_ids
                                for token_id in top_token_ids
                            ]
                            top_tokens = Tokens(
                                top_token_ids,
                                top_token_logprobs,
                                toptoken_texts,
                                special_toptokens,
                            )
                            all_top_tokens.append(top_tokens)
                        top_tokens = all_top_tokens
                    else:
                        top_tokens = None
    
                    generation = Generation(
                        request.id,
                        prefill_tokens,
                        Tokens(
                            _next_token_ids,
                            _next_token_logprobs,
                            next_token_texts,
                            [nid in self.all_special_ids for nid in _next_token_ids],
                        ),
                        generated_text,
                        top_tokens,
                    )
    
                    generations.append(generation)
    
                # Update values
                batch.input_lengths[i] = input_length + n_accepted_ids
                if batch.input_lengths[i] > batch.max_seqlen:
                    batch.max_seqlen = batch.input_lengths[i]
                batch.prefix_offsets[i] = prefix_offset
                batch.read_offsets[i] = read_offset
                batch.all_input_ids[i] = all_input_ids
    
            if stopped:
                del batch
                # No need to return a batch if we know that all requests stopped
                forward_ns = start_decode - start
                decode_ns = time.time_ns() - start_decode
                none_batch = None
                return generations, none_batch, (forward_ns, decode_ns)
    
            batch.prefill_cu_outlens = None
            batch.prefill_head_indices = None
            batch.prefill_next_token_indices = None
    
            forward_ns = start_decode - start
            decode_ns = time.time_ns() - start_decode
            return generations, batch, (forward_ns, decode_ns)
    
        def warmup(self, batch: MindFlashCausalLMBatch):
            # The warmup batch is the biggest batch we could ever receive
            torch.npu.empty_cache()
    
            peak_memory = torch_npu.npu.max_memory_allocated()
            logger.info(f">>>>before warmup peak_memory {peak_memory}")
            try:
                cache_manager = set_cache_manager(
                    batch.blocks,
                    self.num_layers,
                    self.num_kv_heads,
                    self.head_size,
                    False,
                    self.dtype,
                    self.device,
                )
                _, batch, _ = self.generate_token(batch)
            except torch.cuda.OutOfMemoryError as e:
                raise RuntimeError(
                    f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                    f"You need to decrease `--max-batch-prefill-tokens`"
                ) from e
    
            # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
            # Calculate the number of blocks that can be allocated with the free memory
            dtype_size = torch.tensor([], dtype=self.dtype).element_size()
            cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
            total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
            torch_npu.npu.synchronize()
    
            total_gpu_memory = torch_npu.npu.get_device_properties(self.device).total_memory
    
            peak_memory = torch_npu.npu.max_memory_allocated()
            logger.info(
                f">>>>dtype_size {dtype_size}, cache_block_size {cache_block_size}, num_kv_heads {self.num_kv_heads}, "
                f"total_cache_size {total_cache_size}, peak_memory {peak_memory}")
            total_free_memory = total_gpu_memory - peak_memory
            logger.info(f">>>>total_free_memory {total_free_memory}, total_gpu_memory {total_gpu_memory}, "
                           f"MEMORY_FRACTION {MEMORY_FRACTION}")
            free_memory = max(0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory)
    
            num_blocks = (
                # Leave 5% for some wiggle room
                    int((free_memory * 0.95) // total_cache_size)
                    # Add batch.blocks as we allocated it above, so it is included in the peak memory.
                    + cache_manager.num_blocks
            )
    
            del batch
            del cache_manager
    
            real_manager = set_cache_manager(
                num_blocks,
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
                self.sliding_window is not None,
                self.dtype,
                self.device,
            )
            tgi_cache_manager.CACHE_MANAGER = real_manager
            logger.warning(f">>>>real CacheManger {get_cache_manager()}")
            peak_memory = torch_npu.npu.max_memory_allocated()
            logger.warning(f">>>>end warmup peak_memory {peak_memory}")
            logger.warning(f"Warmup return {int(num_blocks * BLOCK_SIZE)}")
            return int(num_blocks * BLOCK_SIZE)
  • Tgi-MindIE/tgi_npu/token_mindie.py
    # Part of codes in this file was copied from project[huggingface][text-generation-inference]
    
    import os
    from typing import List, Optional
    from loguru import logger
    import numpy as np
    import torch
    from text_generation_server.pb import generate_pb2
    
    from mindie_llm.text_generator.utils.sampling_metadata import SamplingData, SamplingParam
    from mindie_llm.text_generator.utils.config import SamplerConfig
    from mindie_llm.text_generator.samplers.sampler import Sampler
    
    from mindie_llm.modeling.backend_type import BackendType
    
    WRAPPER_KEY = "tensor_wrapper"
    try:
        from mindie_llm.text_generator.utils.sampling_metadata import TensorWrapper
    except ImportError:
        class TensorWrapper:
            def __init__(self, backend, device):
                self.device = device
                self.backend = backend
    
            def __call__(self, data):
                if data.dtype == np.int32:
                    dtype = torch.int32
                elif data.dtype == np.bool_:
                    dtype = torch.bool
                else:
                    dtype = None
                return torch.tensor(data, dtype=dtype, device=self.device)
    
    
        WRAPPER_KEY = "to_tensor"
    
    
    def do_filter(sample_param: List, indices):
        if any(sample_param):
            return [sample_param[i] for i in indices]
        return sample_param
    
    
    class MindIELLMHeterogeneousNextTokenChooser:
        def __init__(
                self,
                dtype: torch.dtype,
                device: torch.device,
                watermark: List[bool],
                temperature: List[float],
                repetition_penalty: List[float],
                frequency_penalty: List[float],
                top_k: List[int],
                top_p: List[float],
                typical_p: List[float],
                do_sample: List[bool],
                seeds: List[int],
                grammars: List[str],
                grammar_types: List[int],
                fsm_grammar_states: List[int],
                sample_method,  # mindie-llm sampling method
        ):
            if any(watermark):
                logger.warning(f"Watermark not supported now in mindie-llm")
            if any([x < 1.0 for x in typical_p]):
                logger.warning(f"Typical_p not supported now in mindie-llm")
            if any(grammar_types) or any(grammars) or any(fsm_grammar_states):
                logger.warning(f"Grammar not supported now in mindie-llm")
    
            self.tensor_wrapper = TensorWrapper(BackendType.ATB, device)
            self.wrapper_dict = {WRAPPER_KEY: TensorWrapper(BackendType.ATB, device)}
            self.sample_params = SamplingParam.from_numpy(
                repetition_penalty=np.array(repetition_penalty, dtype=np.float16),
                presence_penalty=None,
                frequency_penalty=np.array(frequency_penalty, dtype=np.float16),
                temperature=np.array(temperature, dtype=np.float16),
                top_k=np.array(top_k),
                top_p=np.array(top_p),
                seed=np.array(seeds).astype(np.int32),
                do_sample=np.array(do_sample),
                **self.wrapper_dict
            )
    
            # Temp store for filter
            self.temperature = temperature
            self.repetition_penalty = repetition_penalty
            self.frequency_penalty = frequency_penalty
            self.top_k = top_k
            self.top_p = top_p
            self.seeds = seeds
            self.do_sample = do_sample
    
            self.choice = sample_method
            self.dtype = dtype
            self.device = device
            self.seeds = seeds
            self.do_sample = self.sample_params.do_sample_meta.do_sample_array
            self.fsm_grammar_states = fsm_grammar_states
    
        def __call__(self,
                     request_ids: List,
                     is_prefill: bool,
                     input_ids: torch.Tensor,
                     scores: torch.Tensor,
                     speculate: int
                     ):
            batch_size = scores.shape[0]
            speculate_size = 1
            scores = scores.view(batch_size, speculate_size, -1)
    
            # Don't use SamplingData。from_numpy to avoid tensor.cpu() to transfer large data
            input_ids_int32 = input_ids.to(torch.int32)
            sample_data = SamplingData(all_input_ids=input_ids_int32, output_ids=input_ids_int32, is_prefill=is_prefill,
                                       request_ids=np.array(request_ids))
            next_ids = torch.zeros((batch_size, speculate_size), device=scores.device, dtype=torch.long)
            for j in range(speculate_size):
                _scores = scores[:, j]
    
                batch_logits, _next_ids = self.choice(batch_logits=_scores, batch_sampling_data=sample_data,
                                                      batch_sampling_params=self.sample_params)
                scores[:, j] = _scores
                next_ids[:, j] = torch.from_numpy(_next_ids)
            next_ids = next_ids.view(batch_size * speculate_size)
            allscores = scores.view(batch_size * speculate_size, -1)
            alllogprobs = torch.log_softmax(allscores, -1)
    
            accepted_ids = torch.ones_like(next_ids)
            logprobs = alllogprobs
    
            next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
    
            speculative_ids = None
            return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
    
        @classmethod
        def from_pb(
                cls,
                pb: List[generate_pb2.NextTokenChooserParameters],
                dtype: torch.dtype,
                device: torch.device,
                fsm_grammar_states: Optional[List[int]] = None,
        ) -> "MindIELLMHeterogeneousNextTokenChooser":
            curr_rank = int(os.getenv("RANK", "0"))
            sample_method = Sampler(SamplerConfig(rank=curr_rank, backend_type=BackendType.ATB, npu_id=curr_rank))
            return MindIELLMHeterogeneousNextTokenChooser(
                watermark=[pb_.watermark for pb_ in pb],
                temperature=[pb_.temperature for pb_ in pb],
                repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
                frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
                top_k=[pb_.top_k for pb_ in pb],
                top_p=[pb_.top_p for pb_ in pb],
                typical_p=[pb_.typical_p for pb_ in pb],
                do_sample=[pb_.do_sample for pb_ in pb],
                seeds=[pb_.seed for pb_ in pb],
                device=device,
                dtype=dtype,
                grammars=[pb_.grammar for pb_ in pb],
                grammar_types=[pb_.grammar_type for pb_ in pb],
                fsm_grammar_states=(
                    fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
                ),
                sample_method=sample_method
            )
    
        def filter(self, indices):
            self.repetition_penalty = do_filter(self.repetition_penalty, indices)
            self.frequency_penalty = do_filter(self.frequency_penalty, indices)
            self.temperature = do_filter(self.temperature, indices)
            self.top_k = do_filter(self.top_k, indices)
            self.top_p = do_filter(self.top_p, indices)
            self.seeds = do_filter(self.seeds, indices)
            self.do_sample = do_filter(self.do_sample, indices)
    
            self.sample_params = SamplingParam.from_numpy(
                repetition_penalty=np.array(self.repetition_penalty, dtype=np.float16),
                presence_penalty=None,
                frequency_penalty=np.array(self.frequency_penalty, dtype=np.float16),
                temperature=np.array(self.temperature, dtype=np.float16),
                top_k=np.array(self.top_k),
                top_p=np.array(self.top_p),
                seed=np.array(self.seeds).astype(np.int32),
                do_sample=np.array(self.do_sample),
                **self.wrapper_dict
            )
            return self
  • Tgi-MindIE/tgi_npu/vlm_mind_models.py
    import re
    from typing import List, Type, Dict
    from dataclasses import dataclass
    import torch
    from loguru import logger
    from atb_llm.runner.tokenizer_wrapper import TokenizerWrapper
    from text_generation_server.pb import generate_pb2
    from transformers import PreTrainedTokenizerBase
    from tgi_npu.mind_models import MindFlashCausalLMBatch, MindModel
    IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
    def split(string) -> List[Dict[str, str]]:
        parts = []
        cursor = 0
        for pattern in IMAGES.finditer(string):
            start = pattern.start()
            if start != cursor:
                parts.append({"text": string[cursor:start]})
            parts.append({"image": pattern.group(1)})
            cursor = pattern.end()
        if cursor != len(string):
            parts.append({"text": string[cursor:]})
        return parts
    @dataclass
    class VlmMindFlashCausalLMBatch(MindFlashCausalLMBatch):
        @classmethod
        def batch_tokenized_inputs(cls, requests, tokenize):
            inputs = []
            for r in requests:
                splits = split(r.inputs)
                single_input = tokenize(splits).tolist()
                inputs.append(single_input)
            return inputs
        @classmethod
        def from_pb_processor(
                cls,
                pb: generate_pb2.Batch,
                tokenizer: PreTrainedTokenizerBase,
                tokenize,
                dtype: torch.dtype,
                device: torch.device,
        ) -> "VlmMindFlashCausalLMBatch":
            batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenize)
            batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
            return batch
    class VlmMindModel(MindModel):
        def __init__(
                self,
                model_id: str
        ):
            logger.warning("Initialize mindie-llm model for vlm.")
            super(VlmMindModel, self).__init__(model_id)
            self.tokenize = TokenizerWrapper(model_id).tokenize
            logger.warning("VlmMindModel from tgi_npu initialized.")
        @property
        def batch_type(self) -> Type[VlmMindFlashCausalLMBatch]:
            return VlmMindFlashCausalLMBatch
  • Tgi-MindIE/pyproject.toml
    [tool.poetry]
    name = "tgi-npu"
    version = "0.1.0"
    description = "NPU MindIE Adapter for TGI v2.0.4"
    authors = ["Your Name <you@example.com>"]
    readme = "README.md"
    exclude = ["router", "cover"]
    
    
    
    [tool.poetry.dependencies]
    python = "^3.10"
    
    [build-system]
    requires = ["poetry-core"]
    build-backend = "poetry.core.masonry.api"