下载
中文
注册

参考代码

vLLM 0.6.2版本昇腾框架适配代码目录结构如下所示:

Vllm-MindIE-0.6.2
├─ .
├─ cover
│  ├─ requirements-npu.txt
│  ├─ setup.py
│  └─ vllm
│     ├─ attention
│     │  ├─ backends
│     │  │  └─ mindie.py
│     │  └─ selector.py
│     ├─ config.py
│     ├─ distributed
│     │  └─ npu_utils.py
│     ├─ engine
│     │  ├─ arg_utils.py
│     │  ├─ async_llm_engine.py
│     │  ├─ llm_engine.py
│     ├─ executor
│     │  ├─ distributed_npu_executor.py
│     │  ├─ npu_executor.py
│     │  ├─ ray_npu_executor.py
│     │  └─ ray_utils.py
│     ├─ model_executor
│     │  ├─ layers
│     │  │  └─ npu_sampler.py
│     │  └─ model_loader
│     │     └─ npu.py
│     ├─ platforms
│     │  ├─ __init__.py
│     │  ├─ interface.py
│     │  └─ npu.py
│     ├─ utils.py
│     └─ worker
│        ├─ cache_engine.py
│        ├─ npu_model_runner.py
│        └─ npu_worker.py
├─ examples
│  ├─ offline_inference.py
│  ├─ offline_inference.sh
│  └─ start_server.sh
├─ README.md
└─ install.sh

其中主要包括三个部分:

1. cover文件夹下包含对vllm框架源码的修改内容。

2. examples文件夹下包含离线模式和在线模式的使用实例代码。

3. install.sh为一键式安装脚本,在将所有的代码文件都还原后,即可运行该脚本一键安装昇腾适配版的vllm框架,其中会自动拉取源码安装vllm原生框架并打上适配补丁。

cover文件夹下对vllm框架源码的修改内容:

  • cover/requirements-npu.txt: 适配昇腾环境的包列表:
    cmake>=3.26
    torch==2.1.0
    ninja
    packaging
    setuptools>=61
    setuptools-scm>=8
    wheel
    jinja2
    psutil
    sentencepiece  # Required for LLaMA tokenizer.
    numpy < 2.0.0
    requests
    tqdm
    py-cpuinfo
    transformers == 4.45.0  # Required for Chameleon and Llama 3.1 hotfox.
    tokenizers >= 0.19.1  # Required for Llama 3.
    protobuf # Required by LlamaTokenizer.
    fastapi < 0.113.0; python_version < '3.9'
    fastapi >= 0.114.1; python_version >= '3.9'
    aiohttp
    openai >= 1.40.0 # Ensure modern openai package (ensure types module present)
    uvicorn[standard]
    pydantic >= 2.9  # Required for fastapi >= 0.113.0
    pillow  # Required for image processing
    prometheus_client >= 0.18.0
    prometheus-fastapi-instrumentator >= 7.0.0
    tiktoken >= 0.6.0  # Required for DBRX tokenizer
    lm-format-enforcer == 0.10.6
    outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
    typing_extensions >= 4.10
    filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
    partial-json-parser # used for parsing partial JSON outputs
    pyzmq
    msgspec
    gguf == 0.10.0
    importlib_metadata
    mistral_common >= 1.4.3
    pyyaml
    ray == 2.9.3
  • cover/setup.py:增加关于昇腾NPU环境的识别及对应安装依赖文件的获取。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of code in this file was copied from project [vLLM Team][vllm] for adapting usage
    
    import importlib.util
    import io
    import logging
    import os
    import re
    import subprocess
    import sys
    from pathlib import Path
    from shutil import which
    from typing import Dict, List
    
    import torch
    from packaging.version import Version, parse
    from setuptools import Extension, find_packages, setup
    from setuptools.command.build_ext import build_ext
    from setuptools_scm import get_version
    from torch.utils.cpp_extension import CUDA_HOME
    
    
    def load_module_from_path(module_name, path):
        spec = importlib.util.spec_from_file_location(module_name, path)
        module = importlib.util.module_from_spec(spec)
        sys.modules[module_name] = module
        spec.loader.exec_module(module)
        return module
    
    
    ROOT_DIR = os.path.dirname(__file__)
    logger = logging.getLogger(__name__)
    
    # cannot import envs directly because it depends on vllm,
    #  which is not installed yet
    envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py'))
    
    VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE
    
    if not sys.platform.startswith("linux"):
        logger.warning(
            "vLLM only supports Linux platform (including WSL). "
            "Building on %s, "
            "so vLLM may not be able to run correctly", sys.platform)
        VLLM_TARGET_DEVICE = "empty"
    
    MAIN_CUDA_VERSION = "12.1"
    
    
    def is_sccache_available() -> bool:
        return which("sccache") is not None
    
    
    def is_ccache_available() -> bool:
        return which("ccache") is not None
    
    
    def is_ninja_available() -> bool:
        return which("ninja") is not None
    
    
    def remove_prefix(text, prefix):
        if text.startswith(prefix):
            return text[len(prefix):]
        return text
    
    
    class CMakeExtension(Extension):
    
        def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
            super().__init__(name, sources=[], py_limited_api=True, **kwa)
            self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
    
    
    class cmake_build_ext(build_ext):
        # A dict of extension directories that have been configured.
        did_config: Dict[str, bool] = {}
    
        #
        # Determine number of compilation jobs and optionally nvcc compile threads.
        #
        def compute_num_jobs(self):
            # `num_jobs` is either the value of the MAX_JOBS environment variable
            # (if defined) or the number of CPUs available.
            num_jobs = envs.MAX_JOBS
            if num_jobs is not None:
                num_jobs = int(num_jobs)
                logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
            else:
                try:
                    # os.sched_getaffinity() isn't universally available, so fall
                    #  back to os.cpu_count() if we get an error here.
                    num_jobs = len(os.sched_getaffinity(0))
                except AttributeError:
                    num_jobs = os.cpu_count()
    
            nvcc_threads = None
            if _is_cuda() and get_nvcc_cuda_version() >= Version("11.2"):
                # `nvcc_threads` is either the value of the NVCC_THREADS
                # environment variable (if defined) or 1.
                # when it is set, we reduce `num_jobs` to avoid
                # overloading the system.
                nvcc_threads = envs.NVCC_THREADS
                if nvcc_threads is not None:
                    nvcc_threads = int(nvcc_threads)
                    logger.info(
                        "Using NVCC_THREADS=%d as the number of nvcc threads.",
                        nvcc_threads)
                else:
                    nvcc_threads = 1
                num_jobs = max(1, num_jobs // nvcc_threads)
    
            return num_jobs, nvcc_threads
    
        #
        # Perform cmake configuration for a single extension.
        #
        def configure(self, ext: CMakeExtension) -> None:
            # If we've already configured using the CMakeLists.txt for
            # this extension, exit early.
            if ext.cmake_lists_dir in cmake_build_ext.did_config:
                return
    
            cmake_build_ext.did_config[ext.cmake_lists_dir] = True
    
            # Select the build type.
            # Note: optimization level + debug info are set by the build type
            default_cfg = "Debug" if self.debug else "RelWithDebInfo"
            cfg = envs.CMAKE_BUILD_TYPE or default_cfg
    
            cmake_args = [
                '-DCMAKE_BUILD_TYPE={}'.format(cfg),
                '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
            ]
    
            verbose = envs.VERBOSE
            if verbose:
                cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON']
    
            if is_sccache_available():
                cmake_args += [
                    '-DCMAKE_C_COMPILER_LAUNCHER=sccache',
                    '-DCMAKE_CXX_COMPILER_LAUNCHER=sccache',
                    '-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache',
                    '-DCMAKE_HIP_COMPILER_LAUNCHER=sccache',
                ]
            elif is_ccache_available():
                cmake_args += [
                    '-DCMAKE_C_COMPILER_LAUNCHER=ccache',
                    '-DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
                    '-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache',
                    '-DCMAKE_HIP_COMPILER_LAUNCHER=ccache',
                ]
    
            # Pass the python executable to cmake so it can find an exact
            # match.
            cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)]
    
            # Pass the python path to cmake so it can reuse the build dependencies
            # on subsequent calls to python.
            cmake_args += ['-DVLLM_PYTHON_PATH={}'.format(":".join(sys.path))]
    
            #
            # Setup parallelism and build tool
            #
            num_jobs, nvcc_threads = self.compute_num_jobs()
    
            if nvcc_threads:
                cmake_args += ['-DNVCC_THREADS={}'.format(nvcc_threads)]
    
            if is_ninja_available():
                build_tool = ['-G', 'Ninja']
                cmake_args += [
                    '-DCMAKE_JOB_POOL_COMPILE:STRING=compile',
                    '-DCMAKE_JOB_POOLS:STRING=compile={}'.format(num_jobs),
                ]
            else:
                # Default build tool to whatever cmake picks.
                build_tool = []
            subprocess.check_call(
                ['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
                cwd=self.build_temp)
    
        def build_extensions(self) -> None:
            # Ensure that CMake is present and working
            try:
                subprocess.check_output(['cmake', '--version'])
            except OSError as e:
                raise RuntimeError('Cannot find CMake executable') from e
    
            # Create build directory if it does not exist.
            if not os.path.exists(self.build_temp):
                os.makedirs(self.build_temp)
    
            targets = []
            target_name = lambda s: remove_prefix(remove_prefix(s, "vllm."),
                                                  "vllm_flash_attn.")
            # Build all the extensions
            for ext in self.extensions:
                self.configure(ext)
                targets.append(target_name(ext.name))
    
            num_jobs, _ = self.compute_num_jobs()
    
            build_args = [
                "--build",
                ".",
                f"-j={num_jobs}",
                *[f"--target={name}" for name in targets],
            ]
    
            subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
    
            # Install the libraries
            for ext in self.extensions:
                # Install the extension into the proper location
                outdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute()
    
                # Skip if the install directory is the same as the build directory
                if outdir == self.build_temp:
                    continue
    
                # CMake appends the extension prefix to the install path,
                # and outdir already contains that prefix, so we need to remove it.
                prefix = outdir
                for i in range(ext.name.count('.')):
                    prefix = prefix.parent
    
                # prefix here should actually be the same for all components
                install_args = [
                    "cmake", "--install", ".", "--prefix", prefix, "--component",
                    target_name(ext.name)
                ]
                subprocess.check_call(install_args, cwd=self.build_temp)
    
        def run(self):
            # First, run the standard build_ext command to compile the extensions
            super().run()
    
            # copy vllm/vllm_flash_attn/*.py from self.build_lib to current
            # directory so that they can be included in the editable build
            import glob
            files = glob.glob(
                os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "*.py"))
            for file in files:
                dst_file = os.path.join("vllm/vllm_flash_attn",
                                        os.path.basename(file))
                print(f"Copying {file} to {dst_file}")
                self.copy_file(file, dst_file)
    
    
    def _no_device() -> bool:
        return VLLM_TARGET_DEVICE == "empty"
    
    
    
    def _is_npu() -> bool:
        try:
            import torch_npu
        except ImportError:
            warnings.warn("torch_npu has not been installed in your environment! Fail to install npu version of vllm.")
            return False
        return True
    
    
    def _is_cuda() -> bool:
        has_cuda = torch.version.cuda is not None
        return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
                and not (_is_neuron() or _is_tpu()))
    
    
    def _is_hip() -> bool:
        return (VLLM_TARGET_DEVICE == "cuda"
                or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None
    
    
    def _is_neuron() -> bool:
        torch_neuronx_installed = True
        try:
            subprocess.run(["neuron-ls"], capture_output=True, check=True)
        except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
            torch_neuronx_installed = False
        return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron"
    
    
    def _is_tpu() -> bool:
        return VLLM_TARGET_DEVICE == "tpu"
    
    
    def _is_cpu() -> bool:
        return VLLM_TARGET_DEVICE == "cpu"
    
    
    def _is_openvino() -> bool:
        return VLLM_TARGET_DEVICE == "openvino"
    
    
    def _is_xpu() -> bool:
        return VLLM_TARGET_DEVICE == "xpu"
    
    
    def _build_custom_ops() -> bool:
        return _is_cuda() or _is_hip() or _is_cpu()
    
    
    def _build_core_ext() -> bool:
        return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu() or _is_npu())
    
    
    def get_hipcc_rocm_version():
        # Run the hipcc --version command
        result = subprocess.run(['hipcc', '--version'],
                                stdout=subprocess.PIPE,
                                stderr=subprocess.STDOUT,
                                text=True)
    
        # Check if the command was executed successfully
        if result.returncode != 0:
            print("Error running 'hipcc --version'")
            return None
    
        # Extract the version using a regular expression
        match = re.search(r'HIP version: (\S+)', result.stdout)
        if match:
            # Return the version string
            return match.group(1)
        else:
            print("Could not find HIP version in the output")
            return None
    
    
    def get_neuronxcc_version():
        import sysconfig
        site_dir = sysconfig.get_paths()["purelib"]
        version_file = os.path.join(site_dir, "neuronxcc", "version",
                                    "__init__.py")
    
        # Check if the command was executed successfully
        with open(version_file, "rt") as fp:
            content = fp.read()
    
        # Extract the version using a regular expression
        match = re.search(r"__version__ = '(\S+)'", content)
        if match:
            # Return the version string
            return match.group(1)
        else:
            raise RuntimeError("Could not find HIP version in the output")
    
    
    def get_nvcc_cuda_version() -> Version:
        """Get the CUDA version from nvcc.
    
        Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
        """
        assert CUDA_HOME is not None, "CUDA_HOME is not set"
        nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
                                              universal_newlines=True)
        output = nvcc_output.split()
        release_idx = output.index("release") + 1
        nvcc_cuda_version = parse(output[release_idx].split(",")[0])
        return nvcc_cuda_version
    
    
    def get_path(*filepath) -> str:
        return os.path.join(ROOT_DIR, *filepath)
    
    
    def get_vllm_version() -> str:
        version = get_version(
            write_to="vllm/_version.py",  # TODO: move this to pyproject.toml
        )
    
        sep = "+" if "+" not in version else "."  # dev versions might contain +
    
        if _no_device():
            if envs.VLLM_TARGET_DEVICE == "empty":
                version += f"{sep}empty"
        elif _is_cuda():
            cuda_version = str(get_nvcc_cuda_version())
            if cuda_version != MAIN_CUDA_VERSION:
                cuda_version_str = cuda_version.replace(".", "")[:3]
                # skip this for source tarball, required for pypi
                if "sdist" not in sys.argv:
                    version += f"{sep}cu{cuda_version_str}"
        elif _is_hip():
            # Get the HIP version
            hipcc_version = get_hipcc_rocm_version()
            if hipcc_version != MAIN_CUDA_VERSION:
                rocm_version_str = hipcc_version.replace(".", "")[:3]
                version += f"{sep}rocm{rocm_version_str}"
        elif _is_neuron():
            # Get the Neuron version
            neuron_version = str(get_neuronxcc_version())
            if neuron_version != MAIN_CUDA_VERSION:
                neuron_version_str = neuron_version.replace(".", "")[:3]
                version += f"{sep}neuron{neuron_version_str}"
        elif _is_openvino():
            version += f"{sep}openvino"
        elif _is_tpu():
            version += f"{sep}tpu"
        elif _is_cpu():
            version += f"{sep}cpu"
        elif _is_xpu():
            version += f"{sep}xpu"
        elif _is_npu():
            version += f"{sep}npu"
        else:
            raise RuntimeError("Unknown runtime environment")
    
        return version
    
    
    def read_readme() -> str:
        """Read the README file if present."""
        p = get_path("README.md")
        if os.path.isfile(p):
            return io.open(get_path("README.md"), "r", encoding="utf-8").read()
        else:
            return ""
    
    
    def get_requirements() -> List[str]:
        """Get Python package dependencies from requirements.txt."""
    
        def _read_requirements(filename: str) -> List[str]:
            with open(get_path(filename)) as f:
                requirements = f.read().strip().split("\n")
            resolved_requirements = []
            for line in requirements:
                if line.startswith("-r "):
                    resolved_requirements += _read_requirements(line.split()[1])
                else:
                    resolved_requirements.append(line)
            return resolved_requirements
    
        if _no_device():
            requirements = _read_requirements("requirements-cuda.txt")
        elif _is_cuda():
            requirements = _read_requirements("requirements-cuda.txt")
            cuda_major, cuda_minor = torch.version.cuda.split(".")
            modified_requirements = []
            for req in requirements:
                if ("vllm-flash-attn" in req
                        and not (cuda_major == "12" and cuda_minor == "1")):
                    # vllm-flash-attn is built only for CUDA 12.1.
                    # Skip for other versions.
                    continue
                modified_requirements.append(req)
            requirements = modified_requirements
        elif _is_hip():
            requirements = _read_requirements("requirements-rocm.txt")
        elif _is_neuron():
            requirements = _read_requirements("requirements-neuron.txt")
        elif _is_openvino():
            requirements = _read_requirements("requirements-openvino.txt")
        elif _is_tpu():
            requirements = _read_requirements("requirements-tpu.txt")
        elif _is_cpu():
            requirements = _read_requirements("requirements-cpu.txt")
        elif _is_xpu():
            requirements = _read_requirements("requirements-xpu.txt")
        elif _is_npu():
            requirements = _read_requirements("requirements-npu.txt")
        else:
            raise ValueError(
                "Unsupported platform, please use CUDA, ROCm, Neuron, "
                "OpenVINO, or CPU.")
        return requirements
    
    
    ext_modules = []
    
    if _build_core_ext():
        ext_modules.append(CMakeExtension(name="vllm._core_C"))
    
    if _is_cuda() or _is_hip():
        ext_modules.append(CMakeExtension(name="vllm._moe_C"))
    
    if _is_hip():
        ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
    
    if _is_cuda():
        ext_modules.append(
            CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
    
    if _build_custom_ops():
        ext_modules.append(CMakeExtension(name="vllm._C"))
    
    package_data = {
        "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
    }
    if envs.VLLM_USE_PRECOMPILED:
        ext_modules = []
        package_data["vllm"].append("*.so")
    
    if _no_device():
        ext_modules = []
    
    setup(
        name="vllm",
        version=get_vllm_version(),
        author="vLLM Team",
        license="Apache 2.0",
        description=("A high-throughput and memory-efficient inference and "
                     "serving engine for LLMs"),
        long_description=read_readme(),
        long_description_content_type="text/markdown",
        url="https://github.com/vllm-project/vllm",
        project_urls={
            "Homepage": "https://github.com/vllm-project/vllm",
            "Documentation": "https://vllm.readthedocs.io/en/latest/",
        },
        classifiers=[
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
            "Programming Language :: Python :: 3.10",
            "Programming Language :: Python :: 3.11",
            "Programming Language :: Python :: 3.12",
            "License :: OSI Approved :: Apache Software License",
            "Topic :: Scientific/Engineering :: Artificial Intelligence",
        ],
        packages=find_packages(exclude=("benchmarks", "csrc", "docs", "examples",
                                        "tests*")),
        python_requires=">=3.8",
        install_requires=get_requirements(),
        ext_modules=ext_modules,
        extras_require={
            "tensorizer": ["tensorizer>=2.9.0"],
            "video": ["opencv-python"],  # Required for video processing
            "audio": ["librosa", "soundfile"]  # Required for audio processing
        },
        cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
        package_data=package_data,
        entry_points={
            "console_scripts": [
                "vllm=vllm.scripts:main",
            ],
        },
    )
  • cover/vllm/attention/backends/mindie.py:基于vLLM 0.6.2版本,增加适配昇腾的AttentionBackend类,对昇腾环境下对接MindIE LLM所需要的attention计算数据以及KV Cache的shape等关键信息进行了定义。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of codes in this file was copied from project [vLLM Team][vllm]
    
    from dataclasses import dataclass
    from typing import TYPE_CHECKING, List, Optional, Tuple, Type
    
    import torch
    from atb_llm.utils.initial import NPUSocInfo
    from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata, AttentionMetadataBuilder
    
    if TYPE_CHECKING:
        from vllm.worker.npu_model_runner import ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata
    
    from vllm.attention.backends.utils import (
        PAD_SLOT_ID,
        CommonAttentionState,
        compute_slot_mapping,
        compute_slot_mapping_start_idx,
        is_block_tables_empty,
    )
    from vllm.utils import async_tensor_h2d, make_tensor_with_pad
    
    
    class MindIEAttentionBackend(AttentionBackend):
    
        @staticmethod
        def get_name() -> str:
            return "mindie-attn-backend"
    
        @staticmethod
        def get_impl_cls():
            return None
    
        @staticmethod
        def get_metadata_cls() -> Type["MindIEAttentionMetadata"]:
            return MindIEAttentionMetadata
    
        @staticmethod
        def get_builder_cls() -> Type["MindIEAttentionMetadataBuilder"]:
            return MindIEAttentionMetadataBuilder
    
        @staticmethod
        def get_state_cls() -> Type["CommonAttentionState"]:
            return CommonAttentionState
    
        @staticmethod
        def get_kv_cache_shape(
            num_blocks: int,
            block_size: int,
            num_kv_heads: int,
            head_size: int,
        ) -> Tuple[int, ...]:
            if not NPUSocInfo().need_nz:
                return (num_blocks, block_size, num_kv_heads, head_size)
            else:
                return (num_blocks, num_kv_heads * head_size // 16, block_size, 16)
    
        @staticmethod
        def swap_blocks(src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor) -> None: ...
    
        @staticmethod
        def copy_blocks(kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor) -> None: 
            for pair in src_to_dists:
                src, dst = pair.tolist()  # Convert tensor elements to Python integers
                for key_cache, value_cache in kv_caches:
                    key_cache.data[dst, :] = key_cache.data[src, :]
                    value_cache.data[dst, :] = value_cache.data[src, :]
    
    
    @dataclass
    class MindIEAttentionMetadata(AttentionMetadata):
        """Metadata for AscendAttentionBackend."""
    
        # (batch_size,). The sequence length per sequence. Sequence length means
        # the computed tokens + new tokens None if it is a decoding.
        seq_lens: Optional[List[int]]
        # seq_lens stored as a tensor.
        seq_lens_tensor: Optional[torch.Tensor]
    
        # Maximum query length in the batch.
        max_query_len: Optional[int]
        # Maximum sequence length in the batch.
        max_seq_len: Optional[int]
        # Maximum sequence length among prefill batch. 0 if there are decoding
        # requests only.
        max_prefill_seq_len: int
        # Maximum sequence length among decode batch. 0 if there are prefill
        # requests only.
        max_decode_seq_len: int
        # (batch_size + 1,). The cumulative subquery lengths of the sequences in
        # the batch, used to index into subquery. E.g., if the subquery length
        # is [4, 6], it is [0, 4, 10].
        query_start_loc: Optional[torch.Tensor]
        # (batch_size + 1,). The cumulative sequence lengths of the sequences in
        # the batch, used to index into sequence. E.g., if the sequence length is
        # [4, 6], it is [0, 4, 10].
        seq_start_loc: Optional[torch.Tensor]
        # (batch_size,) A tensor of context lengths (tokens that are computed
        # so far).
        context_lens_tensor: Optional[torch.Tensor]
    
        block_tables: Optional[torch.Tensor]
    
        # Whether or not if cuda graph is enabled.
        use_cuda_graph: bool
    
        _cached_prefill_metadata: Optional["MindIEAttentionMetadata"] = None
        _cached_decode_metadata: Optional["MindIEAttentionMetadata"] = None
    
        @property
        def prefill_metadata(self) -> Optional["MindIEAttentionMetadata"]:
            if self.num_prefills == 0:
                return None
    
            if self._cached_prefill_metadata is not None:
                return self._cached_prefill_metadata
    
            assert self.seq_lens is not None
            assert self.seq_lens_tensor is not None
            assert self.query_start_loc is not None
            assert self.context_lens_tensor is not None
            assert self.block_tables is not None
            assert self.seq_start_loc is not None
    
            self._cached_prefill_metadata = MindIEAttentionMetadata(
                num_prefills=self.num_prefills,
                num_prefill_tokens=self.num_prefill_tokens,
                num_decode_tokens=0,
                slot_mapping=self.slot_mapping[: self.num_prefill_tokens],
                seq_lens=self.seq_lens[: self.num_prefills],
                seq_lens_tensor=self.seq_lens_tensor[: self.num_prefills],
                max_query_len=self.max_query_len,
                max_seq_len=max(self.seq_lens),
                max_prefill_seq_len=self.max_prefill_seq_len,
                max_decode_seq_len=0,
                query_start_loc=self.query_start_loc[: self.num_prefills + 1],
                seq_start_loc=self.seq_start_loc[: self.num_prefills + 1],
                context_lens_tensor=self.context_lens_tensor[: self.num_prefills],
                block_tables=self.block_tables[: self.num_prefills],
                use_cuda_graph=False,
            )
            return self._cached_prefill_metadata
    
        @property
        def decode_metadata(self) -> Optional["MindIEAttentionMetadata"]:
            if self.num_decode_tokens == 0:
                return None
    
            if self._cached_decode_metadata is not None:
                return self._cached_decode_metadata
            assert self.block_tables is not None
            assert self.seq_lens_tensor is not None
    
            self._cached_decode_metadata = MindIEAttentionMetadata(
                num_prefills=0,
                num_prefill_tokens=0,
                num_decode_tokens=self.num_decode_tokens,
                slot_mapping=self.slot_mapping[self.num_prefill_tokens :],
                seq_lens=None,
                seq_lens_tensor=self.seq_lens_tensor[self.num_prefills :],
                max_query_len=None,
                max_seq_len=max(self.seq_lens),
                max_prefill_seq_len=0,
                max_decode_seq_len=self.max_decode_seq_len,
                query_start_loc=None,
                seq_start_loc=None,
                context_lens_tensor=None,
                block_tables=self.block_tables[self.num_prefills :],
                use_cuda_graph=self.use_cuda_graph,
            )
            return self._cached_decode_metadata
    
        def advance_step(
            self,
            model_input: "ModelInputForNPUWithSamplingMetadata",
            sampled_token_ids: Optional[torch.Tensor],
            block_size: int,
            num_seqs: int,
            num_queries: int,
        ):
            """
            Update metadata in-place to advance one decode step.
            """
            assert self.num_prefills == 0
            assert self.num_prefill_tokens == 0
            assert self.num_decode_tokens == num_seqs
            assert self.slot_mapping.shape == (num_seqs,)
    
            assert self.seq_lens is not None
            assert len(self.seq_lens) == num_seqs
            assert self.seq_lens_tensor is not None
            assert self.seq_lens_tensor.shape == (num_seqs,)
            assert self.max_query_len == 1
            assert self.max_prefill_seq_len == 0
            assert self.max_decode_seq_len == max(self.seq_lens)
    
            assert self.query_start_loc is not None
            assert self.query_start_loc.shape == (num_queries + 1,)
            assert self.seq_start_loc is not None
            assert self.seq_start_loc.shape == (num_seqs + 1,)
    
            assert self.context_lens_tensor is not None
            assert self.context_lens_tensor.shape == (num_queries,)
    
            assert self.block_tables is not None
            assert self.block_tables.shape[0] == num_seqs
    
            for i in range(num_queries):
                self.seq_lens[i] += 1
            self.max_decode_seq_len = max(self.seq_lens)
    
            advance_step_flashattn(
                num_seqs=num_seqs,
                num_queries=num_queries,
                block_size=block_size,
                input_tokens=model_input.input_tokens,
                sampled_token_ids=sampled_token_ids,
                input_positions=model_input.input_positions,
                seq_lens=self.seq_lens_tensor,
                slot_mapping=self.slot_mapping,
                block_tables=self.block_tables,
                block_tables_stride=self.block_tables.stride(0),
            )
    
    
    def advance_step_flashattn(
        num_seqs,
        num_queries,
        block_size,
        input_tokens,
        sampled_token_ids,
        input_positions,
        seq_lens,
        slot_mapping,
        block_tables,
        block_tables_stride,
    ):
    
        # Update input_tokens: matching the shape of input_tokens and sampled_token_ids
        input_tokens[:num_queries] = sampled_token_ids[:num_queries].squeeze(1)
    
        # Update sequence lengths and input positions
        next_seq_len = seq_lens[:num_queries] + 1
        next_input_pos = next_seq_len - 1
        seq_lens[:num_queries] = next_seq_len
        input_positions[:num_queries] = next_input_pos
    
        # Compute block indices and offsets
        block_index = next_input_pos // block_size
        block_offset = next_input_pos % block_size
    
        # Retrieve sequence-specific block tables
        seq_block_tables = block_tables[:num_queries, :block_tables_stride]
    
        # Use gather to map block indices to slots
        slot_num = seq_block_tables.gather(1, block_index.unsqueeze(1)).squeeze(1) * block_size + block_offset
        slot_mapping[:num_queries] = slot_num
    
    
    class MindIEAttentionMetadataBuilder(AttentionMetadataBuilder[MindIEAttentionMetadata]):
    
        def __init__(self, input_builder: "ModelInputForNPUBuilder"):
            self.slot_mapping: List[int] = []
            self.prefill_seq_lens: List[int] = []
            self.context_lens: List[int] = []
            self.block_tables: List[List[int]] = []
            self.curr_seq_lens: List[int] = []
            self.num_prefills = 0
            self.num_prefill_tokens = 0
            self.num_decode_tokens = 0
            self.has_prefix_cache_hit = False
    
            self.input_builder = input_builder
            self.runner = input_builder.runner
            self.sliding_window = input_builder.sliding_window
            self.block_size = input_builder.block_size
            self.use_v2_block_manager = input_builder.scheduler_config.use_v2_block_manager
    
        def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int):
            """Build attention metadata with on-device tensors.
    
            Args:
                seq_lens: The maybe padded sequence lengths of the input sequences.
                query_lens: The query lengths of the input sequences.
                cuda_graph_pad_size: The padding size for cuda graph.
                                     -1 if cuda graph is not used.
                batch_size: The maybe padded batch size.
            """
            prefix_cache_hit = any([inter_data.prefix_cache_hit for inter_data in self.input_builder.inter_data_list])
            for inter_data in self.input_builder.inter_data_list:
                self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled, prefix_cache_hit)
    
            device = self.runner.device
            use_captured_graph = cuda_graph_pad_size != -1
    
            max_query_len = max(query_lens)
            max_seq_len = max(seq_lens)
            max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
            max_decode_seq_len = max(self.curr_seq_lens, default=0)
            num_decode_tokens = self.num_decode_tokens
    
            block_tables = make_tensor_with_pad(
                    self.block_tables,
                    pad=0,
                    dtype=torch.int,
                    device=device,
                )
            assert max_query_len > 0, "query_lens: {}".format(query_lens)
    
            assert device is not None
            context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, device, self.runner.pin_memory)
            seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory)
            query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, self.runner.pin_memory)
            slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory)
            query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device)
            seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device)
            torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:])
            torch.cumsum(query_lens_tensor, dim=0, dtype=query_start_loc.dtype, out=query_start_loc[1:])
    
            # TODO: Remove the unnecessary params
            return MindIEAttentionMetadata(
                num_prefills=self.num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=self.num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
                seq_lens=seq_lens,
                seq_lens_tensor=seq_lens_tensor,
                max_query_len=max_query_len,
                max_seq_len=max_seq_len,
                max_prefill_seq_len=max_prefill_seq_len,
                max_decode_seq_len=max_decode_seq_len,
                query_start_loc=query_start_loc,
                seq_start_loc=seq_start_loc,
                context_lens_tensor=context_lens_tensor,
                block_tables=block_tables,
                use_cuda_graph=use_captured_graph,
            )
    
        def _add_seq_group(
            self,
            inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup",
            chunked_prefill_enabled: bool,
            prefix_cache_hit: bool,
        ):
            """Add a sequence group to the metadata. Specifically update/append
            1. context length.
            2. block table.
            3. slot mapping.
            """
            is_prompt = inter_data.is_prompt
            block_tables = inter_data.block_tables
    
            for seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, _ in zip(
                inter_data.seq_ids,
                [len(t) for t in inter_data.input_tokens],
                inter_data.orig_seq_lens,
                inter_data.seq_lens,
                inter_data.query_lens,
                inter_data.context_lens,
                inter_data.curr_sliding_window_blocks,
            ):
                self.context_lens.append(context_len)
    
                if is_prompt:
                    self.num_prefills += 1
                    self.num_prefill_tokens += token_len
                    self.prefill_seq_lens.append(seq_len)
                else:
                    assert query_len == 1, "seq_len: {}, context_len: {}, query_len: {}".format(
                        seq_len, context_len, query_len
                    )
                    self.num_decode_tokens += query_len
                    self.curr_seq_lens.append(curr_seq_len)
    
                # Compute block table.
                # TODO(sang): Combine chunked prefill and prefix caching by
                # only allowing multiple of block_size chunk size.
                # NOTE: This only works for oooooooxxx style attention.
                block_table = []
    
                # Adapt for prefix-cahce
                if inter_data.block_tables:
                    block_table = inter_data.block_tables[seq_id]
                self.block_tables.append(block_table)
    
                # Compute slot mapping.
                is_profile_run = is_block_tables_empty(block_tables)
                start_idx = compute_slot_mapping_start_idx(
                    is_prompt, query_len, context_len, self.sliding_window, self.use_v2_block_manager
                )
                compute_slot_mapping(
                    is_profile_run,
                    self.slot_mapping,
                    seq_id,
                    seq_len,
                    context_len,
                    start_idx,
                    self.block_size,
                    inter_data.block_tables,
                )
  • cover/vllm/attention/selector.py:基于vLLM0.6.2版本添加了对MINDIE Backend的支持,从而使框架在昇腾环境下会通过get_attn_backend运行MindIEAttentionBackend后端类。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of codes in this file was copied from project [vLLM Team][vllm]
    
    import enum
    import os
    from contextlib import contextmanager
    from functools import lru_cache
    from typing import Generator, Optional, Type
    
    import torch
    
    import vllm.envs as envs
    from vllm.attention.backends.abstract import AttentionBackend
    from vllm.logger import init_logger
    from vllm.platforms import current_platform
    from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu, is_npu
    
    logger = init_logger(__name__)
    
    
    class _Backend(enum.Enum):
        FLASH_ATTN = enum.auto()
        XFORMERS = enum.auto()
        ROCM_FLASH = enum.auto()
        TORCH_SDPA = enum.auto()
        OPENVINO = enum.auto()
        FLASHINFER = enum.auto()
        PALLAS = enum.auto()
        IPEX = enum.auto()
        MINDIE = enum.auto()
    
    
    def backend_name_to_enum(backend_name: str) -> _Backend:
        assert backend_name is not None
    
        backend_members = _Backend.__members__
        if backend_name not in backend_members:
            raise ValueError(f"Invalid attention backend '{backend_name}'. "
                             f"Available backends: {', '.join(backend_members)} "
                             "(case-sensitive).")
    
        return _Backend[backend_name]
    
    
    def get_env_variable_attn_backend() -> Optional[_Backend]:
        '''
        Get the backend override specified by the vLLM attention
        backend environment variable, if one is specified.
    
        Returns:
    
        * _Backend enum value if an override is specified
        * None otherwise
        '''
        backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
        return (None
                if backend_name is None else backend_name_to_enum(backend_name))
    
    
    # Global state allows a particular choice of backend
    # to be forced, overriding the logic which auto-selects
    # a backend based on system & workload configuration
    # (default behavior if this variable is None)
    #
    # THIS SELECTION TAKES PRECEDENCE OVER THE
    # VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE
    forced_attn_backend: Optional[_Backend] = None
    
    
    def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
        '''
        Force all attention operations to use a specified backend.
    
        Passing `None` for the argument re-enables automatic
        backend selection.,
    
        Arguments:
    
        * attn_backend: backend selection (None to revert to auto)
        '''
        global forced_attn_backend
        forced_attn_backend = attn_backend
    
    
    def get_global_forced_attn_backend() -> Optional[_Backend]:
        '''
        Get the currently-forced choice of attention backend,
        or None if auto-selection is currently enabled.
        '''
        return forced_attn_backend
    
    
    @lru_cache(maxsize=None)
    def get_attn_backend(
        num_heads: int,
        head_size: int,
        num_kv_heads: int,
        sliding_window: Optional[int],
        dtype: torch.dtype,
        kv_cache_dtype: Optional[str],
        block_size: int,
        is_blocksparse: bool = False,
    ) -> Type[AttentionBackend]:
        """Selects which attention backend to use and lazily imports it."""
    
        if is_blocksparse:
            logger.info("Using BlocksparseFlashAttention backend.")
            from vllm.attention.backends.blocksparse_attn import (
                BlocksparseFlashAttentionBackend)
            return BlocksparseFlashAttentionBackend
    
        backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
                                    sliding_window, dtype, kv_cache_dtype,
                                    block_size)
        if backend == _Backend.FLASH_ATTN:
            from vllm.attention.backends.flash_attn import (  # noqa: F401
                FlashAttentionBackend)
            return FlashAttentionBackend
        if backend == _Backend.XFORMERS:
            logger.info("Using XFormers backend.")
            from vllm.attention.backends.xformers import (  # noqa: F401
                XFormersBackend)
            return XFormersBackend
        elif backend == _Backend.ROCM_FLASH:
            logger.info("Using ROCmFlashAttention backend.")
            from vllm.attention.backends.rocm_flash_attn import (  # noqa: F401
                ROCmFlashAttentionBackend)
            return ROCmFlashAttentionBackend
        elif backend == _Backend.TORCH_SDPA:
            assert is_cpu(), RuntimeError(
                "Torch SDPA backend is only used for the CPU device.")
            logger.info("Using Torch SDPA backend.")
            from vllm.attention.backends.torch_sdpa import TorchSDPABackend
            return TorchSDPABackend
        elif backend == _Backend.OPENVINO:
            logger.info("Using OpenVINO Attention backend.")
            from vllm.attention.backends.openvino import OpenVINOAttentionBackend
            return OpenVINOAttentionBackend
        elif backend == _Backend.IPEX:
            assert is_xpu(), RuntimeError(
                "IPEX attention backend is only used for the XPU device.")
            logger.info("Using IPEX attention backend.")
            from vllm.attention.backends.ipex_attn import IpexAttnBackend
            return IpexAttnBackend
        elif backend == _Backend.FLASHINFER:
            logger.info("Using Flashinfer backend.")
            from vllm.attention.backends.flashinfer import FlashInferBackend
            return FlashInferBackend
        elif backend == _Backend.PALLAS:
            logger.info("Using Pallas backend.")
            from vllm.attention.backends.pallas import PallasAttentionBackend
            return PallasAttentionBackend
        elif backend == _Backend.MINDIE:
            logger.info("Using MindIE backend.")
            from vllm.attention.backends.mindie import MindIEAttentionBackend
            return MindIEAttentionBackend
        else:
            raise ValueError("Invalid attention backend.")
    
    
    def which_attn_to_use(
        num_heads: int,
        head_size: int,
        num_kv_heads: int,
        sliding_window: Optional[int],
        dtype: torch.dtype,
        kv_cache_dtype: Optional[str],
        block_size: int,
    ) -> _Backend:
        """Returns which flash attention backend to use."""
        # Default case.
        selected_backend = _Backend.FLASH_ATTN
    
        # Check whether a particular choice of backend was
        # previously forced.
        #
        # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
        # ENVIRONMENT VARIABLE.
        backend_by_global_setting: Optional[_Backend] = (
            get_global_forced_attn_backend())
        if backend_by_global_setting is not None:
            selected_backend = backend_by_global_setting
        else:
            # Check the environment variable and override if specified
            backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
            if backend_by_env_var is not None:
                selected_backend = backend_name_to_enum(backend_by_env_var)
    
        if is_cpu():
            if selected_backend != _Backend.TORCH_SDPA:
                logger.info("Cannot use %s backend on CPU.", selected_backend)
            return _Backend.TORCH_SDPA
    
        if is_openvino():
            if selected_backend != _Backend.OPENVINO:
                logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
            return _Backend.OPENVINO
    
        if is_xpu():
            if selected_backend != _Backend.IPEX:
                logger.info("Cannot use %s backend on XPU.", selected_backend)
            return _Backend.IPEX
    
        if current_platform.is_tpu():
            if selected_backend != _Backend.PALLAS:
                logger.info("Cannot use %s backend on TPU.", selected_backend)
            return _Backend.PALLAS
    
        if is_hip():
            # AMD GPUs.
            selected_backend = (_Backend.ROCM_FLASH if selected_backend
                                == _Backend.FLASH_ATTN else selected_backend)
            if selected_backend == _Backend.ROCM_FLASH:
                if not current_platform.has_device_capability(90):
                    # not Instinct series GPUs.
                    logger.info("flash_attn is not supported on NAVI GPUs.")
            else:
                logger.info("%s is not supported in AMD GPUs.", selected_backend)
            return _Backend.ROCM_FLASH
    
        if is_npu():
            if selected_backend != _Backend.MINDIE:
                logger.info("Cannot use %s backend on CPU.", selected_backend)
            return _Backend.MINDIE
    
        # FlashAttn in NVIDIA GPUs.
        if selected_backend == _Backend.FLASH_ATTN:
            if not current_platform.has_device_capability(80):
                # Volta and Turing NVIDIA GPUs.
                logger.info(
                    "Cannot use FlashAttention-2 backend for Volta and Turing "
                    "GPUs.")
                selected_backend = _Backend.XFORMERS
            elif dtype not in (torch.float16, torch.bfloat16):
                logger.info(
                    "Cannot use FlashAttention-2 backend for dtype other than "
                    "torch.float16 or torch.bfloat16.")
                selected_backend = _Backend.XFORMERS
            elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
                logger.info(
                    "Cannot use FlashAttention-2 backend for FP8 KV cache.")
                logger.warning(
                    "Please use FlashInfer backend with FP8 KV Cache for "
                    "better performance by setting environment variable  "
                    "VLLM_ATTENTION_BACKEND=FLASHINFER")
                selected_backend = _Backend.XFORMERS
            elif block_size % 16 != 0:
                logger.info(
                    "Cannot use FlashAttention-2 backend for block size not "
                    "divisible by 16.")
                selected_backend = _Backend.XFORMERS
            elif sliding_window is not None:
                logger.info(
                    "Cannot use FlashAttention-2 backend due to sliding window.")
                selected_backend = _Backend.XFORMERS
    
        # FlashAttn is valid for the model, checking if the package is installed.
        if selected_backend == _Backend.FLASH_ATTN:
            try:
                import vllm.vllm_flash_attn  # noqa: F401
                from vllm.attention.backends.flash_attn import (  # noqa: F401
                    FlashAttentionBackend)
    
                supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
                if head_size not in supported_sizes:
                    logger.info(
                        "Cannot use FlashAttention-2 backend for head size %d.",
                        head_size)
                    selected_backend = _Backend.XFORMERS
            except ImportError:
                logger.info(
                    "Cannot use FlashAttention-2 backend because the "
                    "vllm.vllm_flash_attn package is not found. "
                    "Make sure that vllm_flash_attn was built and installed "
                    "(on by default).")
                selected_backend = _Backend.XFORMERS
    
        return selected_backend
    
    @contextmanager
    def global_force_attn_backend_context_manager(
            attn_backend: _Backend) -> Generator[None, None, None]:
        '''
        Globally force a vLLM attention backend override within a
        context manager, reverting the global attention backend
        override to its prior state upon exiting the context
        manager.
    
        Arguments:
    
        * attn_backend: attention backend to force
    
        Returns:
    
        * Generator
        '''
    
        # Save the current state of the global backend override (if any)
        original_value = get_global_forced_attn_backend()
    
        # Globally force the new backend override
        global_force_attn_backend(attn_backend)
    
        # Yield control back to the enclosed code block
        try:
            yield
        finally:
            # Revert the original global backend override, if any
            global_force_attn_backend(original_value)
    
  • cover/vllm/config.py:加入对昇腾环境NPU卡的识别,将npu添加至device_type。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of code in this file was copied from project [vLLM Team][vllm] for adapting usage
    
    import enum
    import json
    from dataclasses import dataclass, field, fields
    from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping,
                        Optional, Tuple, Type, Union)
    
    import torch
    from transformers import PretrainedConfig
    
    import vllm.envs as envs
    from vllm.logger import init_logger
    from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
    from vllm.model_executor.models import ModelRegistry
    from vllm.platforms import current_platform
    from vllm.tracing import is_otel_available, otel_import_error_traceback
    from vllm.transformers_utils.config import (ConfigFormat, get_config,
                                                get_hf_image_processor_config,
                                                get_hf_text_config)
    from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
                            is_hip, is_neuron, is_openvino, is_xpu, is_npu,
                            print_warning_once)
    
    if TYPE_CHECKING:
        from ray.util.placement_group import PlacementGroup
    
        from vllm.executor.executor_base import ExecutorBase
        from vllm.model_executor.model_loader.loader import BaseModelLoader
        from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
            BaseTokenizerGroup)
    
    logger = init_logger(__name__)
    
    _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
    _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096
    
    _PP_SUPPORTED_MODELS = [
        "AquilaForCausalLM",
        "AquilaModel",
        "DeepseekV2ForCausalLM",
        "GPT2LMHeadModel",
        "InternLM2ForCausalLM",
        "InternLMForCausalLM",
        "InternVLChatModel",
        "JAISLMHeadModel",
        "LlamaForCausalLM",
        "LLaMAForCausalLM",
        "MistralForCausalLM",
        "MixtralForCausalLM",
        "NemotronForCausalLM",
        "Phi3ForCausalLM",
        "Qwen2ForCausalLM",
        "Qwen2MoeForCausalLM",
        "QWenLMHeadModel",
        "Qwen2VLForConditionalGeneration",
    ]
    
    
    class ModelConfig:
        """Configuration for the model.
    
        Args:
            model: Name or path of the huggingface model to use.
                It is also used as the content for `model_name` tag in metrics 
                output when `served_model_name` is not specified. 
            tokenizer: Name or path of the huggingface tokenizer to use.
            tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
                available, "slow" will always use the slow tokenizer, and
                "mistral" will always use the tokenizer from `mistral_common`.
            trust_remote_code: Trust remote code (e.g., from HuggingFace) when
                downloading the model and tokenizer.
            dtype: Data type for model weights and activations. The "auto" option
                will use FP16 precision for FP32 and FP16 models, and BF16 precision
                for BF16 models.
            seed: Random seed for reproducibility.
            revision: The specific model version to use. It can be a branch name,
                a tag name, or a commit id. If unspecified, will use the default
                version.
            code_revision: The specific revision to use for the model code on
                Hugging Face Hub. It can be a branch name, a tag name, or a
                commit id. If unspecified, will use the default version.
            rope_scaling: Dictionary containing the scaling configuration for the
                RoPE embeddings. When using this flag, don't update
                `max_position_embeddings` to the expected new maximum.
            tokenizer_revision: The specific tokenizer version to use. It can be a
                branch name, a tag name, or a commit id. If unspecified, will use
                the default version.
            max_model_len: Maximum length of a sequence (including prompt and
                output). If None, will be derived from the model.
            quantization: Quantization method that was used to quantize the model
                weights. If None, we assume the model weights are not quantized.
            quantization_param_path: Path to JSON file containing scaling factors.
                Used to load KV cache scaling factors into the model when KV cache
                type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
                be used to load activation and weight scaling factors when the
                model dtype is FP8_E4M3 on ROCm.
            enforce_eager: Whether to enforce eager execution. If True, we will
                disable CUDA graph and always execute the model in eager mode.
                If False, we will use CUDA graph and eager execution in hybrid.
                If None, the user did not specify, so default to False.
            max_context_len_to_capture: Maximum context len covered by CUDA graphs.
                When a sequence has context length larger than this, we fall back
                to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
            max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
                When a sequence has context length larger than this, we fall back
                to eager mode. Additionally for encoder-decoder models, if the
                sequence length of the encoder input is larger than this, we fall
                back to the eager mode.
            disable_sliding_window: Whether to disable sliding window. If True,
                we will disable the sliding window functionality of the model.
                If the model does not support sliding window, this argument is
                ignored.
            skip_tokenizer_init: If true, skip initialization of tokenizer and
                detokenizer.
            served_model_name: The model name used in metrics tag `model_name`,
                matches the model name exposed via the APIs. If multiple model 
                names provided, the first name will be used. If not specified, 
                the model name will be the same as `model`.
            limit_mm_per_prompt: Maximum number of data instances per modality 
                per prompt. Only applicable for multimodal models.
            override_neuron_config: Initialize non default neuron config or 
                override default neuron config that are specific to Neuron devices, 
                this argument will be used to configure the neuron config that 
                can not be gathered from the vllm arguments. 
            config_format: The config format which shall be loaded.
                Defaults to 'auto' which defaults to 'hf'.
            mm_processor_kwargs: Arguments to be forwarded to the model's processor
                for multi-modal data, e.g., image processor.
        """
    
        def __init__(self,
                     model: str,
                     tokenizer: str,
                     tokenizer_mode: str,
                     trust_remote_code: bool,
                     dtype: Union[str, torch.dtype],
                     seed: int,
                     revision: Optional[str] = None,
                     code_revision: Optional[str] = None,
                     rope_scaling: Optional[dict] = None,
                     rope_theta: Optional[float] = None,
                     tokenizer_revision: Optional[str] = None,
                     max_model_len: Optional[int] = None,
                     spec_target_max_model_len: Optional[int] = None,
                     quantization: Optional[str] = None,
                     quantization_param_path: Optional[str] = None,
                     enforce_eager: Optional[bool] = None,
                     max_context_len_to_capture: Optional[int] = None,
                     max_seq_len_to_capture: Optional[int] = None,
                     max_logprobs: int = 20,
                     disable_sliding_window: bool = False,
                     skip_tokenizer_init: bool = False,
                     served_model_name: Optional[Union[str, List[str]]] = None,
                     limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
                     use_async_output_proc: bool = True,
                     override_neuron_config: Optional[Dict[str, Any]] = None,
                     config_format: ConfigFormat = ConfigFormat.AUTO,
                     mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
            self.model = model
            self.tokenizer = tokenizer
            self.tokenizer_mode = tokenizer_mode
            self.trust_remote_code = trust_remote_code
            self.seed = seed
            self.revision = revision
            self.code_revision = code_revision
            self.rope_scaling = rope_scaling
            self.rope_theta = rope_theta
            # The tokenizer version is consistent with the model version by default.
            if tokenizer_revision is None:
                self.tokenizer_revision = revision
            else:
                self.tokenizer_revision = tokenizer_revision
            self.quantization = quantization
            self.quantization_param_path = quantization_param_path
            self.enforce_eager = enforce_eager
            if max_context_len_to_capture is not None:
                raise ValueError("`max_context_len_to_capture` is deprecated. "
                                 "Use `max_seq_len_to_capture` instead.")
            self.max_seq_len_to_capture = max_seq_len_to_capture
            self.max_logprobs = max_logprobs
            self.disable_sliding_window = disable_sliding_window
            self.skip_tokenizer_init = skip_tokenizer_init
    
            self.hf_config = get_config(self.model, trust_remote_code, revision,
                                        code_revision, rope_scaling, rope_theta,
                                        config_format)
            self.hf_text_config = get_hf_text_config(self.hf_config)
            self.hf_image_processor_config = get_hf_image_processor_config(
                self.model, revision)
            self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
            self.use_async_output_proc = use_async_output_proc
            self.mm_processor_kwargs = mm_processor_kwargs
    
            # Set enforce_eager to False if the value is unset.
            if self.enforce_eager is None:
                self.enforce_eager = False
    
            if (not self.disable_sliding_window
                    and self.hf_text_config.model_type == "gemma2"
                    and self.hf_text_config.sliding_window is not None):
                print_warning_once(
                    "Gemma 2 uses sliding window attention for every odd layer, "
                    "which is currently not supported by vLLM. Disabling sliding "
                    "window and capping the max length to the sliding window size "
                    f"({self.hf_text_config.sliding_window}).")
                self.disable_sliding_window = True
    
            self.max_model_len = _get_and_verify_max_len(
                hf_config=self.hf_text_config,
                max_model_len=max_model_len,
                disable_sliding_window=self.disable_sliding_window,
                sliding_window_len=self.get_hf_config_sliding_window(),
                spec_target_max_model_len=spec_target_max_model_len)
            self.served_model_name = get_served_model_name(model,
                                                           served_model_name)
            self.multimodal_config = self._init_multimodal_config(
                limit_mm_per_prompt)
            if not self.skip_tokenizer_init:
                self._verify_tokenizer_mode()
    
            self.override_neuron_config = override_neuron_config if is_neuron(
            ) else None
            self._verify_embedding_mode()
            self._verify_quantization()
            self._verify_cuda_graph()
            self._verify_bnb_config()
    
        @property
        def is_encoder_decoder_model(self) -> bool:
            """Extract the HF encoder/decoder model flag."""
            return getattr(self.hf_config, "is_encoder_decoder", False) or (
                (hasattr(self.hf_config, "text_config") and getattr(
                    self.hf_config.text_config, "is_encoder_decoder", False)))
    
        @property
        def is_embedding_model(self) -> bool:
            """Extract the embedding model flag."""
            return self.embedding_mode
    
        @property
        def is_multimodal_model(self) -> bool:
            return self.multimodal_config is not None
    
        def verify_async_output_proc(self, parallel_config, speculative_config,
                                     device_config) -> None:
            if not self.use_async_output_proc:
                # Nothing to check
                return
    
            if parallel_config.pipeline_parallel_size > 1:
                logger.warning("Async output processing can not be enabled "
                               "with pipeline parallel")
                self.use_async_output_proc = False
                return
    
            if device_config.device_type not in ("cuda", "tpu", "npu"):
                logger.warning(
                    "Async output processing is only supported for CUDA, NPU or TPU. "
                    "Disabling it for other platforms.")
                self.use_async_output_proc = False
                return
    
            if envs.VLLM_USE_RAY_SPMD_WORKER:
                logger.warning(
                    "Async output processing can not be enabled with ray spmd")
                self.use_async_output_proc = False
                return
    
            if device_config.device_type == "cuda" and self.enforce_eager:
                logger.warning(
                    "To see benefits of async output processing, enable CUDA "
                    "graph. Since, enforce-eager is enabled, async output "
                    "processor cannot be used")
                self.use_async_output_proc = not self.enforce_eager
                return
    
            # Async postprocessor is not necessary with embedding mode
            # since there is no token generation
            if self.embedding_mode:
                self.use_async_output_proc = False
    
            if speculative_config:
                logger.warning("Async output processing is not supported with"
                               " speculative decoding currently.")
                self.use_async_output_proc = False
    
        def verify_with_parallel_config(
            self,
            parallel_config: "ParallelConfig",
        ) -> None:
            total_num_attention_heads = getattr(self.hf_text_config,
                                                "num_attention_heads", 0)
            tensor_parallel_size = parallel_config.tensor_parallel_size
            if total_num_attention_heads % tensor_parallel_size != 0:
                raise ValueError(
                    f"Total number of attention heads ({total_num_attention_heads})"
                    " must be divisible by tensor parallel size "
                    f"({tensor_parallel_size}).")
    
            pipeline_parallel_size = parallel_config.pipeline_parallel_size
            architectures = getattr(self.hf_config, "architectures", [])
            if not all(arch in _PP_SUPPORTED_MODELS
                       for arch in architectures) and pipeline_parallel_size > 1:
                raise NotImplementedError(
                    "Pipeline parallelism is only supported for the following "
                    f" architectures: {_PP_SUPPORTED_MODELS}.")
    
            if pipeline_parallel_size > 1 and self.use_async_output_proc:
                logger.warning("Async output processor is not supported with "
                               "pipeline parallelism currently. Disabling it.")
                self.use_async_output_proc = False
    
        def get_hf_config_sliding_window(self) -> Optional[int]:
            """Get the sliding window size, or None if disabled."""
    
            # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
            # addition to sliding window size. We check if that field is present
            # and if it's False, return None.
            if (hasattr(self.hf_text_config, "use_sliding_window")
                    and not self.hf_text_config.use_sliding_window):
                return None
            return getattr(self.hf_text_config, "sliding_window", None)
    
        def get_sliding_window(self) -> Optional[int]:
            """Get the sliding window size, or None if disabled.
            """
            # If user disables sliding window, return None.
            if self.disable_sliding_window:
                return None
            # Otherwise get the value from the hf config.
            return self.get_hf_config_sliding_window()
    
        def get_vocab_size(self) -> int:
            return self.hf_text_config.vocab_size
    
        def get_hidden_size(self) -> int:
            return self.hf_text_config.hidden_size
    
        def get_head_size(self) -> int:
            # TODO remove hard code
            if hasattr(self.hf_text_config, "model_type"
                       ) and self.hf_text_config.model_type == 'deepseek_v2':
                # FlashAttention supports only head_size 32, 64, 128, 256,
                # we need to pad head_size 192 to 256
                return 256
            if hasattr(self.hf_text_config, "head_dim"):
                return self.hf_text_config.head_dim
            # FIXME(woosuk): This may not be true for all models.
            return (self.hf_text_config.hidden_size //
                    self.hf_text_config.num_attention_heads)
    
        def get_total_num_kv_heads(self) -> int:
            """Returns the total number of KV heads."""
            # For GPTBigCode & Falcon:
            # NOTE: for falcon, when new_decoder_architecture is True, the
            # multi_query flag is ignored and we use n_head_kv for the number of
            # KV heads.
            falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
            new_decoder_arch_falcon = (
                self.hf_config.model_type in falcon_model_types
                and getattr(self.hf_config, "new_decoder_architecture", False))
            if not new_decoder_arch_falcon and getattr(self.hf_text_config,
                                                       "multi_query", False):
                # Multi-query attention, only one KV head.
                # Currently, tensor parallelism is not supported in this case.
                return 1
    
            # For DBRX and MPT
            if self.hf_config.model_type == "mpt":
                if "kv_n_heads" in self.hf_config.attn_config:
                    return self.hf_config.attn_config["kv_n_heads"]
                return self.hf_config.num_attention_heads
            if self.hf_config.model_type == "dbrx":
                return getattr(self.hf_config.attn_config, "kv_n_heads",
                               self.hf_config.num_attention_heads)
    
            attributes = [
                # For Falcon:
                "n_head_kv",
                "num_kv_heads",
                # For LLaMA-2:
                "num_key_value_heads",
                # For ChatGLM:
                "multi_query_group_num",
            ]
            for attr in attributes:
                num_kv_heads = getattr(self.hf_text_config, attr, None)
                if num_kv_heads is not None:
                    return num_kv_heads
    
            # For non-grouped-query attention models, the number of KV heads is
            # equal to the number of attention heads.
            return self.hf_text_config.num_attention_heads
    
        def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
            """Returns the number of KV heads per GPU."""
            total_num_kv_heads = self.get_total_num_kv_heads()
            # If tensor parallelism is used, we divide the number of KV heads by
            # the tensor parallel size. We will replicate the KV heads in the
            # case where the number of KV heads is smaller than the tensor
            # parallel size so each GPU has at least one KV head.
            return max(1,
                       total_num_kv_heads // parallel_config.tensor_parallel_size)
    
        def get_num_attention_heads(self,
                                    parallel_config: "ParallelConfig") -> int:
            num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
            return num_heads // parallel_config.tensor_parallel_size
    
        def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
            from vllm.distributed.utils import get_pp_indices
            total_num_hidden_layers = getattr(self.hf_text_config,
                                              "num_hidden_layers", 0)
            pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
            pp_size = parallel_config.pipeline_parallel_size
            start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
            return end - start
    
        def contains_seqlen_agnostic_layers(
                self, parallel_config: "ParallelConfig") -> bool:
            """True for Mamba/SSM models (Jamba)"""
            return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
    
        def get_layers_block_type(self,
                                  parallel_config: "ParallelConfig") -> List[str]:
            num_layers = self.get_num_layers(parallel_config)
            # Transformers supports layers_block_type @property
            return getattr(self.hf_config, "layers_block_type",
                           ["attention"] * num_layers)
    
        def get_num_attention_layers(self,
                                     parallel_config: "ParallelConfig") -> int:
            return len([
                t for t in self.get_layers_block_type(parallel_config)
                if t == "attention"
            ])
    
        def get_multimodal_config(self) -> "MultiModalConfig":
            """
            Get the multimodal configuration of the model.
    
            Raises:
                ValueError: If the model is not multimodal.
            """
            if self.multimodal_config is None:
                raise ValueError("The model is not multimodal.")
    
            return self.multimodal_config
    
        def _init_multimodal_config(
            self, limit_mm_per_prompt: Optional[Mapping[str, int]]
        ) -> Optional["MultiModalConfig"]:
            architectures = getattr(self.hf_config, "architectures", [])
            if any(
                    ModelRegistry.is_multimodal_model(arch)
                    for arch in architectures):
                return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
            else:
                if limit_mm_per_prompt:
                    raise ValueError(
                        "limit_mm_per_prompt is only supported for multimodal "
                        "models.")
                return None
    
        def _verify_tokenizer_mode(self) -> None:
            tokenizer_mode = self.tokenizer_mode.lower()
            if tokenizer_mode not in ["auto", "slow", "mistral"]:
                raise ValueError(
                    f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
                    "either 'auto', 'slow' or 'mistral'.")
            self.tokenizer_mode = tokenizer_mode
    
        def _verify_embedding_mode(self) -> None:
            architectures = getattr(self.hf_config, "architectures", [])
            self.embedding_mode = any(
                ModelRegistry.is_embedding_model(arch) for arch in architectures)
    
        def _parse_quant_hf_config(self):
            quant_cfg = getattr(self.hf_config, "quantization_config", None)
            if quant_cfg is None:
                # compressed-tensors uses a "compression_config" key
                quant_cfg = getattr(self.hf_config, "compression_config", None)
            return quant_cfg
    
        def _verify_quantization(self) -> None:
            supported_quantization = [*QUANTIZATION_METHODS]
            rocm_supported_quantization = [
                "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
                "fbgemm_fp8"
            ]
            optimized_quantization_methods = [
                "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
                "awq_marlin", "fbgemm_fp8", "compressed_tensors",
                "compressed-tensors", "experts_int8"
            ]
            tpu_supported_quantization = ["tpu_int8"]
            neuron_supported_quantization = ["neuron_quant"]
            if self.quantization is not None:
                self.quantization = self.quantization.lower()
    
            # Parse quantization method from the HF model config, if available.
            quant_cfg = self._parse_quant_hf_config()
    
            if quant_cfg is not None:
                quant_method = quant_cfg.get("quant_method", "").lower()
    
                # Detect which checkpoint is it
                for _, method in QUANTIZATION_METHODS.items():
                    quantization_override = method.override_quantization_method(
                        quant_cfg, self.quantization)
                    if quantization_override:
                        quant_method = quantization_override
                        self.quantization = quantization_override
                        break
    
                # Verify quantization configurations.
                if self.quantization is None:
                    self.quantization = quant_method
                elif self.quantization != quant_method:
                    raise ValueError(
                        "Quantization method specified in the model config "
                        f"({quant_method}) does not match the quantization "
                        f"method specified in the `quantization` argument "
                        f"({self.quantization}).")
    
            if self.quantization is not None:
                if self.quantization not in supported_quantization:
                    raise ValueError(
                        f"Unknown quantization method: {self.quantization}. Must "
                        f"be one of {supported_quantization}.")
                if is_hip(
                ) and self.quantization not in rocm_supported_quantization:
                    raise ValueError(
                        f"{self.quantization} quantization is currently not "
                        f"supported in ROCm.")
                if current_platform.is_tpu(
                ) and self.quantization not in tpu_supported_quantization:
                    raise ValueError(
                        f"{self.quantization} quantization is currently not "
                        f"supported in TPU Backend.")
                if self.quantization not in optimized_quantization_methods:
                    logger.warning(
                        "%s quantization is not fully "
                        "optimized yet. The speed can be slower than "
                        "non-quantized models.", self.quantization)
                if (self.quantization == "awq" and is_hip()
                        and not envs.VLLM_USE_TRITON_AWQ):
                    logger.warning(
                        "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
                        " is not set, enabling VLLM_USE_TRITON_AWQ.")
                    envs.VLLM_USE_TRITON_AWQ = True
                if is_neuron(
                ) and self.quantization not in neuron_supported_quantization:
                    raise ValueError(
                        f"{self.quantization} quantization is currently not "
                        f"supported in Neuron Backend.")
    
        def _verify_cuda_graph(self) -> None:
            if self.max_seq_len_to_capture is None:
                self.max_seq_len_to_capture = self.max_model_len
            self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
                                              self.max_model_len)
    
        def _verify_bnb_config(self) -> None:
            """
            The current version of bitsandbytes (0.44.0) with 8-bit models does not 
            yet support CUDA graph.
            """
            is_bitsandbytes = self.quantization == "bitsandbytes"
            has_quantization_config = (getattr(self.hf_config,
                                               "quantization_config", None)
                                       is not None)
            is_8bit = (self.hf_config.quantization_config.get(
                "load_in_8bit", False) if has_quantization_config else False)
            if all([
                    is_bitsandbytes,
                    has_quantization_config,
                    is_8bit,
                    not self.enforce_eager,
            ]):
                logger.warning(
                    "CUDA graph is not supported on BitAndBytes 8bit yet, "
                    "fallback to the eager mode.")
                self.enforce_eager = True
    
        def _get_num_seqlen_agnostic_layers(
                self, parallel_config: "ParallelConfig") -> int:
            return len([
                t for t in self.get_layers_block_type(parallel_config)
                if t != "attention"
            ])
    
    
    class CacheConfig:
        """Configuration for the KV cache.
    
        Args:
            block_size: Size of a cache block in number of tokens.
            gpu_memory_utilization: Fraction of GPU memory to use for the
                vLLM execution.
            swap_space: Size of the CPU swap space per GPU (in GiB).
            cache_dtype: Data type for kv cache storage.
            num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
                profiled num_gpu_blocks if specified. Does nothing if None.
        """
    
        def __init__(
            self,
            block_size: int,
            gpu_memory_utilization: float,
            swap_space: float,
            cache_dtype: str,
            num_gpu_blocks_override: Optional[int] = None,
            sliding_window: Optional[int] = None,
            enable_prefix_caching: bool = False,
            cpu_offload_gb: float = 0,
        ) -> None:
            self.block_size = block_size
            self.gpu_memory_utilization = gpu_memory_utilization
            self.swap_space_bytes = swap_space * GiB_bytes
            self.num_gpu_blocks_override = num_gpu_blocks_override
            self.cache_dtype = cache_dtype
            self.sliding_window = sliding_window
            self.enable_prefix_caching = enable_prefix_caching
            self.cpu_offload_gb = cpu_offload_gb
            self._verify_args()
            self._verify_cache_dtype()
            self._verify_prefix_caching()
    
            # Will be set after profiling.
            self.num_gpu_blocks = None
            self.num_cpu_blocks = None
    
        def metrics_info(self):
            # convert cache_config to dict(key: str, value: str) for prometheus
            # metrics info
            return {key: str(value) for key, value in self.__dict__.items()}
    
        def verify_with_parallel_config(
            self,
            parallel_config: "ParallelConfig",
        ) -> None:
            total_cpu_memory = get_cpu_memory()
            # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
            # group are in the same node. However, the GPUs may span multiple nodes.
            num_gpus_per_node = parallel_config.tensor_parallel_size
            cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
    
            msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
                   f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
                   "is allocated for the swap space.")
            if cpu_memory_usage > 0.7 * total_cpu_memory:
                raise ValueError("Too large swap space. " + msg)
            elif cpu_memory_usage > 0.4 * total_cpu_memory:
                logger.warning("Possibly too large swap space. %s", msg)
    
        def _verify_args(self) -> None:
            if self.gpu_memory_utilization > 1.0:
                raise ValueError(
                    "GPU memory utilization must be less than 1.0. Got "
                    f"{self.gpu_memory_utilization}.")
    
        def _verify_cache_dtype(self) -> None:
            if self.cache_dtype == "auto":
                pass
            elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
                logger.info(
                    "Using fp8 data type to store kv cache. It reduces the GPU "
                    "memory footprint and boosts the performance. "
                    "Meanwhile, it may cause accuracy drop without a proper "
                    "scaling factor")
            else:
                raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
    
        def _verify_prefix_caching(self) -> None:
            if not self.enable_prefix_caching:
                return
    
            if self.sliding_window is not None:
                raise NotImplementedError(
                    "Prefix caching is not supported with sliding window. "
                    "Run with --disable-sliding-window to use prefix caching.")
    
    
    @dataclass
    class TokenizerPoolConfig:
        """Configuration for the tokenizer pool.
    
        Args:
            pool_size: Number of tokenizer workers in the pool.
            pool_type: Type of the pool.
            extra_config: Additional config for the pool.
                The way the config will be used depends on the
                pool type.
        """
        pool_size: int
        pool_type: Union[str, Type["BaseTokenizerGroup"]]
        extra_config: dict
    
        def __post_init__(self):
            if self.pool_type not in ("ray", ) and not isinstance(
                    self.pool_type, type):
                raise ValueError(f"Unknown pool type: {self.pool_type}")
            if not isinstance(self.extra_config, dict):
                raise ValueError("extra_config must be a dictionary.")
    
        @classmethod
        def create_config(
            cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
            tokenizer_pool_extra_config: Optional[Union[str, dict]]
        ) -> Optional["TokenizerPoolConfig"]:
            """Create a TokenizerPoolConfig from the given parameters.
    
            If tokenizer_pool_size is 0, return None.
    
            Args:
                tokenizer_pool_size: Number of tokenizer workers in the pool.
                tokenizer_pool_type: Type of the pool.
                tokenizer_pool_extra_config: Additional config for the pool.
                    The way the config will be used depends on the
                    pool type. This can be a JSON string (will be parsed).
            """
            if tokenizer_pool_size:
                if isinstance(tokenizer_pool_extra_config, str):
                    tokenizer_pool_extra_config_parsed = json.loads(
                        tokenizer_pool_extra_config)
                else:
                    tokenizer_pool_extra_config_parsed = (
                        tokenizer_pool_extra_config or {})
                tokenizer_pool_config = cls(tokenizer_pool_size,
                                            tokenizer_pool_type,
                                            tokenizer_pool_extra_config_parsed)
            else:
                tokenizer_pool_config = None
            return tokenizer_pool_config
    
    
    class LoadFormat(str, enum.Enum):
        AUTO = "auto"
        PT = "pt"
        SAFETENSORS = "safetensors"
        NPCACHE = "npcache"
        DUMMY = "dummy"
        TENSORIZER = "tensorizer"
        SHARDED_STATE = "sharded_state"
        GGUF = "gguf"
        BITSANDBYTES = "bitsandbytes"
        MISTRAL = "mistral"
    
    
    @dataclass
    class LoadConfig:
        """
            download_dir: Directory to download and load the weights, default to the
                default cache directory of huggingface.
            load_format: The format of the model weights to load:
                "auto" will try to load the weights in the safetensors format and
                    fall back to the pytorch bin format if safetensors format is
                    not available.
                "pt" will load the weights in the pytorch bin format.
                "safetensors" will load the weights in the safetensors format.
                "npcache" will load the weights in pytorch format and store
                    a numpy cache to speed up the loading.
                "dummy" will initialize the weights with random values, which is
                    mainly for profiling.
                "tensorizer" will use CoreWeave's tensorizer library for
                    fast weight loading.
                "bitsandbytes" will load nf4 type weights.
            ignore_patterns: The list of patterns to ignore when loading the model.
                Default to "original/**/*" to avoid repeated loading of llama's 
                checkpoints.
    
        """
    
        load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
        download_dir: Optional[str] = None
        model_loader_extra_config: Optional[Union[str, dict]] = field(
            default_factory=dict)
        ignore_patterns: Optional[Union[List[str], str]] = None
    
        def __post_init__(self):
            model_loader_extra_config = self.model_loader_extra_config or {}
            if isinstance(model_loader_extra_config, str):
                self.model_loader_extra_config = json.loads(
                    model_loader_extra_config)
            self._verify_load_format()
    
            if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
                logger.info(
                    "Ignoring the following patterns when downloading weights: %s",
                    self.ignore_patterns)
            else:
                self.ignore_patterns = ["original/**/*"]
    
        def _verify_load_format(self) -> None:
            if not isinstance(self.load_format, str):
                return
    
            load_format = self.load_format.lower()
            self.load_format = LoadFormat(load_format)
    
            rocm_not_supported_load_format: List[str] = []
            if is_hip() and load_format in rocm_not_supported_load_format:
                rocm_supported_load_format = [
                    f for f in LoadFormat.__members__
                    if (f not in rocm_not_supported_load_format)
                ]
                raise ValueError(
                    f"load format '{load_format}' is not supported in ROCm. "
                    f"Supported load formats are "
                    f"{rocm_supported_load_format}")
    
    
    class ParallelConfig:
        """Configuration for the distributed execution.
    
        Args:
            pipeline_parallel_size: Number of pipeline parallel groups.
            tensor_parallel_size: Number of tensor parallel groups.
            worker_use_ray: Deprecated, use distributed_executor_backend instead.
            max_parallel_loading_workers: Maximum number of multiple batches
                when load model sequentially. To avoid RAM OOM when using tensor
                parallel and large models.
            disable_custom_all_reduce: Disable the custom all-reduce kernel and
                fall back to NCCL.
            tokenizer_pool_config: Config for the tokenizer pool.
                If None, will use synchronous tokenization.
            ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
                https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
            placement_group: ray distributed model workers placement group.
            distributed_executor_backend: Backend to use for distributed model
                workers, either "ray" or "mp" (multiprocessing). If either
                pipeline_parallel_size or tensor_parallel_size is greater than 1,
                will default to "ray" if Ray is installed or "mp" otherwise.
        """
    
        def __init__(
            self,
            pipeline_parallel_size: int,
            tensor_parallel_size: int,
            worker_use_ray: Optional[bool] = None,
            max_parallel_loading_workers: Optional[int] = None,
            disable_custom_all_reduce: bool = False,
            tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
            ray_workers_use_nsight: bool = False,
            placement_group: Optional["PlacementGroup"] = None,
            distributed_executor_backend: Optional[Union[
                str, Type["ExecutorBase"]]] = None,
        ) -> None:
            self.pipeline_parallel_size = pipeline_parallel_size
            self.tensor_parallel_size = tensor_parallel_size
            self.distributed_executor_backend = distributed_executor_backend
            self.max_parallel_loading_workers = max_parallel_loading_workers
            self.disable_custom_all_reduce = disable_custom_all_reduce
            self.tokenizer_pool_config = tokenizer_pool_config
            self.ray_workers_use_nsight = ray_workers_use_nsight
            self.placement_group = placement_group
            self.world_size = pipeline_parallel_size * self.tensor_parallel_size
    
            if worker_use_ray:
                if self.distributed_executor_backend is None:
                    self.distributed_executor_backend = "ray"
                elif not self.use_ray:
                    raise ValueError(f"worker-use-ray can't be used with "
                                     f"distributed executor backend "
                                     f"'{self.distributed_executor_backend}'.")
    
            if current_platform.is_tpu() and self.world_size > 1:
                if self.distributed_executor_backend is None:
                    self.distributed_executor_backend = "ray"
                if self.distributed_executor_backend != "ray":
                    raise ValueError(
                        "TPU backend only supports Ray for distributed inference.")
    
            if self.distributed_executor_backend is None and self.world_size > 1:
                # We use multiprocessing by default if world_size fits on the
                # current node and we aren't in a ray placement group.
    
                from vllm.executor import ray_utils
                backend = "mp"
                ray_found = ray_utils.ray_is_available()
                if (current_platform.is_cuda()
                        and cuda_device_count_stateless() < self.world_size):
                    if not ray_found:
                        raise ValueError("Unable to load Ray which is "
                                         "required for multi-node inference, "
                                         "please install Ray with `pip install "
                                         "ray`.") from ray_utils.ray_import_err
                    backend = "ray"
                elif ray_found:
                    if self.placement_group:
                        backend = "ray"
                    else:
                        from ray import is_initialized as ray_is_initialized
                        if ray_is_initialized():
                            from ray.util import get_current_placement_group
                            if get_current_placement_group():
                                backend = "ray"
                self.distributed_executor_backend = backend
                logger.info("Defaulting to use %s for distributed inference",
                            backend)
    
            self._verify_args()
            self.rank: int = 0
    
        @property
        def use_ray(self) -> bool:
            return self.distributed_executor_backend == "ray" or (
                isinstance(self.distributed_executor_backend, type)
                and self.distributed_executor_backend.uses_ray)
    
        def _verify_args(self) -> None:
            # Lazy import to avoid circular import
            from vllm.executor.executor_base import ExecutorBase
    
            if self.distributed_executor_backend not in (
                    "ray", "mp", None) and not (isinstance(
                        self.distributed_executor_backend, type) and issubclass(
                            self.distributed_executor_backend, ExecutorBase)):
                raise ValueError(
                    "Unrecognized distributed executor backend "
                    f"{self.distributed_executor_backend}. Supported "
                    "values are 'ray', 'mp' or custom ExecutorBase subclass.")
            if self.use_ray:
                from vllm.executor import ray_utils
                ray_utils.assert_ray_available()
            if is_hip():
                self.disable_custom_all_reduce = True
                logger.info(
                    "Disabled the custom all-reduce kernel because it is not "
                    "supported on AMD GPUs.")
            if self.ray_workers_use_nsight and not self.use_ray:
                raise ValueError("Unable to use nsight profiling unless workers "
                                 "run with Ray.")
    
    
    class SchedulerConfig:
        """Scheduler configuration.
    
        Args:
            max_num_batched_tokens: Maximum number of tokens to be processed in
                a single iteration.
            max_num_seqs: Maximum number of sequences to be processed in a single
                iteration.
            max_model_len: Maximum length of a sequence (including prompt
                and generated text).
            use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
            num_lookahead_slots: The number of slots to allocate per sequence per
                step, beyond the known token ids. This is used in speculative
                decoding to store KV activations of tokens which may or may not be
                accepted.
            delay_factor: Apply a delay (of delay factor multiplied by previous
                prompt latency) before scheduling next prompt.
            enable_chunked_prefill: If True, prefill requests can be chunked based
                on the remaining max_num_batched_tokens.
            embedding_mode: Whether the running model is for embedding.
            preemption_mode: Whether to perform preemption by swapping or 
                recomputation. If not specified, we determine the mode as follows:
                We use recomputation by default since it incurs lower overhead than
                swapping. However, when the sequence group has multiple sequences
                (e.g., beam search), recomputation is not currently supported. In
                such a case, we use swapping instead.
            send_delta_data: Private API. If used, scheduler sends delta data to
                workers instead of an entire data. It should be enabled only
                when SPMD worker architecture is enabled. I.e.,
                VLLM_USE_RAY_SPMD_WORKER=1
            policy: The scheduling policy to use. "fcfs" (default) or "priority".
        """
    
        def __init__(self,
                     max_num_batched_tokens: Optional[int],
                     max_num_seqs: int,
                     max_model_len: int,
                     use_v2_block_manager: bool = False,
                     num_lookahead_slots: int = 0,
                     delay_factor: float = 0.0,
                     enable_chunked_prefill: bool = False,
                     embedding_mode: bool = False,
                     is_multimodal_model: bool = False,
                     preemption_mode: Optional[str] = None,
                     num_scheduler_steps: int = 1,
                     multi_step_stream_outputs: bool = False,
                     send_delta_data: bool = False,
                     policy: str = "fcfs") -> None:
            if max_num_batched_tokens is None:
                if enable_chunked_prefill:
                    # It is the values that have the best balance between ITL
                    # and TTFT on A100. Note it is not optimized for throughput.
                    max_num_batched_tokens = 512
                else:
                    # If max_model_len is too short, use 2048 as the default value
                    # for higher throughput.
                    max_num_batched_tokens = max(max_model_len, 2048)
    
                if embedding_mode:
                    # For embedding, choose specific value for higher throughput
                    max_num_batched_tokens = max(
                        max_num_batched_tokens,
                        _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
                    )
                if is_multimodal_model:
                    # The value needs to be at least the number of multimodal tokens
                    max_num_batched_tokens = max(
                        max_num_batched_tokens,
                        _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
                    )
    
            self.max_num_batched_tokens = max_num_batched_tokens
    
            if enable_chunked_prefill:
                logger.info(
                    "Chunked prefill is enabled with max_num_batched_tokens=%d.",
                    self.max_num_batched_tokens)
    
            self.max_num_seqs = max_num_seqs
            self.max_model_len = max_model_len
            self.use_v2_block_manager = use_v2_block_manager
            self.num_lookahead_slots = num_lookahead_slots
            self.delay_factor = delay_factor
            self.chunked_prefill_enabled = enable_chunked_prefill
            self.embedding_mode = embedding_mode
            self.preemption_mode = preemption_mode
            self.num_scheduler_steps = num_scheduler_steps
            self.multi_step_stream_outputs = multi_step_stream_outputs
            self.send_delta_data = send_delta_data
            self.policy = policy
            self._verify_args()
    
        @property
        def is_multi_step(self) -> bool:
            return self.num_scheduler_steps > 1
    
        def _verify_args(self) -> None:
            if (self.max_num_batched_tokens < self.max_model_len
                    and not self.chunked_prefill_enabled):
                raise ValueError(
                    f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                    f"smaller than max_model_len ({self.max_model_len}). "
                    "This effectively limits the maximum sequence length to "
                    "max_num_batched_tokens and makes vLLM reject longer "
                    "sequences. Please increase max_num_batched_tokens or "
                    "decrease max_model_len.")
    
            if self.max_num_batched_tokens < self.max_num_seqs:
                raise ValueError(
                    f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                    "be greater than or equal to max_num_seqs "
                    f"({self.max_num_seqs}).")
    
            if self.num_lookahead_slots < 0:
                raise ValueError(
                    "num_lookahead_slots "
                    f"({self.num_lookahead_slots}) must be greater than or "
                    "equal to 0.")
    
            if self.num_scheduler_steps < 1:
                raise ValueError(
                    "num_scheduler_steps "
                    f"({self.num_scheduler_steps}) must be greater than or "
                    "equal to 1.")
    
    
    class DeviceConfig:
        device: Optional[torch.device]
    
        def __init__(self, device: str = "auto") -> None:
            if device == "auto":
                # Automated device type detection
                if current_platform.is_cuda_alike():
                    self.device_type = "cuda"
                elif is_neuron():
                    self.device_type = "neuron"
                elif is_openvino():
                    self.device_type = "openvino"
                elif current_platform.is_tpu():
                    self.device_type = "tpu"
                elif current_platform.is_cpu():
                    self.device_type = "cpu"
                elif is_xpu():
                    self.device_type = "xpu"
                elif is_npu():
                    self.device_type = "npu"
                else:
                    raise RuntimeError("Failed to infer device type")
            else:
                # Device type is assigned explicitly
                self.device_type = device
    
            # Some device types require processing inputs on CPU
            if self.device_type in ["neuron", "openvino"]:
                self.device = torch.device("cpu")
            elif self.device_type in ["tpu"]:
                self.device = None
            else:
                # Set device with device type
                self.device = torch.device(self.device_type)
    
    
    class SpeculativeConfig:
        """Configuration for speculative decoding.
    
        The configuration is currently specialized to draft-model speculative
        decoding with top-1 proposals.
        """
    
        def __init__(
            self,
            draft_model_config: ModelConfig,
            draft_parallel_config: ParallelConfig,
            num_speculative_tokens: int,
            speculative_disable_by_batch_size: Optional[int],
            ngram_prompt_lookup_max: Optional[int],
            ngram_prompt_lookup_min: Optional[int],
            draft_token_acceptance_method: str,
            typical_acceptance_sampler_posterior_threshold: float,
            typical_acceptance_sampler_posterior_alpha: float,
            disable_logprobs: bool,
            disable_log_stats: bool,
        ):
            """Create a SpeculativeConfig object.
    
            Args:
                draft_model_config: ModelConfig for the draft model.
                draft_parallel_config: ParallelConfig for the draft model.
                num_speculative_tokens: The number of tokens to sample from the
                    draft model before scoring with the target model.
                speculative_disable_by_batch_size: Disable speculative
                    decoding for new incoming requests when the number of
                    enqueue requests is larger than this value.
                ngram_prompt_lookup_max: Max size of ngram token window.
                ngram_prompt_lookup_min: Min size of ngram token window.
                draft_token_acceptance_method (str): The method to use for
                    accepting draft tokens. This can take two possible
                    values 'rejection_sampler' and 'typical_acceptance_sampler'
                    for RejectionSampler and TypicalAcceptanceSampler
                    respectively.
                typical_acceptance_sampler_posterior_threshold (Optional[float]):
                    A threshold value that sets a lower bound on the posterior
                    probability of a token in the target model for it to be
                    accepted. This threshold is used only when we use the 
                    TypicalAcceptanceSampler for token acceptance.
                typical_acceptance_sampler_posterior_alpha (Optional[float]):
                    A scaling factor for the entropy-based threshold in the
                    TypicalAcceptanceSampler.
                disable_logprobs: If set to True, token log probabilities will not
                    be returned even if requested by sampling parameters. This 
                    reduces latency by skipping logprob calculation in proposal
                    sampling, target sampling, and after accepted tokens are
                    determined. If set to False, log probabilities will be
                    returned.
                disable_log_stats: Whether to disable periodic printing of stage
                    times in speculative decoding.
            """
            self.draft_model_config = draft_model_config
            self.draft_parallel_config = draft_parallel_config
            self.num_speculative_tokens = num_speculative_tokens
            self.speculative_disable_by_batch_size = \
                speculative_disable_by_batch_size
            self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
            self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
            self.draft_token_acceptance_method = draft_token_acceptance_method
            self.typical_acceptance_sampler_posterior_threshold = \
                typical_acceptance_sampler_posterior_threshold
            self.typical_acceptance_sampler_posterior_alpha = \
                typical_acceptance_sampler_posterior_alpha
            self.disable_logprobs = disable_logprobs
            self.disable_log_stats = disable_log_stats
    
            self._verify_args()
    
        def __repr__(self) -> str:
            if self.ngram_prompt_lookup_max > 0:
                draft_model = "[ngram]"
            else:
                draft_model = self.draft_model_config.model
            num_spec_tokens = self.num_speculative_tokens
            return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
    
        @property
        def num_lookahead_slots(self) -> int:
            """The number of additional slots the scheduler should allocate per
            step, in addition to the slots allocated for each known token.
    
            This is equal to the number of speculative tokens, as each speculative
            token must be scored.
            """
            return self.num_speculative_tokens
    
        @staticmethod
        def maybe_create_spec_config(
            target_model_config: ModelConfig,
            target_parallel_config: ParallelConfig,
            target_dtype: str,
            speculative_model: Optional[str],
            speculative_model_quantization: Optional[str],
            speculative_draft_tensor_parallel_size: Optional[int],
            num_speculative_tokens: Optional[int],
            speculative_max_model_len: Optional[int],
            enable_chunked_prefill: bool,
            use_v2_block_manager: bool,
            disable_log_stats: bool,
            speculative_disable_by_batch_size: Optional[int],
            ngram_prompt_lookup_max: Optional[int],
            ngram_prompt_lookup_min: Optional[int],
            draft_token_acceptance_method: str,
            typical_acceptance_sampler_posterior_threshold: Optional[float],
            typical_acceptance_sampler_posterior_alpha: Optional[float],
            disable_logprobs: Optional[bool],
        ) -> Optional["SpeculativeConfig"]:
            """Create a SpeculativeConfig if possible, else return None.
    
            This function attempts to create a SpeculativeConfig object based on the
            provided parameters. If the necessary conditions are met, it returns an
            instance of SpeculativeConfig. Otherwise, it returns None.
    
            Args:
                target_model_config (ModelConfig): The configuration of the target
                    model.
                target_parallel_config (ParallelConfig): The parallel configuration
                    for the target model.
                target_dtype (str): The data type used for the target model.
                speculative_model (Optional[str]): The name of the speculative
                    model, if provided.
                speculative_model_quantization (Optional[str]): Quantization method
                    that was used to quantize the speculative model weights. If
                    None, we assume the model weights are not quantized.
                speculative_draft_tensor_parallel_size (Optional[int]): The degree
                    of the tensor parallelism for the draft model.
                num_speculative_tokens (Optional[int]): The number of speculative
                    tokens, if provided. Will default to the number in the draft
                    model config if present, otherwise is required.
                speculative_max_model_len (Optional[int]): The maximum model len of
                    the speculative model. Used when testing the ability to skip
                    speculation for some sequences.
                enable_chunked_prefill (bool): Whether vLLM is configured to use
                    chunked prefill or not. Used for raising an error since its not
                    yet compatible with spec decode.
                use_v2_block_manager (bool): Whether vLLM is configured to use the
                    v2 block manager or not. Used for raising an error since the v2
                    block manager is required with spec decode.
                speculative_disable_by_batch_size (Optional[int]): Disable
                    speculative decoding for new incoming requests when the number
                    of enqueue requests  is larger than this value, if provided.
                ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
                    window, if provided.
                ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
                    window, if provided.
                draft_token_acceptance_method (str): The method to use for
                    accepting draft tokens. This can take two possible
                    values 'rejection_sampler' and 'typical_acceptance_sampler'
                    for RejectionSampler and TypicalAcceptanceSampler
                    respectively.
                typical_acceptance_sampler_posterior_threshold (Optional[float]):
                    A threshold value that sets a lower bound on the posterior
                    probability of a token in the target model for it to be
                    accepted. This threshold is used only when we use the 
                    TypicalAcceptanceSampler for token acceptance.
                typical_acceptance_sampler_posterior_alpha (Optional[float]):
                    A scaling factor for the entropy-based threshold in the
                    TypicalAcceptanceSampler.
                disable_logprobs (Optional[bool]): If set to True, token log
                    probabilities are not returned during speculative decoding.
                    If set to False, token log probabilities are returned
                    according to the log probability settings in SamplingParams.
                    If not specified, it defaults to True.
    
            Returns:
                Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
                    the necessary conditions are met, else None.
            """
    
            if speculative_model is None:
                if num_speculative_tokens is not None:
                    raise ValueError("num_speculative_tokens was provided without "
                                     "speculative_model.")
                return None
    
            if (speculative_disable_by_batch_size is not None
                    and speculative_disable_by_batch_size < 2):
                raise ValueError("Expect the batch size threshold of disabling "
                                 "speculative decoding is > 1, but got "
                                 f"{speculative_disable_by_batch_size=}")
    
            if enable_chunked_prefill:
                raise ValueError(
                    "Speculative decoding and chunked prefill are "
                    f"currently mutually exclusive ({enable_chunked_prefill=}).")
    
            if not use_v2_block_manager:
                raise ValueError(
                    "Speculative decoding requires usage of the V2 "
                    "block manager. Enable it with --use-v2-block-manager.")
    
            # TODO: The user should be able to specify revision/max model len
            # for the draft model. It is not currently supported.
            draft_revision = None
            draft_code_revision = None
            draft_quantization = speculative_model_quantization
    
            if speculative_model == "[ngram]":
                if ngram_prompt_lookup_min is None:
                    ngram_prompt_lookup_min = 1
                if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
                    raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
                if ngram_prompt_lookup_min < 1:
                    raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
                if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
                    raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
                                     f"larger than {ngram_prompt_lookup_max=}")
    
                # TODO: current we still need extract vocab_size from target model
                # config, in future, we may try refactor it out, and set
                # draft related config as None here.
                draft_model_config = target_model_config
                draft_parallel_config = target_parallel_config
            else:
                ngram_prompt_lookup_max = 0
                ngram_prompt_lookup_min = 0
                draft_model_config = ModelConfig(
                    model=speculative_model,
                    tokenizer=target_model_config.tokenizer,
                    tokenizer_mode=target_model_config.tokenizer_mode,
                    trust_remote_code=target_model_config.trust_remote_code,
                    dtype=target_model_config.dtype,
                    seed=target_model_config.seed,
                    revision=draft_revision,
                    code_revision=draft_code_revision,
                    tokenizer_revision=target_model_config.tokenizer_revision,
                    max_model_len=None,
                    spec_target_max_model_len=target_model_config.max_model_len,
                    quantization=draft_quantization,
                    enforce_eager=target_model_config.enforce_eager,
                    max_seq_len_to_capture=target_model_config.
                    max_seq_len_to_capture,
                    max_logprobs=target_model_config.max_logprobs,
                )
    
                draft_hf_config = draft_model_config.hf_config
    
                if (num_speculative_tokens is not None
                        and hasattr(draft_hf_config, "num_lookahead_tokens")):
                    draft_hf_config.num_lookahead_tokens = num_speculative_tokens
    
                n_predict = getattr(draft_hf_config, "n_predict", None)
                if n_predict is not None:
                    if num_speculative_tokens is None:
                        # Default to max value defined in draft model config.
                        num_speculative_tokens = n_predict
                    elif num_speculative_tokens > n_predict:
                        # Verify provided value doesn't exceed the maximum
                        # supported by the draft model.
                        raise ValueError(
                            "This speculative model supports a maximum of "
                            f"num_speculative_tokens={n_predict}, but "
                            f"{num_speculative_tokens=} was provided.")
    
                draft_model_config.max_model_len = (
                    SpeculativeConfig._maybe_override_draft_max_model_len(
                        speculative_max_model_len,
                        draft_model_config.max_model_len,
                        target_model_config.max_model_len,
                    ))
    
                draft_parallel_config = (
                    SpeculativeConfig.create_draft_parallel_config(
                        target_parallel_config,
                        speculative_draft_tensor_parallel_size, draft_hf_config))
    
            if num_speculative_tokens is None:
                raise ValueError(
                    "num_speculative_tokens must be provided with "
                    "speculative_model unless the draft model config contains an "
                    "n_predict parameter.")
    
            if typical_acceptance_sampler_posterior_threshold is None:
                typical_acceptance_sampler_posterior_threshold = 0.09
            if typical_acceptance_sampler_posterior_alpha is None:
                typical_acceptance_sampler_posterior_alpha = 0.3
            if disable_logprobs is None:
                disable_logprobs = True
    
            return SpeculativeConfig(
                draft_model_config,
                draft_parallel_config,
                num_speculative_tokens,
                speculative_disable_by_batch_size,
                ngram_prompt_lookup_max,
                ngram_prompt_lookup_min,
                draft_token_acceptance_method=draft_token_acceptance_method,
                typical_acceptance_sampler_posterior_threshold=\
                    typical_acceptance_sampler_posterior_threshold,
                typical_acceptance_sampler_posterior_alpha=\
                    typical_acceptance_sampler_posterior_alpha,
                disable_logprobs=disable_logprobs,
                disable_log_stats=disable_log_stats,
            )
    
        @staticmethod
        def _maybe_override_draft_max_model_len(
            speculative_max_model_len: Optional[int],
            draft_max_model_len: int,
            target_max_model_len: int,
        ) -> int:
            """Determine the max sequence len for the draft model. This is usually
            the draft_max_model_len, but may be the target_max_model_len if it is
            less than the draft_max_model_len, or may be speculative_max_model_len
            if it is specified.
    
            This is necessary so that sequences do not exceed the capacity of the
            draft model or the target model.
    
            speculative_max_model_len is mainly used for testing that sequences can
            skip speculation.
            """
    
            if speculative_max_model_len is not None:
    
                if speculative_max_model_len > draft_max_model_len:
                    raise ValueError(f"{speculative_max_model_len=} cannot be "
                                     f"larger than {draft_max_model_len=}")
    
                if speculative_max_model_len > target_max_model_len:
                    raise ValueError(f"{speculative_max_model_len=} cannot be "
                                     f"larger than {target_max_model_len=}")
    
                return speculative_max_model_len
    
            return min(
                draft_max_model_len,
                target_max_model_len,
            )
    
        @staticmethod
        def create_draft_parallel_config(
            target_parallel_config: ParallelConfig,
            speculative_draft_tensor_parallel_size: Optional[int],
            draft_hf_config: PretrainedConfig,
        ) -> ParallelConfig:
            """Create a parallel config for use by the draft worker.
    
            This is mostly a copy of the target parallel config, except the tp_size.
            """
            if speculative_draft_tensor_parallel_size is None:
                if draft_hf_config.model_type == "mlp_speculator":
                    speculative_draft_tensor_parallel_size = 1
                    if target_parallel_config.tensor_parallel_size > 1:
                        logger.warning(
                            "MLPSpeculator cannot currently be run with tp>1; "
                            "setting speculative_draft_tensor_parallel_size=1")
                else:
                    speculative_draft_tensor_parallel_size = \
                        target_parallel_config.tensor_parallel_size
            elif speculative_draft_tensor_parallel_size != 1:
                # TODO(wooyeon): allow tp values larger than 1
                raise ValueError(
                    f"{speculative_draft_tensor_parallel_size=} cannot be "
                    f"other value than 1")
    
            draft_parallel_config = ParallelConfig(
                pipeline_parallel_size=target_parallel_config.
                pipeline_parallel_size,
                tensor_parallel_size=speculative_draft_tensor_parallel_size,
                distributed_executor_backend=target_parallel_config.
                distributed_executor_backend,
                max_parallel_loading_workers=target_parallel_config.
                max_parallel_loading_workers,
                disable_custom_all_reduce=target_parallel_config.
                disable_custom_all_reduce,
                tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
                ray_workers_use_nsight=target_parallel_config.
                ray_workers_use_nsight,
                placement_group=target_parallel_config.placement_group,
            )
    
            return draft_parallel_config
    
        def _verify_args(self) -> None:
            if self.num_speculative_tokens <= 0:
                raise ValueError("Expected num_speculative_tokens to be greater "
                                 f"than zero ({self.num_speculative_tokens}).")
    
            if self.draft_model_config:
                self.draft_model_config.verify_with_parallel_config(
                    self.draft_parallel_config)
                # Validate and set draft token acceptance related settings.
    
            if (self.draft_token_acceptance_method is None):
                raise ValueError("draft_token_acceptance_method is not set. "
                                 "Expected values are rejection_sampler or "
                                 "typical_acceptance_sampler.")
    
            if (self.draft_token_acceptance_method != 'rejection_sampler'
                    and self.draft_token_acceptance_method !=
                    'typical_acceptance_sampler'):
                raise ValueError(
                    "Expected draft_token_acceptance_method to be either "
                    "rejection_sampler or typical_acceptance_sampler. Instead it "
                    f"is {self.draft_token_acceptance_method}")
    
            if (self.typical_acceptance_sampler_posterior_threshold < 0
                    or self.typical_acceptance_sampler_posterior_alpha < 0):
                raise ValueError(
                    "Expected typical_acceptance_sampler_posterior_threshold "
                    "and typical_acceptance_sampler_posterior_alpha to be > 0. "
                    "Instead found "
                    f"typical_acceptance_sampler_posterior_threshold = "
                    f"{self.typical_acceptance_sampler_posterior_threshold} and "
                    f"typical_acceptance_sampler_posterior_alpha = "
                    f"{self.typical_acceptance_sampler_posterior_alpha}")
    
    
    @dataclass
    class LoRAConfig:
        max_lora_rank: int
        max_loras: int
        fully_sharded_loras: bool = False
        max_cpu_loras: Optional[int] = None
        lora_dtype: Optional[torch.dtype] = None
        lora_extra_vocab_size: int = 256
        # This is a constant.
        lora_vocab_padding_size: ClassVar[int] = 256
        long_lora_scaling_factors: Optional[Tuple[float]] = None
    
        def __post_init__(self):
            # Setting the maximum rank to 256 should be able to satisfy the vast
            # majority of applications.
            possible_max_ranks = (8, 16, 32, 64, 128, 256)
            possible_lora_extra_vocab_size = (0, 256, 512)
            if self.max_lora_rank not in possible_max_ranks:
                raise ValueError(
                    f"max_lora_rank ({self.max_lora_rank}) must be one of "
                    f"{possible_max_ranks}.")
            if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
                raise ValueError(
                    f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                    f"must be one of {possible_lora_extra_vocab_size}.")
            if self.max_loras < 1:
                raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
            if self.max_cpu_loras is None:
                self.max_cpu_loras = self.max_loras
            elif self.max_cpu_loras < self.max_loras:
                raise ValueError(
                    f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
                    f"max_loras ({self.max_loras})")
    
        def verify_with_model_config(self, model_config: ModelConfig):
            if self.lora_dtype in (None, "auto"):
                self.lora_dtype = model_config.dtype
            elif isinstance(self.lora_dtype, str):
                self.lora_dtype = getattr(torch, self.lora_dtype)
            if model_config.quantization and model_config.quantization not in [
                    "awq", "gptq"
            ]:
                # TODO support marlin
                logger.warning("%s quantization is not tested with LoRA yet.",
                               model_config.quantization)
    
        def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
            if scheduler_config.chunked_prefill_enabled:
                raise ValueError("LoRA is not supported with chunked prefill yet.")
    
    
    @dataclass
    class PromptAdapterConfig:
        max_prompt_adapters: int
        max_prompt_adapter_token: int
        max_cpu_prompt_adapters: Optional[int] = None
        prompt_adapter_dtype: Optional[torch.dtype] = None
    
        def __post_init__(self):
    
            if self.max_prompt_adapters < 1:
                raise ValueError(f"max_prompt_adapters "
                                 f"({self.max_prompt_adapters}) must be >= 1.")
            if self.max_prompt_adapter_token == 0:
                raise ValueError("max_prompt_adapter_token must be set.")
            if self.max_cpu_prompt_adapters is None:
                self.max_cpu_prompt_adapters = self.max_prompt_adapters
    
        def verify_with_model_config(self, model_config: ModelConfig):
            if self.prompt_adapter_dtype in (None, "auto"):
                self.prompt_adapter_dtype = model_config.dtype
            elif isinstance(self.prompt_adapter_dtype, str):
                self.prompt_adapter_dtype = getattr(torch,
                                                    self.prompt_adapter_dtype)
    
    
    @dataclass
    class MultiModalConfig:
        """Controls the behavior of multimodal models."""
    
        limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
        """
        The maximum number of multi-modal input instances allowed per prompt
        for each :class:`~vllm.multimodal.MultiModalPlugin`.
        """
    
        # TODO: Add configs to init vision tower or not.
    
    
    _STR_DTYPE_TO_TORCH_DTYPE = {
        "half": torch.float16,
        "float16": torch.float16,
        "float": torch.float32,
        "float32": torch.float32,
        "bfloat16": torch.bfloat16,
    }
    
    _ROCM_NOT_SUPPORTED_DTYPE: List[str] = []  #
    
    
    def _get_and_verify_dtype(
        config: PretrainedConfig,
        dtype: Union[str, torch.dtype],
    ) -> torch.dtype:
        # NOTE: `getattr(config, "torch_dtype", torch.float32)` is not correct
        # because config.torch_dtype can be None.
        config_dtype = getattr(config, "torch_dtype", None)
        if config_dtype is None:
            config_dtype = torch.float32
    
        if isinstance(dtype, str):
            dtype = dtype.lower()
            if dtype == "auto":
                if config_dtype == torch.float32:
                    if config.model_type == "gemma2":
                        logger.info(
                            "For Gemma 2, we downcast float32 to bfloat16 instead "
                            "of float16 by default. Please specify `dtype` if you "
                            "want to use float16.")
                        torch_dtype = torch.bfloat16
                    else:
                        # Following the common practice, we use float16 for float32
                        # models.
                        torch_dtype = torch.float16
                else:
                    torch_dtype = config_dtype
            else:
                if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                    raise ValueError(f"Unknown dtype: {dtype}")
                torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
        elif isinstance(dtype, torch.dtype):
            torch_dtype = dtype
        else:
            raise ValueError(f"Unknown dtype: {dtype}")
    
        # Verify the dtype.
        if torch_dtype != config_dtype:
            if torch_dtype == torch.float32:
                # Upcasting to float32 is allowed.
                logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
                pass
            elif config_dtype == torch.float32:
                # Downcasting from float32 to float16 or bfloat16 is allowed.
                logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
                pass
            else:
                # Casting between float16 and bfloat16 is allowed with a warning.
                logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
    
        return torch_dtype
    
    
    def _get_and_verify_max_len(
        hf_config: PretrainedConfig,
        max_model_len: Optional[int],
        disable_sliding_window: bool,
        sliding_window_len: Optional[int],
        spec_target_max_model_len: Optional[int] = None,
    ) -> int:
        """Get and verify the model's maximum length."""
        derived_max_model_len = float("inf")
        possible_keys = [
            # OPT
            "max_position_embeddings",
            # GPT-2
            "n_positions",
            # MPT
            "max_seq_len",
            # ChatGLM2
            "seq_length",
            # Command-R
            "model_max_length",
            # Others
            "max_sequence_length",
            "max_seq_length",
            "seq_len",
        ]
        # Choose the smallest "max_length" from the possible keys.
        max_len_key = None
        for key in possible_keys:
            max_len = getattr(hf_config, key, None)
            if max_len is not None:
                max_len_key = key if max_len < derived_max_model_len \
                    else max_len_key
                derived_max_model_len = min(derived_max_model_len, max_len)
    
        # If sliding window is manually disabled, max_length should be less
        # than the sliding window length in the model config.
        if disable_sliding_window and sliding_window_len is not None:
            max_len_key = "sliding_window" \
                if sliding_window_len < derived_max_model_len else max_len_key
            derived_max_model_len = min(derived_max_model_len, sliding_window_len)
    
        # If none of the keys were found in the config, use a default and
        # log a warning.
        if derived_max_model_len == float("inf"):
            if max_model_len is not None:
                # If max_model_len is specified, we use it.
                return max_model_len
    
            if spec_target_max_model_len is not None:
                # If this is a speculative draft model, we use the max model len
                # from the target model.
                return spec_target_max_model_len
    
            default_max_len = 2048
            logger.warning(
                "The model's config.json does not contain any of the following "
                "keys to determine the original maximum length of the model: "
                "%s. Assuming the model's maximum length is %d.", possible_keys,
                default_max_len)
            derived_max_model_len = default_max_len
    
        rope_scaling = getattr(hf_config, "rope_scaling", None)
        if rope_scaling is not None:
            if "type" in rope_scaling:
                rope_type = rope_scaling["type"]
            elif "rope_type" in rope_scaling:
                rope_type = rope_scaling["rope_type"]
            else:
                raise ValueError(
                    "rope_scaling must have a 'type' or 'rope_type' key.")
    
            # The correct one should be "longrope", kept "su" here
            # to be backward compatible
            if rope_type not in ("su", "longrope", "llama3"):
                if disable_sliding_window:
                    # TODO(robertgshaw): Find a model that supports rope_scaling
                    # with sliding window to see if this case should be allowed.
                    raise NotImplementedError(
                        "Disabling sliding window is not supported for models "
                        "with rope_scaling. Please raise an issue so we can "
                        "investigate.")
    
                if rope_type == "mrope":
                    scaling_factor = 1
                else:
                    assert "factor" in rope_scaling
                    scaling_factor = rope_scaling["factor"]
                if rope_type == "yarn":
                    derived_max_model_len = rope_scaling[
                        "original_max_position_embeddings"]
                derived_max_model_len *= scaling_factor
    
        # If the user specified a max length, make sure it is smaller than the
        # derived length from the HF model config.
        if max_model_len is None:
            max_model_len = int(derived_max_model_len)
        elif max_model_len > derived_max_model_len:
            # Some models might have a separate key for specifying model_max_length
            # that will be bigger than derived_max_model_len. We compare user input
            # with model_max_length and allow this override when it's smaller.
            model_max_length = getattr(hf_config, "model_max_length", None)
            if model_max_length is not None and max_model_len <= model_max_length:
                if disable_sliding_window:
                    # TODO(robertgshaw): Find a model that has model_max_length
                    # with sliding window to see if this case should be allowed.
                    raise NotImplementedError(
                        "Disabling sliding window is not supported for models "
                        "model_max_length in the config. Please raise an issue "
                        "so we can investigate.")
            else:
                msg = (
                    f"User-specified max_model_len ({max_model_len}) is greater "
                    f"than the derived max_model_len ({max_len_key}="
                    f"{derived_max_model_len} or model_max_length="
                    f"{model_max_length} in model's config.json). This may lead "
                    "to incorrect model outputs or CUDA errors.")
                if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN:
                    logger.warning(
                        "%s Make sure the value is correct and within the "
                        "model context size.", msg)
                else:
                    raise ValueError(
                        f"{msg} To allow overriding this maximum, set "
                        "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")
        return int(max_model_len)
    
    
    def get_served_model_name(model: str,
                              served_model_name: Optional[Union[str, List[str]]]):
        """
        If the input is a non-empty list, the first model_name in 
        `served_model_name` is taken. 
        If the input is a non-empty string, it is used directly. 
        For cases where the input is either an empty string or an 
        empty list, the fallback is to use `self.model`.
        """
        if not served_model_name:
            return model
        if isinstance(served_model_name, list):
            return served_model_name[0]
        return served_model_name
    
    
    @dataclass
    class DecodingConfig:
        """Dataclass which contains the decoding strategy of the engine"""
    
        # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
        guided_decoding_backend: str = 'outlines'
    
        def __post_init__(self):
            valid_guided_backends = ['outlines', 'lm-format-enforcer']
            backend = self.guided_decoding_backend
            if backend not in valid_guided_backends:
                raise ValueError(f"Invalid guided_decoding_backend '{backend},"
                                 f"must be one of {valid_guided_backends}")
    
    
    @dataclass
    class ObservabilityConfig:
        """Configuration for observability."""
        otlp_traces_endpoint: Optional[str] = None
    
        # Collecting detailed timing information for each request can be expensive.
    
        # If set, collects the model forward time for the request.
        collect_model_forward_time: bool = False
    
        # If set, collects the model execute time for the request.
        collect_model_execute_time: bool = False
    
        def __post_init__(self):
            if not is_otel_available() and self.otlp_traces_endpoint is not None:
                raise ValueError(
                    "OpenTelemetry is not available. Unable to configure "
                    "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
                    f"installed. Original error:\n{otel_import_error_traceback}")
    
            if ((self.collect_model_forward_time
                 or self.collect_model_execute_time)
                    and self.otlp_traces_endpoint is None):
                raise ValueError(
                    "collect_model_forward_time or collect_model_execute_time "
                    "requires --otlp-traces-endpoint to be set.")
    
    
    @dataclass(frozen=True)
    class EngineConfig:
        """Dataclass which contains all engine-related configuration. This
        simplifies passing around the distinct configurations in the codebase.
        """
    
        model_config: ModelConfig
        cache_config: CacheConfig
        parallel_config: ParallelConfig
        scheduler_config: SchedulerConfig
        device_config: DeviceConfig
        load_config: LoadConfig
        lora_config: Optional[LoRAConfig]
        speculative_config: Optional[SpeculativeConfig]
        decoding_config: Optional[DecodingConfig]
        observability_config: Optional[ObservabilityConfig]
        prompt_adapter_config: Optional[PromptAdapterConfig]
    
        def __post_init__(self):
            """Verify configs are valid & consistent with each other.
            """
            self.model_config.verify_async_output_proc(self.parallel_config,
                                                       self.speculative_config,
                                                       self.device_config)
            self.model_config.verify_with_parallel_config(self.parallel_config)
            self.cache_config.verify_with_parallel_config(self.parallel_config)
    
            if self.lora_config:
                self.lora_config.verify_with_model_config(self.model_config)
                self.lora_config.verify_with_scheduler_config(
                    self.scheduler_config)
            if self.prompt_adapter_config:
                self.prompt_adapter_config.verify_with_model_config(
                    self.model_config)
    
        def to_dict(self):
            """Return the configs as a dictionary, for use in **kwargs.
            """
            return dict(
                (field.name, getattr(self, field.name)) for field in fields(self))
    
  • cover/vllm/distributed/npu_utils.py:基于vLLM 0.6.2版本实现了多卡场景下广播优化机制。
    from typing import Any, Dict, Optional, Union
    from torch.distributed import ProcessGroup
    import torch
    
    
    def get_dimension_and_size(x):
        if x is not None:
            return len(x), list(x)
        else:
            return 0, []
    
    
    def get_true_or_false(x):
        if x:
            return 1, [1]
        else:
            return 1, [0]
    
    
    def get_dimension_and_size_of_single_value(x):
        if x is not None:
            return 1, [int(x)]
        else:
            return 0, []
    
    
    def get_size_or_none(x: Optional[torch.Tensor]):
        return x.size() if x is not None else None
    
    
    bool_keys = ["use_cuda_graph"]
    single_value_keys = [
        "num_seq_groups",
        "virtual_engine",
        "num_steps",
        "num_prefill_tokens",
        "num_decode_tokens",
        "num_prefills",
        "max_query_len",
        "max_seq_len",
        "max_prefill_seq_len",
        "max_decode_seq_len",
    ]
    tensor_keys = [
        "input_tokens",
        "input_positions",
        "selected_token_indices",
        "slot_mapping",
        "seq_lens_tensor",
        "query_start_loc",
        "seq_start_loc",
        "context_lens_tensor",
        "block_tables",
        "blocks_to_swap_in",
        "blocks_to_swap_out",
        "blocks_to_copy",
    ]
    other_data_keys = [
        "lora_requests",
        "lora_mapping",
        "multi_modal_kwargs",
        "prompt_adapter_requests",
        "prompt_adapter_mapping",
        "request_ids_to_seq_ids",
        "finished_requests_ids",
        "_cached_prefill_metadata",
        "_cached_decode_metadata",
    ]
    metadata_keys = tensor_keys + bool_keys + single_value_keys + ["seq_lens"]
    total_key_num = (
        len(metadata_keys) - 1
    )  # seq_lens can be obtain through seq_lens_tensor thus doesn't needed to be broadcast
    total_size_data_num = 50
    
    
    def broadcast(input_: torch.Tensor, src: int = 0, group: Optional[ProcessGroup] = None):
        """Broadcast the input tensor."""
        group = group or torch.distributed.group.WORLD
        ranks = torch.distributed.get_process_group_ranks(group)
        assert src in ranks, f"Invalid src rank ({src})"
    
        # Bypass the function if we are using only 1 GPU.
        world_size = torch.distributed.get_world_size(group=group)
        if world_size == 1:
            return input_
        # Broadcast.
        torch.distributed.broadcast(input_, src=src, group=group)
        return input_
    
    
    def prepare_dim_and_size_tensor(
        data_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]]
    ):
        dim_list = []
        size_list = []
        for key in metadata_keys:
            data = data_dict[key]
            if key in bool_keys:
                dim, size = get_true_or_false(data)
            elif key in single_value_keys:
                dim, size = get_dimension_and_size_of_single_value(data)
            elif key == "seq_lens":
                continue
            else:
                data_size = get_size_or_none(data)
                dim, size = get_dimension_and_size(data_size)
            dim_list.append(dim)
            size_list.extend(size)
        assert len(dim_list) == total_key_num, "the length of dim_list is wrong"
        dim_and_size_list = dim_list + size_list
        if len(dim_and_size_list) < total_size_data_num:
            dim_and_size_list += [-1] * (total_size_data_num - len(dim_and_size_list))
        dim_and_size_tensor = torch.tensor(dim_and_size_list, dtype=torch.int, device="npu")
        return dim_and_size_tensor
    
    
    def concat_tensor_data(data_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]]):
        concat_data_list = []
        for key in tensor_keys:
            data = data_dict[key]
            if data is not None:
                concat_data_list.extend(data.view(-1).tolist())
        concat_data_tensor = torch.tensor(concat_data_list, dtype=torch.int, device="npu")
        return concat_data_tensor
    
    
    def get_sizedata_and_singlevalues_from_total(dim_and_size_tensor: torch.tensor):
        dim_list = dim_and_size_tensor[:total_key_num].tolist()
        size_list = dim_and_size_tensor[total_key_num:].tolist()
        dim_idx = 0
        idx = 0
        size_dict = {}
        single_value_dict = {}
        for key in metadata_keys:
            if key in bool_keys:
                bool_data = True if size_list[idx] == 1 else False
                single_value_dict[key] = bool_data
            elif key in single_value_keys:
                single_value_data = size_list[idx] if dim_list[dim_idx] > 0 else None
                single_value_dict[key] = single_value_data
            elif key == "seq_lens":
                continue
            else:
                size_data = (
                    torch.Size(size_list[idx : idx + dim_list[dim_idx]])
                    if dim_list[dim_idx] > 0
                    else None
                )
                size_dict[key] = size_data
            idx += dim_list[dim_idx]
            dim_idx += 1
    
        return size_dict, single_value_dict
    
    
    def construct_empty_concat_tensor(size_dict):
        total_element_num = 0
        for key in tensor_keys:
            if not (key in size_dict):
                raise ValueError(f"missing key {key} in reveiced size data")
            if size_dict[key]:
                total_element_num += size_dict[key].numel()
        return torch.empty(total_element_num, dtype=torch.int, device="npu")
    
    
    def get_tensor_dict_from_concat_tensor(concat_tensor: torch.tensor, size_dict):
        tensor_dict = {}
        idx = 0
        for key in tensor_keys:
            data_size = size_dict[key]
            if data_size is not None:
                tensor_dict[key] = concat_tensor[idx : idx + data_size.numel()].view(
                    *data_size
                )
                idx += data_size.numel()
            else:
                tensor_dict[key] = None
        return tensor_dict
    
    
    def ascend_broadcast_data_dict(
        data_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
        src: int = 0,
    ):
        group = torch.distributed.group.WORLD
        world_size = torch.distributed.get_world_size(group=group)
        if world_size == 1:
            return data_dict
    
        rank = torch.distributed.get_rank()
    
        if rank == src:
            other_data_list = []
            pure_data_dict = {}
            for k, v in data_dict.items():
                if k in other_data_keys:
                    other_data_list.append((k, v))
                else:
                    pure_data_dict[k] = v
            torch.distributed.broadcast_object_list([other_data_list], src=src)
            dim_and_size_tensor = prepare_dim_and_size_tensor(pure_data_dict)
            handle1 = torch.distributed.broadcast(
                dim_and_size_tensor, src=src, group=group, async_op=True
            )
            concat_tensor = concat_tensor_data(pure_data_dict)
            handle2 = torch.distributed.broadcast(
                concat_tensor, src=src, group=group, async_op=True
            )
            async_handles = [handle1, handle2]
            for async_handle in async_handles:
                async_handle.wait()
        else:
            data_dict = {}
            recv = [None]
            torch.distributed.broadcast_object_list(recv, src=src)
            dim_and_size_tensor = torch.empty(
                total_size_data_num, dtype=torch.int, device="npu"
            )
            other_data_list = recv[0]
            handle1 = torch.distributed.broadcast(
                dim_and_size_tensor, src=src, group=group, async_op=True
            )
            handle1.wait()
            size_dict, single_value_dict = get_sizedata_and_singlevalues_from_total(
                dim_and_size_tensor
            )
            concat_tensor = construct_empty_concat_tensor(size_dict)
            handle2 = torch.distributed.broadcast(
                concat_tensor, src=src, group=group, async_op=True
            )
            data_dict.update(single_value_dict)
            for k, v in other_data_list:
                data_dict[k] = v
            handle2.wait()
            tensor_dict = get_tensor_dict_from_concat_tensor(concat_tensor, size_dict)
            data_dict.update(tensor_dict)
            data_dict["seq_lens"] = data_dict["seq_lens_tensor"].tolist()
        return data_dict
    
  • 在cover/vllm/engine/arg_utils.py:为适配后端block_size默认值128,修改block_size值为128。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of codes in this file was copied from project [vLLM Team][vllm]
    
    import argparse
    import dataclasses
    import json
    from dataclasses import dataclass
    from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
                        Type, Union)
    
    import torch
    
    import vllm.envs as envs
    from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
                             DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
                             LoRAConfig, ModelConfig, ObservabilityConfig,
                             ParallelConfig, PromptAdapterConfig, SchedulerConfig,
                             SpeculativeConfig, TokenizerPoolConfig)
    from vllm.executor.executor_base import ExecutorBase
    from vllm.logger import init_logger
    from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
    from vllm.transformers_utils.utils import check_gguf_file
    from vllm.utils import FlexibleArgumentParser
    
    if TYPE_CHECKING:
        from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
    
    logger = init_logger(__name__)
    
    ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
    
    DEVICE_OPTIONS = [
        "auto",
        "cuda",
        "neuron",
        "cpu",
        "openvino",
        "tpu",
        "xpu",
    ]
    
    
    def nullable_str(val: str):
        if not val or val == "None":
            return None
        return val
    
    
    def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
        """Parses a string containing comma separate key [str] to value [int]
        pairs into a dictionary.
    
        Args:
            val: String value to be parsed.
    
        Returns:
            Dictionary with parsed values.
        """
        if len(val) == 0:
            return None
    
        out_dict: Dict[str, int] = {}
        for item in val.split(","):
            kv_parts = [part.lower().strip() for part in item.split("=")]
            if len(kv_parts) != 2:
                raise argparse.ArgumentTypeError(
                    "Each item should be in the form KEY=VALUE")
            key, value = kv_parts
    
            try:
                parsed_value = int(value)
            except ValueError as exc:
                msg = f"Failed to parse value of item {key}={value}"
                raise argparse.ArgumentTypeError(msg) from exc
    
            if key in out_dict and out_dict[key] != parsed_value:
                raise argparse.ArgumentTypeError(
                    f"Conflicting values specified for key: {key}")
            out_dict[key] = parsed_value
    
        return out_dict
    
    
    @dataclass
    class EngineArgs:
        """Arguments for vLLM engine."""
        model: str = 'facebook/opt-125m'
        served_model_name: Optional[Union[str, List[str]]] = None
        tokenizer: Optional[str] = None
        skip_tokenizer_init: bool = False
        tokenizer_mode: str = 'auto'
        trust_remote_code: bool = False
        download_dir: Optional[str] = None
        load_format: str = 'auto'
        config_format: str = 'auto'
        dtype: str = 'auto'
        kv_cache_dtype: str = 'auto'
        quantization_param_path: Optional[str] = None
        seed: int = 0
        max_model_len: Optional[int] = None
        worker_use_ray: bool = False
        # Note: Specifying a custom executor backend by passing a class
        # is intended for expert use only. The API may change without
        # notice.
        distributed_executor_backend: Optional[Union[str,
                                                     Type[ExecutorBase]]] = None
        pipeline_parallel_size: int = 1
        tensor_parallel_size: int = 1
        max_parallel_loading_workers: Optional[int] = None
        block_size: int = 128
        enable_prefix_caching: bool = False
        disable_sliding_window: bool = False
        use_v2_block_manager: bool = False
        swap_space: float = 4  # GiB
        cpu_offload_gb: float = 0  # GiB
        gpu_memory_utilization: float = 0.90
        max_num_batched_tokens: Optional[int] = None
        max_num_seqs: int = 256
        max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
        disable_log_stats: bool = False
        revision: Optional[str] = None
        code_revision: Optional[str] = None
        rope_scaling: Optional[dict] = None
        rope_theta: Optional[float] = None
        tokenizer_revision: Optional[str] = None
        quantization: Optional[str] = None
        enforce_eager: Optional[bool] = None
        max_context_len_to_capture: Optional[int] = None
        max_seq_len_to_capture: int = 8192
        disable_custom_all_reduce: bool = False
        tokenizer_pool_size: int = 0
        # Note: Specifying a tokenizer pool by passing a class
        # is intended for expert use only. The API may change without
        # notice.
        tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
        tokenizer_pool_extra_config: Optional[dict] = None
        limit_mm_per_prompt: Optional[Mapping[str, int]] = None
        enable_lora: bool = False
        max_loras: int = 1
        max_lora_rank: int = 16
        enable_prompt_adapter: bool = False
        max_prompt_adapters: int = 1
        max_prompt_adapter_token: int = 0
        fully_sharded_loras: bool = False
        lora_extra_vocab_size: int = 256
        long_lora_scaling_factors: Optional[Tuple[float]] = None
        lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
        max_cpu_loras: Optional[int] = None
        device: str = 'auto'
        num_scheduler_steps: int = 1
        multi_step_stream_outputs: bool = False
        ray_workers_use_nsight: bool = False
        num_gpu_blocks_override: Optional[int] = None
        num_lookahead_slots: int = 0
        model_loader_extra_config: Optional[dict] = None
        ignore_patterns: Optional[Union[str, List[str]]] = None
        preemption_mode: Optional[str] = None
    
        scheduler_delay_factor: float = 0.0
        enable_chunked_prefill: Optional[bool] = None
    
        guided_decoding_backend: str = 'outlines'
        # Speculative decoding configuration.
        speculative_model: Optional[str] = None
        speculative_model_quantization: Optional[str] = None
        speculative_draft_tensor_parallel_size: Optional[int] = None
        num_speculative_tokens: Optional[int] = None
        speculative_max_model_len: Optional[int] = None
        speculative_disable_by_batch_size: Optional[int] = None
        ngram_prompt_lookup_max: Optional[int] = None
        ngram_prompt_lookup_min: Optional[int] = None
        spec_decoding_acceptance_method: str = 'rejection_sampler'
        typical_acceptance_sampler_posterior_threshold: Optional[float] = None
        typical_acceptance_sampler_posterior_alpha: Optional[float] = None
        qlora_adapter_name_or_path: Optional[str] = None
        disable_logprobs_during_spec_decoding: Optional[bool] = None
    
        otlp_traces_endpoint: Optional[str] = None
        collect_detailed_traces: Optional[str] = None
        disable_async_output_proc: bool = False
        override_neuron_config: Optional[Dict[str, Any]] = None
        mm_processor_kwargs: Optional[Dict[str, Any]] = None
    
        def __post_init__(self):
            if self.tokenizer is None:
                self.tokenizer = self.model
    
        @staticmethod
        def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
            """Shared CLI arguments for vLLM engine."""
    
            # Model arguments
            parser.add_argument(
                '--model',
                type=str,
                default=EngineArgs.model,
                help='Name or path of the huggingface model to use.')
            parser.add_argument(
                '--tokenizer',
                type=nullable_str,
                default=EngineArgs.tokenizer,
                help='Name or path of the huggingface tokenizer to use. '
                'If unspecified, model name or path will be used.')
            parser.add_argument(
                '--skip-tokenizer-init',
                action='store_true',
                help='Skip initialization of tokenizer and detokenizer')
            parser.add_argument(
                '--revision',
                type=nullable_str,
                default=None,
                help='The specific model version to use. It can be a branch '
                'name, a tag name, or a commit id. If unspecified, will use '
                'the default version.')
            parser.add_argument(
                '--code-revision',
                type=nullable_str,
                default=None,
                help='The specific revision to use for the model code on '
                'Hugging Face Hub. It can be a branch name, a tag name, or a '
                'commit id. If unspecified, will use the default version.')
            parser.add_argument(
                '--tokenizer-revision',
                type=nullable_str,
                default=None,
                help='Revision of the huggingface tokenizer to use. '
                'It can be a branch name, a tag name, or a commit id. '
                'If unspecified, will use the default version.')
            parser.add_argument(
                '--tokenizer-mode',
                type=str,
                default=EngineArgs.tokenizer_mode,
                choices=['auto', 'slow', 'mistral'],
                help='The tokenizer mode.\n\n* "auto" will use the '
                'fast tokenizer if available.\n* "slow" will '
                'always use the slow tokenizer. \n* '
                '"mistral" will always use the `mistral_common` tokenizer.')
            parser.add_argument('--trust-remote-code',
                                action='store_true',
                                help='Trust remote code from huggingface.')
            parser.add_argument('--download-dir',
                                type=nullable_str,
                                default=EngineArgs.download_dir,
                                help='Directory to download and load the weights, '
                                'default to the default cache dir of '
                                'huggingface.')
            parser.add_argument(
                '--load-format',
                type=str,
                default=EngineArgs.load_format,
                choices=[f.value for f in LoadFormat],
                help='The format of the model weights to load.\n\n'
                '* "auto" will try to load the weights in the safetensors format '
                'and fall back to the pytorch bin format if safetensors format '
                'is not available.\n'
                '* "pt" will load the weights in the pytorch bin format.\n'
                '* "safetensors" will load the weights in the safetensors format.\n'
                '* "npcache" will load the weights in pytorch format and store '
                'a numpy cache to speed up the loading.\n'
                '* "dummy" will initialize the weights with random values, '
                'which is mainly for profiling.\n'
                '* "tensorizer" will load the weights using tensorizer from '
                'CoreWeave. See the Tensorize vLLM Model script in the Examples '
                'section for more information.\n'
                '* "bitsandbytes" will load the weights using bitsandbytes '
                'quantization.\n')
            parser.add_argument(
                '--config-format',
                default=EngineArgs.config_format,
                choices=[f.value for f in ConfigFormat],
                help='The format of the model config to load.\n\n'
                '* "auto" will try to load the config in hf format '
                'if available else it will try to load in mistral format ')
            parser.add_argument(
                '--dtype',
                type=str,
                default=EngineArgs.dtype,
                choices=[
                    'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
                ],
                help='Data type for model weights and activations.\n\n'
                '* "auto" will use FP16 precision for FP32 and FP16 models, and '
                'BF16 precision for BF16 models.\n'
                '* "half" for FP16. Recommended for AWQ quantization.\n'
                '* "float16" is the same as "half".\n'
                '* "bfloat16" for a balance between precision and range.\n'
                '* "float" is shorthand for FP32 precision.\n'
                '* "float32" for FP32 precision.')
            parser.add_argument(
                '--kv-cache-dtype',
                type=str,
                choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
                default=EngineArgs.kv_cache_dtype,
                help='Data type for kv cache storage. If "auto", will use model '
                'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
                'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
            parser.add_argument(
                '--quantization-param-path',
                type=nullable_str,
                default=None,
                help='Path to the JSON file containing the KV cache '
                'scaling factors. This should generally be supplied, when '
                'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
                'default to 1.0, which may cause accuracy issues. '
                'FP8_E5M2 (without scaling) is only supported on cuda version'
                'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
                'supported for common inference criteria.')
            parser.add_argument('--max-model-len',
                                type=int,
                                default=EngineArgs.max_model_len,
                                help='Model context length. If unspecified, will '
                                'be automatically derived from the model config.')
            parser.add_argument(
                '--guided-decoding-backend',
                type=str,
                default='outlines',
                choices=['outlines', 'lm-format-enforcer'],
                help='Which engine will be used for guided decoding'
                ' (JSON schema / regex etc) by default. Currently support '
                'https://github.com/outlines-dev/outlines and '
                'https://github.com/noamgat/lm-format-enforcer.'
                ' Can be overridden per request via guided_decoding_backend'
                ' parameter.')
            # Parallel arguments
            parser.add_argument(
                '--distributed-executor-backend',
                choices=['ray', 'mp'],
                default=EngineArgs.distributed_executor_backend,
                help='Backend to use for distributed serving. When more than 1 GPU '
                'is used, will be automatically set to "ray" if installed '
                'or "mp" (multiprocessing) otherwise.')
            parser.add_argument(
                '--worker-use-ray',
                action='store_true',
                help='Deprecated, use --distributed-executor-backend=ray.')
            parser.add_argument('--pipeline-parallel-size',
                                '-pp',
                                type=int,
                                default=EngineArgs.pipeline_parallel_size,
                                help='Number of pipeline stages.')
            parser.add_argument('--tensor-parallel-size',
                                '-tp',
                                type=int,
                                default=EngineArgs.tensor_parallel_size,
                                help='Number of tensor parallel replicas.')
            parser.add_argument(
                '--max-parallel-loading-workers',
                type=int,
                default=EngineArgs.max_parallel_loading_workers,
                help='Load model sequentially in multiple batches, '
                'to avoid RAM OOM when using tensor '
                'parallel and large models.')
            parser.add_argument(
                '--ray-workers-use-nsight',
                action='store_true',
                help='If specified, use nsight to profile Ray workers.')
            # KV cache arguments
            parser.add_argument('--block-size',
                                type=int,
                                default=EngineArgs.block_size,
                                choices=[8, 16, 32, 128],
                                help='Token block size for contiguous chunks of '
                                'tokens. This is ignored on neuron devices and '
                                'set to max-model-len')
    
            parser.add_argument('--enable-prefix-caching',
                                action='store_true',
                                help='Enables automatic prefix caching.')
            parser.add_argument('--disable-sliding-window',
                                action='store_true',
                                help='Disables sliding window, '
                                'capping to sliding window size')
            parser.add_argument('--use-v2-block-manager',
                                action='store_true',
                                help='Use BlockSpaceMangerV2.')
            parser.add_argument(
                '--num-lookahead-slots',
                type=int,
                default=EngineArgs.num_lookahead_slots,
                help='Experimental scheduling config necessary for '
                'speculative decoding. This will be replaced by '
                'speculative config in the future; it is present '
                'to enable correctness tests until then.')
    
            parser.add_argument('--seed',
                                type=int,
                                default=EngineArgs.seed,
                                help='Random seed for operations.')
            parser.add_argument('--swap-space',
                                type=float,
                                default=EngineArgs.swap_space,
                                help='CPU swap space size (GiB) per GPU.')
            parser.add_argument(
                '--cpu-offload-gb',
                type=float,
                default=0,
                help='The space in GiB to offload to CPU, per GPU. '
                'Default is 0, which means no offloading. Intuitively, '
                'this argument can be seen as a virtual way to increase '
                'the GPU memory size. For example, if you have one 24 GB '
                'GPU and set this to 10, virtually you can think of it as '
                'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
                'which requires at least 26GB GPU memory. Note that this '
                'requires fast CPU-GPU interconnect, as part of the model is'
                'loaded from CPU memory to GPU memory on the fly in each '
                'model forward pass.')
            parser.add_argument(
                '--gpu-memory-utilization',
                type=float,
                default=EngineArgs.gpu_memory_utilization,
                help='The fraction of GPU memory to be used for the model '
                'executor, which can range from 0 to 1. For example, a value of '
                '0.5 would imply 50%% GPU memory utilization. If unspecified, '
                'will use the default value of 0.9.')
            parser.add_argument(
                '--num-gpu-blocks-override',
                type=int,
                default=None,
                help='If specified, ignore GPU profiling result and use this number'
                'of GPU blocks. Used for testing preemption.')
            parser.add_argument('--max-num-batched-tokens',
                                type=int,
                                default=EngineArgs.max_num_batched_tokens,
                                help='Maximum number of batched tokens per '
                                'iteration.')
            parser.add_argument('--max-num-seqs',
                                type=int,
                                default=EngineArgs.max_num_seqs,
                                help='Maximum number of sequences per iteration.')
            parser.add_argument(
                '--max-logprobs',
                type=int,
                default=EngineArgs.max_logprobs,
                help=('Max number of log probs to return logprobs is specified in'
                      ' SamplingParams.'))
            parser.add_argument('--disable-log-stats',
                                action='store_true',
                                help='Disable logging statistics.')
            # Quantization settings.
            parser.add_argument('--quantization',
                                '-q',
                                type=nullable_str,
                                choices=[*QUANTIZATION_METHODS, None],
                                default=EngineArgs.quantization,
                                help='Method used to quantize the weights. If '
                                'None, we first check the `quantization_config` '
                                'attribute in the model config file. If that is '
                                'None, we assume the model weights are not '
                                'quantized and use `dtype` to determine the data '
                                'type of the weights.')
            parser.add_argument('--rope-scaling',
                                default=None,
                                type=json.loads,
                                help='RoPE scaling configuration in JSON format. '
                                'For example, {"type":"dynamic","factor":2.0}')
            parser.add_argument('--rope-theta',
                                default=None,
                                type=float,
                                help='RoPE theta. Use with `rope_scaling`. In '
                                'some cases, changing the RoPE theta improves the '
                                'performance of the scaled model.')
            parser.add_argument('--enforce-eager',
                                action='store_true',
                                help='Always use eager-mode PyTorch. If False, '
                                'will use eager mode and CUDA graph in hybrid '
                                'for maximal performance and flexibility.')
            parser.add_argument('--max-context-len-to-capture',
                                type=int,
                                default=EngineArgs.max_context_len_to_capture,
                                help='Maximum context length covered by CUDA '
                                'graphs. When a sequence has context length '
                                'larger than this, we fall back to eager mode. '
                                '(DEPRECATED. Use --max-seq-len-to-capture instead'
                                ')')
            parser.add_argument('--max-seq-len-to-capture',
                                type=int,
                                default=EngineArgs.max_seq_len_to_capture,
                                help='Maximum sequence length covered by CUDA '
                                'graphs. When a sequence has context length '
                                'larger than this, we fall back to eager mode. '
                                'Additionally for encoder-decoder models, if the '
                                'sequence length of the encoder input is larger '
                                'than this, we fall back to the eager mode.')
            parser.add_argument('--disable-custom-all-reduce',
                                action='store_true',
                                default=EngineArgs.disable_custom_all_reduce,
                                help='See ParallelConfig.')
            parser.add_argument('--tokenizer-pool-size',
                                type=int,
                                default=EngineArgs.tokenizer_pool_size,
                                help='Size of tokenizer pool to use for '
                                'asynchronous tokenization. If 0, will '
                                'use synchronous tokenization.')
            parser.add_argument('--tokenizer-pool-type',
                                type=str,
                                default=EngineArgs.tokenizer_pool_type,
                                help='Type of tokenizer pool to use for '
                                'asynchronous tokenization. Ignored '
                                'if tokenizer_pool_size is 0.')
            parser.add_argument('--tokenizer-pool-extra-config',
                                type=nullable_str,
                                default=EngineArgs.tokenizer_pool_extra_config,
                                help='Extra config for tokenizer pool. '
                                'This should be a JSON string that will be '
                                'parsed into a dictionary. Ignored if '
                                'tokenizer_pool_size is 0.')
    
            # Multimodal related configs
            parser.add_argument(
                '--limit-mm-per-prompt',
                type=nullable_kvs,
                default=EngineArgs.limit_mm_per_prompt,
                # The default value is given in
                # MultiModalRegistry.init_mm_limits_per_prompt
                help=('For each multimodal plugin, limit how many '
                      'input instances to allow for each prompt. '
                      'Expects a comma-separated list of items, '
                      'e.g.: `image=16,video=2` allows a maximum of 16 '
                      'images and 2 videos per prompt. Defaults to 1 for '
                      'each modality.'))
            parser.add_argument(
                '--mm-processor-kwargs',
                default=None,
                type=json.loads,
                help=('Overrides for the multimodal input mapping/processing,'
                      'e.g., image processor. For example: {"num_crops": 4}.'))
    
            # LoRA related configs
            parser.add_argument('--enable-lora',
                                action='store_true',
                                help='If True, enable handling of LoRA adapters.')
            parser.add_argument('--max-loras',
                                type=int,
                                default=EngineArgs.max_loras,
                                help='Max number of LoRAs in a single batch.')
            parser.add_argument('--max-lora-rank',
                                type=int,
                                default=EngineArgs.max_lora_rank,
                                help='Max LoRA rank.')
            parser.add_argument(
                '--lora-extra-vocab-size',
                type=int,
                default=EngineArgs.lora_extra_vocab_size,
                help=('Maximum size of extra vocabulary that can be '
                      'present in a LoRA adapter (added to the base '
                      'model vocabulary).'))
            parser.add_argument(
                '--lora-dtype',
                type=str,
                default=EngineArgs.lora_dtype,
                choices=['auto', 'float16', 'bfloat16', 'float32'],
                help=('Data type for LoRA. If auto, will default to '
                      'base model dtype.'))
            parser.add_argument(
                '--long-lora-scaling-factors',
                type=nullable_str,
                default=EngineArgs.long_lora_scaling_factors,
                help=('Specify multiple scaling factors (which can '
                      'be different from base model scaling factor '
                      '- see eg. Long LoRA) to allow for multiple '
                      'LoRA adapters trained with those scaling '
                      'factors to be used at the same time. If not '
                      'specified, only adapters trained with the '
                      'base model scaling factor are allowed.'))
            parser.add_argument(
                '--max-cpu-loras',
                type=int,
                default=EngineArgs.max_cpu_loras,
                help=('Maximum number of LoRAs to store in CPU memory. '
                      'Must be >= than max_num_seqs. '
                      'Defaults to max_num_seqs.'))
            parser.add_argument(
                '--fully-sharded-loras',
                action='store_true',
                help=('By default, only half of the LoRA computation is '
                      'sharded with tensor parallelism. '
                      'Enabling this will use the fully sharded layers. '
                      'At high sequence length, max rank or '
                      'tensor parallel size, this is likely faster.'))
            parser.add_argument('--enable-prompt-adapter',
                                action='store_true',
                                help='If True, enable handling of PromptAdapters.')
            parser.add_argument('--max-prompt-adapters',
                                type=int,
                                default=EngineArgs.max_prompt_adapters,
                                help='Max number of PromptAdapters in a batch.')
            parser.add_argument('--max-prompt-adapter-token',
                                type=int,
                                default=EngineArgs.max_prompt_adapter_token,
                                help='Max number of PromptAdapters tokens')
            parser.add_argument("--device",
                                type=str,
                                default=EngineArgs.device,
                                choices=DEVICE_OPTIONS,
                                help='Device type for vLLM execution.')
            parser.add_argument('--num-scheduler-steps',
                                type=int,
                                default=1,
                                help=('Maximum number of forward steps per '
                                      'scheduler call.'))
    
            parser.add_argument(
                '--multi-step-stream-outputs',
                action='store_true',
                help='If True, then multi-step will stream outputs for every step')
            parser.add_argument(
                '--scheduler-delay-factor',
                type=float,
                default=EngineArgs.scheduler_delay_factor,
                help='Apply a delay (of delay factor multiplied by previous'
                'prompt latency) before scheduling next prompt.')
            parser.add_argument(
                '--enable-chunked-prefill',
                action=StoreBoolean,
                default=EngineArgs.enable_chunked_prefill,
                nargs="?",
                const="True",
                help='If set, the prefill requests can be chunked based on the '
                'max_num_batched_tokens.')
    
            parser.add_argument(
                '--speculative-model',
                type=nullable_str,
                default=EngineArgs.speculative_model,
                help=
                'The name of the draft model to be used in speculative decoding.')
            # Quantization settings for speculative model.
            parser.add_argument(
                '--speculative-model-quantization',
                type=nullable_str,
                choices=[*QUANTIZATION_METHODS, None],
                default=EngineArgs.speculative_model_quantization,
                help='Method used to quantize the weights of speculative model.'
                'If None, we first check the `quantization_config` '
                'attribute in the model config file. If that is '
                'None, we assume the model weights are not '
                'quantized and use `dtype` to determine the data '
                'type of the weights.')
            parser.add_argument(
                '--num-speculative-tokens',
                type=int,
                default=EngineArgs.num_speculative_tokens,
                help='The number of speculative tokens to sample from '
                'the draft model in speculative decoding.')
            parser.add_argument(
                '--speculative-draft-tensor-parallel-size',
                '-spec-draft-tp',
                type=int,
                default=EngineArgs.speculative_draft_tensor_parallel_size,
                help='Number of tensor parallel replicas for '
                'the draft model in speculative decoding.')
    
            parser.add_argument(
                '--speculative-max-model-len',
                type=int,
                default=EngineArgs.speculative_max_model_len,
                help='The maximum sequence length supported by the '
                'draft model. Sequences over this length will skip '
                'speculation.')
    
            parser.add_argument(
                '--speculative-disable-by-batch-size',
                type=int,
                default=EngineArgs.speculative_disable_by_batch_size,
                help='Disable speculative decoding for new incoming requests '
                'if the number of enqueue requests is larger than this value.')
    
            parser.add_argument(
                '--ngram-prompt-lookup-max',
                type=int,
                default=EngineArgs.ngram_prompt_lookup_max,
                help='Max size of window for ngram prompt lookup in speculative '
                'decoding.')
    
            parser.add_argument(
                '--ngram-prompt-lookup-min',
                type=int,
                default=EngineArgs.ngram_prompt_lookup_min,
                help='Min size of window for ngram prompt lookup in speculative '
                'decoding.')
    
            parser.add_argument(
                '--spec-decoding-acceptance-method',
                type=str,
                default=EngineArgs.spec_decoding_acceptance_method,
                choices=['rejection_sampler', 'typical_acceptance_sampler'],
                help='Specify the acceptance method to use during draft token '
                'verification in speculative decoding. Two types of acceptance '
                'routines are supported: '
                '1) RejectionSampler which does not allow changing the '
                'acceptance rate of draft tokens, '
                '2) TypicalAcceptanceSampler which is configurable, allowing for '
                'a higher acceptance rate at the cost of lower quality, '
                'and vice versa.')
    
            parser.add_argument(
                '--typical-acceptance-sampler-posterior-threshold',
                type=float,
                default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
                help='Set the lower bound threshold for the posterior '
                'probability of a token to be accepted. This threshold is '
                'used by the TypicalAcceptanceSampler to make sampling decisions '
                'during speculative decoding. Defaults to 0.09')
    
            parser.add_argument(
                '--typical-acceptance-sampler-posterior-alpha',
                type=float,
                default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
                help='A scaling factor for the entropy-based threshold for token '
                'acceptance in the TypicalAcceptanceSampler. Typically defaults '
                'to sqrt of --typical-acceptance-sampler-posterior-threshold '
                'i.e. 0.3')
    
            parser.add_argument(
                '--disable-logprobs-during-spec-decoding',
                action=StoreBoolean,
                default=EngineArgs.disable_logprobs_during_spec_decoding,
                nargs="?",
                const="True",
                help='If set to True, token log probabilities are not returned '
                'during speculative decoding. If set to False, log probabilities '
                'are returned according to the settings in SamplingParams. If '
                'not specified, it defaults to True. Disabling log probabilities '
                'during speculative decoding reduces latency by skipping logprob '
                'calculation in proposal sampling, target sampling, and after '
                'accepted tokens are determined.')
    
            parser.add_argument('--model-loader-extra-config',
                                type=nullable_str,
                                default=EngineArgs.model_loader_extra_config,
                                help='Extra config for model loader. '
                                'This will be passed to the model loader '
                                'corresponding to the chosen load_format. '
                                'This should be a JSON string that will be '
                                'parsed into a dictionary.')
            parser.add_argument(
                '--ignore-patterns',
                action="append",
                type=str,
                default=[],
                help="The pattern(s) to ignore when loading the model."
                "Default to 'original/**/*' to avoid repeated loading of llama's "
                "checkpoints.")
            parser.add_argument(
                '--preemption-mode',
                type=str,
                default=None,
                help='If \'recompute\', the engine performs preemption by '
                'recomputing; If \'swap\', the engine performs preemption by '
                'block swapping.')
    
            parser.add_argument(
                "--served-model-name",
                nargs="+",
                type=str,
                default=None,
                help="The model name(s) used in the API. If multiple "
                "names are provided, the server will respond to any "
                "of the provided names. The model name in the model "
                "field of a response will be the first name in this "
                "list. If not specified, the model name will be the "
                "same as the `--model` argument. Noted that this name(s)"
                "will also be used in `model_name` tag content of "
                "prometheus metrics, if multiple names provided, metrics"
                "tag will take the first one.")
            parser.add_argument('--qlora-adapter-name-or-path',
                                type=str,
                                default=None,
                                help='Name or path of the QLoRA adapter.')
    
            parser.add_argument(
                '--otlp-traces-endpoint',
                type=str,
                default=None,
                help='Target URL to which OpenTelemetry traces will be sent.')
            parser.add_argument(
                '--collect-detailed-traces',
                type=str,
                default=None,
                help="Valid choices are " +
                ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
                ". It makes sense to set this only if --otlp-traces-endpoint is"
                " set. If set, it will collect detailed traces for the specified "
                "modules. This involves use of possibly costly and or blocking "
                "operations and hence might have a performance impact.")
    
            parser.add_argument(
                '--disable-async-output-proc',
                action='store_true',
                default=EngineArgs.disable_async_output_proc,
                help="Disable async output processing. This may result in "
                "lower performance.")
            parser.add_argument(
                '--override-neuron-config',
                type=lambda configs: {str(key): value for key, value in (config.split(':') for config in configs.split(','))},
                default=None,
                help="override or set neuron device configuration.")
    
            return parser
    
        @classmethod
        def from_cli_args(cls, args: argparse.Namespace):
            # Get the list of attributes of this dataclass.
            attrs = [attr.name for attr in dataclasses.fields(cls)]
            # Set the attributes from the parsed arguments.
            engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
            return engine_args
    
        def create_model_config(self) -> ModelConfig:
            return ModelConfig(
                model=self.model,
                tokenizer=self.tokenizer,
                tokenizer_mode=self.tokenizer_mode,
                trust_remote_code=self.trust_remote_code,
                dtype=self.dtype,
                seed=self.seed,
                revision=self.revision,
                code_revision=self.code_revision,
                rope_scaling=self.rope_scaling,
                rope_theta=self.rope_theta,
                tokenizer_revision=self.tokenizer_revision,
                max_model_len=self.max_model_len,
                quantization=self.quantization,
                quantization_param_path=self.quantization_param_path,
                enforce_eager=self.enforce_eager,
                max_context_len_to_capture=self.max_context_len_to_capture,
                max_seq_len_to_capture=self.max_seq_len_to_capture,
                max_logprobs=self.max_logprobs,
                disable_sliding_window=self.disable_sliding_window,
                skip_tokenizer_init=self.skip_tokenizer_init,
                served_model_name=self.served_model_name,
                limit_mm_per_prompt=self.limit_mm_per_prompt,
                use_async_output_proc=not self.disable_async_output_proc,
                override_neuron_config=self.override_neuron_config,
                config_format=self.config_format,
                mm_processor_kwargs=self.mm_processor_kwargs,
            )
    
        def create_load_config(self) -> LoadConfig:
            return LoadConfig(
                load_format=self.load_format,
                download_dir=self.download_dir,
                model_loader_extra_config=self.model_loader_extra_config,
                ignore_patterns=self.ignore_patterns,
            )
    
        def create_engine_config(self) -> EngineConfig:
            # gguf file needs a specific model loader and doesn't use hf_repo
            if check_gguf_file(self.model):
                self.quantization = self.load_format = "gguf"
    
            # bitsandbytes quantization needs a specific model loader
            # so we make sure the quant method and the load format are consistent
            if (self.quantization == "bitsandbytes" or
               self.qlora_adapter_name_or_path is not None) and \
               self.load_format != "bitsandbytes":
                raise ValueError(
                    "BitsAndBytes quantization and QLoRA adapter only support "
                    f"'bitsandbytes' load format, but got {self.load_format}")
    
            if (self.load_format == "bitsandbytes" or
                self.qlora_adapter_name_or_path is not None) and \
                self.quantization != "bitsandbytes":
                raise ValueError(
                    "BitsAndBytes load format and QLoRA adapter only support "
                    f"'bitsandbytes' quantization, but got {self.quantization}")
    
            assert self.cpu_offload_gb >= 0, (
                "CPU offload space must be non-negative"
                f", but got {self.cpu_offload_gb}")
    
            device_config = DeviceConfig(device=self.device)
            model_config = self.create_model_config()
    
            if model_config.is_multimodal_model:
                if self.enable_prefix_caching:
                    logger.warning(
                        "--enable-prefix-caching is currently not "
                        "supported for multimodal models and has been disabled.")
                self.enable_prefix_caching = False
    
            cache_config = CacheConfig(
                block_size=self.block_size if self.device != "neuron" else
                self.max_model_len,  # neuron needs block_size = max_model_len
                gpu_memory_utilization=self.gpu_memory_utilization,
                swap_space=self.swap_space,
                cache_dtype=self.kv_cache_dtype,
                num_gpu_blocks_override=self.num_gpu_blocks_override,
                sliding_window=model_config.get_sliding_window(),
                enable_prefix_caching=self.enable_prefix_caching,
                cpu_offload_gb=self.cpu_offload_gb,
            )
            parallel_config = ParallelConfig(
                pipeline_parallel_size=self.pipeline_parallel_size,
                tensor_parallel_size=self.tensor_parallel_size,
                worker_use_ray=self.worker_use_ray,
                max_parallel_loading_workers=self.max_parallel_loading_workers,
                disable_custom_all_reduce=self.disable_custom_all_reduce,
                tokenizer_pool_config=TokenizerPoolConfig.create_config(
                    self.tokenizer_pool_size,
                    self.tokenizer_pool_type,
                    self.tokenizer_pool_extra_config,
                ),
                ray_workers_use_nsight=self.ray_workers_use_nsight,
                distributed_executor_backend=self.distributed_executor_backend)
    
            max_model_len = model_config.max_model_len
            use_long_context = max_model_len > 32768
            if self.enable_chunked_prefill is None:
                # If not explicitly set, enable chunked prefill by default for
                # long context (> 32K) models. This is to avoid OOM errors in the
                # initial memory profiling phase.
    
                # Chunked prefill is currently disabled for multimodal models by
                # default.
                if use_long_context and not model_config.is_multimodal_model:
                    is_gpu = device_config.device_type == "cuda"
                    use_sliding_window = (model_config.get_sliding_window()
                                          is not None)
                    use_spec_decode = self.speculative_model is not None
                    has_seqlen_agnostic_layers = (
                        model_config.contains_seqlen_agnostic_layers(
                            parallel_config))
                    if (is_gpu and not use_sliding_window and not use_spec_decode
                            and not self.enable_lora
                            and not self.enable_prompt_adapter
                            and not has_seqlen_agnostic_layers):
                        self.enable_chunked_prefill = True
                        logger.warning(
                            "Chunked prefill is enabled by default for models with "
                            "max_model_len > 32K. Currently, chunked prefill might "
                            "not work with some features or models. If you "
                            "encounter any issues, please disable chunked prefill "
                            "by setting --enable-chunked-prefill=False.")
                if self.enable_chunked_prefill is None:
                    self.enable_chunked_prefill = False
    
            if not self.enable_chunked_prefill and use_long_context:
                logger.warning(
                    "The model has a long context length (%s). This may cause OOM "
                    "errors during the initial memory profiling phase, or result "
                    "in low performance due to small KV cache space. Consider "
                    "setting --max-model-len to a smaller value.", max_model_len)
    
            if self.num_scheduler_steps > 1 and not self.use_v2_block_manager:
                self.use_v2_block_manager = True
                logger.warning(
                    "Enabled BlockSpaceManagerV2 because it is "
                    "required for multi-step (--num-scheduler-steps > 1)")
    
            speculative_config = SpeculativeConfig.maybe_create_spec_config(
                target_model_config=model_config,
                target_parallel_config=parallel_config,
                target_dtype=self.dtype,
                speculative_model=self.speculative_model,
                speculative_model_quantization = \
                    self.speculative_model_quantization,
                speculative_draft_tensor_parallel_size = \
                    self.speculative_draft_tensor_parallel_size,
                num_speculative_tokens=self.num_speculative_tokens,
                speculative_disable_by_batch_size=self.
                speculative_disable_by_batch_size,
                speculative_max_model_len=self.speculative_max_model_len,
                enable_chunked_prefill=self.enable_chunked_prefill,
                use_v2_block_manager=self.use_v2_block_manager,
                disable_log_stats=self.disable_log_stats,
                ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
                ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
                draft_token_acceptance_method=\
                    self.spec_decoding_acceptance_method,
                typical_acceptance_sampler_posterior_threshold=self.
                typical_acceptance_sampler_posterior_threshold,
                typical_acceptance_sampler_posterior_alpha=self.
                typical_acceptance_sampler_posterior_alpha,
                disable_logprobs=self.disable_logprobs_during_spec_decoding,
            )
    
            if self.num_scheduler_steps > 1:
                if speculative_config is not None:
                    raise ValueError("Speculative decoding is not supported with "
                                     "multi-step (--num-scheduler-steps > 1)")
                if self.enable_chunked_prefill:
                    raise ValueError("Chunked prefill is not supported with "
                                     "multi-step (--num-scheduler-steps > 1)")
    
            # make sure num_lookahead_slots is set the higher value depending on
            # if we are using speculative decoding or multi-step
            num_lookahead_slots = max(self.num_lookahead_slots,
                                      self.num_scheduler_steps - 1)
            num_lookahead_slots = num_lookahead_slots \
                if speculative_config is None \
                else speculative_config.num_lookahead_slots
    
            scheduler_config = SchedulerConfig(
                max_num_batched_tokens=self.max_num_batched_tokens,
                max_num_seqs=self.max_num_seqs,
                max_model_len=model_config.max_model_len,
                use_v2_block_manager=self.use_v2_block_manager,
                num_lookahead_slots=num_lookahead_slots,
                delay_factor=self.scheduler_delay_factor,
                enable_chunked_prefill=self.enable_chunked_prefill,
                embedding_mode=model_config.embedding_mode,
                is_multimodal_model=model_config.is_multimodal_model,
                preemption_mode=self.preemption_mode,
                num_scheduler_steps=self.num_scheduler_steps,
                multi_step_stream_outputs=self.multi_step_stream_outputs,
                send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                                 and parallel_config.use_ray),
            )
            lora_config = LoRAConfig(
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_extra_vocab_size=self.lora_extra_vocab_size,
                long_lora_scaling_factors=self.long_lora_scaling_factors,
                lora_dtype=self.lora_dtype,
                max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
                and self.max_cpu_loras > 0 else None) if self.enable_lora else None
    
            if self.qlora_adapter_name_or_path is not None and \
                self.qlora_adapter_name_or_path != "":
                if self.model_loader_extra_config is None:
                    self.model_loader_extra_config = {}
                self.model_loader_extra_config[
                    "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
    
            load_config = self.create_load_config()
    
            prompt_adapter_config = PromptAdapterConfig(
                max_prompt_adapters=self.max_prompt_adapters,
                max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                            if self.enable_prompt_adapter else None
    
            decoding_config = DecodingConfig(
                guided_decoding_backend=self.guided_decoding_backend)
    
            detailed_trace_modules = []
            if self.collect_detailed_traces is not None:
                detailed_trace_modules = self.collect_detailed_traces.split(",")
            for m in detailed_trace_modules:
                if m not in ALLOWED_DETAILED_TRACE_MODULES:
                    raise ValueError(
                        f"Invalid module {m} in collect_detailed_traces. "
                        f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
            observability_config = ObservabilityConfig(
                otlp_traces_endpoint=self.otlp_traces_endpoint,
                collect_model_forward_time="model" in detailed_trace_modules
                or "all" in detailed_trace_modules,
                collect_model_execute_time="worker" in detailed_trace_modules
                or "all" in detailed_trace_modules,
            )
    
            if (model_config.get_sliding_window() is not None
                    and scheduler_config.chunked_prefill_enabled
                    and not scheduler_config.use_v2_block_manager):
                raise ValueError(
                    "Chunked prefill is not supported with sliding window. "
                    "Set --disable-sliding-window to disable sliding window.")
    
            return EngineConfig(
                model_config=model_config,
                cache_config=cache_config,
                parallel_config=parallel_config,
                scheduler_config=scheduler_config,
                device_config=device_config,
                lora_config=lora_config,
                speculative_config=speculative_config,
                load_config=load_config,
                decoding_config=decoding_config,
                observability_config=observability_config,
                prompt_adapter_config=prompt_adapter_config,
            )
    
    
    @dataclass
    class AsyncEngineArgs(EngineArgs):
        """Arguments for asynchronous vLLM engine."""
        disable_log_requests: bool = False
    
        @staticmethod
        def add_cli_args(parser: FlexibleArgumentParser,
                         async_args_only: bool = False) -> FlexibleArgumentParser:
            if not async_args_only:
                parser = EngineArgs.add_cli_args(parser)
            parser.add_argument('--disable-log-requests',
                                action='store_true',
                                help='Disable logging requests.')
            return parser
    
    
    class StoreBoolean(argparse.Action):
    
        def __call__(self, parser, namespace, values, option_string=None):
            if values.lower() == "true":
                setattr(namespace, self.dest, True)
            elif values.lower() == "false":
                setattr(namespace, self.dest, False)
            else:
                raise ValueError(f"Invalid boolean value: {values}. "
                                 "Expected 'true' or 'false'.")
    
    
    # These functions are used by sphinx to build the documentation
    def _engine_args_parser():
        return EngineArgs.add_cli_args(FlexibleArgumentParser())
    
    
    def _async_engine_args_parser():
        return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
                                            async_args_only=True)
    
  • cover/vllm/engine/async_llm_engine.py:增加适配昇腾环境的RayNPUExecutorAsync和NPUExecutorAsync选择。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of codes in this file was copied from project [vLLM Team][vllm]
    
    import asyncio
    import time
    import weakref
    from functools import partial
    from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
                        Mapping, Optional, Set, Tuple, Type, Union)
    from weakref import ReferenceType
    
    import vllm.envs as envs
    from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                             ParallelConfig, SchedulerConfig)
    from vllm.core.scheduler import SchedulerOutputs
    from vllm.engine.arg_utils import AsyncEngineArgs
    from vllm.engine.async_timeout import asyncio_timeout
    from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
    from vllm.engine.metrics_types import StatLoggerBase
    from vllm.executor.executor_base import ExecutorAsyncBase
    from vllm.executor.gpu_executor import GPUExecutorAsync
    from vllm.executor.ray_utils import initialize_ray_cluster
    from vllm.inputs import PromptInputs
    from vllm.logger import init_logger
    from vllm.lora.request import LoRARequest
    from vllm.model_executor.layers.sampler import SamplerOutput
    from vllm.outputs import EmbeddingRequestOutput, RequestOutput
    from vllm.pooling_params import PoolingParams
    from vllm.prompt_adapter.request import PromptAdapterRequest
    from vllm.sampling_params import SamplingParams
    from vllm.sequence import ExecuteModelRequest
    from vllm.transformers_utils.tokenizer import AnyTokenizer
    from vllm.usage.usage_lib import UsageContext
    from vllm.utils import weak_bind
    
    logger = init_logger(__name__)
    ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
    
    
    class AsyncEngineDeadError(RuntimeError):
        pass
    
    
    def _log_task_completion(task: asyncio.Task,
                             error_callback: Callable[[Exception], None]) -> None:
        """This function is only intended for the `engine.run_engine_loop()` task.
    
        In particular, that task runs a `while True` loop that can only exit if
        there is an exception.
        """
    
        exception = None
        try:
            return_value = task.result()
            raise AssertionError(
                f"The engine background task should never finish without an "
                f"exception. {return_value}")
        except asyncio.exceptions.CancelledError:
            # We assume that if the task is cancelled, we are gracefully shutting
            # down. This should only happen on program exit.
            logger.info("Engine is gracefully shutting down.")
        except Exception as e:
            exception = e
            logger.error("Engine background task failed", exc_info=e)
            error_callback(exception)
            raise AsyncEngineDeadError(
                "Task finished unexpectedly. This should never happen! "
                "Please open an issue on Github. See stack trace above for the "
                "actual cause.") from e
    
    
    STOP_ITERATION = Exception()  # Sentinel
    
    
    class AsyncStream:
        """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
        that can be iterated over asynchronously via an async generator."""
    
        def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
            self.request_id = request_id
            self._cancel = cancel
            self._queue: asyncio.Queue = asyncio.Queue()
            self._finished = False
    
        @property
        def finished(self) -> bool:
            return self._finished
    
        @staticmethod
        def _is_raisable(value: Any):
            return isinstance(value, BaseException) or \
                    (isinstance(value, type) and \
                     issubclass(value, BaseException))
    
        def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                                  Exception]) -> None:
            if not self._finished:
                self._queue.put_nowait(item)
    
        def finish(
            self,
            exception: Optional[Union[BaseException, Type[BaseException]]] = None,
        ) -> None:
            if not self._finished:
                self._finished = True
                self._queue.put_nowait(
                    exception if self._is_raisable(exception) else STOP_ITERATION)
    
        async def generator(
            self
        ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
            try:
                while True:
                    result = await self._queue.get()
                    if self._is_raisable(result):
                        if result == STOP_ITERATION:
                            return
                        raise result
                    yield result
            except GeneratorExit:
                self._cancel(self.request_id)
                raise asyncio.CancelledError from None
    
    
    class RequestTracker:
        """Synchronous abstraction for tracking requests."""
    
        def __init__(self) -> None:
            self._request_streams: Dict[str, AsyncStream] = {}
            self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
            self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                    dict]] = asyncio.Queue()
            self.new_requests_event = asyncio.Event()
    
        def __contains__(self, item):
            return item in self._request_streams
    
        def __len__(self) -> int:
            return len(self._request_streams)
    
        def propagate_exception(self,
                                exc: Exception,
                                request_id: Optional[str] = None) -> None:
            """Propagate an exception to request streams
            (all if request_id is None)."""
            if request_id is not None:
                self.abort_request(request_id, exception=exc)
            else:
                # NB: tuple() used here because self.abort_request pops the stream
                # out of self._request_streams, so we can't iterate on it directly
                for rid in tuple(self._request_streams.keys()):
                    self.abort_request(rid, exception=exc)
    
        def process_request_output(self,
                                   request_output: Union[RequestOutput,
                                                         EmbeddingRequestOutput],
                                   *,
                                   verbose: bool = False) -> None:
            """Process a request output from the engine."""
            request_id = request_output.request_id
            finished = request_output.finished
    
            if finished:
                stream = self._request_streams.pop(request_id, None)
            else:
                stream = self._request_streams.get(request_id)
            # Guard against a KeyError which can occur if the request was aborted
            # while the output was generated
            if stream is not None:
                stream.put(request_output)
                if finished:
                    stream.finish()
    
            if verbose and finished:
                logger.info("Finished request %s.", request_id)
    
        def process_exception(self,
                              request_id: str,
                              exception: BaseException,
                              *,
                              verbose: bool = False) -> None:
            """Propagate an exception from the engine."""
            if verbose:
                logger.info("Finished request %s.", request_id)
            self.abort_request(request_id, exception=exception)
    
        def add_request(self,
                        request_id: str,
                        *,
                        verbose: bool = False,
                        **engine_add_request_kwargs) -> AsyncStream:
            """Add a request to be sent to the engine on the next background
            loop iteration."""
            if request_id in self._request_streams:
                raise KeyError(f"Request {request_id} already exists.")
    
            abort_request = partial(self.abort_request, verbose=verbose)
            stream = AsyncStream(request_id, abort_request)
            self._new_requests.put_nowait((stream, {
                "request_id": request_id,
                **engine_add_request_kwargs
            }))
    
            self.new_requests_event.set()
    
            if verbose:
                logger.info("Added request %s.", request_id)
    
            return stream
    
        def abort_request(self,
                          request_id: str,
                          *,
                          exception: Optional[Union[BaseException,
                                                    Type[BaseException]]] = None,
                          verbose: bool = False) -> None:
            """Abort a request during next background loop iteration."""
            if verbose:
                logger.info("Aborted request %s.", request_id)
    
            self._aborted_requests.put_nowait(request_id)
    
            stream = self._request_streams.pop(request_id, None)
            if stream is not None:
                stream.finish(exception=exception)
    
        def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
            """Get the new requests and finished requests to be
            sent to the engine."""
            new_requests: List[Dict] = []
            finished_requests: Set[str] = set()
    
            while not self._aborted_requests.empty():
                request_id = self._aborted_requests.get_nowait()
                finished_requests.add(request_id)
    
            while not self._new_requests.empty():
                stream, new_request = self._new_requests.get_nowait()
                request_id = stream.request_id
                if request_id in finished_requests:
                    # The request has already been aborted.
                    stream.finish(asyncio.CancelledError)
                    finished_requests.discard(request_id)
                else:
                    self._request_streams[request_id] = stream
                    new_requests.append(new_request)
    
            return new_requests, finished_requests
    
        async def wait_for_new_requests(self):
            if not self.has_new_requests():
                await self.new_requests_event.wait()
            self.new_requests_event.clear()
    
        def has_new_requests(self):
            return not self._new_requests.empty()
    
    
    class _AsyncLLMEngine(LLMEngine):
        """Extension of LLMEngine to add async methods."""
    
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
    
        async def step_async(
            self, virtual_engine: int
        ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
            """Performs one decoding iteration and returns newly generated results.
            The workers are ran asynchronously if possible.
    
            This function performs one decoding iteration of the engine. It first
            schedules the sequences to be executed in the next iteration and the
            token blocks to be swapped in/out/copy. Then, it executes the model
            and updates the scheduler with the model outputs. Finally, it decodes
            the sequences and returns the newly generated results.
            """
            # these are cached outputs from previous iterations. None if on first
            # iteration
            cached_outputs = self.cached_scheduler_outputs[virtual_engine]
            seq_group_metadata_list = cached_outputs.seq_group_metadata_list
            scheduler_outputs = cached_outputs.scheduler_outputs
            allow_async_output_proc = cached_outputs.allow_async_output_proc
    
            ctx = self.scheduler_contexts[virtual_engine]
    
            # Clear outputs for each new scheduler iteration
            ctx.request_outputs.clear()
    
            # skip the scheduler if there are any remaining steps in the seq groups.
            # This ensures that the scheduler is only called again when the current
            # batch has completed.
            if not self._has_remaining_steps(seq_group_metadata_list):
    
                # Schedule iteration
                (seq_group_metadata_list, scheduler_outputs,
                 allow_async_output_proc
                 ) = self.scheduler[virtual_engine].schedule()
    
                ctx.seq_group_metadata_list = seq_group_metadata_list
                ctx.scheduler_outputs = scheduler_outputs
    
                # Maybe switch from async mode to sync mode
                if not allow_async_output_proc and len(ctx.output_queue) > 0:
                    self._process_model_outputs(ctx=ctx)
    
                if (self.scheduler_config.is_multi_step
                        and scheduler_outputs.num_lookahead_slots > 0):
                    # cache the scheduler outputs for the next iteration if we have
                    # lookahead slots
                    self._cache_scheduler_outputs_for_multi_step(
                        virtual_engine, seq_group_metadata_list, scheduler_outputs,
                        allow_async_output_proc)
    
            assert seq_group_metadata_list is not None
            assert scheduler_outputs is not None
    
            if not scheduler_outputs.is_empty():
                finished_requests_ids = self.scheduler[
                    virtual_engine].get_and_reset_finished_requests_ids()
    
                # Check if we have a cached last_output from the previous iteration.
                # For supporting PP this is probably the best way to pass the
                # sampled_token_ids, as a separate broadcast over all the PP stages
                # will cause one virtual engine's microbatch to block the pipeline.
                last_sampled_token_ids = \
                    self._get_last_sampled_token_ids(virtual_engine)
    
                execute_model_req = ExecuteModelRequest(
                    seq_group_metadata_list=seq_group_metadata_list,
                    blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                    blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                    blocks_to_copy=scheduler_outputs.blocks_to_copy,
                    virtual_engine=virtual_engine,
                    num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                    running_queue_size=scheduler_outputs.running_queue_size,
                    finished_requests_ids=finished_requests_ids,
                    # We use ExecuteModelRequest to pass the last sampled_token_ids
                    # to each of the non-last PP stages for in-place prepare_input.
                    last_sampled_token_ids=last_sampled_token_ids)
    
                if allow_async_output_proc:
                    execute_model_req.async_callback = self.async_callbacks[
                        virtual_engine]
    
                # Execute the model.
                outputs = await self.model_executor.execute_model_async(
                    execute_model_req)
    
                # we need to do this here so that last step's sampled_token_ids can
                # be passed to the next iteration for PP.
                if self.scheduler_config.is_multi_step:
                    self._update_cached_scheduler_output(virtual_engine, outputs)
            else:
                if len(ctx.output_queue) > 0:
                    self._process_model_outputs(ctx=ctx)
                outputs = []
    
            # Finish the current step for all the sequence groups.
            if self.scheduler_config.is_multi_step:
                for seq_group in seq_group_metadata_list:
                    seq_group.finish_step()
    
            if not self._has_remaining_steps(seq_group_metadata_list):
                # Clear the cache if we have finished all the steps
                if self.scheduler_config.is_multi_step:
                    self.cached_scheduler_outputs[
                        virtual_engine] = SchedulerOutputState()
    
                ctx.append_output(outputs=outputs,
                                  seq_group_metadata_list=seq_group_metadata_list,
                                  scheduler_outputs=scheduler_outputs,
                                  is_async=allow_async_output_proc,
                                  is_last_step=True)
    
                if outputs and allow_async_output_proc:
                    assert len(
                        outputs
                    ) == 1, "Async postprocessor expects only a single output set"
                    self._advance_to_next_step(
                        outputs[0], seq_group_metadata_list,
                        scheduler_outputs.scheduled_seq_groups)
    
                if not allow_async_output_proc:
                    self._process_model_outputs(ctx=ctx)
    
                    # Log stats.
                    self.do_log_stats(scheduler_outputs, outputs)
    
                    # Tracing
                    self.do_tracing(scheduler_outputs)
    
            else:
                # Multi-step case
                return ctx.request_outputs
    
            if not self.has_unfinished_requests():
                # Drain async postprocessor (if exists)
                if len(ctx.output_queue) > 0:
                    self._process_model_outputs(ctx=ctx)
                assert len(ctx.output_queue) == 0
    
            return ctx.request_outputs
    
        async def stop_remote_worker_execution_loop_async(self) -> None:
            """Stop the remote worker execution loop."""
            await self.model_executor.stop_remote_worker_execution_loop_async()
    
        async def add_request_async(
            self,
            request_id: str,
            inputs: PromptInputs,
            params: Union[SamplingParams, PoolingParams],
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        ) -> None:
            """Async version of :meth:`add_request`."""
            if lora_request is not None and not self.lora_config:
                raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                                 "not enabled!")
            if arrival_time is None:
                arrival_time = time.time()
    
            preprocessed_inputs = await self.input_preprocessor.preprocess_async(
                inputs,
                request_id=request_id,
                lora_request=lora_request,
                prompt_adapter_request=prompt_adapter_request,
            )
            processed_inputs = self.input_processor(preprocessed_inputs)
    
            self._add_processed_request(
                request_id=request_id,
                processed_inputs=processed_inputs,
                params=params,
                arrival_time=arrival_time,
                lora_request=lora_request,
                prompt_adapter_request=prompt_adapter_request,
                trace_headers=trace_headers,
            )
    
        async def check_health_async(self) -> None:
            if self.tokenizer:
                self.tokenizer.check_health()
            self.model_executor.check_health()
    
    
    class AsyncLLMEngine:
        """An asynchronous wrapper for :class:`LLMEngine`.
    
        This class is used to wrap the :class:`LLMEngine` class to make it
        asynchronous. It uses asyncio to create a background loop that keeps
        processing incoming requests. The :class:`LLMEngine` is kicked by the
        generate method when there are requests in the waiting queue. The generate
        method yields the outputs from the :class:`LLMEngine` to the caller.
    
        Args:
            log_requests: Whether to log the requests.
            start_engine_loop: If True, the background task to run the engine
                will be automatically started in the generate call.
            *args: Arguments for :class:`LLMEngine`.
            **kwargs: Arguments for :class:`LLMEngine`.
        """
    
        _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
    
        def __init__(self,
                     *args,
                     log_requests: bool = True,
                     start_engine_loop: bool = True,
                     **kwargs) -> None:
            self.log_requests = log_requests
            self.engine = self._engine_class(*args, **kwargs)
    
            # This ensures quick processing of request outputs
            # so the append to asyncio queues is not delayed,
            # especially for multi-step.
            self.use_process_request_outputs_callback = (
                self.engine.model_config.use_async_output_proc)
    
            if self.use_process_request_outputs_callback:
                self.engine.process_request_outputs_callback = \
                    weak_bind(self.process_request_outputs)
    
            self.background_loop: Optional[asyncio.Future] = None
            # We need to keep a reference to unshielded
            # task as well to prevent it from being garbage
            # collected
            self._background_loop_unshielded: Optional[asyncio.Task] = None
            self.start_engine_loop = start_engine_loop
            self._errored_with: Optional[BaseException] = None
    
            # Lazy initialized fields
            self._request_tracker: RequestTracker
    
        def __del__(self):
            if rt := getattr(self, "request_tracker", None):
                # Wake up engine loop so that it will exit cleanly
                rt.new_requests_event.set()
    
        @classmethod
        def _get_executor_cls(
                cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
            distributed_executor_backend = (
                engine_config.parallel_config.distributed_executor_backend)
            if isinstance(distributed_executor_backend, type):
                if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                    raise TypeError(
                        "distributed_executor_backend must be a subclass of "
                        f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
                executor_class = distributed_executor_backend
            elif engine_config.device_config.device_type == "neuron":
                from vllm.executor.neuron_executor import NeuronExecutorAsync
                executor_class = NeuronExecutorAsync
            elif engine_config.device_config.device_type == "tpu":
                if distributed_executor_backend == "ray":
                    from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                    executor_class = RayTPUExecutorAsync
                else:
                    assert distributed_executor_backend is None
                    from vllm.executor.tpu_executor import TPUExecutorAsync
                    executor_class = TPUExecutorAsync
            elif engine_config.device_config.device_type == "cpu":
                from vllm.executor.cpu_executor import CPUExecutorAsync
                executor_class = CPUExecutorAsync
            elif engine_config.device_config.device_type == "openvino":
                assert distributed_executor_backend is None, (
                    "Distributed execution is not supported with "
                    "the OpenVINO backend.")
                from vllm.executor.openvino_executor import OpenVINOExecutorAsync
                executor_class = OpenVINOExecutorAsync
            elif engine_config.device_config.device_type == "xpu":
                if distributed_executor_backend is None:
                    from vllm.executor.xpu_executor import XPUExecutorAsync
                    executor_class = XPUExecutorAsync
                elif distributed_executor_backend == "ray":
                    from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                    executor_class = RayXPUExecutorAsync
                elif distributed_executor_backend == "mp":
                    from vllm.executor.multiproc_xpu_executor import (
                        MultiprocessingXPUExecutorAsync)
                    executor_class = MultiprocessingXPUExecutorAsync
                else:
                    raise RuntimeError(
                        "Not supported distributed execution model on XPU device.")
            elif engine_config.device_config.device_type == "npu":
                if engine_config.parallel_config.use_ray:
                    initialize_ray_cluster(engine_config.parallel_config)
                    from vllm.executor.ray_npu_executor import RayNPUExecutorAsync
                    executor_class = RayNPUExecutorAsync
                else:
                    from vllm.executor.ray_npu_executor import NPUExecutorAsync
                    executor_class = NPUExecutorAsync
            elif distributed_executor_backend == "ray":
                from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
                executor_class = RayGPUExecutorAsync
            elif distributed_executor_backend == "mp":
                from vllm.executor.multiproc_gpu_executor import (
                    MultiprocessingGPUExecutorAsync)
                executor_class = MultiprocessingGPUExecutorAsync
            else:
                from vllm.executor.gpu_executor import GPUExecutorAsync
                executor_class = GPUExecutorAsync
            return executor_class
    
        @classmethod
        def from_engine_args(
            cls,
            engine_args: AsyncEngineArgs,
            engine_config: Optional[EngineConfig] = None,
            start_engine_loop: bool = True,
            usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
            stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        ) -> "AsyncLLMEngine":
            """Creates an async LLM engine from the engine arguments."""
            # Create the engine configs.
            if engine_config is None:
                engine_config = engine_args.create_engine_config()
    
            executor_class = cls._get_executor_cls(engine_config)
    
            if executor_class.uses_ray:
                initialize_ray_cluster(engine_config.parallel_config)
    
            # Create the async LLM engine.
            engine = cls(
                **engine_config.to_dict(),
                executor_class=executor_class,
                log_requests=not engine_args.disable_log_requests,
                log_stats=not engine_args.disable_log_stats,
                start_engine_loop=start_engine_loop,
                usage_context=usage_context,
                stat_loggers=stat_loggers,
            )
            return engine
    
        @property
        def is_running(self) -> bool:
            return (self.background_loop is not None
                    and self._background_loop_unshielded is not None
                    and not self._background_loop_unshielded.done())
    
        @property
        def is_stopped(self) -> bool:
            return self.errored or (self.background_loop is not None and
                                    self._background_loop_unshielded is not None
                                    and self._background_loop_unshielded.done())
    
        @property
        def errored(self) -> bool:
            return self._errored_with is not None
    
        @property
        def dead_error(self) -> BaseException:
            return AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")
    
        def set_errored(self, exc: Exception) -> None:
            self._errored_with = exc
    
        def _error_callback(self, exc: Exception) -> None:
            self.set_errored(exc)
            self._request_tracker.propagate_exception(exc)
    
        async def get_tokenizer(
            self,
            lora_request: Optional[LoRARequest] = None,
        ) -> AnyTokenizer:
            return await (self.engine.get_tokenizer_group().
                          get_lora_tokenizer_async(lora_request))
    
        def start_background_loop(self) -> None:
            """Start the background loop."""
            if self.errored:
                raise AsyncEngineDeadError(
                    "Background loop has errored already.") from self._errored_with
            if self.is_running:
                raise RuntimeError("Background loop is already running.")
            # Initialize the RequestTracker here so it uses the right event loop.
            self._request_tracker = RequestTracker()
    
            self._background_loop_unshielded = asyncio.get_event_loop(
            ).create_task(self.run_engine_loop(weakref.ref(self)))
            self._background_loop_unshielded.add_done_callback(
                partial(_log_task_completion, error_callback=self._error_callback))
            self.background_loop = asyncio.shield(self._background_loop_unshielded)
    
        def shutdown_background_loop(self) -> None:
            """
            Shut down the background loop.
    
            This method needs to be called during cleanup to remove
            references to `self` and properly GC the resources held
            by the async LLM engine (e.g., the executors as well as
            their resources).
            """
            if self._background_loop_unshielded is not None:
                self._background_loop_unshielded.cancel()
                self._background_loop_unshielded = None
            self.background_loop = None
    
        async def engine_step(self, virtual_engine: int) -> bool:
            """Kick the engine to process the waiting requests.
    
            Returns True if there are in-progress requests."""
    
            new_requests, aborted_requests = (
                self._request_tracker.get_new_and_aborted_requests())
    
            for new_request in new_requests:
                # Add the request into the vLLM engine's waiting queue.
                try:
                    await self.engine.add_request_async(**new_request)
                except ValueError as e:
                    # TODO: use a vLLM specific error for failed validation
                    self._request_tracker.process_exception(
                        new_request["request_id"],
                        e,
                        verbose=self.log_requests,
                    )
    
            if aborted_requests:
                await self._engine_abort(aborted_requests)
    
            request_outputs = await self.engine.step_async(virtual_engine)
    
            # Put the outputs into the corresponding streams.
            # If used as a callback, then already invoked inside
            # LLMEngine's _process_model_outputs
            if not self.use_process_request_outputs_callback:
                all_finished = self.process_request_outputs(request_outputs)
            else:
                # For callback case, we only need to detect when all
                # requests are finished
                all_finished = all(request_output.finished
                                   for request_output in request_outputs)
    
            return not all_finished
    
        def process_request_outputs(self, request_outputs) -> bool:
            # Put the outputs into the corresponding streams.
            all_finished = True
            for request_output in request_outputs:
                self._request_tracker.process_request_output(
                    request_output, verbose=self.log_requests)
                all_finished = all_finished and request_output.finished
    
            return all_finished
    
        async def _engine_abort(self, request_ids: Iterable[str]):
            self.engine.abort_request(request_ids)
    
        @staticmethod
        async def run_engine_loop(engine_ref: ReferenceType):
            """We use a weakref to the engine so that the running loop
            doesn't prevent the engine being garbage collected."""
            engine: Optional["AsyncLLMEngine"] = engine_ref()
            if not engine:
                return
    
            pipeline_parallel_size = \
                    engine.engine.parallel_config.pipeline_parallel_size
            has_requests_in_progress = [False] * pipeline_parallel_size
            while True:
                if not any(has_requests_in_progress):
                    logger.debug("Waiting for new requests...")
                    # Stop the execute model loop in parallel workers until there
                    # are more requests to process. This avoids waiting
                    # indefinitely in torch.distributed ops which may otherwise
                    # timeout, and unblocks the RPC thread in the workers so that
                    # they can process any other queued control plane messages,
                    # such as add/remove lora adapters.
                    await engine.engine.stop_remote_worker_execution_loop_async()
                    request_tracker = engine._request_tracker
                    # Allow engine to be garbage collected while
                    # waiting for new requests
                    del engine
                    await asyncio.sleep(0)
                    if engine_ref() is None:
                        return
                    await request_tracker.wait_for_new_requests()
                    engine = engine_ref()
                    if not engine:
                        return
                    logger.debug("Got new requests!")
                    requests_in_progress = [
                        asyncio.create_task(engine.engine_step(ve))
                        for ve in range(pipeline_parallel_size)
                    ]
                    has_requests_in_progress = [True] * pipeline_parallel_size
    
                # Abort if iteration takes too long due to unrecoverable errors
                # (eg. NCCL timeouts).
                try:
                    async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
                        done, _ = await asyncio.wait(
                            requests_in_progress,
                            return_when=asyncio.FIRST_COMPLETED)
                        for _ in range(pipeline_parallel_size):
                            await asyncio.sleep(0)
                    for task in done:
                        result = task.result()
                        virtual_engine = requests_in_progress.index(task)
                        has_unfinished_requests = (
                            engine.engine.
                            has_unfinished_requests_for_virtual_engine(
                                virtual_engine))
                        if result or has_unfinished_requests:
                            requests_in_progress[virtual_engine] = (
                                asyncio.create_task(
                                    engine.engine_step(virtual_engine)))
                            has_requests_in_progress[virtual_engine] = True
                        else:
                            has_requests_in_progress[virtual_engine] = False
                except asyncio.TimeoutError as exc:
                    logger.error(
                        "Engine iteration timed out. This should never happen!")
                    engine.set_errored(exc)
                    raise
                await asyncio.sleep(0)
    
        # This method does not need to be async, but kept that way
        # for backwards compatibility.
        async def add_request(
            self,
            request_id: str,
            inputs: PromptInputs,
            params: Union[SamplingParams, PoolingParams],
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None
        ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
            if not self.is_running:
                if self.start_engine_loop:
                    self.start_background_loop()
                else:
                    raise AsyncEngineDeadError(
                        "Background loop is not running. If it was running, "
                        "inspect the output to find the stacktrace of the "
                        "error that caused the background loop to stop "
                        "(AsyncEngineDeadError).")
    
            stream = self._request_tracker.add_request(
                request_id,
                verbose=self.log_requests,
                inputs=inputs,
                params=params,
                arrival_time=arrival_time or time.time(),
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request)
    
            return stream.generator()
    
        async def generate(
            self,
            inputs: PromptInputs,
            sampling_params: SamplingParams,
            request_id: str,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None
        ) -> AsyncGenerator[RequestOutput, None]:
            """Generate outputs for a request.
    
            Generate outputs for a request. This method is a coroutine. It adds the
            request into the waiting queue of the LLMEngine and streams the outputs
            from the LLMEngine to the caller.
    
            Args:
                inputs: The inputs to the LLM. See
                    :class:`~vllm.inputs.PromptInputs`
                    for more details about the format of each input.
                sampling_params: The sampling parameters of the request.
                request_id: The unique id of the request.
                lora_request: LoRA request to use for generation, if any.
                trace_headers: OpenTelemetry trace headers.
                prompt_adapter_request: Prompt Adapter request to use
                                                for generation, if any.
    
            Yields:
                The output `RequestOutput` objects from the LLMEngine
                for the request.
    
            Details:
                - If the engine is not running, start the background loop,
                  which iteratively invokes
                  :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
                  to process the waiting requests.
                - Add the request to the engine's `RequestTracker`.
                  On the next background loop, this request will be sent to
                  the underlying engine.
                  Also, a corresponding `AsyncStream` will be created.
                - Wait for the request outputs from `AsyncStream` and yield them.
    
            Example:
                >>> # Please refer to entrypoints/api_server.py for
                >>> # the complete example.
                >>>
                >>> # initialize the engine and the example input
                >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
                >>> example_input = {
                >>>     "prompt": "What is LLM?",
                >>>     "stream": False, # assume the non-streaming case
                >>>     "temperature": 0.0,
                >>>     "request_id": 0,
                >>> }
                >>>
                >>> # start the generation
                >>> results_generator = engine.generate(
                >>>    example_input["prompt"],
                >>>    SamplingParams(temperature=example_input["temperature"]),
                >>>    example_input["request_id"])
                >>>
                >>> # get the results
                >>> final_output = None
                >>> async for request_output in results_generator:
                >>>     if await request.is_disconnected():
                >>>         # Abort the request if the client disconnects.
                >>>         await engine.abort(request_id)
                >>>         # Return or raise an error
                >>>         ...
                >>>     final_output = request_output
                >>>
                >>> # Process and return the final output
                >>> ...
            """
            async for output in await self.add_request(
                    request_id,
                    inputs,
                    sampling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    prompt_adapter_request=prompt_adapter_request,
            ):
                yield LLMEngine.validate_output(output, RequestOutput)
    
        async def encode(
            self,
            inputs: PromptInputs,
            pooling_params: PoolingParams,
            request_id: str,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
        ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
            """Generate outputs for a request from an embedding model.
    
            Generate outputs for a request. This method is a coroutine. It adds the
            request into the waiting queue of the LLMEngine and streams the outputs
            from the LLMEngine to the caller.
    
            Args:
                inputs: The inputs to the LLM. See
                    :class:`~vllm.inputs.PromptInputs`
                    for more details about the format of each input.
                pooling_params: The pooling parameters of the request.
                request_id: The unique id of the request.
                lora_request: LoRA request to use for generation, if any.
                trace_headers: OpenTelemetry trace headers.
    
            Yields:
                The output `EmbeddingRequestOutput` objects from the LLMEngine
                for the request.
    
            Details:
                - If the engine is not running, start the background loop,
                  which iteratively invokes
                  :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
                  to process the waiting requests.
                - Add the request to the engine's `RequestTracker`.
                  On the next background loop, this request will be sent to
                  the underlying engine.
                  Also, a corresponding `AsyncStream` will be created.
                - Wait for the request outputs from `AsyncStream` and yield them.
    
            Example:
                >>> # Please refer to entrypoints/api_server.py for
                >>> # the complete example.
                >>>
                >>> # initialize the engine and the example input
                >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
                >>> example_input = {
                >>>     "input": "What is LLM?",
                >>>     "request_id": 0,
                >>> }
                >>>
                >>> # start the generation
                >>> results_generator = engine.encode(
                >>>    example_input["input"],
                >>>    PoolingParams(),
                >>>    example_input["request_id"])
                >>>
                >>> # get the results
                >>> final_output = None
                >>> async for request_output in results_generator:
                >>>     if await request.is_disconnected():
                >>>         # Abort the request if the client disconnects.
                >>>         await engine.abort(request_id)
                >>>         # Return or raise an error
                >>>         ...
                >>>     final_output = request_output
                >>>
                >>> # Process and return the final output
                >>> ...
            """
            async for output in await self.add_request(
                    request_id,
                    inputs,
                    pooling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
            ):
                yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
    
        async def abort(self, request_id: str) -> None:
            """Abort a request.
    
            Abort a submitted request. If the request is finished or not found,
            this method will be a no-op.
    
            Args:
                request_id: The unique id of the request.
            """
            if not self.is_running:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
    
            return self._abort(request_id)
    
        def _abort(self, request_id: str) -> None:
            """Abort a request.
    
            Abort a submitted request. If the request is finished or not found,
            this method will be a no-op.
    
            Args:
                request_id: The unique id of the request.
            """
            self._request_tracker.abort_request(request_id,
                                                exception=asyncio.CancelledError,
                                                verbose=self.log_requests)
    
        async def get_model_config(self) -> ModelConfig:
            """Get the model configuration of the vLLM engine."""
            return self.engine.get_model_config()
    
        async def get_parallel_config(self) -> ParallelConfig:
            """Get the parallel configuration of the vLLM engine."""
            return self.engine.get_parallel_config()
    
        async def get_decoding_config(self) -> DecodingConfig:
            """Get the decoding configuration of the vLLM engine."""
            return self.engine.get_decoding_config()
    
        async def get_scheduler_config(self) -> SchedulerConfig:
            """Get the scheduling configuration of the vLLM engine."""
            return self.engine.get_scheduler_config()
    
        async def get_lora_config(self) -> LoRAConfig:
            """Get the lora configuration of the vLLM engine."""
            return self.engine.get_lora_config()
    
        async def do_log_stats(
                self,
                scheduler_outputs: Optional[SchedulerOutputs] = None,
                model_output: Optional[List[SamplerOutput]] = None) -> None:
            self.engine.do_log_stats()
    
        async def check_health(self) -> None:
            """Raises an error if engine is unhealthy."""
            t = time.perf_counter()
            logger.debug("Starting health check...")
            if self.is_stopped:
                raise AsyncEngineDeadError("Background loop is stopped.")
    
            await self.engine.check_health_async()
            logger.debug("Health check took %fs", time.perf_counter() - t)
    
        async def is_tracing_enabled(self) -> bool:
            return self.engine.is_tracing_enabled()
    
        def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
            self.engine.add_logger(logger_name=logger_name, logger=logger)
    
        def remove_logger(self, logger_name: str) -> None:
            self.engine.remove_logger(logger_name=logger_name)
    
        async def start_profile(self) -> None:
            # using type instead of isinstance to check to avoid capturing
            # inherited classes
            if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
                self.engine.model_executor.start_profile()
            else:
                self.engine.model_executor._run_workers("start_profile")
    
        async def stop_profile(self) -> None:
            # using type instead of isinstance to check to avoid capturing
            # inherited classes
            if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
                self.engine.model_executor.stop_profile()
            else:
                self.engine.model_executor._run_workers("stop_profile")
    
  • cover/vllm/engine/llm_engine.py:识别昇腾环境后增加对RayNPUExecutor的选择。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of codes in this file was copied from project [vLLM Team][vllm]
    
    import time
    from collections import deque
    from contextlib import contextmanager
    from dataclasses import dataclass
    from functools import partial
    from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
                        Iterable, List, Mapping, NamedTuple, Optional)
    from typing import Sequence as GenericSequence
    from typing import Set, Type, Union
    
    import torch
    from typing_extensions import TypeVar
    
    import vllm.envs as envs
    from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
                             EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                             ObservabilityConfig, ParallelConfig,
                             PromptAdapterConfig, SchedulerConfig,
                             SpeculativeConfig)
    from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
                                     SchedulerOutputs)
    from vllm.engine.arg_utils import EngineArgs
    from vllm.engine.metrics_types import StatLoggerBase, Stats
    from vllm.engine.output_processor.interfaces import (
        SequenceGroupOutputProcessor)
    from vllm.engine.output_processor.stop_checker import StopChecker
    from vllm.engine.output_processor.util import create_output_by_sequence_group
    from vllm.executor.executor_base import ExecutorBase
    from vllm.executor.gpu_executor import GPUExecutor
    from vllm.executor.ray_utils import initialize_ray_cluster
    from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
                             InputRegistry, LLMInputs, PromptInputs)
    from vllm.inputs.preprocess import InputPreprocessor
    from vllm.logger import init_logger
    from vllm.lora.request import LoRARequest
    from vllm.model_executor.layers.sampler import SamplerOutput
    from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
                              RequestOutputFactory)
    from vllm.pooling_params import PoolingParams
    from vllm.prompt_adapter.request import PromptAdapterRequest
    from vllm.sampling_params import RequestOutputKind, SamplingParams
    from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
                               Sequence, SequenceGroup, SequenceGroupMetadata,
                               SequenceStatus)
    from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
                              init_tracer)
    from vllm.transformers_utils.config import try_get_generation_config
    from vllm.transformers_utils.detokenizer import Detokenizer
    from vllm.transformers_utils.tokenizer import AnyTokenizer
    from vllm.transformers_utils.tokenizer_group import (
        BaseTokenizerGroup, init_tokenizer_from_configs)
    from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                      usage_message)
    from vllm.utils import Counter, Device, weak_bind
    from vllm.version import __version__ as VLLM_VERSION
    
    logger = init_logger(__name__)
    _LOCAL_LOGGING_INTERVAL_SEC = 5
    
    
    def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
        config = try_get_generation_config(
            model_config.model,
            trust_remote_code=model_config.trust_remote_code,
            revision=model_config.revision,
        )
    
        if config is None:
            return {}
    
        return config.to_diff_dict()
    
    
    _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
    _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
    
    
    @dataclass
    class SchedulerOutputState:
        """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
        scheduler_outputs: Optional[SchedulerOutputs] = None
        allow_async_output_proc: bool = False
        last_output: Optional[SamplerOutput] = None
    
    
    class OutputData(NamedTuple):
        outputs: List[SamplerOutput]
        seq_group_metadata_list: List[SequenceGroupMetadata]
        scheduler_outputs: SchedulerOutputs
        is_async: bool
        is_last_step: bool
        skip: List[int]
    
    
    class SchedulerContext:
    
        def __init__(self, multi_step_stream_outputs: bool = False):
            self.output_queue: Deque[OutputData] = deque()
            self.request_outputs: List[Union[RequestOutput,
                                             EmbeddingRequestOutput]] = []
            self.seq_group_metadata_list: Optional[
                List[SequenceGroupMetadata]] = None
            self.scheduler_outputs: Optional[SchedulerOutputs] = None
    
            self.multi_step_stream_outputs: bool = multi_step_stream_outputs
    
        def append_output(self, outputs: List[SamplerOutput],
                          seq_group_metadata_list: List[SequenceGroupMetadata],
                          scheduler_outputs: SchedulerOutputs, is_async: bool,
                          is_last_step: bool):
            self.output_queue.append(
                OutputData(outputs=outputs,
                           seq_group_metadata_list=seq_group_metadata_list,
                           scheduler_outputs=scheduler_outputs,
                           is_async=is_async,
                           is_last_step=is_last_step,
                           skip=[]))
    
    
    class LLMEngine:
        """An LLM engine that receives requests and generates texts.
    
        This is the main class for the vLLM engine. It receives requests
        from clients and generates texts from the LLM. It includes a tokenizer, a
        language model (possibly distributed across multiple GPUs), and GPU memory
        space allocated for intermediate states (aka KV cache). This class utilizes
        iteration-level scheduling and efficient memory management to maximize the
        serving throughput.
    
        The :class:`~vllm.LLM` class wraps this class for offline batched inference
        and the :class:`AsyncLLMEngine` class wraps this class for online serving.
    
        The config arguments are derived from :class:`~vllm.EngineArgs`. (See
        :ref:`engine_args`)
    
        Args:
            model_config: The configuration related to the LLM model.
            cache_config: The configuration related to the KV cache memory
                management.
            parallel_config: The configuration related to distributed execution.
            scheduler_config: The configuration related to the request scheduler.
            device_config: The configuration related to the device.
            lora_config (Optional): The configuration related to serving multi-LoRA.
            speculative_config (Optional): The configuration related to speculative
                decoding.
            executor_class: The model executor class for managing distributed
                execution.
            prompt_adapter_config (Optional): The configuration related to serving
                prompt adapters.
            log_stats: Whether to log statistics.
            usage_context: Specified entry point, used for usage info collection.
        """
    
        DO_VALIDATE_OUTPUT: ClassVar[bool] = False
        """A flag to toggle whether to validate the type of request output."""
    
        tokenizer: Optional[BaseTokenizerGroup]
    
        def __init__(
            self,
            model_config: ModelConfig,
            cache_config: CacheConfig,
            parallel_config: ParallelConfig,
            scheduler_config: SchedulerConfig,
            device_config: DeviceConfig,
            load_config: LoadConfig,
            lora_config: Optional[LoRAConfig],
            speculative_config: Optional[SpeculativeConfig],
            decoding_config: Optional[DecodingConfig],
            observability_config: Optional[ObservabilityConfig],
            prompt_adapter_config: Optional[PromptAdapterConfig],
            executor_class: Type[ExecutorBase],
            log_stats: bool,
            usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
            stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
            input_registry: InputRegistry = INPUT_REGISTRY,
            use_cached_outputs: bool = False,
        ) -> None:
            logger.info(
                "Initializing an LLM engine (v%s) with config: "
                "model=%r, speculative_config=%r, tokenizer=%r, "
                "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
                "override_neuron_config=%s, "
                "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
                "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
                "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
                "pipeline_parallel_size=%d, "
                "disable_custom_all_reduce=%s, quantization=%s, "
                "enforce_eager=%s, kv_cache_dtype=%s, "
                "quantization_param_path=%s, device_config=%s, "
                "decoding_config=%r, observability_config=%r, "
                "seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
                "num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
                "enable_prefix_caching=%s, use_async_output_proc=%s, "
                "use_cached_outputs=%s, mm_processor_kwargs=%s)",
                VLLM_VERSION,
                model_config.model,
                speculative_config,
                model_config.tokenizer,
                model_config.skip_tokenizer_init,
                model_config.tokenizer_mode,
                model_config.revision,
                model_config.override_neuron_config,
                model_config.rope_scaling,
                model_config.rope_theta,
                model_config.tokenizer_revision,
                model_config.trust_remote_code,
                model_config.dtype,
                model_config.max_model_len,
                load_config.download_dir,
                load_config.load_format,
                parallel_config.tensor_parallel_size,
                parallel_config.pipeline_parallel_size,
                parallel_config.disable_custom_all_reduce,
                model_config.quantization,
                model_config.enforce_eager,
                cache_config.cache_dtype,
                model_config.quantization_param_path,
                device_config.device,
                decoding_config,
                observability_config,
                model_config.seed,
                model_config.served_model_name,
                scheduler_config.use_v2_block_manager,
                scheduler_config.num_scheduler_steps,
                scheduler_config.multi_step_stream_outputs,
                cache_config.enable_prefix_caching,
                model_config.use_async_output_proc,
                use_cached_outputs,
                model_config.mm_processor_kwargs,
            )
            # TODO(woosuk): Print more configs in debug mode.
            from vllm.plugins import load_general_plugins
            load_general_plugins()
    
            self.model_config = model_config
            self.cache_config = cache_config
            self.lora_config = lora_config
            self.parallel_config = parallel_config
            self.scheduler_config = scheduler_config
            self.device_config = device_config
            self.speculative_config = speculative_config
            self.load_config = load_config
            self.decoding_config = decoding_config or DecodingConfig()
            self.prompt_adapter_config = prompt_adapter_config
            self.observability_config = observability_config or ObservabilityConfig(
            )
            self.log_stats = log_stats
            self.use_cached_outputs = use_cached_outputs
    
            if not self.model_config.skip_tokenizer_init:
                self.tokenizer = self._init_tokenizer()
                self.detokenizer = Detokenizer(self.tokenizer)
                tokenizer_group = self.get_tokenizer_group()
            else:
                self.tokenizer = None
                self.detokenizer = None
                tokenizer_group = None
    
            # Ensure that the function doesn't contain a reference to self,
            # to avoid engine GC issues
            def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
                assert tokenizer_group, ("tokenizer_group cannot be None, "
                                         "make sure skip_tokenizer_init is False")
                return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
    
            self.seq_counter = Counter()
            self.generation_config_fields = _load_generation_config_dict(
                model_config)
    
            self.input_preprocessor = InputPreprocessor(model_config,
                                                        self.tokenizer)
    
            self.input_registry = input_registry
            self.input_processor = input_registry.create_input_processor(
                model_config)
    
            self.model_executor = executor_class(
                model_config=model_config,
                cache_config=cache_config,
                parallel_config=parallel_config,
                scheduler_config=scheduler_config,
                device_config=device_config,
                lora_config=lora_config,
                speculative_config=speculative_config,
                load_config=load_config,
                prompt_adapter_config=prompt_adapter_config,
                observability_config=self.observability_config,
            )
    
            if not self.model_config.embedding_mode:
                self._initialize_kv_caches()
    
            # If usage stat is enabled, collect relevant info.
            if is_usage_stats_enabled():
                from vllm.model_executor.model_loader import (
                    get_architecture_class_name)
                usage_message.report_usage(
                    get_architecture_class_name(model_config),
                    usage_context,
                    extra_kvs={
                        # Common configuration
                        "dtype":
                        str(model_config.dtype),
                        "tensor_parallel_size":
                        parallel_config.tensor_parallel_size,
                        "block_size":
                        cache_config.block_size,
                        "gpu_memory_utilization":
                        cache_config.gpu_memory_utilization,
    
                        # Quantization
                        "quantization":
                        model_config.quantization,
                        "kv_cache_dtype":
                        str(cache_config.cache_dtype),
    
                        # Feature flags
                        "enable_lora":
                        bool(lora_config),
                        "enable_prompt_adapter":
                        bool(prompt_adapter_config),
                        "enable_prefix_caching":
                        cache_config.enable_prefix_caching,
                        "enforce_eager":
                        model_config.enforce_eager,
                        "disable_custom_all_reduce":
                        parallel_config.disable_custom_all_reduce,
                    })
    
            if self.tokenizer:
                # Ping the tokenizer to ensure liveness if it runs in a
                # different process.
                self.tokenizer.ping()
    
            self.cached_scheduler_outputs = [
                SchedulerOutputState()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]
    
            self.scheduler_contexts = [
                SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
                                 multi_step_stream_outputs)
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]
    
            if model_config.use_async_output_proc:
                process_model_outputs = weak_bind(self._process_model_outputs)
    
                self.async_callbacks = [
                    partial(process_model_outputs,
                            ctx=self.scheduler_contexts[v_id])
                    for v_id in range(self.parallel_config.pipeline_parallel_size)
                ]
            else:
                self.async_callbacks = []
    
            # Currently used by AsyncLLMEngine to ensure quick append
            # of request outputs to asyncio queues
            self.process_request_outputs_callback: Optional[Callable] = None
    
            # Create the scheduler.
            # NOTE: the cache_config here have been updated with the numbers of
            # GPU and CPU blocks, which are profiled in the distributed executor.
            self.scheduler = [
                Scheduler(
                    scheduler_config, cache_config, lora_config,
                    parallel_config.pipeline_parallel_size,
                    self.async_callbacks[v_id]
                    if model_config.use_async_output_proc else None)
                for v_id in range(parallel_config.pipeline_parallel_size)
            ]
    
            # Metric Logging.
            if self.log_stats:
                if stat_loggers is not None:
                    self.stat_loggers = stat_loggers
                else:
                    # Lazy import for prometheus multiprocessing.
                    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
                    # before prometheus_client is imported.
                    # See https://prometheus.github.io/client_python/multiprocess/
                    from vllm.engine.metrics import (LoggingStatLogger,
                                                     PrometheusStatLogger)
    
                    self.stat_loggers = {
                        "logging":
                        LoggingStatLogger(
                            local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
                        "prometheus":
                        PrometheusStatLogger(
                            local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                            labels=dict(model_name=model_config.served_model_name),
                            max_model_len=self.model_config.max_model_len),
                    }
                    self.stat_loggers["prometheus"].info("cache_config",
                                                         self.cache_config)
    
            self.tracer = None
            if self.observability_config.otlp_traces_endpoint:
                self.tracer = init_tracer(
                    "vllm.llm_engine",
                    self.observability_config.otlp_traces_endpoint)
    
            # Create sequence output processor, e.g. for beam search or
            # speculative decoding.
            self.output_processor = (
                SequenceGroupOutputProcessor.create_output_processor(
                    self.scheduler_config,
                    self.detokenizer,
                    self.scheduler,
                    self.seq_counter,
                    get_tokenizer_for_seq,
                    stop_checker=StopChecker(
                        self.scheduler_config.max_model_len,
                        get_tokenizer_for_seq,
                    ),
                ))
    
        def __reduce__(self):
            # This is to ensure that the LLMEngine is not referenced in
            # the closure used to initialize Ray worker actors
            raise RuntimeError("LLMEngine should not be pickled!")
    
        def __del__(self):
            # Shutdown model executor when engine is garbage collected
            # Use getattr since __init__ can fail before the field is set
            if model_executor := getattr(self, "model_executor", None):
                model_executor.shutdown()
    
        @staticmethod
        def _process_sequence_group_outputs(
            seq_group: SequenceGroup,
            outputs: List[EmbeddingSequenceGroupOutput],
        ) -> None:
            seq_group.embeddings = outputs[0].embeddings
    
            for seq in seq_group.get_seqs():
                seq.status = SequenceStatus.FINISHED_STOPPED
    
            return
    
        @classmethod
        @contextmanager
        def enable_output_validation(cls):
            cls.DO_VALIDATE_OUTPUT = True
    
            yield
    
            cls.DO_VALIDATE_OUTPUT = False
    
        @classmethod
        def validate_output(
            cls,
            output: object,
            output_type: Type[_O],
        ) -> _O:
            do_validate = cls.DO_VALIDATE_OUTPUT
    
            if ((TYPE_CHECKING or do_validate)
                    and not isinstance(output, output_type)):
                raise TypeError(f"Expected output of type {output_type}, "
                                f"but found type {type(output)}")
    
            return output
    
        @classmethod
        def validate_outputs(
            cls,
            outputs: GenericSequence[object],
            output_type: Type[_O],
        ) -> List[_O]:
            do_validate = cls.DO_VALIDATE_OUTPUT
    
            outputs_: List[_O]
            if TYPE_CHECKING or do_validate:
                outputs_ = []
                for output in outputs:
                    if not isinstance(output, output_type):
                        raise TypeError(f"Expected output of type {output_type}, "
                                        f"but found type {type(output)}")
    
                    outputs_.append(output)
            else:
                outputs_ = outputs
    
            return outputs_
    
        @classmethod
        def _get_executor_cls(cls,
                              engine_config: EngineConfig) -> Type[ExecutorBase]:
            distributed_executor_backend = (
                engine_config.parallel_config.distributed_executor_backend)
            # Initialize the cluster and specify the executor class.
            if isinstance(distributed_executor_backend, type):
                if not issubclass(distributed_executor_backend, ExecutorBase):
                    raise TypeError(
                        "distributed_executor_backend must be a subclass of "
                        f"ExecutorBase. Got {distributed_executor_backend}.")
                if distributed_executor_backend.uses_ray:  # type: ignore
                    initialize_ray_cluster(engine_config.parallel_config)
                executor_class = distributed_executor_backend
            elif engine_config.device_config.device_type == "neuron":
                from vllm.executor.neuron_executor import NeuronExecutor
                executor_class = NeuronExecutor
            elif engine_config.device_config.device_type == "tpu":
                if distributed_executor_backend == "ray":
                    initialize_ray_cluster(engine_config.parallel_config)
                    from vllm.executor.ray_tpu_executor import RayTPUExecutor
                    executor_class = RayTPUExecutor
                else:
                    assert distributed_executor_backend is None
                    from vllm.executor.tpu_executor import TPUExecutor
                    executor_class = TPUExecutor
            elif engine_config.device_config.device_type == "cpu":
                from vllm.executor.cpu_executor import CPUExecutor
                executor_class = CPUExecutor
            elif engine_config.device_config.device_type == "openvino":
                from vllm.executor.openvino_executor import OpenVINOExecutor
                executor_class = OpenVINOExecutor
            elif engine_config.device_config.device_type == "xpu":
                if distributed_executor_backend == "ray":
                    initialize_ray_cluster(engine_config.parallel_config)
                    from vllm.executor.ray_xpu_executor import RayXPUExecutor
                    executor_class = RayXPUExecutor
                elif distributed_executor_backend == "mp":
                    # FIXME(kunshang):
                    # spawn needs calling `if __name__ == '__main__':``
                    # fork is not supported for xpu start new process.
                    logger.error(
                        "Both start methods (spawn and fork) have issue "
                        "on XPU if you use mp backend, Please try ray instead.")
                else:
                    from vllm.executor.xpu_executor import XPUExecutor
                    executor_class = XPUExecutor
            elif engine_config.device_config.device_type == "npu":
                if engine_config.parallel_config.use_ray:
                    initialize_ray_cluster(engine_config.parallel_config)
                    from vllm.executor.ray_npu_executor import RayNPUExecutor
                    executor_class = RayNPUExecutor
                else:
                    from vllm.executor.npu_executor import NPUExecutor
                    executor_class = NPUExecutor
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_gpu_executor import RayGPUExecutor
                executor_class = RayGPUExecutor
            elif distributed_executor_backend == "mp":
                from vllm.executor.multiproc_gpu_executor import (
                    MultiprocessingGPUExecutor)
                assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
                    "multiprocessing distributed executor backend does not "
                    "support VLLM_USE_RAY_SPMD_WORKER=1")
                executor_class = MultiprocessingGPUExecutor
            else:
                from vllm.executor.gpu_executor import GPUExecutor
                executor_class = GPUExecutor
            return executor_class
    
        @classmethod
        def from_engine_args(
            cls,
            engine_args: EngineArgs,
            usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
            stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        ) -> "LLMEngine":
            """Creates an LLM engine from the engine arguments."""
            # Create the engine configs.
            engine_config = engine_args.create_engine_config()
            executor_class = cls._get_executor_cls(engine_config)
            # Create the LLM engine.
            engine = cls(
                **engine_config.to_dict(),
                executor_class=executor_class,
                log_stats=not engine_args.disable_log_stats,
                usage_context=usage_context,
                stat_loggers=stat_loggers,
            )
    
            return engine
    
        def get_tokenizer_group(
            self,
            group_type: Type[_G] = BaseTokenizerGroup,
        ) -> _G:
            tokenizer_group = self.tokenizer
    
            if tokenizer_group is None:
                raise ValueError("Unable to get tokenizer because "
                                 "skip_tokenizer_init is True")
            if not isinstance(tokenizer_group, group_type):
                raise TypeError("Invalid type of tokenizer group. "
                                f"Expected type: {group_type}, but "
                                f"found type: {type(tokenizer_group)}")
    
            return tokenizer_group
    
        def get_tokenizer(
            self,
            lora_request: Optional[LoRARequest] = None,
        ) -> AnyTokenizer:
            return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
    
        def stop_remote_worker_execution_loop(self) -> None:
            self.model_executor.stop_remote_worker_execution_loop()
    
        def add_request(
            self,
            request_id: str,
            inputs: PromptInputs,
            params: Union[SamplingParams, PoolingParams],
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
            priority: int = 0,
        ) -> None:
            """Add a request to the engine's request pool.
    
            The request is added to the request pool and will be processed by the
            scheduler as `engine.step()` is called. The exact scheduling policy is
            determined by the scheduler.
    
            Args:
                request_id: The unique ID of the request.
                inputs: The inputs to the LLM. See
                    :class:`~vllm.inputs.PromptInputs`
                    for more details about the format of each input.
                params: Parameters for sampling or pooling.
                    :class:`~vllm.SamplingParams` for text generation.
                    :class:`~vllm.PoolingParams` for pooling.
                arrival_time: The arrival time of the request. If None, we use
                    the current monotonic time.
                trace_headers: OpenTelemetry trace headers.
                priority: The priority of the request.
                    Only applicable with priority scheduling.
    
            Details:
                - Set arrival_time to the current time if it is None.
                - Set prompt_token_ids to the encoded prompt if it is None.
                - Create `best_of` number of :class:`~vllm.Sequence` objects.
                - Create a :class:`~vllm.SequenceGroup` object
                  from the list of :class:`~vllm.Sequence`.
                - Add the :class:`~vllm.SequenceGroup` object to the scheduler.
    
            Example:
                >>> # initialize engine
                >>> engine = LLMEngine.from_engine_args(engine_args)
                >>> # set request arguments
                >>> example_prompt = "Who is the president of the United States?"
                >>> sampling_params = SamplingParams(temperature=0.0)
                >>> request_id = 0
                >>>
                >>> # add the request to the engine
                >>> engine.add_request(
                >>>    str(request_id),
                >>>    example_prompt,
                >>>    SamplingParams(temperature=0.0))
                >>> # continue the request processing
                >>> ...
            """
            if lora_request is not None and not self.lora_config:
                raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                                 "not enabled!")
    
            if priority > 0 and not self.scheduler_config.policy == "priority":
                raise ValueError(f"Got priority {priority} but "
                                 "Priority scheduling is not enabled.")
    
            if arrival_time is None:
                arrival_time = time.time()
    
            preprocessed_inputs = self.input_preprocessor.preprocess(
                inputs,
                request_id=request_id,
                lora_request=lora_request,
                prompt_adapter_request=prompt_adapter_request,
            )
            processed_inputs = self.input_processor(preprocessed_inputs)
    
            self._add_processed_request(
                request_id=request_id,
                processed_inputs=processed_inputs,
                params=params,
                arrival_time=arrival_time,
                lora_request=lora_request,
                prompt_adapter_request=prompt_adapter_request,
                trace_headers=trace_headers,
                priority=priority,
            )
    
        def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
            """Aborts a request(s) with the given ID.
    
            Args:
                request_id: The ID(s) of the request to abort.
    
            Details:
                - Refer to the
                  :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
                  from class :class:`~vllm.core.scheduler.Scheduler`.
    
            Example:
                >>> # initialize engine and add a request with request_id
                >>> request_id = str(0)
                >>> # abort the request
                >>> engine.abort_request(request_id)
            """
            for scheduler in self.scheduler:
                scheduler.abort_seq_group(request_id)
    
        def get_model_config(self) -> ModelConfig:
            """Gets the model configuration."""
            return self.model_config
    
        def get_parallel_config(self) -> ParallelConfig:
            """Gets the parallel configuration."""
            return self.parallel_config
    
        def get_decoding_config(self) -> DecodingConfig:
            """Gets the decoding configuration."""
            return self.decoding_config
    
        def get_scheduler_config(self) -> SchedulerConfig:
            """Gets the scheduler configuration."""
            return self.scheduler_config
    
        def get_lora_config(self) -> LoRAConfig:
            """Gets the LoRA configuration."""
            return self.lora_config
    
        def get_num_unfinished_requests(self) -> int:
            """Gets the number of unfinished requests."""
            return sum(scheduler.get_num_unfinished_seq_groups()
                       for scheduler in self.scheduler)
    
        def has_unfinished_requests(self) -> bool:
            """Returns True if there are unfinished requests."""
            return any(scheduler.has_unfinished_seqs()
                       for scheduler in self.scheduler)
    
        def has_unfinished_requests_for_virtual_engine(
                self, virtual_engine: int) -> bool:
            """
            Returns True if there are unfinished requests for the virtual engine.
            """
            return self.scheduler[virtual_engine].has_unfinished_seqs()
    
        def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
            """Performs one decoding iteration and returns newly generated results.
    
            .. figure:: https://i.imgur.com/sv2HssD.png
                :alt: Overview of the step function
                :align: center
    
                Overview of the step function.
    
            Details:
                - Step 1: Schedules the sequences to be executed in the next
                  iteration and the token blocks to be swapped in/out/copy.
    
                    - Depending on the scheduling policy,
                      sequences may be `preempted/reordered`.
                    - A Sequence Group (SG) refer to a group of sequences
                      that are generated from the same prompt.
    
                - Step 2: Calls the distributed executor to execute the model.
                - Step 3: Processes the model output. This mainly includes:
    
                    - Decodes the relevant outputs.
                    - Updates the scheduled sequence groups with model outputs
                      based on its `sampling parameters` (`use_beam_search` or not).
                    - Frees the finished sequence groups.
    
                - Finally, it creates and returns the newly generated results.
    
            Example:
                >>> # Please see the example/ folder for more detailed examples.
                >>>
                >>> # initialize engine and request arguments
                >>> engine = LLMEngine.from_engine_args(engine_args)
                >>> example_inputs = [(0, "What is LLM?",
                >>>    SamplingParams(temperature=0.0))]
                >>>
                >>> # Start the engine with an event loop
                >>> while True:
                >>>     if example_inputs:
                >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
                >>>         engine.add_request(str(req_id),prompt,sampling_params)
                >>>
                >>>     # continue the request processing
                >>>     request_outputs = engine.step()
                >>>     for request_output in request_outputs:
                >>>         if request_output.finished:
                >>>             # return or show the request output
                >>>
                >>>     if not (engine.has_unfinished_requests() or example_inputs):
                >>>         break
            """
            if self.parallel_config.pipeline_parallel_size > 1:
                raise NotImplementedError(
                    "Pipeline parallelism is only supported through AsyncLLMEngine "
                    "as performance will be severely degraded otherwise.")
    
            # For llm_engine, there is no pipeline parallel support, so the engine
            # used is always 0.
            virtual_engine = 0
    
            # These are cached outputs from previous iterations. None if on first
            # iteration
            cached_outputs = self.cached_scheduler_outputs[virtual_engine]
            seq_group_metadata_list = cached_outputs.seq_group_metadata_list
            scheduler_outputs = cached_outputs.scheduler_outputs
            allow_async_output_proc = cached_outputs.allow_async_output_proc
    
            ctx = self.scheduler_contexts[virtual_engine]
    
            # Clear outputs for each new scheduler iteration
            ctx.request_outputs.clear()
    
            # Skip the scheduler if there are any remaining steps in the seq groups.
            # This ensures that the scheduler is only called again when the current
            # batch has completed.
            if not self._has_remaining_steps(seq_group_metadata_list):
                # Schedule iteration
                (seq_group_metadata_list, scheduler_outputs,
                 allow_async_output_proc
                 ) = self.scheduler[virtual_engine].schedule()
    
                ctx.seq_group_metadata_list = seq_group_metadata_list
                ctx.scheduler_outputs = scheduler_outputs
    
                # Maybe switch from async mode to sync mode
                if not allow_async_output_proc and len(ctx.output_queue) > 0:
                    self._process_model_outputs(ctx=ctx)
    
                if (self.scheduler_config.is_multi_step
                        and scheduler_outputs.num_lookahead_slots > 0):
                    # cache the scheduler outputs for the next iteration if we have
                    # lookahead slots
                    self._cache_scheduler_outputs_for_multi_step(
                        virtual_engine, seq_group_metadata_list, scheduler_outputs,
                        allow_async_output_proc)
    
            assert seq_group_metadata_list is not None
            assert scheduler_outputs is not None
    
            if not scheduler_outputs.is_empty():
                finished_requests_ids = self.scheduler[
                    virtual_engine].get_and_reset_finished_requests_ids()
    
                # Check if we have a cached last_output from the previous iteration.
                # For supporting PP this is probably the best way to pass the
                # sampled_token_ids, as a separate broadcast over all the PP stages
                # will cause one virtual engine's microbatch to block the pipeline.
                last_sampled_token_ids = \
                    self._get_last_sampled_token_ids(virtual_engine)
    
                execute_model_req = ExecuteModelRequest(
                    seq_group_metadata_list=seq_group_metadata_list,
                    blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                    blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                    blocks_to_copy=scheduler_outputs.blocks_to_copy,
                    num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                    running_queue_size=scheduler_outputs.running_queue_size,
                    finished_requests_ids=finished_requests_ids,
                    # We use ExecuteModelRequest to pass the last sampled_token_ids
                    # to each of the non-last PP stages for in-place prepare_input.
                    last_sampled_token_ids=last_sampled_token_ids)
    
                if allow_async_output_proc:
                    execute_model_req.async_callback = self.async_callbacks[
                        virtual_engine]
    
                outputs = self.model_executor.execute_model(
                    execute_model_req=execute_model_req)
    
                # We need to do this here so that last step's sampled_token_ids can
                # be passed to the next iteration for PP.
                if self.scheduler_config.is_multi_step:
                    self._update_cached_scheduler_output(virtual_engine, outputs)
            else:
                # Nothing scheduled => If there is pending async postprocessor,
                # then finish it here.
                if len(ctx.output_queue) > 0:
                    self._process_model_outputs(ctx=ctx)
                # No outputs in this case
                outputs = []
    
            # Finish the current step for all the sequence groups.
            if self.scheduler_config.is_multi_step:
                for seq_group in seq_group_metadata_list:
                    seq_group.finish_step()
    
            if not self._has_remaining_steps(seq_group_metadata_list):
                # clear the cache if we have finished all the steps.
                if self.scheduler_config.is_multi_step:
                    self.cached_scheduler_outputs[0] = SchedulerOutputState()
    
                # Add results to the output_queue
                ctx.append_output(outputs=outputs,
                                  seq_group_metadata_list=seq_group_metadata_list,
                                  scheduler_outputs=scheduler_outputs,
                                  is_async=allow_async_output_proc,
                                  is_last_step=True)
    
                if outputs and allow_async_output_proc:
                    assert len(outputs) == 1, (
                        "Async postprocessor expects only a single output set")
    
                    self._advance_to_next_step(
                        outputs[0], seq_group_metadata_list,
                        scheduler_outputs.scheduled_seq_groups)
    
                # Check if need to run the usual non-async path
                if not allow_async_output_proc:
                    self._process_model_outputs(ctx=ctx)
    
                    # Log stats.
                    self.do_log_stats(scheduler_outputs, outputs)
    
                    # Tracing
                    self.do_tracing(scheduler_outputs)
            else:
                # Multi-step case
                return ctx.request_outputs
    
            if not self.has_unfinished_requests():
                # Drain async postprocessor (if exists)
                if len(ctx.output_queue) > 0:
                    self._process_model_outputs(ctx=ctx)
                assert len(ctx.output_queue) == 0
    
                # Stop the execute model loop in parallel workers until there are
                # more requests to process. This avoids waiting indefinitely in
                # torch.distributed ops which may otherwise timeout, and unblocks
                # the RPC thread in the workers so that they can process any other
                # queued control plane messages, such as add/remove lora adapters.
                logger.debug("Stopping remote worker execution loop.")
                self.model_executor.stop_remote_worker_execution_loop()
    
            return ctx.request_outputs
    
        def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
            if not self.log_stats:
                raise RuntimeError(
                    "Stat logging is disabled. Set `disable_log_stats=False` "
                    "argument to enable.")
            if logger_name in self.stat_loggers:
                raise KeyError(f"Logger with name {logger_name} already exists.")
            self.stat_loggers[logger_name] = logger
    
        def remove_logger(self, logger_name: str) -> None:
            if not self.log_stats:
                raise RuntimeError(
                    "Stat logging is disabled. Set `disable_log_stats=False` "
                    "argument to enable.")
            if logger_name not in self.stat_loggers:
                raise KeyError(f"Logger with name {logger_name} does not exist.")
            del self.stat_loggers[logger_name]
    
        def do_log_stats(self,
                         scheduler_outputs: Optional[SchedulerOutputs] = None,
                         model_output: Optional[List[SamplerOutput]] = None,
                         finished_before: Optional[List[int]] = None,
                         skip: Optional[List[int]] = None) -> None:
            """Forced log when no requests active."""
            if self.log_stats:
                stats = self._get_stats(scheduler_outputs, model_output,
                                        finished_before, skip)
                for logger in self.stat_loggers.values():
                    logger.log(stats)
    
        def add_lora(self, lora_request: LoRARequest) -> bool:
            return self.model_executor.add_lora(lora_request)
    
        def remove_lora(self, lora_id: int) -> bool:
            return self.model_executor.remove_lora(lora_id)
    
        def list_loras(self) -> Set[int]:
            return self.model_executor.list_loras()
    
        def pin_lora(self, lora_id: int) -> bool:
            return self.model_executor.pin_lora(lora_id)
    
        def add_prompt_adapter(
                self, prompt_adapter_request: PromptAdapterRequest) -> bool:
            return self.model_executor.add_prompt_adapter(prompt_adapter_request)
    
        def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
            return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
    
        def list_prompt_adapters(self) -> List[int]:
            return self.model_executor.list_prompt_adapters()
    
        def check_health(self) -> None:
            if self.tokenizer:
                self.tokenizer.check_health()
            self.model_executor.check_health()
    
        def start_profile(self) -> None:
            # using type instead of isinstance to check to avoid capturing
            # inherited classes (MultiprocessingGPUExecutor)
            if type(self.model_executor) == GPUExecutor:  # noqa: E721
                self.model_executor.start_profile()
            else:
                self.model_executor._run_workers("start_profile")
    
        def stop_profile(self) -> None:
            # using type instead of isinstance to check to avoid capturing
            # inherited classes (MultiprocessingGPUExecutor)
            if type(self.model_executor) == GPUExecutor:  # noqa: E721
                self.model_executor.stop_profile()
            else:
                self.model_executor._run_workers("stop_profile")
    
        def is_tracing_enabled(self) -> bool:
            return self.tracer is not None
    
        def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None:
            if self.tracer is None:
                return
    
            for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
                seq_group = scheduled_seq_group.seq_group
                if seq_group.is_finished():
                    self.create_trace_span(seq_group)
    
        def create_trace_span(self, seq_group: SequenceGroup) -> None:
            if self.tracer is None or seq_group.sampling_params is None:
                return
            arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)
    
            trace_context = extract_trace_context(seq_group.trace_headers)
    
            with self.tracer.start_as_current_span(
                    "llm_request",
                    kind=SpanKind.SERVER,
                    context=trace_context,
                    start_time=arrival_time_nano_seconds) as seq_span:
                metrics = seq_group.metrics
                ttft = metrics.first_token_time - metrics.arrival_time
                e2e_time = metrics.finished_time - metrics.arrival_time
                # attribute names are based on
                # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
                seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
                                       self.model_config.model)
                seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
                                       seq_group.request_id)
                seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
                                       seq_group.sampling_params.temperature)
                seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
                                       seq_group.sampling_params.top_p)
                seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
                                       seq_group.sampling_params.max_tokens)
                seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
                                       seq_group.sampling_params.best_of)
                seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
                                       seq_group.sampling_params.n)
                seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
                                       seq_group.num_seqs())
                seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
                                       len(seq_group.prompt_token_ids))
                seq_span.set_attribute(
                    SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
                    sum([
                        seq.get_output_len()
                        for seq in seq_group.get_finished_seqs()
                    ]))
                seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
                                       metrics.time_in_queue)
                seq_span.set_attribute(
                    SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
                seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
                if metrics.scheduler_time is not None:
                    seq_span.set_attribute(
                        SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER,
                        metrics.scheduler_time)
                if metrics.model_forward_time is not None:
                    seq_span.set_attribute(
                        SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD,
                        metrics.model_forward_time / 1000.0)
                if metrics.model_execute_time is not None:
                    seq_span.set_attribute(
                        SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
                        metrics.model_execute_time)
    
        def is_encoder_decoder_model(self):
            return self.input_preprocessor.is_encoder_decoder_model()
    
        def is_embedding_model(self):
            return self.model_config.is_embedding_model
    
        def _init_tokenizer(self) -> BaseTokenizerGroup:
            return init_tokenizer_from_configs(
                model_config=self.model_config,
                scheduler_config=self.scheduler_config,
                parallel_config=self.parallel_config,
                enable_lora=bool(self.lora_config))
    
        def _verify_args(self) -> None:
            self.model_config.verify_with_parallel_config(self.parallel_config)
            self.cache_config.verify_with_parallel_config(self.parallel_config)
            if self.lora_config:
                self.lora_config.verify_with_model_config(self.model_config)
                self.lora_config.verify_with_scheduler_config(
                    self.scheduler_config)
            if self.prompt_adapter_config:
                self.prompt_adapter_config.verify_with_model_config(
                    self.model_config)
    
        def _add_processed_request(
            self,
            request_id: str,
            processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
            params: Union[SamplingParams, PoolingParams],
            arrival_time: float,
            lora_request: Optional[LoRARequest],
            prompt_adapter_request: Optional[PromptAdapterRequest],
            trace_headers: Optional[Mapping[str, str]] = None,
            priority: int = 0,
        ) -> None:
            self._validate_model_inputs(processed_inputs)
            # Create the sequences.
            block_size = self.cache_config.block_size
            seq_id = next(self.seq_counter)
            eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
    
            seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
                           lora_request, prompt_adapter_request)
    
            encoder_seq = None
            if 'encoder_prompt_token_ids' in processed_inputs:
                encoder_seq = Sequence(seq_id,
                                       processed_inputs,
                                       block_size,
                                       eos_token_id,
                                       lora_request,
                                       prompt_adapter_request,
                                       from_decoder_prompt=False)
    
            # Create a SequenceGroup based on SamplingParams or PoolingParams
            if isinstance(params, SamplingParams):
                seq_group = self._create_sequence_group_with_sampling(
                    request_id,
                    seq,
                    params,
                    arrival_time=arrival_time,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    prompt_adapter_request=prompt_adapter_request,
                    encoder_seq=encoder_seq,
                    priority=priority)
            elif isinstance(params, PoolingParams):
                seq_group = self._create_sequence_group_with_pooling(
                    request_id,
                    seq,
                    params,
                    arrival_time=arrival_time,
                    lora_request=lora_request,
                    prompt_adapter_request=prompt_adapter_request,
                    encoder_seq=encoder_seq,
                    priority=priority)
            else:
                raise ValueError(
                    "Either SamplingParams or PoolingParams must be provided.")
    
            # Add the sequence group to the scheduler with least unfinished seqs.
            costs = [
                scheduler.get_num_unfinished_seq_groups()
                for scheduler in self.scheduler
            ]
            min_cost_scheduler = self.scheduler[costs.index(min(costs))]
            min_cost_scheduler.add_seq_group(seq_group)
    
        def _initialize_kv_caches(self) -> None:
            """Initialize the KV cache in the worker(s).
    
            The workers will determine the number of blocks in both the GPU cache
            and the swap CPU cache.
            """
            num_gpu_blocks, num_cpu_blocks = (
                self.model_executor.determine_num_available_blocks())
    
            if self.cache_config.num_gpu_blocks_override is not None:
                num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
                logger.info(
                    "Overriding num_gpu_blocks=%d with "
                    "num_gpu_blocks_override=%d", num_gpu_blocks,
                    num_gpu_blocks_override)
                num_gpu_blocks = num_gpu_blocks_override
    
            self.cache_config.num_gpu_blocks = num_gpu_blocks
            self.cache_config.num_cpu_blocks = num_cpu_blocks
    
            self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
    
        def _create_sequence_group_with_sampling(
            self,
            request_id: str,
            seq: Sequence,
            sampling_params: SamplingParams,
            arrival_time: float,
            lora_request: Optional[LoRARequest],
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
            encoder_seq: Optional[Sequence] = None,
            priority: int = 0,
        ) -> SequenceGroup:
            """Creates a SequenceGroup with SamplingParams."""
            max_logprobs = self.get_model_config().max_logprobs
            if (sampling_params.logprobs
                    and sampling_params.logprobs > max_logprobs) or (
                        sampling_params.prompt_logprobs
                        and sampling_params.prompt_logprobs > max_logprobs):
                raise ValueError(f"Cannot request more than "
                                 f"{max_logprobs} logprobs.")
    
            # Defensive copy of SamplingParams, which are used by the sampler,
            # this doesn't deep-copy LogitsProcessor objects
            sampling_params = sampling_params.clone()
    
            sampling_params.update_from_generation_config(
                self.generation_config_fields, seq.eos_token_id)
    
            # Create the sequence group.
            seq_group = SequenceGroup(
                request_id=request_id,
                seqs=[seq],
                arrival_time=arrival_time,
                sampling_params=sampling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                encoder_seq=encoder_seq,
                priority=priority)
    
            return seq_group
    
        def _create_sequence_group_with_pooling(
            self,
            request_id: str,
            seq: Sequence,
            pooling_params: PoolingParams,
            arrival_time: float,
            lora_request: Optional[LoRARequest],
            prompt_adapter_request: Optional[PromptAdapterRequest],
            encoder_seq: Optional[Sequence] = None,
            priority: int = 0,
        ) -> SequenceGroup:
            """Creates a SequenceGroup with PoolingParams."""
            # Defensive copy of PoolingParams, which are used by the pooler
            pooling_params = pooling_params.clone()
            # Create the sequence group.
            seq_group = SequenceGroup(
                request_id=request_id,
                seqs=[seq],
                arrival_time=arrival_time,
                lora_request=lora_request,
                pooling_params=pooling_params,
                prompt_adapter_request=prompt_adapter_request,
                encoder_seq=encoder_seq,
                priority=priority)
            return seq_group
    
        def _process_model_outputs(self,
                                   ctx: SchedulerContext,
                                   request_id: Optional[str] = None) -> None:
            """Apply the model output to the sequences in the scheduled seq groups
            and return responses.
    
            ctx: The virtual engine context to work on
            request_id: If provided, then only this request is going to be processed
    
            """
            now = time.time()
    
            if len(ctx.output_queue) == 0:
                return None
    
            # Get pending async postprocessor
            if request_id:
                # When we process only one request, no pop is required
                # (since later we will process all of the rest)
                (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
                 is_last_step, skip) = ctx.output_queue[0]
            else:
                (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
                 is_last_step, skip) = ctx.output_queue.popleft()
    
            # Sanity check
            assert len(seq_group_metadata_list) == len(
                scheduler_outputs.scheduled_seq_groups)
    
            # Organize outputs by [step][sequence group] instead of
            # [sequence group][step].
            if len(outputs) > 1:
                outputs_by_sequence_group = create_output_by_sequence_group(
                    outputs, num_seq_groups=len(seq_group_metadata_list))
            else:
                outputs_by_sequence_group = outputs
    
            # Determine the requests we need to operate on
            if request_id:
                indices = []
                for i, seq_group_meta in enumerate(seq_group_metadata_list):
                    if seq_group_meta.request_id == request_id:
                        assert i not in skip  # Cannot be called twice
                        indices.append(i)
                        break
    
                # If the request_id was not found, then it means that
                # this is a new request that has no pending async
                # postprocessor
                if not indices:
                    return
            else:
                indices = range(len(seq_group_metadata_list))  # type: ignore
    
            finished_before: List[int] = []
            finished_now: List[int] = []
            for i in indices:
                if i in skip:
                    continue
    
                seq_group_meta = seq_group_metadata_list[i]
                scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
    
                seq_group = scheduled_seq_group.seq_group
    
                if seq_group.is_finished():
                    finished_before.append(i)
                    continue
    
                if len(outputs) > 1:
                    output = outputs_by_sequence_group[i]
                else:
                    output = [outputs_by_sequence_group[0][i]]
    
                if not is_async:
                    seq_group.update_num_computed_tokens(
                        scheduled_seq_group.token_chunk_size)
    
                if outputs:
                    for o in outputs:
                        if (isinstance(o, SamplerOutput)
                                and seq_group.metrics is not None):
                            if seq_group.metrics.model_forward_time is not None:
                                seq_group.metrics.model_forward_time += (
                                    o.model_forward_time)
                            else:
                                seq_group.metrics.model_forward_time = (
                                    o.model_forward_time)
                            if seq_group.metrics.model_execute_time is not None:
                                seq_group.metrics.model_execute_time += (
                                    o.model_execute_time)
                            else:
                                seq_group.metrics.model_execute_time = (
                                    o.model_execute_time)
    
                if self.model_config.embedding_mode:
                    self._process_sequence_group_outputs(seq_group, output)
                else:
                    self.output_processor.process_prompt_logprob(seq_group, output)
                    if seq_group_meta.do_sample:
                        self.output_processor.process_outputs(
                            seq_group, output, is_async)
    
                if seq_group.is_finished():
                    finished_now.append(i)
    
            # Generate outputs for the requests that finished this iteration
            for i in finished_now:
                scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
    
                seq_group = scheduled_seq_group.seq_group
                seq_group.maybe_set_first_token_time(now)
                request_output = RequestOutputFactory.create(
                    seq_group, use_cache=self.use_cached_outputs)
                if request_output:
                    ctx.request_outputs.append(request_output)
    
            # When we process a single request, we skip it for the next time,
            # and invoke the request output callback (if there was final output)
            if request_id:
                assert len(indices) == 1
                skip.append(indices[0])
    
                if (finished_now
                        and self.process_request_outputs_callback is not None):
                    self.process_request_outputs_callback(ctx.request_outputs)
                    ctx.request_outputs.clear()
                return
    
            # Free currently finished requests
            if finished_now:
                for scheduler in self.scheduler:
                    scheduler.free_finished_seq_groups()
    
            # For multi-step without streaming, don't create outputs each iteration
            if not is_last_step and not ctx.multi_step_stream_outputs:
                # Immediately process request outputs here (if callback is given)
                if (finished_now
                        and self.process_request_outputs_callback is not None):
                    self.process_request_outputs_callback(ctx.request_outputs)
                    ctx.request_outputs.clear()
                return
    
            # Create the outputs
            for i in indices:
                if i in skip or i in finished_before or i in finished_now:
                    continue  # Avoids double processing
    
                scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
    
                seq_group = scheduled_seq_group.seq_group
                seq_group.maybe_set_first_token_time(now)
                request_output = RequestOutputFactory.create(
                    seq_group, use_cache=self.use_cached_outputs)
                if request_output:
                    ctx.request_outputs.append(request_output)
    
            # For multi-step with streaming, create outputs each iteration
            if not is_last_step and ctx.multi_step_stream_outputs:
                # Immediately process request outputs here (if callback is given)
                if self.process_request_outputs_callback is not None:
                    self.process_request_outputs_callback(ctx.request_outputs)
                    ctx.request_outputs.clear()
                return
    
            for seq_group in scheduler_outputs.ignored_seq_groups:
                params = seq_group.sampling_params
                if params is not None and params.output_kind == (
                        RequestOutputKind.DELTA) and not seq_group.is_finished():
                    continue
    
                request_output = RequestOutputFactory.create(
                    seq_group, use_cache=self.use_cached_outputs)
                if request_output:
                    ctx.request_outputs.append(request_output)
    
            # Immediately process request outputs here (if callback is given)
            if (ctx.request_outputs
                    and self.process_request_outputs_callback is not None):
                self.process_request_outputs_callback(ctx.request_outputs)
                ctx.request_outputs.clear()
    
            # For async case, we need to record the stats here.
            # For non-async case, the stats are done in the
            # LLMEngine/AsyncLLMEngine directly
            if is_async:
                # Log stats.
                self.do_log_stats(scheduler_outputs, outputs, finished_before,
                                  skip)
    
                # Tracing
                self.do_tracing(scheduler_outputs)
    
            return None
    
        def _advance_to_next_step(
                self, output: List[SamplerOutput],
                seq_group_metadata_list: List[SequenceGroupMetadata],
                scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
            """Given model output from a single run, append the tokens to the
            sequences. This is normally done inside output processor, but it is
            required if the worker is to perform async forward pass to next step.
            """
            for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
                zip(seq_group_metadata_list, output, scheduled_seq_groups):
                seq_group = scheduled_seq_group.seq_group
    
                if seq_group.is_finished():
                    continue
    
                seq_group.update_num_computed_tokens(
                    seq_group_metadata.token_chunk_size)
    
                if seq_group_metadata.do_sample:
                    assert len(sequence_group_outputs.samples) == 1, (
                        "Async output processor expects a single sample"
                        " (i.e sampling_params.n == 1 and no "
                        "sampling_params.best_of > 1)")
                    sample = sequence_group_outputs.samples[0]
    
                    assert len(seq_group.seqs) == 1
                    seq = seq_group.seqs[0]
                    seq.append_token_id(sample.output_token, sample.logprobs)
    
        def _has_remaining_steps(
            self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
        ) -> bool:
            if (not self.scheduler_config.is_multi_step
                    or not seq_group_metadata_list):
                return False
    
            # TODO(will) this is a sanity check for nowto make sure that all the
            # seqs are on the same steps. Eventually we will want to do some sort of
            # dynamic scheduling when doing multi-step decoding.
            ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
            if any([
                    seq_group.state.remaining_steps != ref_remaining_steps
                    for seq_group in seq_group_metadata_list[1:]
            ]):
                raise AssertionError(("All running sequence groups should "
                                      "have the same remaining steps."))
    
            return ref_remaining_steps > 0
    
        def _cache_scheduler_outputs_for_multi_step(
                self, virtual_engine: int,
                seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
                scheduler_outputs: SchedulerOutputs,
                allow_async_output_proc: bool) -> None:
            co = self.cached_scheduler_outputs[virtual_engine]
    
            co.seq_group_metadata_list = seq_group_metadata_list
            co.scheduler_outputs = scheduler_outputs
            co.allow_async_output_proc = allow_async_output_proc
            co.last_output = None
    
        def _update_cached_scheduler_output(
                self, virtual_engine: int,
                output: List[Optional[SamplerOutput]]) -> None:
            if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                    and output[0] is not None):
                last_output = output[-1]
                assert last_output is not None
                assert last_output.sampled_token_ids_cpu is not None
                assert last_output.sampled_token_ids is None
                assert last_output.sampled_token_probs is None
                self.cached_scheduler_outputs[
                    virtual_engine].last_output = last_output
    
        def _get_last_sampled_token_ids(
                self, virtual_engine: int) -> Optional[torch.Tensor]:
            cached_last_output = self.cached_scheduler_outputs[
                virtual_engine].last_output
            if (self.scheduler_config.is_multi_step
                    and self.parallel_config.pipeline_parallel_size > 1
                    and cached_last_output is not None
                    and cached_last_output.sampled_token_ids_cpu is not None):
                return cached_last_output.sampled_token_ids_cpu
            return None
    
        def _get_stats(self,
                       scheduler_outputs: Optional[SchedulerOutputs],
                       model_output: Optional[List[SamplerOutput]] = None,
                       finished_before: Optional[List[int]] = None,
                       skip: Optional[List[int]] = None) -> Stats:
            """Get Stats to be Logged to Prometheus.
    
            Args:
                scheduler_outputs: Optional, used to populate metrics related to
                    the scheduled batch,
                model_output: Optional, used to emit speculative decoding metrics
                    which are created by the workers.
                finished_before: Optional, indices of sequences that were finished
                    before. These sequences will be ignored.
                skip: Optional, indices of sequences that were preempted. These
                    sequences will be ignored.
            """
            now = time.time()
    
            # System State
            #   Scheduler State
            num_running_sys = sum(
                len(scheduler.running) for scheduler in self.scheduler)
            num_swapped_sys = sum(
                len(scheduler.swapped) for scheduler in self.scheduler)
            num_waiting_sys = sum(
                len(scheduler.waiting) for scheduler in self.scheduler)
    
            # KV Cache Usage in %
            num_total_gpu = self.cache_config.num_gpu_blocks
            gpu_cache_usage_sys = 0.
            if num_total_gpu is not None:
                num_free_gpu = sum(
                    scheduler.block_manager.get_num_free_gpu_blocks()
                    for scheduler in self.scheduler)
                gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
    
            num_total_cpu = self.cache_config.num_cpu_blocks
            cpu_cache_usage_sys = 0.
            if num_total_cpu is not None and num_total_cpu > 0:
                num_free_cpu = sum(
                    scheduler.block_manager.get_num_free_cpu_blocks()
                    for scheduler in self.scheduler)
                cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
    
            # Prefix Cache Hit Rate. Note that we always use
            # the cache hit rate of the first virtual engine.
            cpu_prefix_cache_hit_rate = self.scheduler[
                0].get_prefix_cache_hit_rate(Device.CPU)
            gpu_prefix_cache_hit_rate = self.scheduler[
                0].get_prefix_cache_hit_rate(Device.GPU)
    
            # Iteration stats
            num_prompt_tokens_iter = 0
            num_generation_tokens_iter = 0
            time_to_first_tokens_iter: List[float] = []
            time_per_output_tokens_iter: List[float] = []
            num_preemption_iter = (0 if scheduler_outputs is None else
                                   scheduler_outputs.preempted)
    
            # Request stats
            #   Latency
            time_e2e_requests: List[float] = []
            #   Metadata
            num_prompt_tokens_requests: List[int] = []
            num_generation_tokens_requests: List[int] = []
            best_of_requests: List[int] = []
            n_requests: List[int] = []
            finished_reason_requests: List[str] = []
    
            # NOTE: This loop assumes prefill seq_groups are before
            # decode seq_groups in scheduled_seq_groups.
            if scheduler_outputs is not None:
                # For async postprocessor, already finished sequences need to be
                # not counted (to avoid double counting)
                actual_num_batched_tokens = scheduler_outputs.num_batched_tokens  # type: ignore
    
                num_generation_tokens_from_prefill_groups = 0.
                # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
                # the len of scheduler_outputs.scheduled_seq_groups is !=
                # scheduler_outputs.num_prefill_groups, this means that
                # chunked prefills have been detected.
    
                for idx, scheduled_seq_group in enumerate(
                        scheduler_outputs.scheduled_seq_groups):
                    # Skip double logging when using async output proc
                    if finished_before and idx in finished_before:
                        actual_num_batched_tokens -= 1
                        continue
    
                    # Currently, skip == preempted sequences, so we need to skip
                    # their log stats
                    if skip and idx in skip:
                        continue
    
                    group_was_prefill = idx < scheduler_outputs.num_prefill_groups
                    seq_group = scheduled_seq_group.seq_group
    
                    # NOTE: a seq_group that completed all of its prefill tokens
                    # in the last iteration will have seq_group.is_prefill() = False
                    # with group_was_prefill = True
                    if group_was_prefill:
                        # Number of prompt tokens.
                        num_prompt_tokens_iter += (
                            scheduled_seq_group.token_chunk_size)
    
                        # If the seq_group just finished the prefill state
                        # get TTFT.
                        if not seq_group.is_prefill():
                            latency = seq_group.get_last_latency(now)
                            time_to_first_tokens_iter.append(latency)
    
                            # One generation token per finished prefill.
                            num_generation_tokens_from_prefill_groups += (
                                seq_group.num_seqs())
                    else:
                        # TPOTs.
                        latency = seq_group.get_last_latency(now)
                        time_per_output_tokens_iter.append(latency)
    
                    # Because of chunked prefill, we can have a single sequence
                    # group that does multiple prompt_runs. To prevent logging
                    # the same metadata more than once per request, we standardize
                    # on logging request level information for finished requests,
                    # which can only happen once.
                    if seq_group.is_finished():
                        # Latency timings
                        time_e2e_requests.append(now -
                                                 seq_group.metrics.arrival_time)
                        # Metadata
                        num_prompt_tokens_requests.append(
                            len(seq_group.prompt_token_ids))
                        num_generation_tokens_requests.extend([
                            seq.get_output_len()
                            for seq in seq_group.get_finished_seqs()
                        ])
                        if seq_group.sampling_params is not None:
                            best_of_requests.append(
                                seq_group.sampling_params.best_of)
                            n_requests.append(seq_group.sampling_params.n)
                        finished_reason_requests.extend([
                            SequenceStatus.get_finished_reason(seq.status)
                            for seq in seq_group.get_finished_seqs()
                        ])
                num_generation_tokens_iter = (
                    actual_num_batched_tokens - num_prompt_tokens_iter +
                    num_generation_tokens_from_prefill_groups)
    
            # Spec decode, if enabled, emits specialized metrics from the worker in
            # sampler output.
            if model_output and (model_output[0].spec_decode_worker_metrics
                                 is not None):
                spec_decode_metrics = model_output[0].spec_decode_worker_metrics
            else:
                spec_decode_metrics = None
    
            return Stats(
                now=now,
                # System stats
                #   Scheduler State
                num_running_sys=num_running_sys,
                num_swapped_sys=num_swapped_sys,
                num_waiting_sys=num_waiting_sys,
                #   KV Cache Usage in %
                gpu_cache_usage_sys=gpu_cache_usage_sys,
                cpu_cache_usage_sys=cpu_cache_usage_sys,
                #   Prefix Cache Hit Rate
                cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
                gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
    
                # Iteration stats
                num_prompt_tokens_iter=num_prompt_tokens_iter,
                num_generation_tokens_iter=num_generation_tokens_iter,
                time_to_first_tokens_iter=time_to_first_tokens_iter,
                time_per_output_tokens_iter=time_per_output_tokens_iter,
                spec_decode_metrics=spec_decode_metrics,
                num_preemption_iter=num_preemption_iter,
    
                # Request stats
                #   Latency
                time_e2e_requests=time_e2e_requests,
                #   Metadata
                num_prompt_tokens_requests=num_prompt_tokens_requests,
                num_generation_tokens_requests=num_generation_tokens_requests,
                best_of_requests=best_of_requests,
                n_requests=n_requests,
                finished_reason_requests=finished_reason_requests,
            )
    
        def _validate_model_inputs(self, inputs: Union[LLMInputs,
                                                       EncoderDecoderLLMInputs]):
            if self.model_config.is_multimodal_model:
                # For encoder-decoder multimodal models, the max_prompt_len
                # restricts the decoder prompt length
                prompt_ids = inputs.get("prompt_token_ids")
            elif self.is_encoder_decoder_model():
                prompt_ids = inputs.get("encoder_prompt_token_ids")
            else:
                prompt_ids = inputs.get("prompt_token_ids")
    
            if prompt_ids is None or len(prompt_ids) == 0:
                raise ValueError("Prompt cannot be empty")
    
            if self.model_config.is_multimodal_model:
                max_prompt_len = self.model_config.max_model_len
    
                if len(prompt_ids) > max_prompt_len:
                    raise ValueError(
                        f"The prompt (total length {len(prompt_ids)}) is too long "
                        f"to fit into the model (context length {max_prompt_len}). "
                        "Make sure that `max_model_len` is no smaller than the "
                        "number of text tokens plus multimodal tokens. For image "
                        "inputs, the number of image tokens depends on the number "
                        "of images, and possibly their aspect ratios as well.")
    
  • cover/vllm/executor/npu_executor.py:基于vLLM 0.6.2版本在executor模块中继承ExecutorBase实现NPUExecutor,NPUExecutorAsync,用于单卡环境同步模式和异步模式的推理。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of codes in this file was copied from project [vLLM Team][vllm]
    
    from typing import List, Set, Tuple, Optional, Callable, Type, Dict, Any
    
    from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
    from vllm.logger import init_logger
    from vllm.prompt_adapter.request import PromptAdapterRequest
    from vllm.lora.request import LoRARequest
    from vllm.sequence import ExecuteModelRequest
    from vllm.model_executor.layers.sampler import SamplerOutput
    from vllm.utils import get_distributed_init_method, get_ip, get_open_port, make_async
    from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
    from vllm.worker.npu_worker import NPUWorker
    
    
    logger = init_logger(__name__)
    
    
    def create_worker(
        worker_module_name: str, worker_class_name: str, worker_class_fn: Optional[Callable[[], Type[WorkerBase]]], **kwargs
    ):
        wrapper = WorkerWrapperBase(
            worker_module_name=worker_module_name,
            worker_class_name=worker_class_name,
            worker_class_fn=worker_class_fn,
        )
        wrapper.init_worker(**kwargs)
        return wrapper.worker
    
    
    class NPUExecutor(ExecutorBase):
    
        def determine_num_available_blocks(self) -> Tuple[int, int]:
            """Determine the number of available KV blocks by invoking the
            underlying worker.
            """
            return self.driver_worker.determine_num_available_blocks()
    
        def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
            """Initialize the KV cache by invoking the underlying worker."""
            self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
    
        def execute_model(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
            output = self.driver_worker.execute_model(execute_model_req)
            return output
    
        def add_lora(self, lora_request: LoRARequest) -> bool:
            return self.driver_worker.add_lora(lora_request)
    
        def remove_lora(self, lora_id: int) -> bool:
            return self.driver_worker.remove_lora(lora_id)
    
        def list_loras(self) -> Set[int]:
            return self.driver_worker.list_loras()
    
        def check_health(self) -> None:
            # NeuronExecutor will always be healthy as long as
            # it's running.
            return
    
        def pin_lora(self, lora_id: int) -> bool: ...
    
        def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool: ...
    
        def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: ...
    
        def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: ...
    
        def list_prompt_adapters(self) -> Set[int]: ...
    
        def _init_executor(self) -> None:
            assert not self.speculative_config, "Speculative decoding is not yet supported for Ascend backend."
            self.driver_worker = self._create_worker()
            self.driver_worker.init_device()
            self.driver_worker.load_model()
    
        def _create_worker(self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None):
            return create_worker(
                **self._get_create_worker_kwargs(
                    local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method
                )
            )
    
        def _get_create_worker_kwargs(
            self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None
        ) -> Dict:
            worker_kwargs = self._get_worker_kwargs(local_rank, rank, distributed_init_method)
    
            (worker_module_name, worker_class_name, worker_class_fn) = self._get_worker_module_and_class()
            worker_kwargs.update(
                worker_module_name=worker_module_name,
                worker_class_name=worker_class_name,
                worker_class_fn=worker_class_fn,
            )
    
            return worker_kwargs
    
        def _get_worker_kwargs(
            self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None
        ) -> Dict[str, Any]:
            """Return worker init args for a given rank."""
            if distributed_init_method is None:
                distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
            return dict(
                model_config=self.model_config,
                parallel_config=self.parallel_config,
                scheduler_config=self.scheduler_config,
                device_config=self.device_config,
                cache_config=self.cache_config,
                load_config=self.load_config,
                local_rank=local_rank,
                rank=rank,
                distributed_init_method=distributed_init_method,
                lora_config=self.lora_config,
                is_driver_worker=(not self.parallel_config) or (rank % self.parallel_config.tensor_parallel_size == 0),
                # TODO: Add support for speculative_config, prompt_adapter_config, observability_config.
            )
    
        def _get_worker_module_and_class(self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
            worker_class_fn = None
            if self.scheduler_config.is_multi_step:
                worker_module_name = "vllm.worker.npu_worker"
                worker_class_name = "MultiStepNPUWorker"
            else:
                worker_module_name = "vllm.worker.npu_worker"
                worker_class_name = "NPUWorker"
            return (worker_module_name, worker_class_name, worker_class_fn)
    
    
    class NPUExecutorAsync(NPUExecutor, ExecutorAsyncBase):
    
        async def execute_model_async(
            self,
            execute_model_req: ExecuteModelRequest,
        ) -> List[SamplerOutput]:
            output = await make_async(self.driver_worker.execute_model)(
                execute_model_req=execute_model_req,
            )
            return output
  • cover/vllm/executor/distributed_npu_executor.py:实现DistributedNPUExecutor和DistributedNPUExecutorAsync,支持分布式环境下的同步和异步推理。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of codes in this file was copied from project [vLLM Team][vllm]
    
    import asyncio
    from abc import abstractmethod
    from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
    
    from vllm.executor.executor_base import ExecutorAsyncBase
    from vllm.executor.npu_executor import NPUExecutor
    from vllm.logger import init_logger
    from vllm.lora.request import LoRARequest
    from vllm.model_executor.layers.sampler import SamplerOutput
    from vllm.sequence import ExecuteModelRequest
    
    logger = init_logger(__name__)
    
    
    class DistributedNPUExecutor(NPUExecutor):
        """Abstract superclass of multi-NPU executor implementations."""
    
        def __init__(self, *args, **kwargs):
            # This is non-None when the execute model loop is running
            # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
            self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
            # Updated by implementations that require additional args to be passed
            # to the _run_workers execute_model call
            self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
    
            super().__init__(*args, **kwargs)
    
        def determine_num_available_blocks(self) -> Tuple[int, int]:
            """Determine the number of available KV blocks.
    
            This invokes `determine_num_available_blocks` on each worker and takes
            the min of the results, guaranteeing that the selected cache sizes are
            compatible with all workers.
    
            Returns:
                - tuple[num_gpu_blocks, num_cpu_blocks]
            """
            # Get the maximum number of blocks that can be allocated on NPU and CPU.
            num_blocks = self._run_workers("determine_num_available_blocks", )
    
            # Since we use a shared centralized controller, we take the minimum
            # number of blocks across all workers to make sure all the memory
            # operators can be applied to all workers.
            num_gpu_blocks = min(b[0] for b in num_blocks)
            num_cpu_blocks = min(b[1] for b in num_blocks)
    
            return num_gpu_blocks, num_cpu_blocks
    
        def initialize_cache(self, num_gpu_blocks: int,
                             num_cpu_blocks: int) -> None:
            """Initialize the KV cache in all workers.
            """
    
            # NOTE: We log here to avoid multiple logs when number of workers is
            # greater than one. We could log in the engine, but not all executors
            # have NPUs.
            logger.info("# NPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
                        num_cpu_blocks)
    
            self.cache_config.num_gpu_blocks = num_gpu_blocks
            self.cache_config.num_cpu_blocks = num_cpu_blocks
    
            self._run_workers("initialize_cache",
                              num_gpu_blocks=num_gpu_blocks,
                              num_cpu_blocks=num_cpu_blocks)
    
        def execute_model(
            self,
            execute_model_req: ExecuteModelRequest,
        ) -> List[SamplerOutput]:
            if self.parallel_worker_tasks is None:
                self.parallel_worker_tasks = self._run_workers(
                    "start_worker_execution_loop",
                    async_run_tensor_parallel_workers_only=True,
                    **self.extra_execute_model_run_workers_kwargs)
    
            # Only the driver worker returns the sampling results.
            driver_outputs = self._driver_execute_model(execute_model_req)
            assert driver_outputs is not None
            return driver_outputs
    
        def stop_remote_worker_execution_loop(self) -> None:
            if self.parallel_worker_tasks is None:
                return
    
            self._driver_execute_model(execute_model_req=None)
            parallel_worker_tasks = self.parallel_worker_tasks
            self.parallel_worker_tasks = None
            # Ensure that workers exit model loop cleanly
            # (this will raise otherwise)
            self._wait_for_tasks_completion(parallel_worker_tasks)
    
        def add_lora(self, lora_request: LoRARequest) -> bool:
            assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
            return self._run_workers(
                "add_lora",
                lora_request=lora_request,
            )
    
        def remove_lora(self, lora_id: int) -> bool:
            assert lora_id > 0, "lora_id must be greater than 0."
            return self._run_workers(
                "remove_lora",
                lora_id=lora_id,
            )
    
        def pin_lora(self, lora_id: int) -> bool:
            assert lora_id > 0, "lora_id must be greater than 0."
            return self._run_workers(
                "pin_lora",
                lora_id=lora_id,
            )
    
        def list_loras(self) -> Set[int]:
            return self._run_workers("list_loras")
    
        def save_sharded_state(
            self,
            path: str,
            pattern: Optional[str] = None,
            max_size: Optional[int] = None,
        ) -> None:
            self._run_workers("save_sharded_state",
                              path=path,
                              pattern=pattern,
                              max_size=max_size)
    
        @abstractmethod
        def _driver_execute_model(
            self, execute_model_req: Optional[ExecuteModelRequest]
        ) -> Optional[List[SamplerOutput]]:
            """Run execute_model in the driver worker.
    
            Passing None will cause the driver to stop the model execution loop
            running in each of the remote workers. In this case, this method
            returns None. Otherwise, this method returns the model output.
            """
            raise NotImplementedError
    
        @abstractmethod
        def _run_workers(
            self,
            method: str,
            *args,
            async_run_tensor_parallel_workers_only: bool = False,
            max_concurrent_workers: Optional[int] = None,
            **kwargs,
        ) -> Any:
            """Runs the given method on all workers.
    
            Args:
                async_run_tensor_parallel_workers_only: If True the method will be
                    run only in the remote TP workers, not the driver worker.
                    It will also be run asynchronously and return a list of futures
                    rather than blocking on the results.
            """
            raise NotImplementedError
    
        @abstractmethod
        def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
            """Wait for futures returned from _run_workers() with
            async_run_remote_workers_only to complete."""
            raise NotImplementedError
    
    
    class DistributedNPUExecutorAsync(DistributedNPUExecutor, ExecutorAsyncBase):
    
        async def execute_model_async(
                self,
                execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
            if self.parallel_worker_tasks is None:
                # Start model execution loop running in the parallel workers
                self.parallel_worker_tasks = asyncio.create_task(
                    self._start_worker_execution_loop())
    
            # Only the driver worker returns the sampling results.
            return await self._driver_execute_model_async(execute_model_req)
    
        async def stop_remote_worker_execution_loop_async(self) -> None:
            if self.parallel_worker_tasks is None:
                return
    
            await self._driver_execute_model_async()
            parallel_worker_tasks = self.parallel_worker_tasks
            self.parallel_worker_tasks = None
            # Ensure that workers exit model loop cleanly
            # (this will raise otherwise)
            await parallel_worker_tasks
    
        @abstractmethod
        async def _driver_execute_model_async(
            self,
            execute_model_req: Optional[ExecuteModelRequest] = None,
        ) -> List[SamplerOutput]:
            """Execute the model asynchronously in the driver worker.
    
            Passing None will cause the driver to stop the model execution
            loop running in each of the remote workers.
            """
            raise NotImplementedError
    
        @abstractmethod
        async def _start_worker_execution_loop(self):
            """Run execution loop on all workers. It guarantees all workers run
            the loop or None of them is running the loop. Loop can be stopped by
            `stop_remote_worker_execution_loop`.
            The API is idempotent (guarantee only 1 loop run at any moment)."""
            raise NotImplementedError
  • cover/vllm/executor/ray_npu_executor.py:实现了RayNPUExecutor和RayNPUExecutorAsync,用于多卡ray分布式环境下同步和异步调用模式的推理。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of codes in this file was copied from project [vLLM Team][vllm]
    
    import asyncio
    import os
    import pickle
    from collections import defaultdict
    from itertools import islice, repeat
    from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
    
    import vllm.envs as envs
    from vllm.executor.distributed_npu_executor import DistributedNPUExecutor, DistributedNPUExecutorAsync
    from vllm.executor.ray_utils import RayWorkerWrapper, ray
    from vllm.logger import init_logger
    from vllm.model_executor.layers.sampler import SamplerOutput
    from vllm.sequence import ExecuteModelRequest
    from vllm.utils import (
        get_distributed_init_method,
        get_ip,
        get_open_port,
        get_vllm_instance_id,
        make_async,
        _run_task_with_lock,
    )
    
    if ray is not None:
        from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
    
    if TYPE_CHECKING:
        from ray.util.placement_group import PlacementGroup
    
    logger = init_logger(__name__)
    
    USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
    
    
    class RayNPUExecutor(DistributedNPUExecutor):
        uses_ray: bool = True
        def execute_model(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
            all_outputs = self._run_workers(
                "execute_model",
                driver_kwargs={"execute_model_req": execute_model_req},
                use_ray_compiled_dag=USE_RAY_COMPILED_DAG,
            )
    
            # Only the driver worker returns the sampling results.
            return all_outputs[0]
    
        def check_health(self) -> None:
            """Raises an error if engine is unhealthy."""
            self._check_if_any_actor_is_dead()
    
        def _init_executor(self) -> None:
            self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
            # If the env var is set, it uses the Ray's compiled DAG API
            # which optimizes the control plane overhead.
            # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
            # Currently, this requires USE_RAY_SPMD_WORKER=True.
            self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
            # If the env var is set, then we do not distinguish between the
            # "driver worker" vs other workers. Also, the rank 0 worker will
            # be executed in a remote Ray worker. Currently this requires
            # USE_RAY_COMPILED_DAG=True.
            self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
            if self.use_ray_compiled_dag:
                assert self.use_ray_spmd_worker, "VLLM_USE_RAY_COMPILED_DAG=1 requires " "VLLM_USE_RAY_SPMD_WORKER=1"
            if self.use_ray_spmd_worker:
                # TODO: Support SPMD worker for non-DAG Ray executor.
                assert self.use_ray_compiled_dag, "VLLM_USE_RAY_SPMD_WORKER=1 requires " "VLLM_USE_RAY_COMPILED_DAG=1"
    
            assert not self.speculative_config, "Speculative decoding not yet supported for RayNPU backend."
    
            assert self.parallel_config.use_ray
            placement_group = self.parallel_config.placement_group
    
            # Disable Ray usage stats collection.
            ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
            if ray_usage != "1":
                os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
    
            # Create the parallel NPU workers.
            self._init_workers_ray(placement_group)
    
            self.forward_dag = None
            if USE_RAY_COMPILED_DAG:
                self.forward_dag = self._compiled_ray_dag()
    
        def _get_worker_wrapper_args(self) -> Dict[str, Any]:
            (worker_module_name, worker_class_name, worker_class_fn) = self._get_worker_module_and_class()
    
            return dict(
                worker_module_name=worker_module_name,
                worker_class_name=worker_class_name,
                worker_class_fn=worker_class_fn,
                trust_remote_code=self.model_config.trust_remote_code,
            )
    
        def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
            if self.parallel_config.tensor_parallel_size == 1:
                # For single NPU case, we use a ray worker with constrained memory.
                num_gpus = self.cache_config.gpu_memory_utilization
            else:
                # Otherwise, the ray workers are allocated with a full NPU.
                num_gpus = 1
    
            # The driver dummy worker does not actually use any resources.
            # It holds the resource for the driver worker.
            self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
            # The remaining workers are the actual ray actors.
            self.workers: List[RayWorkerWrapper] = []
    
            if self.parallel_config.ray_workers_use_nsight:
                ray_remote_kwargs = self._configure_ray_workers_use_nsight(ray_remote_kwargs)
    
            # Create the workers.
            driver_ip = get_ip()
            worker_wrapper_kwargs = self._get_worker_wrapper_args()
            for bundle_id, _ in enumerate(placement_group.bundle_specs):
                scheduling_strategy = PlacementGroupSchedulingStrategy(
                    placement_group=placement_group,
                    placement_group_capture_child_tasks=True,
                    placement_group_bundle_index=bundle_id,
                )
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
                )(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
    
                if self.use_ray_spmd_worker:
                    self.workers.append(worker)
                else:
                    worker_ip = ray.get(worker.get_node_ip.remote())
                    if worker_ip == driver_ip and self.driver_dummy_worker is None:
                        # If the worker is on the same node as the driver, we use it
                        # as the resource holder for the driver process.
                        self.driver_dummy_worker = worker
                        self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
                    else:
                        # Else, added to the list of workers.
                        self.workers.append(worker)
    
            logger.debug("workers: %s", self.workers)
            logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
            if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
                raise ValueError(
                    "Ray does not allocate any NPUs on the driver node. Consider "
                    "adjusting the Ray placement group or running the driver on a "
                    "NPU node."
                )
    
            # Get the set of NPU IDs used on each node.
            worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", use_dummy_driver=True)
    
            node_workers = defaultdict(list)
            node_gpus = defaultdict(list)
    
            for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
                node_workers[node_id].append(i)
                node_gpus[node_id].extend(gpu_ids)
            for node_id, gpu_ids in node_gpus.items():
                node_gpus[node_id] = sorted(gpu_ids)
    
            VLLM_INSTANCE_ID = get_vllm_instance_id()
    
            # Set environment variables for the driver and workers.
            all_args_to_update_environment_variables = [
                (
                    {
                        "ASCEND_RT_VISIBLE_DEVICES": ",".join(map(str, node_gpus[node_id])),
                        "VLLM_INSTANCE_ID": VLLM_INSTANCE_ID,
                        "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION),
                    },
                )
                for (node_id, _) in worker_node_and_gpu_ids
            ]
            self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables)
    
            distributed_init_method = get_distributed_init_method(driver_ip, get_open_port())
    
            # Initialize the actual workers inside worker wrapper.
            init_worker_all_kwargs = [
                self._get_worker_kwargs(
                    local_rank=node_workers[node_id].index(rank),
                    rank=rank,
                    distributed_init_method=distributed_init_method,
                )
                for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
            ]
            self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
    
            self._run_workers("init_device")
            self._run_workers("load_model", max_concurrent_workers=self.parallel_config.max_parallel_loading_workers)
    
            if self.use_ray_spmd_worker:
                for pp_rank in range(self.parallel_config.pipeline_parallel_size):
                    self.pp_tp_workers.append([])
                    for tp_rank in range(self.parallel_config.tensor_parallel_size):
                        # PP=2, TP=4
                        # pp_tp_workers will be [[0, 1, 2, 3], [4, 5, 6, 7]]
                        rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank
                        assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                        assert pp_rank < len(self.pp_tp_workers)
                        self.pp_tp_workers[pp_rank].append(self.workers[rank])
    
            # This is the list of workers that are rank 0 of each TP group EXCEPT
            # global rank 0. These are the workers that will broadcast to the
            # rest of the workers.
            self.tp_driver_workers: List[RayWorkerWrapper] = []
            # This is the list of workers that are not drivers and not the first
            # worker in a TP group. These are the workers that will be
            # broadcasted to.
            self.non_driver_workers: List[RayWorkerWrapper] = []
    
            # Enforce rank order for correct rank to return final output.
            for index, worker in enumerate(self.workers):
                # The driver worker is rank 0 and not in self.workers.
                rank = index + 1
                if rank % self.parallel_config.tensor_parallel_size == 0:
                    self.tp_driver_workers.append(worker)
                else:
                    self.non_driver_workers.append(worker)
    
        def _driver_execute_model(self, execute_model_req: Optional[ExecuteModelRequest]) -> Optional[List[SamplerOutput]]:
            """Run execute_model in the driver worker.
    
            Passing None will cause the driver to stop the model execution
            loop running in each of the remote workers.
            """
            # assert not self.use_ray_spmd_worker, (
            #     "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
            return self.driver_worker.execute_method("execute_model", execute_model_req)
    
        def _run_workers(
            self,
            method: str,
            *args,
            driver_args: Optional[Tuple[Any, ...]] = None,
            driver_kwargs: Optional[Dict[str, Any]] = None,
            all_args: Optional[List[Tuple[Any, ...]]] = None,
            all_kwargs: Optional[List[Dict[str, Any]]] = None,
            use_dummy_driver: bool = False,
            max_concurrent_workers: Optional[int] = None,
            use_ray_compiled_dag: bool = False,
            **kwargs,
        ) -> Any:
            """Runs the given method on all workers. Can be used in the following
            ways:
    
            - args/kwargs: All workers share the same args/kwargs
            - args/kwargs and driver_args/driver_kwargs: Driver worker has
              different args
            - all_args/all_kwargs: args/kwargs for each worker are specified
              individually
            """
    
            if max_concurrent_workers:
                raise NotImplementedError("max_concurrent_workers is not supported yet.")
    
            if driver_args is None:
                driver_args = args if all_args is None else all_args[0]
            if driver_kwargs is None:
                driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
    
            count = len(self.workers)
            all_worker_args = repeat(args, count) if all_args is None else islice(all_args, 1, None)
            all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None else islice(all_kwargs, 1, None)
    
            if use_ray_compiled_dag:
                # Right now, compiled DAG can only accept a single
                # input. TODO(sang): Fix it.
                assert self.forward_dag is not None
                output_channels = self.forward_dag.execute(1)
            else:
                # Start the ray workers first.
                ray_worker_outputs = [
                    worker.execute_method.remote(method, *worker_args, **worker_kwargs)
                    for (worker, worker_args, worker_kwargs) in zip(self.workers, all_worker_args, all_worker_kwargs)
                ]
    
            # Start the driver worker after all the ray workers.
            if not use_dummy_driver:
                driver_worker_output = self.driver_worker.execute_method(method, *driver_args, **driver_kwargs)
            else:
                assert self.driver_dummy_worker is not None
                driver_worker_output = ray.get(
                    self.driver_dummy_worker.execute_method.remote(method, *driver_args, **driver_kwargs)
                )
            # Get the results of the ray workers.
            if self.workers:
                if use_ray_compiled_dag:
                    try:
                        ray_worker_outputs = [pickle.loads(chan.begin_read()) for chan in output_channels]
                    finally:
                        # Has to call end_read in order to reuse the DAG.
                        for chan in output_channels:
                            chan.end_read()
                else:
                    ray_worker_outputs = ray.get(ray_worker_outputs)
    
            return [driver_worker_output] + ray_worker_outputs
    
        def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
            """Wait for futures returned from _run_workers() with
            async_run_remote_workers_only to complete."""
            ray.get(parallel_worker_tasks)
    
        def _compiled_ray_dag(self):
            assert self.parallel_config.worker_use_ray
            self._check_ray_adag_installation()
            from ray.dag import InputNode, MultiOutputNode
    
            # Right now, compiled DAG requires at least 1 arg. We send
            # a dummy value for now. It will be fixed soon.
            with InputNode() as input_data:
                forward_dag = MultiOutputNode(
                    [
                        worker.execute_model_compiled_dag_remote.bind(input_data)  # type: ignore[attr-defined]
                        for worker in self.workers
                    ]
                )
            return forward_dag.experimental_compile()
    
        def _check_if_any_actor_is_dead(self):
            if not self.workers:
                return
    
            dead_actors = []
            for actor in self.workers:
                actor_state = ray.state.actors(actor._ray_actor_id.hex())  # pylint: disable=protected-access
                if actor_state["State"] == "DEAD":
                    dead_actors.append(actor)
            if dead_actors:
                raise RuntimeError("At least one Worker is dead. " f"Dead Workers: {dead_actors}. ")
    
    
    class RayNPUExecutorAsync(RayNPUExecutor, DistributedNPUExecutorAsync):
    
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.driver_executor = make_async(self.driver_worker.execute_method)
            if not self.use_ray_compiled_dag:
                self.driver_exec_method = make_async(self.driver_worker.execute_method)
    
        def __del__(self):
            self.shutdown()
    
        async def _driver_execute_model_async(
            self, execute_model_req: Optional[ExecuteModelRequest] = None
        ) -> List[SamplerOutput]:
            assert not self.use_ray_spmd_worker, "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
            if not self.tp_driver_workers:
                return await self.driver_exec_method("execute_model", execute_model_req)
            if self.pp_locks is None:
                # This locks each pipeline parallel stage so multiple virtual
                # engines can't execute on the same stage at the same time
                # We create the locks here to avoid creating them in the constructor
                # which uses a different asyncio loop.
                self.pp_locks = [asyncio.Lock() for _ in range(self.parallel_config.pipeline_parallel_size)]
    
            tasks = [
                asyncio.create_task(
                    _run_task_with_lock(self.driver_exec_method, self.pp_locks[0], "execute_model", execute_model_req)
                )
            ]
            for pp_rank, driver_worker in enumerate(self.tp_driver_workers, start=1):
                tasks.append(
                    asyncio.create_task(
                        _run_task_with_lock(
                            driver_worker.execute_method.remote, self.pp_locks[pp_rank], "execute_model", execute_model_req
                        )
                    )
                )
    
            results = await asyncio.gather(*tasks)
    
            # Only the last PP stage has the final results.
            return results[-1]
    
        async def _start_worker_execution_loop(self):
            assert not self.use_ray_spmd_worker, "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1"
            coros = [worker.execute_method.remote("start_worker_execution_loop") for worker in self.non_driver_workers]
            return await asyncio.gather(*coros)
  • cover/vllm/executor/ray_utils.py:在原生框架基础上修改initialize_ray_cluster函数,通过手动指定npu卡数量解决昇腾环境下无法识别卡数目的问题。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of codes in this file was copied from project [vLLM Team][vllm]
    
    import os
    import time
    from collections import defaultdict
    from typing import Dict, List, Optional, Tuple, Union
    
    import msgspec
    
    from vllm.config import ParallelConfig
    from vllm.executor.msgspec_utils import decode_hook, encode_hook
    from vllm.logger import init_logger
    from vllm.platforms import current_platform
    from vllm.sequence import ExecuteModelRequest, IntermediateTensors
    from vllm.utils import get_ip, is_hip, is_xpu, is_npu
    from vllm.worker.worker_base import WorkerWrapperBase
    
    logger = init_logger(__name__)
    PG_WAIT_TIMEOUT = 1800
    
    try:
        import ray
        from ray.util import placement_group_table
        from ray.util.placement_group import PlacementGroup
        try:
            from ray._private.state import available_resources_per_node
        except ImportError:
            # Ray 2.9.x doesn't expose `available_resources_per_node`
            from ray._private.state import state as _state
            available_resources_per_node = _state._available_resources_per_node
    
        class RayWorkerWrapper(WorkerWrapperBase):
            """Ray wrapper for vllm.worker.Worker, allowing Worker to be
            lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
    
            def __init__(self, *args, **kwargs) -> None:
                super().__init__(*args, **kwargs)
                # Since the compiled DAG runs a main execution
                # in a different thread that calls cuda.set_device.
                # The flag indicates is set_device is called on
                # that thread.
                self.compiled_dag_cuda_device_set = False
    
                self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
                                                             dec_hook=decode_hook)
                self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
    
            def get_node_ip(self) -> str:
                return get_ip()
    
            def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
                node_id = ray.get_runtime_context().get_node_id()
                gpu_ids = ray.get_gpu_ids()
                return node_id, gpu_ids
    
            def execute_model_spmd(
                self, req_or_tuple: Union[bytes,
                                          Tuple[bytes,
                                                Optional[IntermediateTensors]]]
            ) -> bytes:
                """Execute model in SPMD fashion: used only when SPMD worker and
                compiled DAG are both enabled.
    
                Args:
                    req_or_tuple: A request or a tuple containing the
                        request and intermediate tensors. Intermediate tensors are
                        None unless if it is provided because it is > 0 pipeline
                        stage. The request is serialized by msgspec.
                """
                if isinstance(req_or_tuple, bytes):
                    serialized_req, intermediate_tensors = req_or_tuple, None
                else:
                    serialized_req, intermediate_tensors = req_or_tuple
    
                execute_model_req = self.input_decoder.decode(serialized_req)
    
                # TODO(swang): This is needed right now because Ray aDAG executes
                # on a background thread, so we need to reset torch's current
                # device.
                import torch
                if not self.compiled_dag_cuda_device_set:
                    torch.cuda.set_device(self.worker.device)
                    self.compiled_dag_cuda_device_set = True
    
                output = self.worker._execute_model_spmd(execute_model_req,
                                                         intermediate_tensors)
                # Pipeline model request and output to the next pipeline stage.
                if isinstance(output, IntermediateTensors):
                    output = serialized_req, output
                else:
                    output = self.output_encoder.encode(output)
    
                return output
    
            def override_env_vars(self, vars: Dict[str, str]):
                os.environ.update(vars)
    
        ray_import_err = None
    
    except ImportError as e:
        ray = None  # type: ignore
        ray_import_err = e
        RayWorkerWrapper = None  # type: ignore
    
    
    def ray_is_available() -> bool:
        """Returns True if Ray is available."""
        return ray is not None
    
    
    def assert_ray_available():
        """Raise an exception if Ray is not available."""
        if ray is None:
            raise ValueError("Failed to import Ray, please install Ray with "
                             "`pip install ray`.") from ray_import_err
    
    
    def _verify_bundles(placement_group: "PlacementGroup",
                        parallel_config: ParallelConfig, device_str: str):
        """Verify a given placement group has bundles located in the right place.
    
        There are 2 rules.
        - Warn if all tensor parallel workers cannot fit in a single node.
        - Fail if driver node is not included in a placement group.
        """
        assert ray.is_initialized(), (
            "Ray is not initialized although distributed-executor-backend is ray.")
        pg_data = placement_group_table(placement_group)
        # bundle_idx -> node_id
        bundle_to_node_ids = pg_data["bundles_to_node_id"]
        # bundle_idx -> bundle (e.g., {"GPU": 1})
        bundles = pg_data["bundles"]
        # node_id -> List of bundle (e.g., {"GPU": 1})
        node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)
    
        for bundle_idx, node_id in bundle_to_node_ids.items():
            node_id_to_bundle[node_id].append(bundles[bundle_idx])
        driver_node_id = ray.get_runtime_context().get_node_id()
    
        if driver_node_id not in node_id_to_bundle:
            raise RuntimeError(
                f"driver node id {driver_node_id} is not included in a placement "
                f"group {placement_group.id}. Node id -> bundles "
                f"{node_id_to_bundle}. "
                "You don't have enough GPUs available in a current node. Check "
                "`ray status` to see if you have available GPUs in a node "
                f"{driver_node_id} before starting an vLLM engine.")
    
        for node_id, bundles in node_id_to_bundle.items():
            if len(bundles) < parallel_config.tensor_parallel_size:
                logger.warning(
                    "tensor_parallel_size=%d "
                    "is bigger than a reserved number of %ss (%d "
                    "%ss) in a node %s. Tensor parallel workers can be "
                    "spread out to 2+ nodes which can degrade the performance "
                    "unless you have fast interconnect across nodes, like "
                    "Infiniband. To resolve this issue, make sure you have more "
                    "than %d GPUs available at each node.",
                    parallel_config.tensor_parallel_size, device_str, len(bundles),
                    device_str, node_id, parallel_config.tensor_parallel_size)
    
    
    def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
        """Wait until a placement group is ready.
    
        It prints the informative log messages if the placement group is
        not created within time.
    
        """
        # Wait until PG is ready - this will block until all
        # requested resources are available, and will timeout
        # if they cannot be provisioned.
        placement_group_specs = current_placement_group.bundle_specs
    
        s = time.time()
        pg_ready_ref = current_placement_group.ready()
        wait_interval = 10
        while time.time() - s < PG_WAIT_TIMEOUT:
            ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
            if len(ready) > 0:
                break
    
            # Exponential backoff for warning print.
            wait_interval *= 2
            logger.info(
                "Waiting for creating a placement group of specs for "
                "%d seconds. specs=%s. Check "
                "`ray status` to see if you have enough resources.",
                int(time.time() - s), placement_group_specs)
    
        try:
            ray.get(pg_ready_ref, timeout=0)
        except ray.exceptions.GetTimeoutError:
            raise ValueError(
                "Cannot provide a placement group of "
                f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
                "`ray status` to make sure the cluster has enough resources."
            ) from None
    
    
    def _wait_until_pg_removed(current_placement_group: "PlacementGroup"):
        ray.util.remove_placement_group(current_placement_group)
        s = time.time()
        wait_interval = 10
        while time.time() - s < PG_WAIT_TIMEOUT:
            pg = ray.util.get_current_placement_group()
            if pg is None:
                break
    
            # Exponential backoff for warning print.
            wait_interval *= 2
            logger.info(
                "Waiting for removing a placement group of specs for "
                "%d seconds.", int(time.time() - s))
            time.sleep(wait_interval)
    
    
    def initialize_ray_cluster(
        parallel_config: ParallelConfig,
        ray_address: Optional[str] = None,
    ):
        """Initialize the distributed cluster with Ray.
    
        it will connect to the Ray cluster and create a placement group
        for the workers, which includes the specification of the resources
        for each distributed worker.
    
        Args:
            parallel_config: The configurations for parallel execution.
            ray_address: The address of the Ray cluster. If None, uses
                the default Ray cluster address.
        """
        assert_ray_available()
    
        # Connect to a ray cluster.
        if is_hip() or is_xpu() or is_npu():
            ray.init(address=ray_address,
                     ignore_reinit_error=True,
                     num_gpus=parallel_config.world_size)
        else:
            ray.init(address=ray_address, ignore_reinit_error=True)
    
        if parallel_config.placement_group:
            # Placement group is already set.
            return
    
        device_str = "GPU" if not current_platform.is_tpu() else "TPU"
        # Create placement group for worker processes
        current_placement_group = ray.util.get_current_placement_group()
        if current_placement_group:
            # We are in a placement group
            bundles = current_placement_group.bundle_specs
            # Verify that we can use the placement group.
            device_bundles = 0
            for bundle in bundles:
                bundle_devices = bundle.get(device_str, 0)
                if bundle_devices > 1:
                    raise ValueError(
                        "Placement group bundle cannot have more than 1 "
                        f"{device_str}.")
                if bundle_devices:
                    device_bundles += 1
            if parallel_config.world_size > device_bundles:
                raise ValueError(
                    f"The number of required {device_str}s exceeds the total "
                    f"number of available {device_str}s in the placement group."
                    f"Required number of devices: {parallel_config.world_size}. "
                    f"Total number of devices: {device_bundles}.")
        else:
            num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
            if parallel_config.world_size > num_devices_in_cluster:
                raise ValueError(
                    f"The number of required {device_str}s exceeds the total "
                    f"number of available {device_str}s in the placement group.")
            # Create a new placement group
            placement_group_specs: List[Dict[str, float]] = ([{
                device_str: 1.0
            } for _ in range(parallel_config.world_size)])
    
            # vLLM engine is also a worker to execute model with an accelerator,
            # so it requires to have the device in a current node. Check if
            # the current node has at least one device.
            current_ip = get_ip()
            # This way, at least bundle is required to be created in a current
            # node.
            placement_group_specs[0][f"node:{current_ip}"] = 0.001
    
            # By default, Ray packs resources as much as possible.
            current_placement_group = ray.util.placement_group(
                placement_group_specs, strategy="PACK")
            _wait_until_pg_ready(current_placement_group)
    
        assert current_placement_group is not None
        _verify_bundles(current_placement_group, parallel_config, device_str)
        # Set the placement group in the parallel config
        parallel_config.placement_group = current_placement_group
    
    
    def get_num_tpu_nodes() -> int:
        from ray._private.accelerators import TPUAcceleratorManager
        cluster_resources = ray.cluster_resources()
        total_tpus = int(cluster_resources["TPU"])
        tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
        assert total_tpus % tpus_per_node == 0
        return total_tpus // tpus_per_node
    
    
    def get_num_nodes_in_placement_group() -> int:
        pg_table = ray.util.placement_group_table()
        current_pg = ray.util.get_current_placement_group()
        num_nodes = 0
    
        if current_pg:
            nodes_in_pg = set()
            for pg_key, pg in pg_table.items():
                if pg_key == current_pg.id.hex():
                    for _, node in pg["bundles_to_node_id"].items():
                        nodes_in_pg.add(node)
            num_nodes = len(nodes_in_pg)
    
        return num_nodes
  • cover/vllm/model_executor/model_loader/npu.py:实现MindIELlMWrapper类,该类对MindIE LLM提供的Generator Torch统一接口进行实例化操作,并从vLLM原生框架的数据结构中拆解出MindIE LLM所需要的模型推理参数,从而传给统一接口调用模型服务。此外,该模块中重写了vLLM框架中get_model和get_architecture_class_name函数,从而将MindIELlMWrapper类引入到vLLM框架中。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of code in this file was copied from project [vLLM Team][vllm] for adapting usage
    
    import contextlib
    import math
    from typing import List, Optional, Tuple
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from mindie_llm.text_generator.adapter.generator_torch import GeneratorTorch
    from atb_llm.utils.initial import NPUSocInfo
    from torch import nn
    from vllm.attention import AttentionMetadata
    from vllm.config import DeviceConfig, LoadConfig, LoadFormat, ModelConfig
    from vllm.logger import init_logger
    from vllm.lora.request import LoRARequest
    from vllm.model_executor import SamplingMetadata
    from vllm.model_executor.layers.npu_sampler import MindIESampler
    from vllm.model_executor.layers.sampler import SamplerOutput, Sampler
    from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
    from vllm.sequence import IntermediateTensors
    
    logger = init_logger(__name__)
    
    KVCache = Tuple[torch.Tensor, torch.Tensor]
    
    
    # TODO: Refactor this to other file
    class MindIELlmWrapper(nn.Module):
        """
        A wrapper class for the MindIE model. It provides functionality for forward pass, sampling, 
        and model weight loading.
    
        Attributes:
            mindie_config : Configuration dictionary containing model parameters, 
            rank: Rank of the current device in the distributed setup.
            local_rank : Local rank of the device.
            npu_id: NPU device ID.
            world_size: Total number of devices in the world size.
            mindie_model: Instance of the generator model, initialized with the provided configuration.
            sampler: Sampler instance for token generation.
            dummy_block_num: Number of dummy blocks for cache creation.
        """
        def __init__(self, mindie_config, linear_method=None, lora_config=None):
            """
            Initializes the MindIELlmWrapper with the provided configuration and optional LoRA setup.
    
            Args:
                mindie_config: Configuration dictionary for the model, including rank, local_rank, world_size, etc.
                linear_method (optional): Method to apply linear transformations, default is None.
                lora_config (optional): Configuration for LoRA adapters, default is None.
            """
    
            super(MindIELlmWrapper, self).__init__()
    
            self.mindie_config = mindie_config
            self.rank = mindie_config["rank"]
            self.local_rank = mindie_config["local_rank"]
            self.npu_id = self.local_rank
            self.world_size = mindie_config["world_size"]
            self.mindie_model = None
            self.sampler = None
            self.need_nz = NPUSocInfo().need_nz
    
        def forward(
            self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
            kv_caches: List[KVCache],
            attn_metadata: AttentionMetadata,
            intermediate_tensors: Optional[IntermediateTensors],
            lora_requests: List[LoRARequest],
        ) -> torch.Tensor:
            """
            Performs the forward pass through the model, applying attention and token generation.
    
            Args:
                input_ids (torch.Tensor): Input tensor containing token IDs.
                positions (torch.Tensor): Indicate the position of each token in the input sequence.
                kv_caches (List[KVCache]): List of key-value caches for attention layers.
                attn_metadata (AttentionMetadata): Metadata related to attention mechanisms,including information 
                  relevant to the prefill and decode phases.
                intermediate_tensors (optional): Store intermediate states such as hidden states and residuals 
                  during model execution, facilitating operations like gradient checkpointing and model 
                  parallelism, default is None.
                lora_requests (List[LoRARequest]): List of LoRA requests to apply during forward pass.
    
            Returns:
                torch.Tensor: Logits or generated token predictions from the model.
            """
            is_prompt = attn_metadata.prefill_metadata is not None
    
            if kv_caches[0] is None:
                kv_caches, block_tables, slots = self.create_dummy_kv_cache(attn_metadata, input_ids)
            else:
                block_tables = self.create_block_tables(attn_metadata)
                slots = attn_metadata.slot_mapping
    
            if attn_metadata.prefill_metadata is None:
                input_lengths = attn_metadata.decode_metadata.seq_lens_tensor
                max_seq_len = attn_metadata.decode_metadata.max_seq_len
                query_lens = []
                lm_head_indices = None
            else:
                input_lengths = attn_metadata.prefill_metadata.seq_lens_tensor
                max_seq_len = attn_metadata.prefill_metadata.max_seq_len
                query_start_loc = attn_metadata.prefill_metadata.query_start_loc
                query_lens_tensor = query_start_loc[1:] - query_start_loc[:-1]
                if attn_metadata.decode_metadata is not None:
                    input_lengths = torch.cat((input_lengths, attn_metadata.decode_metadata.seq_lens_tensor), dim=0)
                    max_seq_len = max(max_seq_len, attn_metadata.decode_metadata.max_seq_len)
                    query_lens_tensor = F.pad(query_lens_tensor, (0, attn_metadata.num_decode_tokens), "constant", 1)
                query_lens = query_lens_tensor.tolist()
                lm_head_indices = query_lens_tensor.cumsum(dim=-1) - 1
    
            if not lora_requests:
                adapter_ids = ["base"] * len(input_lengths)
            else:
                adapter_ids = [lora_request.lora_name if lora_request else "base" for lora_request in lora_requests]
    
            # TODO: Can MindIE take advantage of intermediate_tensors?
            logits = self.mindie_model.forward_tensor(
                input_ids,
                positions,
                is_prompt,
                kv_caches,
                block_tables,
                slots,
                input_lengths,
                max_seq_len,
                lm_head_indices,
                adapter_ids=adapter_ids,
                q_lens=query_lens,
            )
    
            return logits
    
        def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor:
            return hidden_states
    
        def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype, device: torch.device
        ) -> IntermediateTensors:
            ...
    
        def sample(
            self,
            logits: torch.Tensor,
            sampling_metadata: SamplingMetadata,
        ) -> Optional[SamplerOutput]:
            """
            Samples tokens from the logits based on the provided sampling metadata.
    
            Args:
                logits (torch.Tensor): Logits tensor from which tokens will be sampled.
                sampling_metadata (SamplingMetadata): Metadata defining how sampling should be performed.
    
            Returns:
                Optional[SamplerOutput]: The sampler output from sampling.
            """
            next_tokens = self.sampler(logits, sampling_metadata) # hidden_states is logits
            return next_tokens
    
        def load_weights(self):
            """
            Loads the weights into the model, initializing the MindIE model and MindIE sampler.
            """
            self.weight_dtype = torch.get_default_dtype()
            torch.set_default_dtype(torch.float32)
    
            self.mindie_model = GeneratorTorch(self.mindie_config)
            self.sampler = MindIESampler(self.mindie_model)
    
            torch.set_default_dtype(self.weight_dtype)
    
        # when warmup, create dummy kvcache, block_tables, slot_mapping
        def create_dummy_kv_cache(self, attn_metadata, input_ids):
            """
            Creates a dummy key-value cache for attention during warmup phase.
    
            Args:
                attn_metadata (AttentionMetadata): Metadata related to attention for the current batch.
                input_ids (torch.Tensor): Input token IDs for the current batch.
    
            Returns:
                Tuple: A tuple containing the key-value cache, block tables, and slot mappings.
            """        
            dummy_block_size = 128
            max_s = max(attn_metadata.prefill_metadata.seq_lens_tensor)
            max_need_block = math.ceil(max_s / dummy_block_size)
            batch_size = len(attn_metadata.prefill_metadata.seq_lens_tensor)
            self.dummy_block_num = max_need_block * batch_size
    
            model_runner = self.mindie_model.model_wrapper.model_runner
            if not self.need_nz:
                dummy_kv_cache_shape = (
                    self.dummy_block_num,
                    dummy_block_size,
                    model_runner.num_kv_heads,
                    model_runner.head_size
                )
            else:
                dummy_kv_cache_shape = (
                    self.dummy_block_num,
                    model_runner.num_kv_heads * model_runner.head_size // 16,
                    dummy_block_size,
                    16
                )
            kv_cache = [
                (
                    torch.empty(
                        size=dummy_kv_cache_shape,
                        dtype=self.weight_dtype,
                        device="npu",
                    ),
                    torch.empty(
                        size=dummy_kv_cache_shape,
                        dtype=self.weight_dtype,
                        device="npu",
                    ),
                )
                for _ in range(model_runner.num_layers)
            ]
    
            block_tables = torch.zeros(batch_size, max_need_block, dtype=int, device="npu")
            slot = [i for i in range(dummy_block_size)]
            slots = []
            warm_up_len = len(input_ids)
            while warm_up_len > 0:
                if warm_up_len > dummy_block_size:
                    slots.extend(slot)
                    warm_up_len -= dummy_block_size
                else:
                    slots.extend(slot[:warm_up_len])
                    warm_up_len = 0
            slots = torch.tensor(slots, dtype=torch.long, device="npu")
            return kv_cache, block_tables, slots
    
        def create_block_tables(self, attn_metadata):
            """
            Creates block tables for attention, based on prefill and decode metadata.
            """
            if attn_metadata.prefill_metadata is None:
                return attn_metadata.decode_metadata.block_tables
            prefill_block_tables = attn_metadata.prefill_metadata.block_tables
            if prefill_block_tables.numel() == 0:
                return torch.tensor([0], dtype=torch.int32, device="npu")
            if attn_metadata.decode_metadata is None:
                return prefill_block_tables
    
            decode_block_tables = attn_metadata.decode_metadata.block_tables
            pad_size = prefill_block_tables.size(1) - decode_block_tables.size(1)
            if pad_size > 0:
                decode_block_tables = F.pad(decode_block_tables, (0, pad_size), "constant", 0)
            elif pad_size < 0:
                prefill_block_tables = F.pad(prefill_block_tables, (0, -pad_size), "constant", 0)
            return torch.cat((prefill_block_tables, decode_block_tables), dim=0)
    
    
    def get_architecture_class_name(model_config: ModelConfig) -> str:
        """
        Determines and returns the architecture class name based on the provided model configuration.
    
        This function checks the architecture type in the model's configuration and adjusts
        the architecture name in case quantization is enabled and the model is of type "MixtralForCausalLM".
        If quantization is enabled and not set to "fp8", the architecture name is updated to "QuantMixtralForCausalLM".
    
        Args:
            model_config (ModelConfig): The configuration object containing model-specific settings.
    
        Returns:
            str: The name of the model architecture class.
        """    
        architectures = getattr(model_config.hf_config, "architectures", [])
        if (
            model_config.quantization is not None
            and model_config.quantization != "fp8"
            and "MixtralForCausalLM" in architectures
        ):
            architectures = ["QuantMixtralForCausalLM"]
        return architectures[0]
    
    
    @contextlib.contextmanager
    def _set_default_torch_dtype(dtype: torch.dtype):
        """Sets the default torch dtype to the given dtype."""
        old_dtype = torch.get_default_dtype()
        torch.set_default_dtype(dtype)
        yield
        torch.set_default_dtype(old_dtype)
    
    
    def get_model(
        model_config: ModelConfig, device_config: DeviceConfig, load_config: LoadConfig, mindie_config, **kwargs
    ) -> nn.Module:
        """
        Loads and initializes a model based on the given configuration and prepares it for inference.
    
        This function instantiates the `MindIELlmWrapper` model with the provided `mindie_config`, 
        and loads the model weights based on the specified `load_config`. It also supports loading 
        LoRA configurations if provided. The model is moved to the appropriate device (e.g., NPU).
    
        Args:
            model_config (ModelConfig): The configuration object containing model-specific settings.
            device_config (DeviceConfig): The configuration object that defines the device settings (e.g., NPU).
            load_config (LoadConfig): The configuration object that specifies how to load the model weights.
            mindie_config: The configuration for MindIE specific parameters.
    
        Returns:
            nn.Module: The initialized Mindie model.
    
        """
        if kwargs.get("lora_config"):
            logger.info(
                "Using LoRA(s) with MindIE backend:\n"
                "Please make sure your '--lora-modules' matches with your 'lora_adapter.json' in the model directory!\n"
                "Current config for LoRA(s): %s",
                kwargs.get("lora_config"),
            )
    
        with _set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = MindIELlmWrapper(mindie_config)
            if load_config.load_format == LoadFormat.DUMMY:
                initialize_dummy_weights(model)
            else:
                model.load_weights()
            model = model.npu()
    
        return model.eval()
    
  • cover/vllm/model_executor/layers/npu_sampler.py:实现MindIESampler类,该类进行vLLM原生框架的数据结构和MindIE LLM模型仓底层数据结构之间的对接,对模型推理得到的结果进行后处理。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of codes in this file was copied from project [vLLM Team][vllm]
    import random
    from typing import Dict, List, Optional, Tuple
    
    import numpy as np
    import torch
    import torch.nn as nn
    from mindie_llm.text_generator.utils.sampling_metadata import (
        SamplingData,
        SamplingParam,
    )
    from vllm.model_executor.layers.sampler import (
        SamplerOutput,
        get_logprobs,
        _get_sampled_logprob_if_needed,
        _build_sampler_output,
    )
    from vllm.model_executor.sampling_metadata import (
        SamplingMetadata,
        SequenceGroupToSample,
    )
    from vllm.sampling_params import SamplingType
    from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
    
    # from loguru import logger
    _SAMPLING_EPS = 1e-5
    
    SampleResultType = List[Tuple[List[int], List[int]]]
    
    
    # TODO: Figure out how to remove _get_logprobs
    def _to_tensor(data, dtype=None):
        if dtype:
            return torch.tensor(data, dtype=dtype, device=torch.device("npu"))
        else:
            return torch.tensor(data, device=torch.device("npu"))
    
    
    def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
        """
        This function calculates the ranks of the chosen tokens in a logprob tensor.
    
        Args:
            x (torch.Tensor): 2D logprob tensor of shape (N, M)
                            where N is the no. of tokens and M is the vocab dim.
            indices (torch.Tensor): List of chosen token indices.
    
        Returns:
            torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
                        Each element in the returned tensor represents the rank
                        of the chosen token in the input logprob tensor.
        """
        vals = x[
            torch.arange(0, len(x), device=x.device, dtype=indices.dtype), indices
        ]
        return (x > vals[:, None]).long().sum(1).add_(1)
    
    
    def _get_prompt_logprob_if_needed(
        seq_group: SequenceGroupToSample,
        selected_logprobs: torch.Tensor,
        ranks: torch.Tensor,
        top_token_ids: torch.Tensor,
        top_logprobs: torch.Tensor,
        selected_logprobs_idx: int,
        top_logprob_idx: int,
    ):
        """Compute the prompt logprob from a sequence group if needed."""
        sampling_params = seq_group.sampling_params
        is_prompt = seq_group.is_prompt
    
        # Find prompt logprobs
        prompt_logprobs: Optional[PromptLogprobs] = None
        if is_prompt and sampling_params.prompt_logprobs is not None:
            prompt_logprobs = []
            num_logprobs = sampling_params.prompt_logprobs
            next_prompt_tokens = _get_next_prompt_tokens(seq_group)
            for token_id in next_prompt_tokens:
                # Calculate the prompt logprob of the real prompt tokens.
                # Use tuple here for performance (to use to_list()).
                # {token_id: (logprob, rank_from_vocab)}
                prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
                    token_id: (
                        selected_logprobs[selected_logprobs_idx].item(),
                        ranks[selected_logprobs_idx].item(),
                    )
                }
    
                # Add top K prompt logprobs along with its rank.
                if num_logprobs > 0:
                    prompt_logprobs_dict.update(
                        zip(
                            top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
                            zip(
                                top_logprobs[
                                    top_logprob_idx, :num_logprobs
                                ].tolist(),
                                # This is ranks. Since top_logprob is sorted,
                                # we can just use a range here.
                                range(1, num_logprobs + 1),
                            ),
                        )
                    )
                prompt_logprobs.append(
                    {
                        token_id: Logprob(*logprob_and_rank)
                        for token_id, logprob_and_rank in prompt_logprobs_dict.items()
                    }
                )
                # + 1 to go to the next prompt token.
                top_logprob_idx += 1
                selected_logprobs_idx += 1
        return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
    
    
    def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
        """Get a list of next prompt tokens to compute logprob from a
            given sequence group.
    
        It is used to compute prompt logprob. Imagine you have logprob for each
        query token. Query token needs to know the next prompt token id to compute
        prompt logprob. This is a helper to obtain next prompt token ids.
    
        This API has to be used only when the caller knows seq_group is in prefill
        stage.
    
        Returns:
            A list of next prompt tokens to compute logprob.
        """
        assert (
            seq_group.is_prompt
        ), "Caller should ensure the sequence group is in a prefill stage."
        seq_ids = seq_group.seq_ids
        query_len = seq_group.query_len
        assert query_len is not None
        # prompt has only 1 seq id.
        assert len(seq_ids) == 1
        seq_data = seq_group.seq_data[seq_ids[0]]
        computed_len = seq_data.get_num_computed_tokens()
        prompt_tokens = seq_data.prompt_token_ids
        # +1 because we are looking for a next prompt token.
        next_token_index_start = computed_len + 1
        next_token_index_end = min(computed_len + query_len + 1, len(prompt_tokens))
        next_prompt_tokens = prompt_tokens[
            next_token_index_start:next_token_index_end
        ]
        return next_prompt_tokens
    
    
    class MindIESampler(nn.Module):
        """
        A sampler class for generating tokens using the MindIE Sampler.
    
        This class performs sampling over token logits, generating tokens based on the
        sampling configurations defined in `sampling_metadata`. It integrates with the
        `mindie_model`, a token generation model, to handle different sampling strategies
        such as greedy, random, and beam search (although beam search is not yet implemented).
    
        Attributes:
            mindie_model (GeneratorTorch): Integrate MindIE model initialized with the
              configuration `mindie_config` and call the model's `sample` method to generate tokens, it handles the core sampling logic for generating the next token in the sequence.
    
            include_gpu_probs_tensor (bool): Flag indicating whether to include GPU-based
                probabilities in the returned output tensor.
    
        Methods:
            forward:
                Performs token sampling and return the results including log probabilities and
                  sampled tokens based on the provided logits and sampling metadata.
    
            construct_data:
                Constructs the necessary data and parameters for sampling based on the provided
                  metadata, including configuration for temperature, penalties, and sampling type.
    
            recover_data:
                Post-processes the sampled tokens and log probabilities, categorizing the results
                  according to the sampling types (Greedy, Random). It also constructs the final
                  sample results and optionally includes GPU-based probabilities if requested.
        """
    
        def __init__(self, mindie_model):
            """
            Initializes the MindIESampler with the given configuration and optional GPU probability flag.
    
            Args:
                mindie_config (MindIESamplerConfig): Configuration object containing the parameters
                  for the MindIE model.
    
                include_gpu_probs_tensor (bool, optional): If set to True, the method will include
                  GPU-based probabilities in the returned output tensor. Default is False.
    
            """
            super().__init__()
            self.mindie_model = mindie_model
            self.include_gpu_probs_tensor = False
    
        def forward(
            self,
            logits: torch.Tensor,
            sampling_metadata: SamplingMetadata,
        ) -> Optional[SamplerOutput]:
            """
            Performs token sampling based on the provided logits and sampling metadata.
    
            This method uses the `mindie_model` to generate token samples from the logits and metadata.
            The generated tokens and their associated log probabilities are returned as a result.
    
            Args:
                logits (torch.Tensor): A tensor containing the logits for token generation.
                    This tensor should be of shape `(seq_length, vocab_size)`.
    
                sampling_metadata (SamplingMetadata): Metadata containing information about 
                    the sampling configuration, including sampling types, sequence groups, 
                    and other sampling parameters.
    
            Returns:
                Optional[SamplerOutput]: The output of the token sampling process, which contains
                the sampled tokens and their associated log probabilities. 
            """
            _, vocab_size = logits.shape
            expanded_logits_lst = []
            idx = 0
            for seq_group in sampling_metadata.seq_groups:
                best_of = seq_group.sampling_params.best_of
                num_seqs = len(seq_group.seq_ids)
                seq_group_logits = logits[idx:idx + num_seqs]
                if seq_group.is_prompt:
                    if seq_group_logits.dim() == 1:
                        seq_group_logits = seq_group_logits.unsqueeze(0)
                    expanded_logits = seq_group_logits.repeat_interleave(
                        best_of, dim=0
                    )
                else:
                    expanded_logits = seq_group_logits
                expanded_logits_lst.append(expanded_logits)
                idx += num_seqs
            expanded_logits = torch.cat(expanded_logits_lst, dim=0)
    
            mindie_sampling_data, mindie_sampling_param = self.construct_data(
                sampling_metadata, vocab_size
            )
            probs = torch.softmax(expanded_logits, dim=-1, dtype=torch.float)
            logprobs = torch.log_softmax(expanded_logits, dim=-1, dtype=torch.float)
            if mindie_sampling_param:
                sampling_mask = (
                    mindie_sampling_param.do_sample_meta.do_sample_tensor.tolist()
                )
            else:
                sampling_mask = [
                    seq_group.do_sample
                    for seq_group in sampling_metadata.seq_groups
                ]
    
            filtered_logits = expanded_logits[sampling_mask]
    
            if filtered_logits.size(0) > 0:
                next_tokens, _ = self.mindie_model.sample(
                    filtered_logits,
                    sampling_data=mindie_sampling_data,
                    sampling_param=mindie_sampling_param,
                )
            else:
                next_tokens = None
    
            sample_results, maybe_sampled_tokens_tensor = recover_data(
                sampling_metadata=sampling_metadata,
                sampled_tokens=next_tokens,
                logprobs=logprobs,
                include_gpu_probs_tensor=self.include_gpu_probs_tensor,
            )
            if self.include_gpu_probs_tensor:
                if maybe_sampled_tokens_tensor is None:
                    raise RuntimeError("maybe_sampled_tokens_tensor is None")
                on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
            else:
                on_device_tensors = None
    
            # Get the logprobs query results.
            prompt_logprobs, sample_logprobs = get_logprobs(
                logprobs, sampling_metadata, sample_results
            )
            return _build_sampler_output(
                sample_results,
                sampling_metadata,
                prompt_logprobs,
                sample_logprobs,
                on_device_tensors=on_device_tensors,
            )
    
        def construct_data(
            self,
            sampling_metadata: SamplingMetadata,
            vocab_size: int,
        ) -> Tuple[SamplingData, SamplingParam]:
            all_input_tokens: List[List[int]] = []
            prompt_tokens: List[List[int]] = []
            output_tokens: List[List[int]] = []
            top_ks: List[int] = []
            temperatures: List[float] = []
            top_ps: List[float] = []
            min_ps: List[float] = []
            presence_penalties: List[float] = []
            frequency_penalties: List[float] = []
            repetition_penalties: List[float] = []
            sampling_seeds: List[int] = []
            sample_indices: List[int] = []
            do_samples: List[bool] = []  # To Do
            do_penalties = False
            do_top_p_top_k = False
            do_min_p = False
            greedy_flag = False
            non_greedy_flag = False
    
            if sampling_metadata.seq_groups is None:
                raise RuntimeError(
                    "sampling_metadata.seq_group is None, no data received."
                )
            for seq_group in sampling_metadata.seq_groups:
                seq_ids = seq_group.seq_ids
                sampling_params = seq_group.sampling_params
                temperature = sampling_params.temperature
                p = sampling_params.presence_penalty
                f = sampling_params.frequency_penalty
                r = sampling_params.repetition_penalty
                top_p = sampling_params.top_p
                min_p = sampling_params.min_p
                is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
                best_of = sampling_params.best_of
                if not seq_group.is_prompt:
                    do_samples.extend([seq_group.do_sample] * len(seq_ids))  # TODO
                else:
                    do_samples.extend([seq_group.do_sample] * best_of)
                # seed = sampling_params.seed
                if not is_greedy:
                    non_greedy_flag = True
                if is_greedy:
                    seed = 0
                    greedy_flag = True
                else:
                    seed = sampling_params.seed
    
                # k should not be greater than the vocab size.
                top_k = min(sampling_params.top_k, vocab_size)
                top_k = vocab_size if top_k == -1 else top_k
                if temperature < _SAMPLING_EPS:
                    temperature = 1.0
                if not do_top_p_top_k and (
                    top_p < 1.0 - _SAMPLING_EPS or top_k != vocab_size
                ):
                    do_top_p_top_k = True
                if not do_min_p and min_p > _SAMPLING_EPS:
                    do_min_p = True
                if not do_penalties:
                    if abs(p) >= _SAMPLING_EPS:
                        do_penalties = True
                    elif abs(f) >= _SAMPLING_EPS:
                        do_penalties = True
                    elif abs(r - 1.0) >= _SAMPLING_EPS:
                        do_penalties = True
    
                if (
                    seq_group.is_prompt
                    and sampling_params.prompt_logprobs is not None
                ):
                    # For tokens in the prompt that we only need to get
                    # their logprobs
                    query_len = seq_group.query_len
                    if query_len is None:
                        raise RuntimeError("query_len is None")
                    prefill_len = len(seq_group.prompt_logprob_indices)
                    temperatures += [temperature] * prefill_len
                    sampling_seeds += [seed] * prefill_len
                    top_ps += [top_p] * prefill_len
                    top_ks += [top_k] * prefill_len
                    min_ps += [min_p] * prefill_len
                    presence_penalties += [0] * prefill_len
                    frequency_penalties += [0] * prefill_len
                    repetition_penalties += [1] * prefill_len
                    prompt_tokens.extend([] for _ in range(prefill_len))
                    output_tokens.extend([] for _ in range(prefill_len))
                    all_input_tokens.extend([] for _ in range(prefill_len))
    
                if seq_group.do_sample:
                    sample_lens = len(seq_group.sample_indices)
                    if sample_lens != len(seq_ids):
                        raise ValueError("sample_lens != len(seq_ids)")
                    for seq_id in seq_ids:
                        if seq_group.is_prompt:
                            seq_data = seq_group.seq_data[seq_id]
                            prompt_tokens.extend(
                                [seq_data.prompt_token_ids] * best_of
                            )
                            output_tokens.extend(
                                [seq_data.output_token_ids] * best_of
                            )
                            all_input_tokens.extend(
                                [
                                    seq_data.prompt_token_ids
                                    + seq_data.output_token_ids
                                ]
                                * best_of
                            )
                            if seed is None:
                                lo, hi = (
                                    torch.iinfo(torch.long).min,
                                    torch.iinfo(torch.long).max,
                                )
                                seeds = [
                                    random.randint(lo, hi) for _ in range(best_of)
                                ]
                            else:
                                seeds = [seed] * best_of
                            temperatures += [temperature] * best_of
                            sampling_seeds += seeds
                            top_ps += [top_p] * best_of
                            top_ks += [top_k] * best_of
                            min_ps += [min_p] * best_of
                            presence_penalties += [p] * best_of
                            frequency_penalties += [f] * best_of
                            repetition_penalties += [r] * best_of
    
                        else:
                            seq_data = seq_group.seq_data[seq_id]
                            prompt_tokens.append(seq_data.prompt_token_ids)
                            output_tokens.append(seq_data.output_token_ids)
                            all_input_tokens.append(
                                seq_data.prompt_token_ids
                                + seq_data.output_token_ids
                            )
                            if seed is None:
                                lo, hi = (
                                    torch.iinfo(torch.long).min,
                                    torch.iinfo(torch.long).max,
                                )
                                seeds = [random.randint(lo, hi)]
                            else:
                                seeds = [seed]
                            temperatures += [temperature]
                            sampling_seeds += seeds
                            top_ps += [top_p]
                            top_ks += [top_k]
                            min_ps += [min_p]
                            presence_penalties += [p]
                            frequency_penalties += [f]
                            repetition_penalties += [r]
    
            repetition_penalties = np.array(repetition_penalties, dtype=np.float32)
            frequency_penalties = np.array(frequency_penalties, dtype=np.float32)
            presence_penalties = np.array(presence_penalties, dtype=np.float32)
            temperatures = np.array(temperatures, dtype=np.float32)
            top_ks = np.array(top_ks, dtype=np.int32)
            top_ps = np.array(top_ps, dtype=np.float32)
            sampling_seeds = np.array(sampling_seeds)
            do_samples = np.array(do_samples)
    
            max_tokens_len = max(
                [len(tokens) for tokens in all_input_tokens], default=0
            )
            # TODO: tokens are tuple now
            padded_all_input_tokens = [
                list(tokens) + [vocab_size] * (max_tokens_len - len(tokens))
                for tokens in all_input_tokens
            ]
            padded_all_input_tokens = np.array(
                padded_all_input_tokens, dtype=np.int32
            )
            output_max_len = max(
                [len(tokens) for tokens in output_tokens], default=0
            )
            # TODO: tokens are tuple now
            padded_output_tokens = [
                list(tokens) + [vocab_size] * (output_max_len - len(tokens))
                for tokens in output_tokens
            ]
            padded_output_tokens = np.array(padded_output_tokens, dtype=np.int32)
    
            all_input_ids_tensor = (
                _to_tensor(padded_all_input_tokens, torch.int32)
                if padded_all_input_tokens is not None
                else None
            )
            output_ids_tensor = (
                _to_tensor(padded_output_tokens, torch.int32)
                if padded_output_tokens is not None
                else None
            )
            mindie_sampling_data = SamplingData(
                all_input_ids=all_input_ids_tensor, output_ids=output_ids_tensor
            )
    
            if not non_greedy_flag:
                mindie_sampling_param = None
            else:
                mindie_sampling_param = SamplingParam.from_numpy(
                    repetition_penalty=repetition_penalties,
                    frequency_penalty=frequency_penalties,
                    presence_penalty=presence_penalties,
                    temperature=temperatures,
                    top_k=top_ks,
                    top_p=top_ps,
                    seed=sampling_seeds,
                    do_sample=do_samples,
                    to_tensor=_to_tensor,
                )
            return (mindie_sampling_data, mindie_sampling_param)
    
    
    def recover_data(
        sampling_metadata: SamplingMetadata,
        sampled_tokens: np.ndarray,
        logprobs: torch.Tensor,
        include_gpu_probs_tensor: bool,
    ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
        categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
            t: [] for t in SamplingType
        }
        categorized_sample_indices = sampling_metadata.categorized_sample_indices
        for i, seq_group in enumerate(sampling_metadata.seq_groups):
            sampling_params = seq_group.sampling_params
            sampling_type = sampling_params.sampling_type
            categorized_seq_group_ids[sampling_type].append(i)
    
        sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
        sample_metadata = {}
    
        # Create output tensor for sampled token ids.
        sampled_tokens = sampled_tokens.tolist()
        if include_gpu_probs_tensor:
            sampled_token_ids_tensor = torch.empty(
                logprobs.shape[0], 1, dtype=torch.long, device=logprobs.device
            )
        else:
            sampled_token_ids_tensor = None
    
        for sampling_type in SamplingType:
            sample_indices = categorized_sample_indices[sampling_type]
            num_tokens = len(sample_indices)
            if num_tokens == 0:
                continue
    
            seq_group_id = categorized_seq_group_ids[sampling_type]
            seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
            sample_metadata[sampling_type] = (seq_group_id, seq_groups)
    
        greedy_samples = []
        random_samples = []
        beam_samples = []
        idx = 0
        for seq_group in sampling_metadata.seq_groups:
            seq_ids = seq_group.seq_ids
            sampling_params = seq_group.sampling_params
            if sampling_params.sampling_type == SamplingType.GREEDY:
                for seq_id in seq_ids:
                    greedy_samples.extend([sampled_tokens[idx] for i in seq_ids])
                    idx += 1
            elif sampling_params.sampling_type in (
                SamplingType.RANDOM,
                SamplingType.RANDOM_SEED,
            ):
                if seq_group.is_prompt:
                    for seq_id in seq_ids:
                        random_samples.extend(
                            [
                                sampled_tokens[idx + i]
                                for i in range(sampling_params.best_of)
                            ]
                        )
                        idx += sampling_params.best_of
                else:
                    for seq_id in seq_ids:
                        random_samples.append(sampled_tokens[idx])
                        idx += 1
            elif sampling_params.sampling_type == SamplingType.BEAM:
                if seq_group.is_prompt:
                    for seq_id in seq_ids:
                        beam_samples.extend(
                            [
                                sampled_tokens[idx + i]
                                for i in range(sampling_params.best_of)
                            ]
                        )
                        idx += sampling_params.best_of
                else:
                    for seq_id in seq_ids:
                        beam_samples.append(sampled_tokens[idx])
                        idx += 1
    
        for sampling_type in SamplingType:
            if sampling_type not in sample_metadata:
                continue
            (seq_group_id, seq_groups) = sample_metadata[sampling_type]
            if sampling_type == SamplingType.GREEDY:
                sample_results = _greedy_sample(seq_groups, greedy_samples)
            elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
                sample_results = _random_sample(seq_groups, random_samples)
            elif sampling_type == SamplingType.BEAM:
                sample_results = beam_wrap(seq_groups, beam_samples)
            sample_results_dict.update(zip(seq_group_id, sample_results))
    
        sample_results = [
            sample_results_dict.get(i, ([], []))
            for i in range(len(sampling_metadata.seq_groups))
        ]
        return sample_results, sampled_token_ids_tensor
    
    
    def _greedy_sample(
        selected_seq_groups: List[SequenceGroupToSample],
        samples: np.ndarray,
    ):
        samples_lst = samples
        sample_idx = 0
        results: SampleResultType = []
        for seq_group in selected_seq_groups:
            if not seq_group.do_sample:
                results.append(([], []))
                continue
            seq_ids = seq_group.seq_ids
            num_parent_seqs = len(seq_ids)
            assert num_parent_seqs == 1, "Greedy sampling should have only one seq."
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = [samples_lst[sample_idx]]
            results.append((next_token_ids, parent_ids))
            sample_idx += num_parent_seqs
        return results
    
    
    def _random_sample(
        selected_seq_groups: List[SequenceGroupToSample],
        samples: np.ndarray,
    ):
        sample_idx = 0
        results: SampleResultType = []
        for seq_group in selected_seq_groups:
            if not seq_group.do_sample:
                results.append(([], []))
                continue
    
            seq_ids = seq_group.seq_ids
            sampling_params = seq_group.sampling_params
            is_prompt = seq_group.is_prompt
            num_parent_seqs = len(seq_ids)
            if is_prompt:
                parent_ids = [0] * sampling_params.best_of
                next_token_ids = samples[
                    sample_idx:sample_idx + sampling_params.best_of
                ]
                sample_idx += sampling_params.best_of
            else:
                parent_ids = list(range(num_parent_seqs))
                next_token_ids = samples[sample_idx:sample_idx + num_parent_seqs]
                sample_idx += num_parent_seqs
            results.append((next_token_ids, parent_ids))
        return results
    
    
    def beam_wrap(
        selected_seq_groups: List[SequenceGroupToSample],
        samples: np.ndarray,
    ):
        raise ValueError(f"Unsupported sampling type: beam search")
  • cover/vllm/platforms/__init__.py:增加对昇腾NPU环境下NpuPlatform的识别。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of code in this file was copied from project [vLLM Team][vllm] for adapting usage
    
    from .interface import Platform, PlatformEnum, UnspecifiedPlatform
    
    current_platform: Platform
    
    # NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
    # they only indicate the build configuration, not the runtime environment.
    # For example, people can install a cuda build of pytorch but run on tpu.
    
    is_tpu = False
    try:
        # While it's technically possible to install libtpu on a non-TPU machine,
        # this is a very uncommon scenario. Therefore, we assume that libtpu is
        # installed if and only if the machine has TPUs.
        import libtpu  # noqa: F401
        is_tpu = True
    except Exception:
        pass
    
    is_cuda = False
    
    try:
        import pynvml
        pynvml.nvmlInit()
        try:
            if pynvml.nvmlDeviceGetCount() > 0:
                is_cuda = True
        finally:
            pynvml.nvmlShutdown()
    except Exception:
        pass
    
    is_rocm = False
    
    try:
        import amdsmi
        amdsmi.amdsmi_init()
        try:
            if len(amdsmi.amdsmi_get_processor_handles()) > 0:
                is_rocm = True
        finally:
            amdsmi.amdsmi_shut_down()
    except Exception:
        pass
    
    is_cpu = False
    try:
        from importlib.metadata import version
        is_cpu = "cpu" in version("vllm")
    except Exception:
        pass
    
    
    is_npu = False
    try:
        import torch_npu
        is_npu = True
    except Exception:
        pass
    
    if is_tpu:
        # people might install pytorch built with cuda but run on tpu
        # so we need to check tpu first
        from .tpu import TpuPlatform
        current_platform = TpuPlatform()
    elif is_cuda:
        from .cuda import CudaPlatform
        current_platform = CudaPlatform()
    elif is_rocm:
        from .rocm import RocmPlatform
        current_platform = RocmPlatform()
    elif is_npu:
        from .npu import NpuPlatform
        current_platform = NpuPlatform()
    elif is_cpu:
        from .cpu import CpuPlatform
        current_platform = CpuPlatform()
    else:
        current_platform = UnspecifiedPlatform()
    
    __all__ = ['Platform', 'PlatformEnum', 'current_platform']
    
  • cover/vllm/platforms/interface.py:
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of code in this file was copied from project [vLLM Team][vllm] for adapting usage
    
    import enum
    from typing import NamedTuple, Optional, Tuple, Union
    
    import torch
    
    
    class PlatformEnum(enum.Enum):
        CUDA = enum.auto()
        ROCM = enum.auto()
        TPU = enum.auto()
        CPU = enum.auto()
        NPU = enum.auto()
        UNSPECIFIED = enum.auto()
    
    
    class DeviceCapability(NamedTuple):
        major: int
        minor: int
    
        def as_version_str(self) -> str:
            return f"{self.major}.{self.minor}"
    
        def to_int(self) -> int:
            """
            Express device capability as an integer ``<major><minor>``.
    
            It is assumed that the minor version is always a single digit.
            """
            assert 0 <= self.minor < 10
            return self.major * 10 + self.minor
    
    
    class Platform:
        _enum: PlatformEnum
    
        @classmethod
        def get_device_capability(
            cls,
            device_id: int = 0,
        ) -> Optional[DeviceCapability]:
            """Stateless version of :func:`torch.cuda.get_device_capability`."""
            return None
    
        @classmethod
        def has_device_capability(
            cls,
            capability: Union[Tuple[int, int], int],
            device_id: int = 0,
        ) -> bool:
            """
            Test whether this platform is compatible with a device capability.
    
            The ``capability`` argument can either be:
    
            - A tuple ``(major, minor)``.
            - An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
            """
            current_capability = cls.get_device_capability(device_id=device_id)
            if current_capability is None:
                return False
    
            if isinstance(capability, tuple):
                return current_capability >= capability
    
            return current_capability.to_int() >= capability
    
        @classmethod
        def get_device_name(cls, device_id: int = 0) -> str:
            raise NotImplementedError
    
        @classmethod
        def inference_mode(cls):
            """A device-specific wrapper of `torch.inference_mode`.
    
            This wrapper is recommended because some hardware backends such as TPU
            do not support `torch.inference_mode`. In such a case, they will fall
            back to `torch.no_grad` by overriding this method.
            """
            return torch.inference_mode(mode=True)
    
        def is_cuda(self) -> bool:
            return self._enum == PlatformEnum.CUDA
    
        def is_rocm(self) -> bool:
            return self._enum == PlatformEnum.ROCM
    
        def is_tpu(self) -> bool:
            return self._enum == PlatformEnum.TPU
    
        def is_cpu(self) -> bool:
            return self._enum == PlatformEnum.CPU
    
        def is_npu(self) -> bool:
            return self._enum == PlatformEnum.NPU
    
        def is_cuda_alike(self) -> bool:
            """Stateless version of :func:`torch.cuda.is_available`."""
            return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
    
    
    class UnspecifiedPlatform(Platform):
        _enum = PlatformEnum.UNSPECIFIED
    
  • cover/vllm/platforms/npu.py
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of code in this file was copied from project [vLLM Team][vllm] for adapting usage
    
    from typing import Tuple
    
    import torch
    
    from .interface import Platform, PlatformEnum
    
    
    class NpuPlatform(Platform):
        _enum = PlatformEnum.NPU
    
        @staticmethod
        def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
            raise RuntimeError("NPU does not have device capability.")
    
        @staticmethod
        def inference_mode():
            return torch.no_grad()
    
  • cover/vllm/utils.py
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of code in this file was copied from project [vLLM Team][vllm] for adapting usage
    
    import argparse
    import asyncio
    import contextlib
    import datetime
    import enum
    import gc
    import inspect
    import ipaddress
    import os
    import random
    import socket
    import subprocess
    import sys
    import tempfile
    import threading
    import uuid
    import warnings
    import weakref
    from asyncio import FIRST_COMPLETED, ensure_future
    from functools import lru_cache, partial, wraps
    from platform import uname
    from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
                        Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
                        Type, TypeVar, Union, overload)
    from uuid import uuid4
    
    import numpy as np
    import numpy.typing as npt
    import psutil
    import torch
    import torch.types
    import yaml
    from packaging.version import Version
    from typing_extensions import ParamSpec, TypeIs, assert_never
    
    import vllm.envs as envs
    from vllm.logger import enable_trace_function_call, init_logger
    from vllm.platforms import current_platform
    
    logger = init_logger(__name__)
    
    # Exception strings for non-implemented encoder/decoder scenarios
    
    STR_NOT_IMPL_ENC_DEC_SWA = \
        "Sliding window attention for encoder/decoder models " + \
                        "is not currently supported."
    
    STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
        "Prefix caching for encoder/decoder models " + \
                        "is not currently supported."
    
    STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
        "Chunked prefill for encoder/decoder models " + \
                        "is not currently supported."
    
    STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
        "Models with logits_soft_cap "
        "require FlashInfer backend, which is "
        "currently not supported for encoder/decoder "
        "models.")
    
    STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
                                 "supported with encoder/decoder "
                                 "models.")
    
    STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
                               "currently supported with "
                               "encoder/decoder models.")
    
    STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
                               "supported with encoder/decoder "
                               "models.")
    
    STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
                                     "currently supported with encoder/"
                                     "decoder models.")
    
    STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
                                    "currently supported with encoder/"
                                    "decoder models.")
    
    STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
                                           "currently supported with encoder/"
                                           "decoder models.")
    
    STR_NOT_IMPL_ENC_DEC_CPU = ("CPU is not currently supported with "
                                "encoder/decoder models.")
    
    # Efficiently import all enc/dec error strings
    # rather than having to import all of the above
    STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
        "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
        "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
        "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
        STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
        "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
        "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
        "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
        "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
        "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
        "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
        "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
        "STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU
    }
    
    # Constants related to forcing the attention backend selection
    
    # String name of register which may be set in order to
    # force auto-selection of attention backend by Attention
    # wrapper
    STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
    
    # Possible string values of STR_BACKEND_ENV_VAR
    # register, corresponding to possible backends
    STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
    STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
    STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
    STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
    STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
    STR_INVALID_VAL: str = "INVALID"
    
    GiB_bytes = 1 << 30
    """The number of bytes in one gibibyte (GiB)."""
    
    STR_DTYPE_TO_TORCH_DTYPE = {
        "half": torch.half,
        "bfloat16": torch.bfloat16,
        "float": torch.float,
        "fp8": torch.uint8,
        "fp8_e4m3": torch.uint8,
        "fp8_e5m2": torch.uint8,
    }
    
    TORCH_DTYPE_TO_NUMPY_DTYPE = {
        torch.float16: np.float16,
        torch.float32: np.float32,
        torch.float64: np.float64,
        torch.uint8: np.uint8,
        torch.int32: np.int32,
        torch.int64: np.int64,
    }
    
    P = ParamSpec('P')
    K = TypeVar("K")
    T = TypeVar("T")
    U = TypeVar("U")
    
    
    class _Sentinel:
        ...
    
    
    ALL_PINNED_SENTINEL = _Sentinel()
    
    
    class Device(enum.Enum):
        GPU = enum.auto()
        CPU = enum.auto()
    
    
    class Counter:
    
        def __init__(self, start: int = 0) -> None:
            self.counter = start
    
        def __next__(self) -> int:
            i = self.counter
            self.counter += 1
            return i
    
        def reset(self) -> None:
            self.counter = 0
    
    
    class LRUCache(Generic[T]):
    
        def __init__(self, capacity: int):
            self.cache: OrderedDict[Hashable, T] = OrderedDict()
            self.pinned_items: Set[Hashable] = set()
            self.capacity = capacity
    
        def __contains__(self, key: Hashable) -> bool:
            return key in self.cache
    
        def __len__(self) -> int:
            return len(self.cache)
    
        def __getitem__(self, key: Hashable) -> T:
            value = self.cache[key]  # Raise KeyError if not exists
            self.cache.move_to_end(key)
            return value
    
        def __setitem__(self, key: Hashable, value: T) -> None:
            self.put(key, value)
    
        def __delitem__(self, key: Hashable) -> None:
            self.pop(key)
    
        def touch(self, key: Hashable) -> None:
            self.cache.move_to_end(key)
    
        def get(self,
                key: Hashable,
                default_value: Optional[T] = None) -> Optional[T]:
            value: Optional[T]
            if key in self.cache:
                value = self.cache[key]
                self.cache.move_to_end(key)
            else:
                value = default_value
            return value
    
        def put(self, key: Hashable, value: T) -> None:
            self.cache[key] = value
            self.cache.move_to_end(key)
            self._remove_old_if_needed()
    
        def pin(self, key: Hashable) -> None:
            """
            Pins a key in the cache preventing it from being
            evicted in the LRU order.
            """
            if key not in self.cache:
                raise ValueError(f"Cannot pin key: {key} not in cache.")
            self.pinned_items.add(key)
    
        def remove_oldest(self, remove_pinned=False):
            if not self.cache:
                return
    
            if not remove_pinned:
                # pop the oldest item in the cache that is not pinned
                lru_key = next(
                    (key for key in self.cache if key not in self.pinned_items),
                    ALL_PINNED_SENTINEL)
                if lru_key is ALL_PINNED_SENTINEL:
                    raise RuntimeError("All items are pinned, "
                                       "cannot remove oldest from the cache.")
            else:
                lru_key = next(iter(self.cache))
            self.pop(lru_key)
    
        def pop(self,
                key: Hashable,
                default_value: Optional[T] = None) -> Optional[T]:
            run_on_remove = key in self.cache
            value: Optional[T] = self.cache.pop(key, default_value)
            # remove from pinned items
            if key in self.pinned_items:
                self._unpin(key)
            if run_on_remove:
                self._on_remove(key, value)
            return value
    
        def clear(self):
            while len(self.cache) > 0:
                self.remove_oldest(remove_pinned=True)
            self.cache.clear()
    
        def _unpin(self, key: Hashable) -> None:
            self.pinned_items.remove(key)
    
        def _on_remove(self, key: Hashable, value: Optional[T]):
            pass
    
        def _remove_old_if_needed(self) -> None:
            while len(self.cache) > self.capacity:
                self.remove_oldest()
    
    
    class PyObjectCache:
        """Used to cache python objects to avoid object allocations
        across scheduler iterations.
        """
    
        def __init__(self, obj_builder):
            self._obj_builder = obj_builder
            self._index = 0
    
            self._obj_cache = []
            for _ in range(128):
                self._obj_cache.append(self._obj_builder())
    
        def get_object(self):
            """Returns a pre-allocated cached object. If there is not enough
            objects, then the cache size will double.
            """
            if self._index >= len(self._obj_cache):
                self._grow_cache()
                assert self._index < len(self._obj_cache)
    
            obj = self._obj_cache[self._index]
            self._index += 1
    
            return obj
    
        def reset(self):
            """Makes all cached-objects available for the next scheduler iteration.
            """
            self._index = 0
    
        def _grow_cache(self):
            # Double the size of the cache
            num_objs = len(self._obj_cache)
            for _ in range(num_objs):
                self._obj_cache.append(self._obj_builder())
    
    
    def is_hip() -> bool:
        return torch.version.hip is not None
    
    
    @lru_cache(maxsize=None)
    def is_cpu() -> bool:
        from importlib.metadata import PackageNotFoundError, version
        try:
            return "cpu" in version("vllm")
        except PackageNotFoundError:
            return False
    
    
    @lru_cache(maxsize=None)
    def is_openvino() -> bool:
        from importlib.metadata import PackageNotFoundError, version
        try:
            return "openvino" in version("vllm")
        except PackageNotFoundError:
            return False
    
    
    @lru_cache(maxsize=None)
    def is_neuron() -> bool:
        try:
            import transformers_neuronx
        except ImportError:
            transformers_neuronx = None
        return transformers_neuronx is not None
    
    
    @lru_cache(maxsize=None)
    def is_xpu() -> bool:
        from importlib.metadata import PackageNotFoundError, version
        try:
            is_xpu_flag = "xpu" in version("vllm")
        except PackageNotFoundError:
            return False
        # vllm is not build with xpu
        if not is_xpu_flag:
            return False
        try:
            import intel_extension_for_pytorch as ipex  # noqa: F401
            _import_ipex = True
        except ImportError as e:
            logger.warning("Import Error for IPEX: %s", e.msg)
            _import_ipex = False
        # ipex dependency is not ready
        if not _import_ipex:
            logger.warning("not found ipex lib")
            return False
        return hasattr(torch, "xpu") and torch.xpu.is_available()
    
    
    @lru_cache(maxsize=None)
    def is_npu() -> bool:
        try:
            import torch_npu
        except ImportError:
            torch_npu = None
        return torch_npu is not None
    
    
    @lru_cache(maxsize=None)
    def get_max_shared_memory_bytes(gpu: int = 0) -> int:
        """Returns the maximum shared memory per thread block in bytes."""
        from vllm import _custom_ops as ops
        max_shared_mem = (
            ops.get_max_shared_memory_per_block_device_attribute(gpu))
        # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
        # will fail
        assert max_shared_mem > 0, "max_shared_mem can not be zero"
        return int(max_shared_mem)
    
    
    def get_cpu_memory() -> int:
        """Returns the total CPU memory of the node in bytes."""
        return psutil.virtual_memory().total
    
    
    def seed_everything(seed: int) -> None:
        """
        Set the seed of each random module.
    
        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
        """
        random.seed(seed)
        np.random.seed(seed)
    
        if current_platform.is_cuda_alike():
            torch.cuda.manual_seed_all(seed)
    
        if is_xpu():
            torch.xpu.manual_seed_all(seed)
    
    
    def random_uuid() -> str:
        return str(uuid.uuid4().hex)
    
    
    @lru_cache(maxsize=None)
    def get_vllm_instance_id() -> str:
        """
        If the environment variable VLLM_INSTANCE_ID is set, return it.
        Otherwise, return a random UUID.
        Instance id represents an instance of the VLLM. All processes in the same
        instance should have the same instance id.
        """
        return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}"
    
    
    @lru_cache(maxsize=None)
    def in_wsl() -> bool:
        # Reference: https://github.com/microsoft/WSL/issues/4071
        return "microsoft" in " ".join(uname()).lower()
    
    
    def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
        """Take a blocking function, and run it on in an executor thread.
    
        This function prevents the blocking function from blocking the
        asyncio event loop.
        The code in this function needs to be thread safe.
        """
    
        def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
            loop = asyncio.get_event_loop()
            p_func = partial(func, *args, **kwargs)
            return loop.run_in_executor(executor=None, func=p_func)
    
        return _async_wrapper
    
    
    async def iterate_with_cancellation(
        iterator: AsyncGenerator[T, None],
        is_cancelled: Callable[[], Awaitable[bool]],
    ) -> AsyncGenerator[T, None]:
        """Convert async iterator into one that polls the provided function
        at least once per second to check for client cancellation.
        """
    
        # Can use anext() in python >= 3.10
        awaits = [ensure_future(iterator.__anext__())]
        while True:
            done, pending = await asyncio.wait(awaits, timeout=1)
            if await is_cancelled():
                with contextlib.suppress(BaseException):
                    awaits[0].cancel()
                    await iterator.aclose()
                raise asyncio.CancelledError("client cancelled")
            if done:
                try:
                    item = await awaits[0]
                    awaits[0] = ensure_future(iterator.__anext__())
                    yield item
                except StopAsyncIteration:
                    # we are done
                    return
    
    
    async def merge_async_iterators(
        *iterators: AsyncGenerator[T, None],
        is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None,
    ) -> AsyncGenerator[Tuple[int, T], None]:
        """Merge multiple asynchronous iterators into a single iterator.
    
        This method handle the case where some iterators finish before others.
        When it yields, it yields a tuple (i, item) where i is the index of the
        iterator that yields the item.
    
        It also optionally polls a provided function at least once per second
        to check for client cancellation.
        """
    
        # Can use anext() in python >= 3.10
        awaits = {
            ensure_future(pair[1].__anext__()): pair
            for pair in enumerate(iterators)
        }
        timeout = None if is_cancelled is None else 1
        try:
            while awaits:
                done, pending = await asyncio.wait(awaits.keys(),
                                                   return_when=FIRST_COMPLETED,
                                                   timeout=timeout)
                if is_cancelled is not None and await is_cancelled():
                    raise asyncio.CancelledError("client cancelled")
                for d in done:
                    pair = awaits.pop(d)
                    try:
                        item = await d
                        i, it = pair
                        awaits[ensure_future(it.__anext__())] = pair
                        yield i, item
                    except StopAsyncIteration:
                        pass
        finally:
            # Cancel any remaining iterators
            for f, (_, it) in awaits.items():
                with contextlib.suppress(BaseException):
                    f.cancel()
                    await it.aclose()
    
    
    def get_ip() -> str:
        host_ip = envs.VLLM_HOST_IP
        if host_ip:
            return host_ip
    
        # IP is not set, try to get it from the network interface
    
        # try ipv4
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        try:
            s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
            return s.getsockname()[0]
        except Exception:
            pass
    
        # try ipv6
        try:
            s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
            # Google's public DNS server, see
            # https://developers.google.com/speed/public-dns/docs/using#addresses
            s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
            return s.getsockname()[0]
        except Exception:
            pass
    
        warnings.warn(
            "Failed to get the IP address, using 0.0.0.0 by default."
            "The value can be set by the environment variable"
            " VLLM_HOST_IP or HOST_IP.",
            stacklevel=2)
        return "0.0.0.0"
    
    
    def is_valid_ipv6_address(address: str) -> bool:
        try:
            ipaddress.IPv6Address(address)
            return True
        except ValueError:
            return False
    
    
    def get_distributed_init_method(ip: str, port: int) -> str:
        # Brackets are not permitted in ipv4 addresses,
        # see https://github.com/python/cpython/issues/103848
        return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
    
    
    def get_open_zmq_ipc_path() -> str:
        base_rpc_path = envs.VLLM_RPC_BASE_PATH
        return f"ipc://{base_rpc_path}/{uuid4()}"
    
    
    def get_open_port() -> int:
        port = envs.VLLM_PORT
        if port is not None:
            while True:
                try:
                    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                        s.bind(("", port))
                        return port
                except OSError:
                    port += 1  # Increment port number if already in use
                    logger.info("Port %d is already in use, trying port %d",
                                port - 1, port)
        # try ipv4
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(("", 0))
                return s.getsockname()[1]
        except OSError:
            # try ipv6
            with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
                s.bind(("", 0))
                return s.getsockname()[1]
    
    
    def find_process_using_port(port: int) -> Optional[psutil.Process]:
        for conn in psutil.net_connections():
            if conn.laddr.port == port:
                try:
                    return psutil.Process(conn.pid)
                except psutil.NoSuchProcess:
                    return None
        return None
    
    
    def update_environment_variables(envs: Dict[str, str]):
        for k, v in envs.items():
            if k in os.environ and os.environ[k] != v:
                logger.warning(
                    "Overwriting environment variable %s "
                    "from '%s' to '%s'", k, os.environ[k], v)
            os.environ[k] = v
    
    
    def chunk_list(lst: List[T], chunk_size: int):
        """Yield successive chunk_size chunks from lst."""
        for i in range(0, len(lst), chunk_size):
            yield lst[i:i + chunk_size]
    
    
    def cdiv(a: int, b: int) -> int:
        """Ceiling division."""
        return -(a // -b)
    
    
    def _generate_random_fp8(
        tensor: torch.Tensor,
        low: float,
        high: float,
    ) -> None:
        # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
        # it may occur Inf or NaN if we directly use torch.randint
        # to generate random data for fp8 data.
        # For example, s.11111.00 in fp8e5m2 format represents Inf.
        #     | E4M3        | E5M2
        #-----|-------------|-------------------
        # Inf | N/A         | s.11111.00
        # NaN | s.1111.111  | s.11111.{01,10,11}
        from vllm import _custom_ops as ops
        tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
        tensor_tmp.uniform_(low, high)
        ops.convert_fp8(tensor, tensor_tmp)
        del tensor_tmp
    
    
    def get_kv_cache_torch_dtype(
            cache_dtype: Optional[Union[str, torch.dtype]],
            model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
        if isinstance(cache_dtype, str):
            if cache_dtype == "auto":
                if isinstance(model_dtype, str):
                    torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
                elif isinstance(model_dtype, torch.dtype):
                    torch_dtype = model_dtype
                else:
                    raise ValueError(f"Invalid model dtype: {model_dtype}")
            elif cache_dtype in ["half", "bfloat16", "float"]:
                torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
            elif cache_dtype == "fp8":
                torch_dtype = torch.uint8
            else:
                raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
        elif isinstance(cache_dtype, torch.dtype):
            torch_dtype = cache_dtype
        else:
            raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
        return torch_dtype
    
    
    def create_kv_caches_with_random_flash(
        num_blocks: int,
        block_size: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        cache_dtype: Optional[Union[str, torch.dtype]],
        model_dtype: Optional[Union[str, torch.dtype]] = None,
        seed: int = 0,
        device: Optional[str] = "cuda",
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        seed_everything(seed)
    
        torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
        key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
        scale = head_size**-0.5
    
        key_caches: List[torch.Tensor] = []
        value_caches: List[torch.Tensor] = []
    
        for _ in range(num_layers):
            key_value_cache = torch.empty(size=key_value_cache_shape,
                                          dtype=torch_dtype,
                                          device=device)
            if cache_dtype in ["auto", "half", "bfloat16", "float"]:
                key_value_cache.uniform_(-scale, scale)
            elif cache_dtype == 'fp8':
                _generate_random_fp8(key_value_cache, -scale, scale)
            else:
                raise ValueError(
                    f"Does not support key cache of type {cache_dtype}")
            key_caches.append(key_value_cache[:, 0])
            value_caches.append(key_value_cache[:, 1])
        return key_caches, value_caches
    
    
    def create_kv_caches_with_random(
        num_blocks: int,
        block_size: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        cache_dtype: Optional[Union[str, torch.dtype]],
        model_dtype: Optional[Union[str, torch.dtype]] = None,
        seed: int = 0,
        device: Optional[str] = "cuda",
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    
        if cache_dtype == "fp8" and head_size % 16:
            raise ValueError(
                f"Does not support key cache of type fp8 with head_size {head_size}"
            )
    
        seed_everything(seed)
    
        torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
    
        scale = head_size**-0.5
        x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
        key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
        key_caches: List[torch.Tensor] = []
        for _ in range(num_layers):
            key_cache = torch.empty(size=key_cache_shape,
                                    dtype=torch_dtype,
                                    device=device)
            if cache_dtype in ["auto", "half", "bfloat16", "float"]:
                key_cache.uniform_(-scale, scale)
            elif cache_dtype == 'fp8':
                _generate_random_fp8(key_cache, -scale, scale)
            else:
                raise ValueError(
                    f"Does not support key cache of type {cache_dtype}")
            key_caches.append(key_cache)
    
        value_cache_shape = (num_blocks, num_heads, head_size, block_size)
        value_caches: List[torch.Tensor] = []
        for _ in range(num_layers):
            value_cache = torch.empty(size=value_cache_shape,
                                      dtype=torch_dtype,
                                      device=device)
            if cache_dtype in ["auto", "half", "bfloat16", "float"]:
                value_cache.uniform_(-scale, scale)
            elif cache_dtype == 'fp8':
                _generate_random_fp8(value_cache, -scale, scale)
            else:
                raise ValueError(
                    f"Does not support value cache of type {cache_dtype}")
            value_caches.append(value_cache)
        return key_caches, value_caches
    
    
    @lru_cache
    def print_warning_once(msg: str) -> None:
        logger.warning(msg)
    
    
    @lru_cache(maxsize=None)
    def is_pin_memory_available() -> bool:
    
        if in_wsl():
            # Pinning memory in WSL is not supported.
            # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
            print_warning_once("Using 'pin_memory=False' as WSL is detected. "
                               "This may slow down the performance.")
            return False
        elif is_xpu():
            print_warning_once("Pin memory is not supported on XPU.")
            return False
        elif is_neuron():
            print_warning_once("Pin memory is not supported on Neuron.")
            return False
        elif is_cpu() or is_openvino():
            return False
        return True
    
    
    class DeviceMemoryProfiler:
    
        def __init__(self, device: Optional[torch.types.Device] = None):
            self.device = device
    
        def current_memory_usage(self) -> float:
            # Return the memory usage in bytes.
            if current_platform.is_cuda_alike():
                torch.cuda.reset_peak_memory_stats(self.device)
                mem = torch.cuda.max_memory_allocated(self.device)
            elif is_xpu():
                torch.xpu.reset_peak_memory_stats(self.device)  # type: ignore
                mem = torch.xpu.max_memory_allocated(self.device)  # type: ignore
            elif is_npu():
                torch.npu.reset_peak_memory_stats(self.device)
                mem = torch.npu.max_memory_allocated(self.device)
            return mem
    
        def __enter__(self):
            self.initial_memory = self.current_memory_usage()
            # This allows us to call methods of the context manager if needed
            return self
    
        def __exit__(self, exc_type, exc_val, exc_tb):
            self.final_memory = self.current_memory_usage()
            self.consumed_memory = self.final_memory - self.initial_memory
    
            # Force garbage collection
            gc.collect()
    
    
    def make_ndarray_with_pad(
        x: List[List[T]],
        pad: T,
        dtype: npt.DTypeLike,
        *,
        max_len: Optional[int] = None,
    ) -> npt.NDArray:
        """
        Make a padded array from 2D inputs.
    
        The padding is applied to the end of each inner list until it reaches
        `max_len`.
        """
        if max_len is None:
            # Unlike for most functions, map is faster than a genexpr over `len`
            max_len = max(map(len, x), default=0)
    
        padded_x = np.full((len(x), max_len), pad, dtype=dtype)
        for ind, blocktb in enumerate(x):
            assert len(blocktb) <= max_len
            padded_x[ind, :len(blocktb)] = blocktb
    
        return padded_x
    
    
    def make_tensor_with_pad(
        x: List[List[T]],
        pad: T,
        dtype: torch.dtype,
        *,
        max_len: Optional[int] = None,
        device: Optional[Union[str, torch.device]] = None,
        pin_memory: bool = False,
    ) -> torch.Tensor:
        """
        Make a padded tensor from 2D inputs.
    
        The padding is applied to the end of each inner list until it reaches
        `max_len`.
        """
        np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
        padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
    
        tensor = torch.from_numpy(padded_x).to(device)
        if pin_memory:
            tensor = tensor.pin_memory()
    
        return tensor
    
    
    def async_tensor_h2d(
        data: list,
        dtype: torch.dtype,
        target_device: Union[str, torch.device],
        pin_memory: bool,
    ) -> torch.Tensor:
        """Asynchronously create a tensor and copy it from host to device."""
        t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
        return t.to(device=target_device, non_blocking=True)
    
    
    def get_dtype_size(dtype: torch.dtype) -> int:
        """Get the size of the data type in bytes."""
        return torch.tensor([], dtype=dtype).element_size()
    
    
    # `collections` helpers
    def is_list_of(
        value: object,
        typ: Type[T],
        *,
        check: Literal["first", "all"] = "first",
    ) -> TypeIs[List[T]]:
        if not isinstance(value, list):
            return False
    
        if check == "first":
            return len(value) == 0 or isinstance(value[0], typ)
        elif check == "all":
            return all(isinstance(v, typ) for v in value)
    
        assert_never(check)
    
    
    JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
                     Tuple["JSONTree[T]", ...], T]
    """A nested JSON structure where the leaves need not be JSON-serializable."""
    
    
    @overload
    def json_map_leaves(
        func: Callable[[T], U],
        value: Dict[str, JSONTree[T]],
    ) -> Dict[str, JSONTree[U]]:
        ...
    
    
    @overload
    def json_map_leaves(
        func: Callable[[T], U],
        value: List[JSONTree[T]],
    ) -> List[JSONTree[U]]:
        ...
    
    
    @overload
    def json_map_leaves(
        func: Callable[[T], U],
        value: Tuple[JSONTree[T], ...],
    ) -> Tuple[JSONTree[U], ...]:
        ...
    
    
    @overload
    def json_map_leaves(
        func: Callable[[T], U],
        value: JSONTree[T],
    ) -> JSONTree[U]:
        ...
    
    
    def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
        if isinstance(value, dict):
            return {k: json_map_leaves(func, v) for k, v in value.items()}
        elif isinstance(value, list):
            return [json_map_leaves(func, v) for v in value]
        elif isinstance(value, tuple):
            return tuple(json_map_leaves(func, v) for v in value)
        else:
            return func(value)
    
    
    def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
        """Flatten a list of lists to a single list."""
        return [item for sublist in lists for item in sublist]
    
    
    def init_cached_hf_modules() -> None:
        """
        Lazy initialization of the Hugging Face modules.
        """
        from transformers.dynamic_module_utils import init_hf_modules
        init_hf_modules()
    
    
    @lru_cache(maxsize=None)
    def find_library(lib_name: str) -> str:
        """
        Find the library file in the system.
        `lib_name` is full filename, with both prefix and suffix.
        This function resolves `lib_name` to the full path of the library.
        """
        # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
        # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
        # `/sbin/ldconfig` should exist in all Linux systems.
        # `/sbin/ldconfig` searches the library in the system
        libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
        # each line looks like the following:
        # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
        locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
        # `LD_LIBRARY_PATH` searches the library in the user-defined paths
        env_ld_library_path = envs.LD_LIBRARY_PATH
        if not locs and env_ld_library_path:
            locs = [
                os.path.join(dir, lib_name)
                for dir in env_ld_library_path.split(":")
                if os.path.exists(os.path.join(dir, lib_name))
            ]
        if not locs:
            raise ValueError(f"Cannot find {lib_name} in the system.")
        return locs[0]
    
    
    def find_nccl_library() -> str:
        """
        We either use the library file specified by the `VLLM_NCCL_SO_PATH`
        environment variable, or we find the library file brought by PyTorch.
        After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
        found by `ctypes` automatically.
        """
        so_file = envs.VLLM_NCCL_SO_PATH
    
        # manually load the nccl library
        if so_file:
            logger.info(
                "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s",
                so_file)
        else:
            if torch.version.cuda is not None:
                so_file = "libnccl.so.2"
            elif torch.version.hip is not None:
                so_file = "librccl.so.1"
            else:
                raise ValueError("NCCL only supports CUDA and ROCm backends.")
            logger.info("Found nccl from library %s", so_file)
        return so_file
    
    
    def enable_trace_function_call_for_thread() -> None:
        """Set up function tracing for the current thread,
        if enabled via the VLLM_TRACE_FUNCTION environment variable
        """
    
        if envs.VLLM_TRACE_FUNCTION:
            tmp_dir = tempfile.gettempdir()
            filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
                        f"_thread_{threading.get_ident()}_"
                        f"at_{datetime.datetime.now()}.log").replace(" ", "_")
            log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
                                    filename)
            os.makedirs(os.path.dirname(log_path), exist_ok=True)
            enable_trace_function_call(log_path)
    
    
    # `functools` helpers
    def identity(value: T) -> T:
        return value
    
    
    F = TypeVar('F', bound=Callable[..., Any])
    
    
    def deprecate_kwargs(
            *kws: str,
            is_deprecated: Union[bool, Callable[[], bool]] = True,
            additional_message: Optional[str] = None) -> Callable[[F], F]:
        deprecated_kws = set(kws)
    
        if not callable(is_deprecated):
            is_deprecated = partial(identity, is_deprecated)
    
        def wrapper(fn: F) -> F:
    
            @wraps(fn)
            def inner(*args, **kwargs):
                if is_deprecated():
                    deprecated_kwargs = kwargs.keys() & deprecated_kws
                    if deprecated_kwargs:
                        msg = (
                            f"The keyword arguments {deprecated_kwargs} are "
                            "deprecated and will be removed in a future update.")
                        if additional_message is not None:
                            msg += f" {additional_message}"
    
                        warnings.warn(
                            DeprecationWarning(msg),
                            stacklevel=3,  # The inner function takes up one level
                        )
    
                return fn(*args, **kwargs)
    
            return inner  # type: ignore
    
        return wrapper
    
    
    @lru_cache(maxsize=8)
    def _cuda_device_count_stateless(
            cuda_visible_devices: Optional[str] = None) -> int:
        # Note: cuda_visible_devices is not used, but we keep it as an argument for
        # LRU Cache purposes.
    
        # Code below is based on
        # https://github.com/pytorch/pytorch/blob/
        # c1cd946818442aca8c7f812b16d187ce1586c3bc/
        # torch/cuda/__init__.py#L831C1-L831C17
        import torch.cuda
        import torch.version
    
        if not torch.cuda._is_compiled():
            return 0
        if is_hip():
            # ROCm uses amdsmi instead of nvml for stateless device count
            # This requires a sufficiently modern version of Torch 2.4.0
            raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
                torch.cuda, "_device_count_amdsmi")) else -1
        else:
            raw_count = torch.cuda._device_count_nvml()
        r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
        return r
    
    
    def cuda_device_count_stateless() -> int:
        """Get number of CUDA devices, caching based on the value of
        CUDA_VISIBLE_DEVICES at the time of call.
    
        This should be used instead of torch.cuda.device_count()
        unless CUDA_VISIBLE_DEVICES has already been set to the desired
        value."""
    
        # This can be removed and simply replaced with torch.cuda.get_device_count
        # after https://github.com/pytorch/pytorch/pull/122815 is released.
        return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
    
    
    def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
        """Make an instance method that weakly references
        its associated instance and no-ops once that
        instance is collected."""
        ref = weakref.ref(bound_method.__self__)  # type: ignore[attr-defined]
        unbound = bound_method.__func__  # type: ignore[attr-defined]
    
        def weak_bound(*args, **kwargs) -> None:
            if inst := ref():
                unbound(inst, *args, **kwargs)
    
        return weak_bound
    
    
    #From: https://stackoverflow.com/a/4104188/2749989
    def run_once(f: Callable[P, None]) -> Callable[P, None]:
    
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
            if not wrapper.has_run:  # type: ignore[attr-defined]
                wrapper.has_run = True  # type: ignore[attr-defined]
                return f(*args, **kwargs)
    
        wrapper.has_run = False  # type: ignore[attr-defined]
        return wrapper
    
    
    class FlexibleArgumentParser(argparse.ArgumentParser):
        """ArgumentParser that allows both underscore and dash in names."""
    
        @staticmethod
        def _pull_args_from_config(args: List[str]) -> List[str]:
            """Method to pull arguments specified in the config file
            into the command-line args variable.
    
            The arguments in config file will be inserted between
            the argument list.
    
            example:
            ```yaml
                port: 12323
                tensor-parallel-size: 4
            ```
            ```python
            $: vllm {serve,chat,complete} "facebook/opt-12B" \
                --config config.yaml -tp 2
            $: args = [
                "serve,chat,complete",
                "facebook/opt-12B",
                '--config', 'config.yaml',
                '-tp', '2'
            ]
            $: args = [
                "serve,chat,complete",
                "facebook/opt-12B",
                '--port', '12323',
                '--tensor-parallel-size', '4',
                '-tp', '2'
                ]
            ```
    
            Please note how the config args are inserted after the sub command.
            this way the order of priorities is maintained when these are args
            parsed by super().
            """
            assert args.count(
                '--config') <= 1, "More than one config file specified!"
    
            index = args.index('--config')
            if index == len(args) - 1:
                raise ValueError("No config file specified! \
                                 Please check your command-line arguments.")
    
            file_path = args[index + 1]
    
            config_args = FlexibleArgumentParser._load_config_file(file_path)
    
            # 0th index is for {serve,chat,complete}
            # followed by config args
            # followed by rest of cli args.
            # maintaining this order will enforce the precedence
            # of cli > config > defaults
            args = [args[0]] + config_args + args[1:index] + args[index + 2:]
    
            return args
    
        @staticmethod
        def _load_config_file(file_path: str) -> List[str]:
            """Loads a yaml file and returns the key value pairs as a
            flattened list with argparse like pattern
            ```yaml
                port: 12323
                tensor-parallel-size: 4
            ```
            returns:
                processed_args: list[str] = [
                    '--port': '12323',
                    '--tensor-parallel-size': '4'
                ]
    
            """
    
            extension: str = file_path.split('.')[-1]
            if extension not in ('yaml', 'yml'):
                raise ValueError(
                    "Config file must be of a yaml/yml type.\
                                  %s supplied", extension)
    
            # only expecting a flat dictionary of atomic types
            processed_args: List[str] = []
    
            config: Dict[str, Union[int, str]] = {}
            try:
                with open(file_path, 'r') as config_file:
                    config = yaml.safe_load(config_file)
            except Exception as ex:
                logger.error(
                    "Unable to read the config file at %s. \
                    Make sure path is correct", file_path)
                raise ex
    
            for key, value in config.items():
                processed_args.append('--' + key)
                processed_args.append(str(value))
    
            return processed_args
    
        def parse_args(self, args=None, namespace=None):
            if args is None:
                args = sys.argv[1:]
    
            if '--config' in args:
                args = FlexibleArgumentParser._pull_args_from_config(args)
    
            # Convert underscores to dashes and vice versa in argument names
            processed_args = []
            for arg in args:
                if arg.startswith('--'):
                    if '=' in arg:
                        key, value = arg.split('=', 1)
                        key = '--' + key[len('--'):].replace('_', '-')
                        processed_args.append(f'{key}={value}')
                    else:
                        processed_args.append('--' +
                                              arg[len('--'):].replace('_', '-'))
                else:
                    processed_args.append(arg)
    
            return super().parse_args(processed_args, namespace)
    
    
    async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
                                  **kwargs):
        """Utility function to run async task in a lock"""
        async with lock:
            return await task(*args, **kwargs)
    
    
    def get_allowed_kwarg_only_overrides(
        callable: Callable[..., object],
        overrides: Optional[Dict[str, Any]],
    ) -> Dict[str, Any]:
        """
        Given a callable which has one or more keyword only params and a dict
        mapping param names to values, drop values that can be not be kwarg
        expanded to overwrite one or more keyword-only args. This is used in a
        few places to handle custom processor overrides for multimodal models,
        e.g., for profiling when processor options provided by the user
        may affect the number of mm tokens per instance.
    
        Args:
            callable: Callable which takes 0 or more keyword only arguments.
            overrides: Potential overrides to be used when invoking the callable.
    
        Returns:
            Dictionary containing the kwargs to be leveraged which may be used
            to overwrite one or more keyword only arguments when invoking the
            callable.
        """
        if not overrides:
            return {}
    
        allowed_override_names = [
            name for name, param in inspect.signature(callable).parameters.items()
            if param.kind == inspect.Parameter.KEYWORD_ONLY
        ]
    
        # Drop any mm_processor_kwargs provided by the user that are
        # not kwarg names accepted by the provided input processor.
        filtered_overrides = {
            kwarg_name: val
            for kwarg_name, val in overrides.items()
            if kwarg_name in allowed_override_names
        }
    
        # If anything is dropped, log a warning
        dropped_keys = overrides.keys() - filtered_overrides.keys()
        if dropped_keys:
            logger.warning(
                "The following intended overrides are not keyword-only args "
                "and and will be dropped: %s", dropped_keys)
    
        return filtered_overrides
    
    
    # Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
    # In particular, the FakeScalarType is not supported for earlier versions of
    # PyTorch which breaks dynamo for any ops registered using ScalarType.
    def supports_dynamo() -> bool:
        base_torch_version = Version(Version(torch.__version__).base_version)
        return base_torch_version >= Version("2.4.0")
    
    
    # Some backends use pytorch version < 2.4.0 which doesn't
    # support `torch.library.custom_op`.
    def supports_custom_op() -> bool:
        return hasattr(torch.library, "custom_op")
    
    
    class AtomicCounter:
        """An atomic, thread-safe counter"""
    
        def __init__(self, initial=0):
            """Initialize a new atomic counter to given initial value"""
            self._value = initial
            self._lock = threading.Lock()
    
        @property
        def value(self):
            return self._value
    
        def inc(self, num=1):
            """Atomically increment the counter by num and return the new value"""
            with self._lock:
                self._value += num
                return self._value
    
        def dec(self, num=1):
            """Atomically decrement the counter by num and return the new value"""
            with self._lock:
                self._value -= num
                return self._value
  • cover/vllm/worker/cache_engine.py: 重写原生框架中CacheEngine的_allocate_kv_cache函数,主要是对生成kv_cache的数据格式进行了修改,从Torch.tensor修改为Tuple[torch.Tensor, torch.Tensor]
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of code in this file was copied from project [vLLM Team][vllm] for adapting usage
    """CacheEngine class for managing the KV cache."""
    from typing import List, Tuple, Union
    import torch
    from vllm.attention import get_attn_backend
    from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
    from vllm.logger import init_logger
    from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
                            is_pin_memory_available)
    logger = init_logger(__name__)
    class CacheEngine:
        """Manages the KV cache.
        This class is responsible for initializing and managing the GPU and CPU KV
        caches. It also provides methods for performing KV cache operations, such
        as swapping and copying.
        """
        def __init__(
            self,
            cache_config: CacheConfig,
            model_config: ModelConfig,
            parallel_config: ParallelConfig,
            device_config: DeviceConfig,
        ) -> None:
            self.cache_config = cache_config
            self.model_config = model_config
            self.parallel_config = parallel_config
            self.device_config = device_config
            self.head_size = model_config.get_head_size()
            # Models like Jamba, have mixed typed layers, E.g Mamba
            self.num_attention_layers = model_config.get_num_attention_layers(
                parallel_config)
            self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
            self.block_size = cache_config.block_size
            self.num_gpu_blocks = cache_config.num_gpu_blocks
            if self.num_gpu_blocks:
                self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
            self.num_cpu_blocks = cache_config.num_cpu_blocks
            if self.num_cpu_blocks:
                self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
            if cache_config.cache_dtype == "auto":
                self.dtype = model_config.dtype
            else:
                self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
            # Get attention backend.
            self.attn_backend = get_attn_backend(
                model_config.get_num_attention_heads(parallel_config),
                self.head_size,
                self.num_kv_heads,
                model_config.get_sliding_window(),
                model_config.dtype,
                cache_config.cache_dtype,
                self.block_size,
            )
            # Initialize the cache.
            self.gpu_cache = self._allocate_kv_cache(
                self.num_gpu_blocks, self.device_config.device_type)
            self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
        @staticmethod
        def get_cache_block_size(
            cache_config: CacheConfig,
            model_config: ModelConfig,
            parallel_config: ParallelConfig,
        ) -> int:
            head_size = model_config.get_head_size()
            num_heads = model_config.get_num_kv_heads(parallel_config)
            num_attention_layers = model_config.get_num_attention_layers(
                parallel_config)
            key_cache_block = cache_config.block_size * num_heads * head_size
            value_cache_block = key_cache_block
            total = num_attention_layers * (key_cache_block + value_cache_block)
            if cache_config.cache_dtype == "auto":
                dtype = model_config.dtype
            else:
                dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
            dtype_size = get_dtype_size(dtype)
            return dtype_size * total
        def swap_in(self, src_to_dst: torch.Tensor) -> None:
            for i in range(self.num_attention_layers):
                self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
                                              src_to_dst)
        def swap_out(self, src_to_dst: torch.Tensor) -> None:
            for i in range(self.num_attention_layers):
                self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
                                              src_to_dst)
        def copy(self, src_to_dsts: torch.Tensor) -> None:
            device = self.device_config.device_type
            self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
        def _allocate_kv_cache(
            self,
            num_blocks: int,
            device: str,
        ) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
            """Allocates KV cache on the specified device."""
            kv_cache_shape = self.attn_backend.get_kv_cache_shape(
                num_blocks, self.block_size, self.num_kv_heads, self.head_size)
            pin_memory = is_pin_memory_available() if device == "cpu" else False
            kv_cache: List[torch.Tensor] = []
            for _ in range(self.num_attention_layers):
                # null block in CpuGpuBlockAllocator requires at least that
                # block to be zeroed-out.
                # We zero-out everything for simplicity.
                if self.device_config.device_type == "npu":
                    # Cannot set 5-dim tensor on NPU platform
                    key_blocks = torch.zeros(size=kv_cache_shape, dtype=self.dtype, device=device)
                    value_blocks = torch.zeros(size=kv_cache_shape, dtype=self.dtype, device=device)
                    kv_cache.append((key_blocks, value_blocks))
                else:
                    kv_cache.append(torch.zeros(kv_cache_shape, dtype=self.dtype, pin_memory=pin_memory, device=device))
            return kv_cache
  • cover/vllm/worker/npu_worker.py:实现了NPUWorker,以供executor模块中executor类进行调用。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of codes in this file was copied from project [vLLM Team][vllm]
    
    """An NPU worker class."""
    import dataclasses
    from dataclasses import dataclass
    import gc
    from typing import List, Optional, Set, Tuple, Dict
    
    import torch
    import torch.distributed
    from vllm.config import (
        CacheConfig,
        DeviceConfig,
        LoadConfig,
        LoRAConfig,
        ModelConfig,
        ObservabilityConfig,
        ParallelConfig,
        PromptAdapterConfig,
        SchedulerConfig,
        SpeculativeConfig,
    )
    from vllm.sequence import ExecuteModelRequest
    from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment
    from vllm.lora.request import LoRARequest
    from vllm.model_executor import set_random_seed
    from vllm.prompt_adapter.request import PromptAdapterRequest
    from vllm.worker.cache_engine import CacheEngine
    from vllm.worker.npu_model_runner import ModelRunner
    from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput, extract_previous_hidden_states
    from vllm.worker.model_runner_base import BroadcastableModelInput, ModelRunnerInputBase
    from vllm.worker.npu_model_runner import MultiStepNPUModelRunner, StatefulModelInput
    from vllm.distributed import broadcast_tensor_dict, get_pp_group
    from vllm.model_executor.layers.sampler import SamplerOutput
    from vllm.distributed.npu_utils import ascend_broadcast_data_dict
    
    
    class NPUWorker(LocalOrDistributedWorkerBase):
        """A worker class that executes the model on a group of Ascend NPUs."""
    
        def __init__(
            self,
            model_config: ModelConfig,
            parallel_config: ParallelConfig,
            scheduler_config: SchedulerConfig,
            device_config: DeviceConfig,
            cache_config: CacheConfig,
            load_config: LoadConfig,
            local_rank: int,
            rank: int,
            distributed_init_method: str,
            lora_config: Optional[LoRAConfig] = None,
            speculative_config: Optional[SpeculativeConfig] = None,
            prompt_adapter_config: Optional[PromptAdapterConfig] = None,
            is_driver_worker: bool = False,
            observability_config: Optional[ObservabilityConfig] = None,
        ) -> None:
            self.model_config = model_config
            self.parallel_config = parallel_config
            self.scheduler_config = scheduler_config
            self.device_config = device_config
            self.cache_config = cache_config
            self.local_rank = local_rank
            self.rank = rank
            self.distributed_init_method = distributed_init_method
            self.lora_config = lora_config
            self.load_config = load_config
            self.is_driver_worker = is_driver_worker
            if parallel_config and is_driver_worker:
                assert (
                    rank % parallel_config.tensor_parallel_size == 0
                ), "Driver worker should be rank 0 of tensor parallel group."
    
            if self.model_config.trust_remote_code:
                # note: lazy import to avoid importing torch before initializing
                from vllm.utils import init_cached_hf_modules
    
                init_cached_hf_modules()
    
            mindie_config = {
                "backend_type": "atb",
                "model_id": model_config.model,
                "rank": rank,
                "local_rank": local_rank,
                "world_size": parallel_config.world_size,
                "npu_device_id": local_rank,
                "trust_remote_code": model_config.trust_remote_code,
                "inference_mode": (
                    2 if scheduler_config.chunked_prefill_enabled or cache_config.enable_prefix_caching else 0
                ),
            }
            self.model_runner = ModelRunner(
                model_config,
                parallel_config,
                scheduler_config,
                device_config,
                cache_config,
                load_config,
                lora_config,
                mindie_config,
                kv_cache_dtype=self.cache_config.cache_dtype,
                is_driver_worker=is_driver_worker,
            )
            # Uninitialized cache engine. Will be initialized by
            # self.initialize_cache().
            self.cache_engine: List[CacheEngine]
            self.npu_cache: List[torch.Tensor]
    
        @property
        def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
            return self.npu_cache
    
        @property
        def do_metadata_broadcast(self) -> bool:
            return self.parallel_config.tensor_parallel_size > 1
    
        def init_device(self) -> None:
            self.device = torch.device(f"npu:{self.local_rank}")
            torch.npu.set_device(self.device)
            gc.collect()
            # Initialize the distributed environment.
            init_worker_distributed_environment(
                self.parallel_config, self.rank, self.distributed_init_method, self.local_rank
            )
            # Initialize the model.
            set_random_seed(self.model_config.seed)
    
        def load_model(self):
            self.model_runner.load_model()
    
        @torch.inference_mode()
        def determine_num_available_blocks(self) -> Tuple[int, int]:
            """Profiles the peak memory usage of the model and returns the maximum
            number of NPU and CPU cache blocks that can be allocated.
            """
            # Profile the memory usage of the model and get the maximum number of
            # cache blocks that can be allocated with the remaining free memory.
            torch.npu.empty_cache()
            torch.npu.reset_peak_memory_stats()
    
            # Execute a forward pass with dummy inputs to profile the memory usage
            # of the model.
            self.model_runner.profile_run()
            block_size = self.cache_config.block_size
            dummy_block_size = 128
            dummy_num_blocks = dummy_block_size // block_size * self.model_runner.model.dummy_block_num
    
            # Calculate the number of blocks that can be allocated with the
            # profiled peak memory.
            torch.npu.synchronize()
            peak_memory = torch.npu.max_memory_allocated()
    
            total_gpu_memory = torch.npu.get_device_properties(self.rank).total_memory
            cache_block_size = CacheEngine.get_cache_block_size(self.cache_config, self.model_config, self.parallel_config)
    
            num_gpu_blocks = (
                int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size)
                + dummy_num_blocks
            )
            num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size)
            num_gpu_blocks = max(num_gpu_blocks, 0)
            num_cpu_blocks = max(num_cpu_blocks, 0)
    
            if self.model_runner.lora_manager:
                self.model_runner.remove_all_loras()
            gc.collect()
            torch.npu.empty_cache()
            return num_gpu_blocks, num_cpu_blocks
    
        def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
            raise_if_cache_size_invalid(num_gpu_blocks, self.cache_config.block_size, self.model_config.max_model_len)
            self.cache_config.num_gpu_blocks = num_gpu_blocks
            self.cache_config.num_cpu_blocks = num_cpu_blocks
    
            self._init_cache_engine()
            self._warm_up_model()
    
        def _get_worker_input_from_broadcast(
            self
        ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
                str, torch.Tensor]]]:
            """ Get the worker input from the broadcasted tensor dict. """
            assert self.do_metadata_broadcast
            assert not self.is_driver_worker
            broadcast_data = ascend_broadcast_data_dict(src=0)
            if not broadcast_data:
                return None
    
            worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
            model_input = (
                self.model_runner.make_model_input_from_broadcasted_tensor_dict(
                    broadcast_data))
    
            kwargs = extract_previous_hidden_states(broadcast_data)
    
            return model_input, worker_input, kwargs
    
        def _get_driver_input_and_broadcast(
            self, execute_model_req: ExecuteModelRequest
        ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
            """ Get the driver input and broadcast it to other workers.  """
            assert self.is_driver_worker
    
            worker_input: WorkerInput = self.prepare_worker_input(
                execute_model_req=execute_model_req)
            model_input: ModelRunnerInputBase = (
                self.model_runner.prepare_model_input(
                    execute_model_req.seq_group_metadata_list,
                    execute_model_req.virtual_engine,
                    execute_model_req.finished_requests_ids))
    
            kwargs = extract_previous_hidden_states(execute_model_req)
    
            if self.do_metadata_broadcast:
                broadcast_data = worker_input.as_broadcastable_tensor_dict()
                broadcast_data.update(model_input.as_broadcastable_tensor_dict())
                broadcast_data.update(kwargs)
                ascend_broadcast_data_dict(broadcast_data, src=0)
            if execute_model_req.async_callback:
                model_input = dataclasses.replace(  # type: ignore
                    model_input,
                    async_callback=execute_model_req.async_callback)
    
            return model_input, worker_input, kwargs
    
        @torch.inference_mode()
        def prepare_worker_input(self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
            virtual_engine = execute_model_req.virtual_engine
            num_steps = execute_model_req.num_steps
            num_seq_groups = len(execute_model_req.seq_group_metadata_list)
            # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
            # they contain parameters to launch cudamemcpyasync.
            blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, device="cpu", dtype=torch.int64).view(
                -1, 2
            )
            blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, device="cpu", dtype=torch.int64).view(
                -1, 2
            )
            # `blocks_to_copy` is a gpu tensor. The src and tgt of
            # blocks to copy are in the same device, and `blocks_to_copy`
            # can be used directly within cuda kernels.
            blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, device=self.device, dtype=torch.int64).view(
                -1, 2
            )
    
            return WorkerInput(
                num_seq_groups=num_seq_groups,
                blocks_to_swap_in=blocks_to_swap_in,
                blocks_to_swap_out=blocks_to_swap_out,
                blocks_to_copy=blocks_to_copy,
                virtual_engine=virtual_engine,
                num_steps=num_steps,
            )
    
        @torch.inference_mode()
        def execute_worker(self, worker_input: WorkerInput) -> None:
            virtual_engine = worker_input.virtual_engine
            # Issue cache operations.
            if worker_input.blocks_to_swap_in is not None and worker_input.blocks_to_swap_in.numel() > 0:
                self.cache_engine[virtual_engine].swap_in(worker_input.blocks_to_swap_in)
            if worker_input.blocks_to_swap_out is not None and worker_input.blocks_to_swap_out.numel() > 0:
                self.cache_engine[virtual_engine].swap_out(worker_input.blocks_to_swap_out)
            if worker_input.blocks_to_copy is not None and worker_input.blocks_to_copy.numel() > 0:
                self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
    
        def add_lora(self, lora_request: LoRARequest) -> bool:
            return self.model_runner.add_lora(lora_request)
    
        def remove_lora(self, lora_id: int) -> bool:
            return self.model_runner.remove_lora(lora_id)
    
        def list_loras(self) -> Set[int]:
            return self.model_runner.list_loras()
    
        def pin_lora(self, lora_id: int) -> bool:
            return self.model_runner.pin_lora(lora_id)
    
        def list_loras(self) -> Set[int]:
            return self.model_runner.list_loras()
    
        def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool:
            return self.model_runner.add_prompt_adapter(prompt_adapter_request)
    
        def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
            return self.model_runner.remove_lora(prompt_adapter_id)
    
        def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
            return self.model_runner.pin_prompt_adapter(prompt_adapter_id)
    
        def list_prompt_adapters(self) -> Set[int]:
            return self.model_runner.list_prompt_adapters()
    
        def get_cache_block_size_bytes(self) -> int:
            """Get the size of the KV cache block size in bytes."""
            return CacheEngine.get_cache_block_size(self.cache_config, self.model_config, self.parallel_config)
    
        def _init_cache_engine(self):
            assert self.cache_config.num_gpu_blocks is not None
            self.cache_engine = [
                CacheEngine(self.cache_config, self.model_config, self.parallel_config, self.device_config)
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]
            self.npu_cache = [self.cache_engine[ve].gpu_cache for ve in range(self.parallel_config.pipeline_parallel_size)]
    
        def _warm_up_model(self) -> None:
            pass
    
    
    def init_worker_distributed_environment(
        parallel_config: ParallelConfig,
        rank: int,
        distributed_init_method: Optional[str] = None,
        local_rank: int = -1,
    ) -> None:
        """Initialize the distributed environment."""
        init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, "hccl")
    
        ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size)
    
    
    def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len) -> None:
        if num_gpu_blocks <= 0:
            raise ValueError(
                "No available memory for the cache blocks. "
                "Try increasing `gpu_memory_utilization` when "
                "initializing the engine."
            )
        max_seq_len = block_size * num_gpu_blocks
        if max_model_len > max_seq_len:
            raise ValueError(
                f"The model's max seq len ({max_model_len}) "
                "is larger than the maximum number of tokens that can be "
                f"stored in KV cache ({max_seq_len}). Try increasing "
                "`gpu_memory_utilization` or decreasing `max_model_len` when "
                "initializing the engine."
            )
    
    
    @dataclass
    class MultiStepState:
        worker_input: WorkerInput
        model_input: StatefulModelInput
    
    
    class MultiStepNPUWorker(NPUWorker):
    
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            base_model_runner = self.model_runner
            # for multi-step model, wrap the model runner with MultiStepModelRunner
            mindie_config = {
                "backend_type": "atb",
                "model_id": kwargs.get("model_config").model,
                "rank": kwargs.get("rank"),
                "local_rank": kwargs.get("local_rank"),
                "world_size": kwargs.get("parallel_config").world_size,
                "npu_device_id": kwargs.get("local_rank"),
                "inference_mode": 2 if kwargs.get("scheduler_config").chunked_prefill_enabled else 0,
            }
            self.model_runner = MultiStepNPUModelRunner(
                base_model_runner,
                base_model_runner.model_config,
                base_model_runner.parallel_config,
                base_model_runner.scheduler_config,
                base_model_runner.device_config,
                base_model_runner.cache_config,
                load_config=base_model_runner.load_config,
                lora_config=self.lora_config,
                mindie_config=mindie_config,
                kv_cache_dtype=self.cache_config.cache_dtype,
                is_driver_worker=base_model_runner.is_driver_worker,
                prompt_adapter_config=base_model_runner.prompt_adapter_config,
                observability_config=base_model_runner.observability_config,
            )
    
            pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
            self.multi_step_states: List[Optional[MultiStepState]] = [None] * pipeline_parallel_size
            self.temp_output = None
    
        def prepare_input(
            self,
            execute_model_req: Optional[ExecuteModelRequest] = None,
        ) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str, torch.Tensor]]]:
            """
            Depending on the current state of the request and multi step worker,
            this method may skip the normal _prepare_model_input and
            _prepare_worker_input methods and instead used cached values.
            """
            if self.is_driver_worker:
                if execute_model_req is None:
                    if self.do_metadata_broadcast:
                        # This signals that there's no more requests to process for
                        # now. All workers are running infinite loop with
                        # broadcast_tensor_dict, and it stops the loop when the
                        # driver broadcasts an empty input. Send an empty input to
                        # notify all other workers to stop their execution loop.
                        broadcast_tensor_dict({}, src=0)
                    return None
    
                virtual_engine = execute_model_req.virtual_engine
                (model_input, worker_input, kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
                assert isinstance(model_input, StatefulModelInput)
                if execute_model_req.is_first_multi_step:
                    # cache the worker input and model input for the next steps
                    self.multi_step_states[virtual_engine] = MultiStepState(
                        worker_input=worker_input, model_input=model_input
                    )
            # if TP workers
            else:
                broadcast_data = self._get_worker_input_from_broadcast()
                # if the driver has sent an empty input, we should stop the worker
                # loop
                if broadcast_data is None:
                    return None
                model_input, worker_input, kwargs = broadcast_data
                assert isinstance(model_input, StatefulModelInput)
                virtual_engine = worker_input.virtual_engine
                if model_input.is_first_multi_step:
                    pass
                    # TODO(will) Can cache the worker input and model input for the
                    # next steps. See below for details
                else:
                    # TODO(will) possible to also cache and reuse the cached worker
                    # input and model input. The idea is essentially the delta
                    # optimization for model_inputs. Where the TP workers can cache
                    # the model input states and we only broadcast the delta need
                    # for the next step (sampled_token_ids from the previous step)
    
                    assert isinstance(model_input, StatefulModelInput)
                    # we need to update the last sampled token ids in the model
                    # input for the workers so that they can run inplace
                    # advance_step
                    model_input.add_sampler_output(
                        SamplerOutput(outputs=[], sampled_token_ids=None), model_input.last_sampled_token_ids
                    )
    
            assert model_input is not None
            assert worker_input is not None
            return model_input, worker_input, kwargs
    
        def _get_worker_input_from_broadcast(
            self
        ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
                str, torch.Tensor]]]:
            """ Get the worker input from the broadcasted tensor dict. """
            assert self.do_metadata_broadcast
            assert not self.is_driver_worker
            broadcast_data = broadcast_tensor_dict(src=0)
            if not broadcast_data:
                return None
    
            worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
            model_input = (
                self.model_runner.make_model_input_from_broadcasted_tensor_dict(
                    broadcast_data))
    
            kwargs = extract_previous_hidden_states(broadcast_data)
    
            return model_input, worker_input, kwargs
    
        def _get_driver_input_and_broadcast(
            self, execute_model_req: ExecuteModelRequest
        ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
            """
            Get the driver input and broadcast it to other workers.
            """
            assert self.is_driver_worker
            virtual_engine = execute_model_req.virtual_engine
            is_first_multi_step = execute_model_req.is_first_multi_step
            if is_first_multi_step:
                # on first step we prepare the worker input and model input normally
                worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req)
                model_input: StatefulModelInput = self.model_runner.prepare_model_input(
                    execute_model_req.seq_group_metadata_list,
                    execute_model_req.virtual_engine,
                    execute_model_req.finished_requests_ids,
                )
    
                if execute_model_req.async_callback:
                    model_input.frozen_model_input = dataclasses.replace(  # type: ignore
                        model_input.frozen_model_input, async_callback=execute_model_req.async_callback
                    )
            else:
                # on subsequent steps we reuse the worker input and model input
                multi_step_state = self.multi_step_states[virtual_engine]
                worker_input = multi_step_state.worker_input
                model_input = multi_step_state.model_input
                frozen_model_input = model_input.frozen_model_input
                assert frozen_model_input is not None
                assert frozen_model_input.attn_metadata is not None
                # clear the cached decode metadata so that it can be recomputed on
                # the workers
                frozen_model_input.attn_metadata._cached_decode_metadata = None
    
            model_input.is_first_multi_step = is_first_multi_step
            model_input.is_last_step = execute_model_req.is_last_step
    
            if not is_first_multi_step:
                # we broadcast the last sampled token ids to all TP workers so they
                # can update their model input metadata in-place.
                self._prepare_last_sampled_token_ids_for_tp_workers(
                    execute_model_req=execute_model_req, model_input=model_input
                )
    
            if self.do_metadata_broadcast:
                broadcast_data = worker_input.as_broadcastable_tensor_dict()
                broadcast_data.update(model_input.as_broadcastable_tensor_dict())
                broadcast_tensor_dict(broadcast_data, src=0)
    
            # Retuning empty dict here to keep this compatible with
            # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
            return model_input, worker_input, {}
    
        def _prepare_last_sampled_token_ids_for_tp_workers(
            self,
            execute_model_req: ExecuteModelRequest,
            model_input: StatefulModelInput,
        ) -> None:
            """
            Prepare the last sampled token ids for TP workers. If it's the last
            PP rank, then the last sampled token ids are already in the model_input.
            If it is NOT the last PP rank, then we need to get the last sampled
            token that is cached in the execute_model_req.
            """
            if get_pp_group().is_last_rank:
                assert model_input.cached_outputs[-1].sampler_output.sampled_token_ids is None
                assert model_input.cached_outputs[-1].sampled_token_ids is not None
                model_input.last_sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
                # free sampled token ids from the previous step if it has been
                # pythonized. Cannot free the last sampled token ids because
                # we need it for NPU advance_step.
                for output in model_input.cached_outputs[:-1]:
                    if output.pythonized:
                        output.sampled_token_ids = None
            else:
                # otherwise we need to get the cached sampled token ids from the
                # execute_model_req
                assert execute_model_req.last_sampled_token_ids is not None
                model_input.last_sampled_token_ids = execute_model_req.last_sampled_token_ids.cuda()
                model_input.add_sampler_output(
                    SamplerOutput(outputs=[], sampled_token_ids=None), model_input.last_sampled_token_ids
                )
    
                # free sampled token ids from the previous step.
                # TODO(will) we could reuse the sampled token ids tensor from
                # the previous step instead.
                for output in model_input.cached_outputs[:-1]:
                    output.sampled_token_ids = None
                assert model_input.cached_outputs[-1].sampled_token_ids is not None
  • cover/vllm/worker/npu_model_runner.py:实现NPUModelRunner类,在NPUWorker中被调用。NPUModelRunner继承自原生框架中ModelRunner类,主要是为了对原生的load_model,execute_model和profile_run函数进行重写。vLLM新版本中执行模型调用时先调用模型生成hidden_states,再使用一个process处理hidden_states得到logits,然后进行最后的sample操作得到结果;而在MindIE模型仓中前两步操作是通过模型调用一步完成的,因此在这里进行了修改;profile_run函数的修改主要是为了构造warmup时使用的fake data。
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of code in this file was copied from project [vLLM Team][vllm] for adapting usage
    
    import dataclasses
    import functools
    import itertools
    import weakref
    from dataclasses import dataclass, field
    from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
    
    import torch
    import torch.distributed
    import torch.nn as nn
    import vllm.envs as envs
    from vllm.attention import AttentionMetadata, get_attn_backend
    from vllm.attention.backends.utils import CommonAttentionState
    from vllm.config import (
        CacheConfig,
        DeviceConfig,
        LoadConfig,
        LoRAConfig,
        ModelConfig,
        ObservabilityConfig,
        ParallelConfig,
        PromptAdapterConfig,
        SchedulerConfig,
    )
    from vllm.core.scheduler import SchedulerOutputs
    from vllm.distributed import get_pp_group
    from vllm.inputs import INPUT_REGISTRY, InputRegistry
    from vllm.logger import init_logger
    from vllm.lora.layers import LoRAMapping
    from vllm.lora.request import LoRARequest
    from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
    from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
    from vllm.model_executor.layers.sampler import (
        PromptLogprobs,
        SampleLogprobs,
        SamplerOutput,
        SamplingMetadata,
        get_logprobs,
        get_pythonized_sample_results,
    )
    from vllm.model_executor.model_loader.npu import get_model
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
    from vllm.model_executor.models.interfaces import supports_lora, supports_multimodal
    from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
    from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalRegistry
    from vllm.prompt_adapter.layers import PromptAdapterMapping
    from vllm.prompt_adapter.request import PromptAdapterRequest
    from vllm.prompt_adapter.worker_manager import LRUCacheWorkerPromptAdapterManager
    from vllm.sampling_params import SamplingParams
    from vllm.sequence import (
        CompletionSequenceGroupOutput,
        IntermediateTensors,
        Logprob,
        SequenceGroupMetadata,
        SequenceOutput,
    )
    from vllm.utils import DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_pin_memory_available
    from vllm.worker.model_runner_base import (
        BroadcastableModelInput,
        ModelRunnerBase,
        ModelRunnerInputBase,
        ModelRunnerInputBuilderBase,
        _add_attn_metadata_broadcastable_dict,
        _add_sampling_metadata_broadcastable_dict,
        _init_attn_metadata_from_tensor_dict,
        _init_frozen_model_input_from_tensor_dict,
        _init_sampling_metadata_from_tensor_dict,
    )
    
    from ..model_executor.model_loader.tensorizer import TensorizerConfig
    
    if TYPE_CHECKING:
        from vllm.attention.backends.abstract import AttentionBackend
    
    logger = init_logger(__name__)
    
    LORA_WARMUP_RANK = 8
    _BATCH_SIZE_ALIGNMENT = 8
    # all the token sizes that **can** be captured by cudagraph.
    # they can be arbitrarily large.
    # currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
    # the actual sizes to capture will be determined by the model,
    # depending on the model's max_num_seqs.
    # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
    _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)]
    MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer", "mindie-attn-backend"]
    
    TModelInputForNPU = TypeVar("TModelInputForNPU", bound="ModelInputForNPU")
    
    
    @dataclass(frozen=True)
    class ModelInputForNPU(ModelRunnerInputBase):
        """
        This base class contains metadata needed for the base model forward pass
        but not metadata for possible additional steps, e.g., sampling. Model
        runners that run additional steps should subclass this method to add
        additional fields.
        """
    
        input_tokens: Optional[torch.Tensor] = None
        input_positions: Optional[torch.Tensor] = None
        seq_lens: Optional[List[int]] = None
        query_lens: Optional[List[int]] = None
        lora_mapping: Optional["LoRAMapping"] = None
        lora_requests: Optional[List[LoRARequest]] = None
        attn_metadata: Optional["AttentionMetadata"] = None
        prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
        prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
        multi_modal_kwargs: Optional[BatchedTensorInputs] = None
        request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
        finished_requests_ids: Optional[List[str]] = None
        virtual_engine: int = 0
        async_callback: Optional[Callable] = None
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
        scheduler_outputs: Optional[SchedulerOutputs] = None
    
        @classmethod
        def from_broadcasted_tensor_dict(
            cls: Type[TModelInputForNPU],
            tensor_dict: Dict[str, Any],
            attn_backend: Optional["AttentionBackend"] = None,
        ) -> TModelInputForNPU:
            if attn_backend is not None:
                tensor_dict = _init_attn_metadata_from_tensor_dict(attn_backend, tensor_dict)
            return cls(**tensor_dict)
    
        def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
            tensor_dict = {
                "input_tokens": self.input_tokens,
                "input_positions": self.input_positions,
                "lora_requests": self.lora_requests,
                "lora_mapping": self.lora_mapping,
                "multi_modal_kwargs": self.multi_modal_kwargs,
                "prompt_adapter_mapping": self.prompt_adapter_mapping,
                "prompt_adapter_requests": self.prompt_adapter_requests,
                "virtual_engine": self.virtual_engine,
                "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
                "finished_requests_ids": self.finished_requests_ids,
            }
            _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
            return tensor_dict
    
    
    @dataclass(frozen=True)
    class ModelInputForNPUWithSamplingMetadata(ModelInputForNPU):
        """
        Used by the ModelRunner.
        """
    
        sampling_metadata: Optional["SamplingMetadata"] = None
        # Used for speculative decoding. We do not broadcast it because it is only
        # used by the driver worker.
        is_prompt: Optional[bool] = None
    
        @classmethod
        def from_broadcasted_tensor_dict(
            cls,
            tensor_dict: Dict[str, Any],
            attn_backend: Optional["AttentionBackend"] = None,
        ) -> "ModelInputForNPUWithSamplingMetadata":
            tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
            if attn_backend is not None:
                tensor_dict = _init_attn_metadata_from_tensor_dict(attn_backend, tensor_dict)
            return cls(**tensor_dict)
    
        def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
            tensor_dict = {
                "input_tokens": self.input_tokens,
                "input_positions": self.input_positions,
                "lora_requests": self.lora_requests,
                "lora_mapping": self.lora_mapping,
                "multi_modal_kwargs": self.multi_modal_kwargs,
                "prompt_adapter_mapping": self.prompt_adapter_mapping,
                "prompt_adapter_requests": self.prompt_adapter_requests,
                "virtual_engine": self.virtual_engine,
                "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
                "finished_requests_ids": self.finished_requests_ids,
            }
            _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
            _add_sampling_metadata_broadcastable_dict(tensor_dict, self.sampling_metadata)
            return tensor_dict
    
    
    class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
        """Build ModelInputForNPU from SequenceGroupMetadata."""
    
        # Note: ideally we would be using a dataclass(kw_only=True)
        # here, so that this can be subclassed easily,
        # but kw_only is not supported in python<3.10.
        class InterDataForSeqGroup:
            """Intermediate data for the current sequence group."""
    
            def __init__(
                self,
                *,
                # From sequence group metadata.
                request_id: str,
                seq_ids: List[int],
                is_prompt: bool,
                block_tables: Optional[Dict[int, List[int]]],
                computed_block_nums: List[int],
                n_seqs: int = 0,
                # Input tokens and positions.
                input_tokens: Optional[List[List[int]]] = None,
                input_positions: Optional[List[List[int]]] = None,
                # The sequence length (may be capped to the sliding window).
                seq_lens: Optional[List[int]] = None,
                # The original sequence length (before applying sliding window).
                # This is used to compute slot mapping.
                orig_seq_lens: Optional[List[int]] = None,
                # The query length.
                query_lens: Optional[List[int]] = None,
                # The number of tokens that are already computed.
                context_lens: Optional[List[int]] = None,
                # The current sliding window block.
                curr_sliding_window_blocks: Optional[List[int]] = None,
                # LoRA inputs.
                lora_index_mapping: Optional[List[List[int]]] = None,
                lora_prompt_mapping: Optional[List[List[int]]] = None,
                lora_requests: Optional[List[LoRARequest]] = None,
                # Prompt adapter inputs.
                prompt_adapter_index_mapping: Optional[List[int]] = None,
                prompt_adapter_prompt_mapping: Optional[List[int]] = None,
                prompt_adapter_request: Optional[PromptAdapterRequest] = None,
                # Multi-modal inputs.
                multi_modal_inputs: Optional[MultiModalInputs] = None,
                # Whether the prefix cache is hit (prefill only).
                prefix_cache_hit: bool = False,
                reinit: bool = False,
                reinit_use_defaults: bool = False,
            ):
                if reinit:
                    assert len(self.seq_ids) == len(seq_ids)  # type: ignore
                    for i, seq_id in enumerate(seq_ids):
                        self.seq_ids[i] = seq_id  # type: ignore
                else:
                    self.seq_ids = seq_ids
    
                self.request_id = request_id
                self.is_prompt = is_prompt
                self.block_tables = block_tables
                self.computed_block_nums = computed_block_nums
                self.n_seqs = n_seqs
    
                if reinit:
                    if len(self.seq_ids) == 1 and reinit_use_defaults:
                        self.simple_reinit()
                    else:
                        if input_tokens:
                            self.input_tokens = input_tokens
                        else:
                            for seq_id in range(len(self.seq_ids)):
                                self.input_tokens[seq_id].clear()
    
                        if input_positions:
                            self.input_positions = input_positions
                        else:
                            for seq_id in range(len(self.seq_ids)):
                                self.input_positions[seq_id].clear()
    
                        if seq_lens:
                            self.seq_lens = seq_lens
                        else:
                            for seq_id in range(len(self.seq_ids)):
                                self.seq_lens[seq_id] = 0
    
                        if orig_seq_lens:
                            self.orig_seq_lens = orig_seq_lens
                        else:
                            for seq_id in range(len(self.seq_ids)):
                                self.orig_seq_lens[seq_id] = 0
    
                        if query_lens:
                            self.query_lens = query_lens
                        else:
                            for seq_id in range(len(self.seq_ids)):
                                self.query_lens[seq_id] = 0
    
                        if context_lens:
                            self.context_lens = context_lens
                        else:
                            for seq_id in range(len(self.seq_ids)):
                                self.context_lens[seq_id] = 0
    
                        if curr_sliding_window_blocks:
                            self.curr_sliding_window_blocks = curr_sliding_window_blocks
                        else:
                            for seq_id in range(len(self.seq_ids)):
                                self.curr_sliding_window_blocks[seq_id] = 0
    
                        if lora_index_mapping:
                            self.lora_index_mapping = lora_index_mapping
                        else:
                            self.lora_index_mapping.clear()
    
                        if lora_prompt_mapping:
                            self.lora_prompt_mapping = lora_prompt_mapping
                        else:
                            self.lora_prompt_mapping.clear()
    
                        if lora_requests:
                            self.lora_requests = lora_requests
                        else:
                            self.lora_requests.clear()
    
                        if prompt_adapter_index_mapping:
                            self.prompt_adapter_index_mapping = prompt_adapter_index_mapping
                        else:
                            self.prompt_adapter_index_mapping.clear()
    
                        if prompt_adapter_prompt_mapping:
                            self.prompt_adapter_prompt_mapping = prompt_adapter_prompt_mapping
                        else:
                            self.prompt_adapter_prompt_mapping.clear()
    
                else:
                    self.input_tokens = input_tokens or []
                    self.input_positions = input_positions or []
                    self.seq_lens = seq_lens or []
                    self.orig_seq_lens = orig_seq_lens or []
                    self.query_lens = query_lens or []
                    self.context_lens = context_lens or []
                    self.curr_sliding_window_blocks = curr_sliding_window_blocks or []
    
                    self.lora_index_mapping = lora_index_mapping or []
                    self.lora_prompt_mapping = lora_prompt_mapping or []
                    self.lora_requests = lora_requests or []
    
                    self.prompt_adapter_index_mapping = prompt_adapter_index_mapping or []
                    self.prompt_adapter_prompt_mapping = prompt_adapter_prompt_mapping or []
    
                self.prompt_adapter_request = prompt_adapter_request
                self.multi_modal_inputs = multi_modal_inputs
                self.prefix_cache_hit = prefix_cache_hit
    
                self.n_seqs = len(self.seq_ids)
    
                if not reinit:
                    self.__post_init__()
    
            def __post_init__(self):
                self.n_seqs = len(self.seq_ids)
    
                self.input_tokens = [[] for _ in range(self.n_seqs)]
                self.input_positions = [[] for _ in range(self.n_seqs)]
                self.seq_lens = [0] * self.n_seqs
                self.orig_seq_lens = [0] * self.n_seqs
                self.query_lens = [0] * self.n_seqs
                self.context_lens = [0] * self.n_seqs
                self.curr_sliding_window_blocks = [0] * self.n_seqs
    
                self.lora_index_mapping = []
                self.lora_prompt_mapping = []
    
            def simple_reinit(self):
                self.input_tokens[0].clear()  # type: ignore
                self.input_positions[0].clear()  # type: ignore
                self.seq_lens[0] = 0  # type: ignore
                self.orig_seq_lens[0] = 0  # type: ignore
                self.query_lens[0] = 0  # type: ignore
                self.context_lens[0] = 0  # type: ignore
                self.curr_sliding_window_blocks[0] = 0  # type: ignore
                self.lora_index_mapping.clear()  # type: ignore
                self.lora_prompt_mapping.clear()  # type: ignore
                self.lora_requests.clear()  # type: ignore
                self.prompt_adapter_index_mapping.clear()  # type: ignore
                self.prompt_adapter_prompt_mapping.clear()  # type: ignore
    
        def __init__(self, runner: "NPUModelRunnerBase", finished_requests_ids: Optional[List[str]] = None):
            super().__init__()
            # Compute functions for each sequence in a sequence group.
            # WARNING: The order of the functions matters!
            self.per_seq_compute_fns = [
                self._compute_lens,
                self._compute_for_prefix_cache_hit,
                self._compute_for_sliding_window,
                self._compute_lora_input,
            ]
            # Compute functions for each sequence group.
            # WARNING: The order of the functions matters!
            self.per_seq_group_compute_fns = [
                self._compute_prompt_adapter_input,
                self._compute_multi_modal_input,
            ]
    
            self.runner = runner
            self.model_input_cls = self.runner._model_input_cls
            self.attn_backend = self.runner.attn_backend
            self.scheduler_config = self.runner.scheduler_config
            self.sliding_window = self.runner.sliding_window
            self.block_size = self.runner.block_size
            self.enable_lora = self.runner.lora_config is not None
            self.enable_prompt_adapter = self.runner.prompt_adapter_config is not None
            self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
            self.finished_requests_ids = finished_requests_ids
            self.decode_only = True
    
            # Intermediate data (data in CPU before going to NPU) for
            # the current sequence group.
            self.inter_data_list: List[ModelInputForNPUBuilder.InterDataForSeqGroup] = []
    
            # Attention metadata inputs.
            self.attn_metadata_builder = self.attn_backend.make_metadata_builder(weakref.proxy(self))
    
            # Engine/Model configurations.
            self.chunked_prefill_enabled = (
                self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled
            )
            if self.sliding_window is not None:
                self.sliding_window_blocks = (self.sliding_window + self.block_size - 1) // self.block_size
                self.block_aligned_sliding_window = self.sliding_window_blocks * self.block_size
    
        def gen_inter_data_builder(self, num_seqs: int):
            return lambda: ModelInputForNPUBuilder.InterDataForSeqGroup(
                request_id="", seq_ids=[0] * num_seqs, is_prompt=True, block_tables=None, computed_block_nums=[]
            )
    
        def init_cached_inter_data(self, *args, **kwargs):
            assert len(args) == 0
            assert "seq_ids" in kwargs
            seq_ids = kwargs.get("seq_ids")
            num_seqs = len(seq_ids)
    
            # The inter-data cache is per model_runner
            inter_data_cache = self.runner.inter_data_cache
            if num_seqs not in inter_data_cache:
                inter_data_cache[num_seqs] = PyObjectCache(self.gen_inter_data_builder(num_seqs))
    
            obj = inter_data_cache[num_seqs].get_object()
            obj.__init__(*args, **kwargs)
            return obj
    
        def reset_cached_inter_data(self):
            for cache in self.runner.inter_data_cache.values():
                cache.reset()
    
        def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
            """Add a sequence group to the builder."""
            seq_ids = seq_group_metadata.seq_data.keys()
            n_seqs = len(seq_ids)
            is_prompt = seq_group_metadata.is_prompt
    
            if is_prompt:
                assert n_seqs == 1
                self.decode_only = False
    
            inter_data = self.init_cached_inter_data(
                request_id=seq_group_metadata.request_id,
                seq_ids=seq_ids,
                is_prompt=is_prompt,
                block_tables=seq_group_metadata.block_tables,
                computed_block_nums=seq_group_metadata.computed_block_nums,
                reinit=True,
                reinit_use_defaults=True,
            )
    
            self.inter_data_list.append(inter_data)
    
            for seq_idx in range(n_seqs):
                for per_seq_fn in self.per_seq_compute_fns:
                    per_seq_fn(inter_data, seq_idx, seq_group_metadata)
            for per_seq_group_fn in self.per_seq_group_compute_fns:
                per_seq_group_fn(inter_data, seq_group_metadata)
    
        def build(self) -> ModelInputForNPU:
            """Finalize the builder intermediate data and
            create on-device tensors.
            """
            # Combine and flatten intermediate data.
            input_tokens = []
            for inter_data in self.inter_data_list:
                for cur_input_tokens in inter_data.input_tokens:
                    input_tokens.extend(cur_input_tokens)
    
            if not input_tokens:
                # This may happen when all prefill requests hit
                # prefix caching and there is no decode request.
                return self.model_input_cls()
    
            input_positions = []
            for inter_data in self.inter_data_list:
                for cur_input_positions in inter_data.input_positions:
                    input_positions.extend(cur_input_positions)
    
            seq_lens = []
            max_decode_seq_len = 0
            for inter_data in self.inter_data_list:
                seq_lens.extend(inter_data.seq_lens)
                if not inter_data.is_prompt:
                    max_decode_seq_len = max(max_decode_seq_len, max(inter_data.seq_lens))
            query_lens = []
            for inter_data in self.inter_data_list:
                query_lens.extend(inter_data.query_lens)
    
            # Mapping from request IDs to sequence IDs. Used for Jamba models
            # that manages the cache by itself.
            request_ids_to_seq_ids = {data.request_id: data.seq_ids for data in self.inter_data_list}
    
            batch_size = len(input_tokens)
            use_captured_graph = self._use_captured_graph(batch_size, max_decode_seq_len)
    
            # If cuda graph can be used, pad tensors accordingly.
            # See `capture_model` API for more details.
            # vLLM uses cuda graph only for decoding requests.
            cuda_graph_pad_size = -1
            if use_captured_graph:
                graph_batch_size = _get_graph_batch_size(batch_size)
                assert graph_batch_size >= batch_size
                cuda_graph_pad_size = graph_batch_size - batch_size
                batch_size = graph_batch_size
    
            # Tokens and positions.
            if cuda_graph_pad_size:
                input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
                input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
            assert self.runner.device is not None
            input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, self.runner.device, self.runner.pin_memory)
            input_positions_tensor = async_tensor_h2d(
                input_positions, torch.long, self.runner.device, self.runner.pin_memory
            )
    
            # Sequence and query lengths.
            if cuda_graph_pad_size:
                seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
    
            # Attention metadata.
            attn_metadata = self.attn_metadata_builder.build(seq_lens, query_lens, cuda_graph_pad_size, batch_size)
    
            # LoRA data.
            lora_requests = []
            lora_mapping = None
            if self.enable_lora:
                lora_requests = list(r if r else None for data in self.inter_data_list for r in data.lora_requests)
                lora_index_mapping = flatten_2d_lists(
                    [flatten_2d_lists(inter_data.lora_index_mapping) for inter_data in self.inter_data_list]
                )
                if cuda_graph_pad_size:
                    lora_index_mapping.extend(itertools.repeat(0, cuda_graph_pad_size))
                lora_prompt_mapping = flatten_2d_lists(
                    [flatten_2d_lists(inter_data.lora_prompt_mapping) for inter_data in self.inter_data_list]
                )
    
                lora_mapping = LoRAMapping(
                    **dict(
                        index_mapping=lora_index_mapping,
                        prompt_mapping=lora_prompt_mapping,
                        is_prefill=not self.decode_only,
                    )
                )
    
            # Prompt adapter data.
            prompt_adapter_requests: Set[PromptAdapterRequest] = set()
            prompt_adapter_mapping = None
            if self.enable_prompt_adapter:
                prompt_adapter_requests = set(
                    data.prompt_adapter_request for data in self.inter_data_list if data.prompt_adapter_request is not None
                )
                prompt_adapter_index_mapping = flatten_2d_lists(
                    [inter_data.prompt_adapter_index_mapping for inter_data in self.inter_data_list]
                )
                if cuda_graph_pad_size:
                    prompt_adapter_index_mapping.extend(itertools.repeat(0, cuda_graph_pad_size))
                prompt_adapter_prompt_mapping = flatten_2d_lists(
                    [inter_data.prompt_adapter_prompt_mapping for inter_data in self.inter_data_list]
                )
                prompt_adapter_mapping = PromptAdapterMapping(
                    prompt_adapter_index_mapping,
                    prompt_adapter_prompt_mapping,
                )
    
            # Multi-modal data.
            multi_modal_inputs_list = [
                data.multi_modal_inputs for data in self.inter_data_list if data.multi_modal_inputs is not None
            ]
            multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
    
            return self.model_input_cls(
                input_tokens=input_tokens_tensor,
                input_positions=input_positions_tensor,
                attn_metadata=attn_metadata,
                seq_lens=seq_lens,
                query_lens=query_lens,
                lora_mapping=lora_mapping,
                lora_requests=lora_requests,
                multi_modal_kwargs=multi_modal_kwargs,
                request_ids_to_seq_ids=request_ids_to_seq_ids,
                finished_requests_ids=self.finished_requests_ids,
                prompt_adapter_mapping=prompt_adapter_mapping,
                prompt_adapter_requests=prompt_adapter_requests,
            )
    
        def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, seq_group_metadata: SequenceGroupMetadata):
            """Compute context length, sequence length and tokens
            for the given sequence data.
            """
            seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
            token_chunk_size = seq_group_metadata.token_chunk_size
    
            # Compute context length (the number of tokens that are
            # already computed) and sequence length (total number of tokens).
            seq_len = seq_data.get_len()
            if inter_data.is_prompt:
                context_len = seq_data.get_num_computed_tokens()
            else:
                # get_num_computed_tokens is incorrect for spec decoding.
                # So, we should have a special logic here.
                # TODO(sang): Fix it.
                context_len = seq_len - 1
            seq_len = min(seq_len, context_len + token_chunk_size)
    
            # Compute tokens.
            if inter_data.is_prompt:
                tokens = seq_data.get_token_ids()
                if context_len != 0 or seq_len < len(tokens):
                    tokens = tokens[context_len:seq_len]
            else:
                # Optimization. get_token_ids requires the entire copy of
                # tokens.
                tokens = seq_data.get_last_token_id()
    
            inter_data.seq_lens[seq_idx] = seq_len
            inter_data.orig_seq_lens[seq_idx] = seq_len
            inter_data.context_lens[seq_idx] = context_len
    
            if isinstance(tokens, list):
                inter_data.input_tokens[seq_idx].extend(tokens)
            else:
                inter_data.input_tokens[seq_idx].append(tokens)
    
            if (seq_len - context_len) == 1:
                inter_data.input_positions[seq_idx].append(seq_len - 1)
            else:
                inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
    
            inter_data.query_lens[seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
    
        def _compute_for_prefix_cache_hit(
            self, inter_data: InterDataForSeqGroup, seq_idx: int, seq_group_metadata: SequenceGroupMetadata
        ):
            """Check if hit prefix cache (i.e., some blocks are already computed).
            If hit, update input tokens and positions to only compute the
            remaining blocks.
            """
            computed_block_nums = inter_data.computed_block_nums
    
            # Note that prefix caching does not support sliding window.
            prefix_cache_hit = (
                computed_block_nums is not None
                and len(computed_block_nums) > 0
                and self.sliding_window is None
                and inter_data.is_prompt
            )
            inter_data.prefix_cache_hit = prefix_cache_hit
    
            if not prefix_cache_hit:
                return
    
            assert computed_block_nums is not None
            # The cache hit prompt tokens in this sequence. Note that
            # this may be larger than the sequence length if chunked
            # prefill is enabled.
            prefix_cache_len = len(computed_block_nums) * self.block_size
            # The number of so far computed prompt tokens in this sequence.
            context_len = inter_data.context_lens[seq_idx]
            # The total number of prompt tokens in this sequence.
            # When chunked prefill is enabled, this is the token number of
            # computed chunks + current chunk.
            seq_len = inter_data.seq_lens[seq_idx]
            if prefix_cache_len <= context_len:
                # We already passed the cache hit region,
                # so do normal computation.
                pass
            elif context_len < prefix_cache_len < seq_len:
                # Partial hit. Compute the missing part.
                uncomputed_start = prefix_cache_len - context_len
                inter_data.input_tokens[seq_idx] = inter_data.input_tokens[seq_idx][uncomputed_start:]
                inter_data.input_positions[seq_idx] = inter_data.input_positions[seq_idx][uncomputed_start:]
                context_len = prefix_cache_len
    
                inter_data.context_lens[seq_idx] = context_len
                inter_data.query_lens[seq_idx] = inter_data.seq_lens[seq_idx] - context_len
            elif seq_len <= prefix_cache_len:
                # Full hit. Only compute the last token to avoid
                # erroneous behavior. FIXME: Ideally we should directly
                # mark all tokens as computed in the scheduler and do not
                # schedule this sequence, so this case should not happen.
                inter_data.input_tokens[seq_idx] = inter_data.input_tokens[seq_idx][-1:]
                inter_data.input_positions[seq_idx] = inter_data.input_positions[seq_idx][-1:]
                inter_data.query_lens[seq_idx] = 1
                inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
    
        def _compute_for_sliding_window(
            self, inter_data: InterDataForSeqGroup, seq_idx: int, seq_group_metadata: SequenceGroupMetadata
        ):
            """Update seq_len and curr_sliding_window_block for the given
            sequence data (only required by decoding) if sliding window is enabled.
            """
            curr_sliding_window_block = 0
            sliding_seq_len = inter_data.seq_lens[seq_idx]
            if not inter_data.is_prompt and self.sliding_window is not None:
                # TODO(sang): This is a hack to make sliding window work with
                # paged attn. We can remove it if we make paged attn kernel
                # to properly handle slinding window attn.
                curr_sliding_window_block = self.sliding_window_blocks
                if self.scheduler_config.use_v2_block_manager:
                    # number of elements in last block
                    suff_len = inter_data.seq_lens[seq_idx] % self.block_size
                    sliding_seq_len = min(inter_data.seq_lens[seq_idx], self.block_aligned_sliding_window + suff_len)
                    if suff_len > 0:
                        curr_sliding_window_block += 1
                else:
                    sliding_seq_len = min(inter_data.seq_lens[seq_idx], self.sliding_window)
    
            inter_data.curr_sliding_window_blocks[seq_idx] = curr_sliding_window_block
            inter_data.seq_lens[seq_idx] = sliding_seq_len
    
        def _compute_lora_input(
            self, inter_data: InterDataForSeqGroup, seq_idx: int, seq_group_metadata: SequenceGroupMetadata
        ):
            """If LoRA is enabled, compute LoRA index and prompt mapping."""
            if not self.enable_lora:
                return
    
            lora_id = seq_group_metadata.lora_int_id
            if lora_id > 0:
                inter_data.lora_requests.append(seq_group_metadata.lora_request)
            query_len = inter_data.query_lens[seq_idx]
            inter_data.lora_index_mapping.append([lora_id] * query_len)
            inter_data.lora_prompt_mapping.append(
                [lora_id]
                * (
                    query_len
                    if seq_group_metadata.sampling_params and seq_group_metadata.sampling_params.prompt_logprobs is not None
                    else 1
                )
            )
    
        def _compute_prompt_adapter_input(
            self, inter_data: InterDataForSeqGroup, seq_group_metadata: SequenceGroupMetadata
        ):
            """If prompt adapter is enabled, compute index and prompt mapping."""
            # Note that when is_prompt=True, we expect only one sequence
            # in the group.
            if not self.enable_prompt_adapter:
                return
    
            prompt_adapter_id = seq_group_metadata.prompt_adapter_id
            if prompt_adapter_id <= 0 or not inter_data.is_prompt:
                return
    
            # We expect only one sequence in the group when is_prompt=True.
            assert inter_data.n_seqs == 1
            query_len = inter_data.query_lens[0]
            inter_data.prompt_adapter_request = seq_group_metadata.prompt_adapter_request
    
            num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
            inter_data.prompt_adapter_index_mapping = [prompt_adapter_id] * num_tokens + [0] * (query_len - num_tokens)
            inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
                query_len
                if seq_group_metadata.sampling_params and seq_group_metadata.sampling_params.prompt_logprobs
                else 1
            )
    
        def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, seq_group_metadata: SequenceGroupMetadata):
            """If multi-modal data is given, add it to the input."""
            mm_data = seq_group_metadata.multi_modal_data
            if not mm_data:
                return
    
            mm_kwargs = self.multi_modal_input_mapper(mm_data)
            inter_data.multi_modal_inputs = mm_kwargs
    
        def _use_captured_graph(self, batch_size: int, max_decode_seq_len: int) -> bool:
            return (
                self.decode_only
                and not self.runner.model_config.enforce_eager
                and batch_size <= self.runner.max_batchsize_to_capture
                and max_decode_seq_len <= self.runner.max_seq_len_to_capture
            )
    
    
    class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
        """
        Helper class for shared methods between NPU model runners.
        """
    
        _model_input_cls: Type[TModelInputForNPU]
        _builder_cls: Type[ModelInputForNPUBuilder]
    
        def __init__(
            self,
            model_config: ModelConfig,
            parallel_config: ParallelConfig,
            scheduler_config: SchedulerConfig,
            device_config: DeviceConfig,
            cache_config: CacheConfig,
            load_config: LoadConfig,
            lora_config: Optional[LoRAConfig],
            mindie_config: Dict[str, Any],
            kv_cache_dtype: Optional[str] = "auto",
            is_driver_worker: bool = False,
            prompt_adapter_config: Optional[PromptAdapterConfig] = None,
            return_hidden_states: bool = False,
            observability_config: Optional[ObservabilityConfig] = None,
            input_registry: InputRegistry = INPUT_REGISTRY,
            mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
        ):
            self.model_config = model_config
            self.parallel_config = parallel_config
            self.scheduler_config = scheduler_config
            self.device_config = device_config
            self.cache_config = cache_config
            self.lora_config = lora_config
            self.load_config = load_config
            self.is_driver_worker = is_driver_worker
            self.prompt_adapter_config = prompt_adapter_config
            self.return_hidden_states = return_hidden_states
            self.observability_config = observability_config
            self.mindie_config = mindie_config
    
            self.device = self.device_config.device
            self.pin_memory = is_pin_memory_available()
    
            self.kv_cache_dtype = kv_cache_dtype
            self.sliding_window = model_config.get_sliding_window()
            self.block_size = cache_config.block_size
            self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
            self.max_batchsize_to_capture = _get_max_graph_batch_size(self.scheduler_config.max_num_seqs)
    
            self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(parallel_config)
    
            num_attn_heads = self.model_config.get_num_attention_heads(self.parallel_config)
            self.attn_backend = (
                get_attn_backend(
                    num_attn_heads,
                    self.model_config.get_head_size(),
                    self.model_config.get_num_kv_heads(self.parallel_config),
                    self.model_config.get_sliding_window(),
                    self.model_config.dtype,
                    self.kv_cache_dtype,
                    self.block_size,
                )
                if num_attn_heads
                else None
            )
            if self.attn_backend:
                self.attn_state = self.attn_backend.get_state_cls()(weakref.proxy(self))
            else:
                self.attn_state = CommonAttentionState(weakref.proxy(self))
    
            # Multi-modal data support
            self.input_registry = input_registry
            self.mm_registry = mm_registry
            self.multi_modal_input_mapper = mm_registry.create_input_mapper(model_config)
            self.mm_registry.init_mm_limits_per_prompt(self.model_config)
    
            # Lazy initialization
            self.model: nn.Module  # Set after load_model
            # Set after load_model.
            self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
            self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
    
            set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3))
    
            # Used to cache python objects
            self.inter_data_cache: Dict[int, PyObjectCache] = {}
            self.sampling_metadata_cache: SamplingMetadataCache = SamplingMetadataCache()
    
        @property
        def vocab_size(self) -> int:
            return self.model_config.get_vocab_size()
    
        def load_model(self) -> None:
            logger.info("Starting to load model %s...", self.model_config.model)
            with DeviceMemoryProfiler() as m:
                self.model = get_model(
                    model_config=self.model_config,
                    device_config=self.device_config,
                    load_config=self.load_config,
                    mindie_config=self.mindie_config,
                    lora_config=self.lora_config,
                    parallel_config=self.parallel_config,
                    scheduler_config=self.scheduler_config,
                    cache_config=self.cache_config,
                )
    
            self.model_memory_usage = m.consumed_memory
            logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30))
    
            if self.lora_config:
                assert supports_lora(self.model), "Model does not support LoRA"
                assert not supports_multimodal(self.model), "To be tested: Multi-modal model with LoRA settings."
    
                logger.info("LoRA manager will be initialized in the MindIE backend.")
    
            # TODO: What is this prompt adapter?
            if self.prompt_adapter_config:
                self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
                    self.scheduler_config.max_num_seqs,
                    self.scheduler_config.max_num_batched_tokens,
                    self.device,
                    self.prompt_adapter_config,
                )
                self.model = self.prompt_adapter_manager.create_prompt_adapter_manager(self.model)
    
        def save_sharded_state(
            self,
            path: str,
            pattern: Optional[str] = None,
            max_size: Optional[int] = None,
        ) -> None:
            from vllm.model_executor.model_loader.loader import ShardedStateLoader
    
            ShardedStateLoader.save_model(
                self.model,
                path,
                pattern=pattern,
                max_size=max_size,
            )
    
        def save_tensorized_model(
            self,
            tensorizer_config: TensorizerConfig,
        ) -> None:
            from vllm.model_executor.model_loader.loader import TensorizerLoader
    
            TensorizerLoader.save_model(
                self.model,
                tensorizer_config=tensorizer_config,
            )
    
        def get_max_block_per_batch(self) -> int:
            block_size = self.block_size
            return (self.max_seq_len_to_capture + block_size - 1) // block_size
    
        def _prepare_model_input_tensors(
            self, seq_group_metadata_list: List[SequenceGroupMetadata], finished_requests_ids: Optional[List[str]] = None
        ) -> TModelInputForNPU:
            """Helper method to prepare the model input based on a given sequence
            group. Prepares metadata needed for the base model forward pass but not
            metadata for possible additional steps, e.g., sampling.
    
            The API assumes seq_group_metadata_list is sorted by prefill -> decode.
    
            The result tensors and data structure also batches input in prefill
            -> decode order. For example,
    
            - input_tokens[:num_prefill_tokens] contains prefill tokens.
            - input_tokens[num_prefill_tokens:] contains decode tokens.
    
            If cuda graph is required, this API automatically pads inputs.
            """
            builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
            for seq_group_metadata in seq_group_metadata_list:
                builder.add_seq_group(seq_group_metadata)
    
            builder.reset_cached_inter_data()
    
            return builder.build()  # type: ignore
    
        @torch.inference_mode()
        def profile_run(self) -> None:
            # Enable top-k sampling to reflect the accurate memory usage.
            sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
            max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
            max_num_seqs = self.scheduler_config.max_num_seqs
            # This represents the maximum number of different requests
            # that will have unique loras, an therefore the max amount of memory
            # consumption create dummy lora request copies from the lora request
            # passed in, which contains a lora from the lora warmup path.
            dummy_lora_requests: List[LoRARequest] = []
            dummy_lora_requests_per_seq: List[LoRARequest] = []
            if self.lora_config:
                assert self.lora_manager is not None
                with self.lora_manager.dummy_lora_cache():
                    for idx in range(self.lora_config.max_loras):
                        lora_id = idx + 1
                        dummy_lora_request = LoRARequest(
                            lora_name=f"warmup_{lora_id}",
                            lora_int_id=lora_id,
                            lora_path="/not/a/real/path",
                        )
                        self.lora_manager.add_dummy_lora(dummy_lora_request, rank=LORA_WARMUP_RANK)
                        dummy_lora_requests.append(dummy_lora_request)
                    dummy_lora_requests_per_seq = [
                        dummy_lora_requests[idx % len(dummy_lora_requests)] for idx in range(max_num_seqs)
                    ]
    
            # Profile memory usage with max_num_sequences sequences and the total
            # number of tokens equal to max_num_batched_tokens.
            seqs: List[SequenceGroupMetadata] = []
            # TODO: support MM models
            # Additional NPU memory may be needed for multi-modal encoding, which
            # needs to be accounted for when calculating the NPU blocks for
            # vLLM blocker manager.
            # To exercise the worst scenario for NPU memory consumption,
            # the number of seqs (batch_size) is chosen to maximize the number
            # of images processed.
    
            batch_size = 0
            for group_id in range(max_num_seqs):
                seq_len = max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)
                batch_size += seq_len
    
                seq_data, dummy_multi_modal_data = self.input_registry.dummy_data_for_profiling(
                    self.model_config, seq_len, self.mm_registry
                )
    
                seq = SequenceGroupMetadata(
                    request_id=str(group_id),
                    is_prompt=True,
                    seq_data={group_id: seq_data},
                    sampling_params=sampling_params,
                    block_tables=None,
                    lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None,
                    multi_modal_data=dummy_multi_modal_data,
                )
                seqs.append(seq)
    
            # Run the model with the dummy inputs.
            num_layers = self.model_config.get_num_layers(self.parallel_config)
            kv_caches = [None] * num_layers
            finished_requests_ids = [seq.request_id for seq in seqs]
            model_input = self.prepare_model_input(seqs, finished_requests_ids=finished_requests_ids)
            intermediate_tensors = None
            self.execute_model(model_input, kv_caches, intermediate_tensors)
            torch.npu.synchronize()
            return
    
        def remove_all_loras(self): ...
    
        def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: ...
    
        def add_lora(self, lora_request: LoRARequest) -> bool: ...
    
        def remove_lora(self, lora_id: int) -> bool: ...
    
        def pin_lora(self, lora_id: int) -> bool: ...
    
        def list_loras(self) -> Set[int]: ...
    
        def remove_all_prompt_adapters(self): ...
    
        def set_active_prompt_adapters(
            self, prompt_adapter_requests: Set[PromptAdapterRequest], prompt_adapter_mapping: PromptAdapterMapping
        ) -> None: ...
    
        def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool: ...
    
        def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: ...
    
        def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: ...
    
        def list_prompt_adapters(self) -> Set[int]: ...
    
    
    class ModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
        """
        NPU model runner with sampling step.
        """
    
        _model_input_cls: Type[ModelInputForNPUWithSamplingMetadata] = ModelInputForNPUWithSamplingMetadata
        _builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder
    
        def make_model_input_from_broadcasted_tensor_dict(
            self,
            tensor_dict: Dict[str, Any],
        ) -> ModelInputForNPUWithSamplingMetadata:
            model_input = ModelInputForNPUWithSamplingMetadata.from_broadcasted_tensor_dict(
                tensor_dict,
                attn_backend=self.attn_backend,
            )
            return model_input
    
        def prepare_model_input(
            self,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            virtual_engine: int = 0,
            finished_requests_ids: Optional[List[str]] = None,
        ) -> ModelInputForNPUWithSamplingMetadata:
            """Prepare the model input based on a given sequence group, including
            metadata for the sampling step.
    
            The API assumes seq_group_metadata_list is sorted by prefill -> decode.
    
            The result tensors and data structure also batches input in prefill
            -> decode order. For example,
    
            - input_tokens[:num_prefill_tokens] contains prefill tokens.
            - input_tokens[num_prefill_tokens:] contains decode tokens.
    
            If cuda graph is required, this API automatically pads inputs.
            """
            model_input = self._prepare_model_input_tensors(seq_group_metadata_list, finished_requests_ids)
            if get_pp_group().is_last_rank:
                # Sampling metadata is only required for the final pp group
                generators = self.get_generators(finished_requests_ids)
                sampling_metadata = SamplingMetadata.prepare(
                    seq_group_metadata_list,
                    model_input.seq_lens,
                    model_input.query_lens,
                    self.device,
                    self.pin_memory,
                    generators,
                    self.sampling_metadata_cache,
                )
            else:
                sampling_metadata = None
            is_prompt = seq_group_metadata_list[0].is_prompt if seq_group_metadata_list else None
            return dataclasses.replace(
                model_input, sampling_metadata=sampling_metadata, is_prompt=is_prompt, virtual_engine=virtual_engine
            )
    
        @torch.inference_mode()
        def execute_model(
            self,
            model_input: ModelInputForNPUWithSamplingMetadata,
            kv_caches: List[torch.Tensor],
            intermediate_tensors: Optional[IntermediateTensors] = None,
            num_steps: int = 1,
        ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
            if num_steps > 1:
                raise ValueError("num_steps > 1 is not supported in ModelRunner")
    
            self.attn_state.begin_forward(model_input)
    
            # Currently cuda graph is only supported by the decode phase.
            assert model_input.attn_metadata is not None
            prefill_meta = model_input.attn_metadata.prefill_metadata
            decode_meta = model_input.attn_metadata.decode_metadata
            # TODO(andoorve): We can remove this once all
            # virtual engines share the same kv cache.
            virtual_engine = model_input.virtual_engine
            if prefill_meta is None and decode_meta.use_cuda_graph:
                assert model_input.input_tokens is not None
                graph_batch_size = model_input.input_tokens.shape[0]
                model_executable = self.graph_runners[virtual_engine][graph_batch_size]
            else:
                model_executable = self.model
    
            multi_modal_kwargs = model_input.multi_modal_kwargs or {}
            seqlen_agnostic_kwargs = (
                {
                    "finished_requests_ids": model_input.finished_requests_ids,
                    "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
                }
                if self.has_seqlen_agnostic
                else {}
            )
            if self.observability_config is not None and self.observability_config.collect_model_forward_time:
                model_forward_start = torch.npu.streams.Event(enable_timing=True)
                model_forward_end = torch.npu.streams.Event(enable_timing=True)
                model_forward_start.record()
    
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                kv_caches=kv_caches,
                attn_metadata=model_input.attn_metadata,
                intermediate_tensors=intermediate_tensors,
                lora_requests=model_input.lora_requests,
                **MultiModalInputs.as_kwargs(multi_modal_kwargs, device=self.device),
                **seqlen_agnostic_kwargs,
            )
    
            if self.observability_config is not None and self.observability_config.collect_model_forward_time:
                model_forward_end.record()
    
            # Compute the logits in the last pipeline stage.
            if not get_pp_group().is_last_rank:
                if (
                    self.is_driver_worker
                    and hidden_or_intermediate_states is not None
                    and isinstance(hidden_or_intermediate_states, IntermediateTensors)
                    and self.observability_config is not None
                    and self.observability_config.collect_model_forward_time
                ):
                    model_forward_end.synchronize()
                    model_forward_time = model_forward_start.elapsed_time(model_forward_end)
                    orig_model_forward_time = 0.0
                    if intermediate_tensors is not None:
                        orig_model_forward_time = intermediate_tensors.tensors.get(
                            "model_forward_time", torch.tensor(0.0)
                        ).item()
                    hidden_or_intermediate_states.tensors["model_forward_time"] = torch.tensor(
                        model_forward_time + orig_model_forward_time
                    )
                return hidden_or_intermediate_states
    
            logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata)
    
            if not self.is_driver_worker:
                return []
    
            if model_input.async_callback is not None:
                model_input.async_callback()
    
            # Sample the next token.
            output: SamplerOutput = self.model.sample(
                logits=logits,
                sampling_metadata=model_input.sampling_metadata,
            )
            if (
                self.observability_config is not None
                and self.observability_config.collect_model_forward_time
                and output is not None
            ):
                model_forward_end.synchronize()
                model_forward_time = model_forward_start.elapsed_time(model_forward_end)
                orig_model_forward_time = 0.0
                if intermediate_tensors is not None:
                    orig_model_forward_time = intermediate_tensors.tensors.get(
                        "model_forward_time", torch.tensor(0.0)
                    ).item()
                # If there are multiple workers, we are still tracking the latency
                # from the start time of the driver worker to the end time of the
                # driver worker. The model forward time will then end up covering
                # the communication time as well.
                output.model_forward_time = orig_model_forward_time + model_forward_time
    
            if self.return_hidden_states:
                # we only need to pass hidden states of most recent token
                assert model_input.sampling_metadata is not None
                indices = model_input.sampling_metadata.selected_token_indices
                if model_input.is_prompt:
                    hidden_states = hidden_or_intermediate_states.index_select(0, indices)
                    output.prefill_hidden_states = hidden_or_intermediate_states
                elif decode_meta.use_cuda_graph:
                    hidden_states = hidden_or_intermediate_states[: len(indices)]
                else:
                    hidden_states = hidden_or_intermediate_states
    
                output.hidden_states = hidden_states
    
            return [output]
    
    
    def _get_graph_batch_size(batch_size: int) -> int:
        """Returns the padded batch size given actual batch size.
    
        Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
        2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
        """
        if batch_size <= 2:
            return batch_size
        elif batch_size <= 4:
            return 4
        else:
            return (batch_size + _BATCH_SIZE_ALIGNMENT - 1) // _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT
    
    
    def _get_max_graph_batch_size(max_num_seqs: int) -> int:
        """
        max_num_seqs: Maximum number of sequences in a batch.
        _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
    
        pad the max_num_seqs if necessary by calling _get_graph_batch_size,
        which will deal with some edge cases like 1, 2, 4.
    
        if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
        if not, it means the padded size is larger than the largest size in
        _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
        """
        padded_size = _get_graph_batch_size(max_num_seqs)
        if padded_size in _BATCH_SIZES_TO_CAPTURE:
            return padded_size
        assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
        return _BATCH_SIZES_TO_CAPTURE[-1]
    
    
    def seq_output_builder():
        return SequenceOutput(0, 0, {0: Logprob(logprob=float("inf"), rank=None, decoded_token=None)})
    
    
    def completion_seq_group_output_builder():
        return CompletionSequenceGroupOutput([], None)
    
    
    class PythonizationCache:
    
        def __init__(self):
            self.cached_seq_output = PyObjectCache(seq_output_builder)
            self.cached_completion_seq_group_output = PyObjectCache(completion_seq_group_output_builder)
    
        def reset(self):
            self.cached_seq_output.reset()
            self.cached_completion_seq_group_output.reset()
    
    
    @dataclass
    class ModelOutput:
        sampler_output: SamplerOutput
        sampler_output_ready_event: torch.npu.streams.Event
        sampled_token_ids: Optional[torch.Tensor] = None
        pythonized: bool = False
        # On-device tensor containing the logprobs of each token.
        logprobs: Optional["torch.Tensor"] = None
        pythonization_cache: Optional[PythonizationCache] = None
    
        def pythonize(
            self,
            input_metadata: "StatefulModelInput",
            copy_stream: torch.npu.streams.Stream,
            pinned_sampled_token_buffer: torch.Tensor,
        ) -> None:
            """Pythonize the output. Blocking."""
            if not self.pythonized:
                self._pythonize_sampler_output(input_metadata, copy_stream, pinned_sampled_token_buffer, True)
                self.pythonized = True
    
        def maybe_pythonize(
            self,
            input_metadata: "StatefulModelInput",
            copy_stream: torch.npu.streams.Stream,
            pinned_sampled_token_buffer: torch.Tensor,
        ) -> None:
            """Pythonize the output if ready, else return None. Non-blocking."""
            if not self.pythonized:
                self.pythonized = self._pythonize_sampler_output(
                    input_metadata, copy_stream, pinned_sampled_token_buffer, False
                )
    
        def _pythonize_sampler_output(
            self,
            input_metadata: "StatefulModelInput",
            copy_stream: torch.npu.streams.Stream,
            pinned_sampled_token_buffer: torch.Tensor,
            blocking: bool,
        ) -> bool:
            """
            If blocking is set, will block until the forward pass for the output is
            ready and pythonize the output. Upon completing Pythonization, erases
            self.logprobs (note that a non-blocking call that is performed when
            the sampler output is not yet ready, will not erase self.logprobs.)
            """
            assert self.sampled_token_ids is not None
            if not blocking and not self.sampler_output_ready_event.query():
                return False
    
            if blocking:
                self.sampler_output_ready_event.synchronize()
            with torch.npu.utils.stream(copy_stream):
                _pythonize_sampler_output(
                    input_metadata,
                    self.sampler_output,
                    pinned_sampled_token_buffer,
                    self.sampled_token_ids,
                    self.logprobs,
                    self.pythonization_cache,
                )
            self.logprobs = None
            return True
    
    
    @dataclass(frozen=False)
    class StatefulModelInput(BroadcastableModelInput):
        # actual frozen model input dataclass passed to _base_model_runner
        frozen_model_input: Optional[ModelInputForNPUWithSamplingMetadata] = None
    
        # list of model outputs for each step, may not be all pythonized
        cached_outputs: List[ModelOutput] = field(default_factory=list)
    
        # used to pass sampled token ids from the last step to the current step for
        # TP workers. Used to append to end of outputs and used by advance_step
        last_sampled_token_ids: Optional[torch.Tensor] = None
        current_step: int = 0
        is_multi_step: bool = True
        is_last_step: bool = False
        is_first_multi_step: bool = False
        # ping-pong data structures for multi-step to wait on the previous step
        step_npu_events: List[torch.npu.streams.Event] = field(
            default_factory=lambda: [torch.npu.streams.Event(blocking=True)] * 2
        )
        num_seqs: int = -1
        num_queries: int = -1
    
        @classmethod
        def from_broadcasted_tensor_dict(
            cls,
            tensor_dict: Dict[str, Any],
            attn_backend: Optional["AttentionBackend"] = None,
        ) -> "StatefulModelInput":
            tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
            if attn_backend is not None:
                tensor_dict = _init_attn_metadata_from_tensor_dict(attn_backend, tensor_dict)
            tensor_dict = _init_frozen_model_input_from_tensor_dict(ModelInputForNPUWithSamplingMetadata, tensor_dict)
    
            return cls(**tensor_dict)
    
        def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
            assert self.frozen_model_input is not None
            tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict()
            new_tensor_dict = {
                "last_sampled_token_ids": self.last_sampled_token_ids,
                "current_step": self.current_step,
                "is_multi_step": self.is_multi_step,
                "is_last_step": self.is_last_step,
                "is_first_multi_step": self.is_first_multi_step,
                "num_seqs": self.num_seqs,
                "num_queries": self.num_queries,
            }
            tensor_dict.update(new_tensor_dict)
            return tensor_dict
    
        def record_step_event(self, current_stream: torch.npu.streams.Stream):
            # record the event for the current step so that the next step can sync
            # on it. We modulo by 2 to keep the events in a circular buffer and
            # support any attn backends that may be supported in the future. ie
            # Flashinfer would want two DecodeWrappers to overlap the CPU and NPU.
            self.step_npu_events[self.current_step & 1] = torch.npu.streams.Event(blocking=True)
            self.step_npu_events[self.current_step & 1].record(current_stream)
    
        def wait_previous_step(self):
            # These cuda events are an explicit synchronization to ensure that
            # advance_step() (for other attn backends that may be supported in the
            # future) do not clobber any data structures that is also used by any
            # enqueued forwards steps. For distributed case, only a single event is
            # needed, but for single NPU case, since we can let the CPU run much
            # further ahead, two events allow us to overlap the advance_step with
            # the previous forward (ie using two DecodeWrappers for flashinfer
            # backend)
            self.step_npu_events[(self.current_step + 1) & 1].wait()
    
        def add_sampler_output(self, sampler_output: SamplerOutput, sampled_token_ids: Optional[torch.Tensor] = None):
            self.cached_outputs.append(
                ModelOutput(
                    sampler_output=sampler_output,
                    sampler_output_ready_event=None,
                    sampled_token_ids=sampled_token_ids,
                    pythonized=False,
                )
            )
    
    
    class MultiStepNPUModelRunner(NPUModelRunnerBase[StatefulModelInput]):
    
        def __init__(self, base_model_runner: NPUModelRunnerBase, *args, **kwargs):
            super().__init__(*args, **kwargs)
    
            # uses the base model runner to execute the model and wraps it with
            # multi-step logic
            self._base_model_runner: NPUModelRunnerBase = base_model_runner
    
            self.is_multi_step = self.scheduler_config.is_multi_step
            self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
    
            self.pythonization_cache = PythonizationCache()
    
        @property
        def vocab_size(self) -> int:
            return self._base_model_runner.vocab_size
    
        def make_model_input_from_broadcasted_tensor_dict(self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
            model_input = StatefulModelInput.from_broadcasted_tensor_dict(
                tensor_dict,
                attn_backend=self.attn_backend,
            )
            return model_input
    
        def prepare_model_input(
            self,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            virtual_engine: int = 0,
            finished_requests_ids: Optional[List[str]] = None,
        ) -> StatefulModelInput:
            frozen_model_input = self._base_model_runner.prepare_model_input(
                seq_group_metadata_list, virtual_engine, finished_requests_ids
            )
    
            model_input = StatefulModelInput(
                frozen_model_input=frozen_model_input,
                num_seqs=len(frozen_model_input.seq_lens),
                num_queries=len(frozen_model_input.query_lens),
            )
            return model_input
    
        @torch.inference_mode()
        def execute_model(
            self,
            model_input: StatefulModelInput,
            kv_caches: List[torch.Tensor],
            intermediate_tensors: Optional[IntermediateTensors] = None,
            num_steps: int = 1,
        ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
            """
            Execute the model for a single step and update multi-step
            metadata
            """
            assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
            frozen_model_input = model_input.frozen_model_input
            assert frozen_model_input is not None
    
            # path for warm up runs
            if not model_input.is_multi_step:
                return self._base_model_runner.execute_model(frozen_model_input, kv_caches, intermediate_tensors, num_steps)
    
            # make sure we skip the sampler on the lask rank and only pythonize
            # if CPU is ahead.
            if self.is_driver_worker and get_pp_group().is_last_rank:
                if self.pinned_sampled_token_ids is None:
                    self.pinned_sampled_token_ids = torch.zeros(
                        (self.scheduler_config.max_num_seqs, 1), dtype=torch.long, device="cpu", pin_memory=True
                    )
    
                self._base_model_runner.model.sampler.include_gpu_probs_tensor = True
                if frozen_model_input.sampling_metadata:
                    frozen_model_input.sampling_metadata.skip_sampler_cpu_output = True
    
            # some pre-execute model logic for multi-step:
            #   - if it's the first step, we need to reset the sampling tensors
            #   - if it's not the first step, we need to advance the step using the
            #   appended sampler output from last iteration
            #   - also maybe pythonize if CPU is ahead of NPU
    
            current_stream = torch.npu.utils.current_stream()
            if not model_input.is_first_multi_step:
                # Explicitly block on the previous step's forward to make sure we
                # don't clobber any NPU tensors still in use.
                # This is not needed for flashattn backend, but for other attn
                # backends such as flashinfer that performs extra CPU operations on
                # input metadata we may need to synchronize any CPU operations that
                # might clobber enqueued forwards. (prevents CPU from running too
                # far ahead if needed)
                model_input.wait_previous_step()
                model_input = self._advance_step(model_input, model_input.cached_outputs[-1].sampler_output)
    
            output_proc_callback = None
            if frozen_model_input.async_callback is not None:
                output_proc_callback = frozen_model_input.async_callback
                assert output_proc_callback is not None
                async_callback = functools.partial(
                    self._async_process_outputs, model_input=model_input, output_proc_callback=output_proc_callback
                )
    
                frozen_model_input = dataclasses.replace(  # type: ignore
                    model_input.frozen_model_input, async_callback=async_callback
                )
                assert frozen_model_input is not None
    
            # Execute the model
            output = self._base_model_runner.execute_model(frozen_model_input, kv_caches, intermediate_tensors, num_steps=1)
    
            # record the event for the current step so that the next step can sync
            model_input.record_step_event(current_stream)
    
            if get_pp_group().is_last_rank and self.is_driver_worker:
                assert len(output) == 1, "MultiStepModelRunner requires single-step base_models"
    
                # event for the pythonization so that we only pythonize if the
                # tensors are ready. May be able to be combined with the step event
                output_ready_event = torch.npu.streams.Event()
                output_ready_event.record(current_stream)
                if self.parallel_config.pipeline_parallel_size > 1:
                    output[0].sampled_token_ids_cpu = output[0].sampled_token_ids.cpu()
                model_input.cached_outputs.append(
                    ModelOutput(
                        output[0],
                        output_ready_event,
                        output[0].sampled_token_ids,
                        False,
                        output[0].logprobs,
                        self.pythonization_cache,
                    )
                )
    
                # These NPU tensors are not required by multi-step;
                # erase them to ensure they are not pythonized or
                # transferred to CPU
                output[0].sampled_token_ids = None
                output[0].sampled_token_probs = None
                output[0].logprobs = None
    
                # Pythonize the output if CPU is ahead and the previous step is
                # ready.
                if frozen_model_input.async_callback is None:
                    for model_output in model_input.cached_outputs:
                        model_output.maybe_pythonize(model_input, self._copy_stream, self.pinned_sampled_token_ids)
    
            model_input.current_step += 1
    
            if not get_pp_group().is_last_rank:
                # Should be IntermediateTensors
                assert isinstance(output, IntermediateTensors)
                return output
            if not self.is_driver_worker:
                return []
    
            # Pythonize the output and block if needed since it is the last step
            if model_input.is_last_step:
                outputs = self._final_process_outputs(model_input, output_proc_callback)
                self.pythonization_cache.reset()
                return outputs
    
            # should be [SamplerOutput]
            return output
    
        def load_model(self) -> None:
            return self._base_model_runner.load_model()
    
        def save_sharded_state(
            self,
            path: str,
            pattern: Optional[str] = None,
            max_size: Optional[int] = None,
        ) -> None:
            return self._base_model_runner.save_sharded_state(path, pattern, max_size)
    
        def save_tensorized_model(self, tensorizer_config: TensorizerConfig) -> None:
            return self._base_model_runner.save_tensorized_model(tensorizer_config)
    
        def profile_run(self) -> None:
            return self._base_model_runner.profile_run()
    
        def remove_all_loras(self):
            return self._base_model_runner.remove_all_loras()
    
        def capture_model(self, kv_caches: List[List]) -> None:
            return self._base_model_runner.capture_model(kv_caches)
    
        @functools.cached_property
        def _copy_stream(self):
            # used to copy tensors from NPU to CPU asynchronously
            return torch.npu.streams.Stream()
    
        def _async_process_outputs(self, model_input: StatefulModelInput, output_proc_callback: Callable):
            # Proceed with pythonization and output_proc in order.
            # Stop on the first one that fails to pythonize
            output_proc_callback()
    
            cont = True
            for model_output in model_input.cached_outputs:
                if not model_output.pythonized:
                    model_output.maybe_pythonize(model_input, self._copy_stream, self.pinned_sampled_token_ids)
                    if model_output.pythonized:
                        ctx = output_proc_callback.keywords["ctx"]
                        ctx.append_output(
                            outputs=[model_output.sampler_output],
                            seq_group_metadata_list=ctx.seq_group_metadata_list,
                            scheduler_outputs=ctx.scheduler_outputs,
                            is_async=False,
                            is_last_step=False,
                        )
    
                        output_proc_callback()
                    else:
                        cont = False
    
                if not cont:
                    break
    
        def _final_process_outputs(self, model_input: StatefulModelInput, output_proc_callback: Optional[Callable]):
            assert model_input.frozen_model_input is not None
    
            has_async_callback = output_proc_callback is not None
    
            outputs = []
            for output_id in range(len(model_input.cached_outputs)):
                output = model_input.cached_outputs[output_id]
                is_last_step = output_id == len(model_input.cached_outputs) - 1
    
                # For non-async case:
                #   -- We simply add the outputs
                # For async case:
                #   -- Invoke callback, pythonize, add to callback queue and repeat
                #   -- For last output, just add to callback queue
                if has_async_callback:
                    assert output_proc_callback is not None
    
                    # Invoke callback before pythonize (to overlap with NPU)
                    output_proc_callback()
    
                    # Pythonize
                    if not output.pythonized:
                        output.pythonize(model_input, self._copy_stream, self.pinned_sampled_token_ids)
    
                        # For non last step, add to callback queue to chain
                        # callbacks=>pythonize pairs (for NPU overlap)
                        if not is_last_step:
                            ctx = output_proc_callback.keywords[  # type: ignore
                                "ctx"
                            ]  # type: ignore
                            ctx.append_output(
                                outputs=[output.sampler_output],
                                seq_group_metadata_list=ctx.seq_group_metadata_list,
                                scheduler_outputs=ctx.scheduler_outputs,
                                is_async=False,
                                is_last_step=False,
                            )
                        else:
                            outputs.append(output.sampler_output)
                else:
                    output.pythonize(model_input, self._copy_stream, self.pinned_sampled_token_ids)
                    outputs.append(output.sampler_output)
    
            return outputs
    
        def _update_sampling_metadata(self, sampling_metadata, num_seqs, num_queries):
    
            assert sampling_metadata.num_prompts == 0
            assert len(sampling_metadata.seq_groups) == num_queries
            assert sampling_metadata.selected_token_indices.shape == (num_queries,)
            # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
    
            # Verify that all sequences are decodes
            for i in range(num_queries):
                seq_group = sampling_metadata.seq_groups[i]
    
                assert seq_group.is_prompt is False  # No prompt
                assert seq_group.prompt_logprob_indices == []  # No prompt
                assert seq_group.sample_indices == [i]  # Simple
                assert seq_group.seq_len is None  # Decode
                assert seq_group.query_len is None  # Decode
    
        def _advance_step(self, model_input: StatefulModelInput, out: SamplerOutput) -> StatefulModelInput:
            if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS:
                raise ValueError(
                    f"Multi-step not supported for attention backend: "
                    f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
                    f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}."
                )
    
            sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
            num_seqs = model_input.num_seqs
            num_queries = model_input.num_queries
            frozen_model_input = model_input.frozen_model_input
            assert frozen_model_input is not None
            attn_metadata = frozen_model_input.attn_metadata
            assert attn_metadata is not None
    
            attn_metadata.advance_step(
                frozen_model_input,
                sampled_token_ids,
                self.block_size,
                num_seqs,
                num_queries,
            )
    
            return model_input
    
    
    DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]], Optional[List[SampleLogprobs]]]
    
    
    def deferred_pythonize_logprobs(
        output: SamplerOutput,
        sampling_metadata: SamplingMetadata,
        logprobs_tensor: Optional[torch.Tensor],
    ) -> DeferredLogprobsReturnType:
        """Perform deferred logprob Pythonization.
    
        1. Pythonize NPU-side sampler result tensors into CPU-side sampler result.
        2. Pythonize NPU-side logprobs tensor into CPU-side logprobs lists,
           utilizing  the Pythonized sampler result computed in step 1.
    
        These deferred computations are not required for single-step scheduling
        or the `profile_run()` phase of multi-step scheduling.
    
        Args:
            output: sampler output (under deferred Pythonization)
            sampling_metadata
    
        Returns:
            prompt_logprobs (CPU), sample_logprobs (CPU)
        """
    
        # - Deferred pythonization of sample result
        sampler_result = get_pythonized_sample_results(output.deferred_sample_results_args)
    
        # - Erase the NPU-side deferred sample_result
        #   computation args to ensure it is never
        #   pythonized or transferred to CPU
        output.deferred_sample_results_args = None
    
        # - Deferred pythonization of logprobs
        (
            prompt_logprobs,
            sample_logprobs,
        ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
        assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
        assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
    
        return prompt_logprobs, sample_logprobs
    
    
    def _pythonize_sampler_output(
        model_input: StatefulModelInput,
        output: SamplerOutput,
        pinned_sampled_token_buffer: torch.Tensor,
        sampled_token_ids: torch.Tensor,
        logprobs_tensor: Optional[torch.Tensor],
        cache: Optional[PythonizationCache],
    ) -> None:
        assert model_input.frozen_model_input is not None
    
        frozen_model_input = model_input.frozen_model_input
        assert frozen_model_input.sampling_metadata is not None
        sampling_metadata = frozen_model_input.sampling_metadata
        # samples generation should have been skipped
        # assert not output.outputs
    
        pinned_buffer = pinned_sampled_token_buffer[: model_input.num_queries]
    
        # We guarantee output tensors are ready, so it is safe to
        # pythonize the sampler output & obtain CPU-side logprobs.
        #
        # However we should check whether logprobs pythonization may
        # be skipped entirely, i.e. because no logprobs were requested
        # or pythonization was not deferred. To that end,
        #
        # * `prompt_logprobs_are_requested_for_prefill` signals that
        #   there are *any* prefill-phase requests which specify that
        #   prompt logprobs should be returned.
        #
        # * `any_logprobs_are_requested` signals that there are any
        #   requests which (1) specify that sample logprobs should be
        #   returned, or (2) are in the prefill phase AND specify that
        #   prompt logprobs should be returned.
        #
        # Later on, these flags cause adjustments to the pythonization
        # process to accommodate logprobs.
    
        seq_groups = sampling_metadata.seq_groups
        prompt_logprobs_are_requested_for_prefill = any(
            [sg.sampling_params.prompt_logprobs is not None and sg.is_prompt for sg in seq_groups]
        )
        any_logprobs_are_requested = prompt_logprobs_are_requested_for_prefill or any(
            [sg.sampling_params.logprobs is not None for sg in seq_groups]
        )
    
        if prompt_logprobs_are_requested_for_prefill:
            # CPU NPU sync, after gathering *only* sampled tokens (since
            # requesting prompt logprobs leads `sampled_token_ids` to
            # include prompt token ids in addition to sampled token ids.)
            sample_idx_tensor = torch.tensor([sdx for sg in seq_groups for sdx in sg.sample_indices])
            pinned_buffer = pinned_buffer.copy_(sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
        else:
            # CPU NPU sync
            pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False)
    
        # this will not block as the tensors are already on CPU
        samples_list = pinned_buffer.tolist()
    
        skip_sampler_cpu_output = frozen_model_input.sampling_metadata.skip_sampler_cpu_output
    
        # *Don't* skip logprobs pythonization *if*:
        # * Any requests require logprobs to be returned in this
        # iteration AND
        # * These requests are being scheduled in a fashion which
        # defers pythonization (i.e. multi-step scheduling.)
        do_pythonize_logprobs = skip_sampler_cpu_output and any_logprobs_are_requested
        (
            prompt_logprobs,
            sample_logprobs,
        ) = (
            deferred_pythonize_logprobs(output, sampling_metadata, logprobs_tensor)
            if do_pythonize_logprobs
            else (None, None)
        )
    
        for sgdx, (seq_group, sample_result) in enumerate(zip(seq_groups, samples_list)):
            if seq_group.sampling_params.logits_processors:
                assert (
                    len(seq_group.sampling_params.logits_processors) == 0
                ), "Logits Processors are not supported in multi-step decoding"
    
            if do_pythonize_logprobs:
                assert prompt_logprobs is not None
                assert sample_logprobs is not None
    
                (
                    group_prompt_logprobs,
                    group_sample_logprobs,
                ) = (  # Utilize deferred pythonization results
                    prompt_logprobs[sgdx],
                    sample_logprobs[sgdx],
                )
            elif any_logprobs_are_requested:
                (
                    group_prompt_logprobs,
                    group_sample_logprobs,
                ) = (
                    # profile_run: use already-computed logprobs
                    output.outputs[sgdx].prompt_logprobs,
                    [sample.logprobs for sample in output.outputs[sgdx].samples],
                )
    
            seq_ids = seq_group.seq_ids
            next_token_ids = sample_result
            parent_ids = [0]
    
            if cache is not None:
                completion_seq_group_output: CompletionSequenceGroupOutput = (
                    cache.cached_completion_seq_group_output.get_object()
                )
                completion_seq_group_output.samples.clear()
                seq_outputs: List[SequenceOutput] = completion_seq_group_output.samples
            else:
                seq_outputs = []
    
            for tdx, (parent_id, next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
                if cache is not None:
                    seq_output: SequenceOutput = cache.cached_seq_output.get_object()
                    seq_output.parent_seq_id = seq_ids[parent_id]
                    seq_output.output_token = next_token_id
    
                    if any_logprobs_are_requested:
                        seq_output.logprobs = group_sample_logprobs[tdx]
                    else:
                        logprobs = next(iter(seq_output.logprobs.values()))
                        seq_output.logprobs.clear()
    
                        logprobs.logprob = float("inf")
                        logprobs.rank = None
                        logprobs.decoded_token = None
    
                        seq_output.logprobs[next_token_id] = logprobs
    
                    seq_outputs.append(seq_output)
    
                else:
                    seq_outputs.append(
                        SequenceOutput(
                            seq_ids[parent_id],
                            next_token_id,
                            (
                                group_sample_logprobs[tdx]
                                if any_logprobs_are_requested
                                else {next_token_id: Logprob(logprob=float("inf"), rank=None, decoded_token=None)}
                            ),
                        )
                    )
            if cache is not None:
                completion_seq_group_output.prompt_logprobs = group_prompt_logprobs if any_logprobs_are_requested else None
                output.outputs.append(completion_seq_group_output)
            else:
                output.outputs.append(
                    CompletionSequenceGroupOutput(
                        seq_outputs, (group_prompt_logprobs if any_logprobs_are_requested else None)
                    )
                )
    
        assert len(output.outputs) > 0
    

examples文件夹下提供在线和离线模式下的使用实例代码。

  • examples/offline_inference.py:
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    # Part of codes in this file was copied from project [vLLM Team][vllm]
    
    import argparse
    from vllm import LLM, SamplingParams
    from vllm.logger import init_logger
    
    logger = init_logger(__name__)
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-125m")
    
    # input prompts for test
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    
    sampling_params = SamplingParams(max_tokens=512, temperature=0)
    args = parser.parse_args()
    model_path = args.model_path
    llm = LLM(
        model=model_path,
        tensor_parallel_size=1,  # number of NPUs to be used
        max_num_seqs=256,  # max batch number
        enforce_eager=True,  # disable CUDA graph mode
        trust_remote_code=True,  # If the model is a custom model not yet available in the HuggingFace transformers library
        worker_use_ray=True,
    )
    
    outputs = llm.generate(prompts, sampling_params)
    for i, output in enumerate(outputs):
        prompt = output.prompt
        generated_text = output.outputs[0].text
        logger.info(
            f"req_num: {i}\nPrompt: {prompt!r}\nGenerated text: {generated_text!r}"
        )
  • examples/offline_inference.sh
    #!/bin/bash
    python3 offline_inference.py --model-path facebook/opt-125m
  • examples/start_server.sh:在线模式起服务实例,通过环境变量指定使用的昇腾卡号。
    #!/bin/bash
    
    export ASCEND_RT_VISIBLE_DEVICES=0
    python -m vllm.entrypoints.openai.api_server  \
           --model=/home/data/models/LLaMA3-8B \
           --trust-remote-code \
           --enforce-eager \
           --max-model-len 4096 \
           -tp 1 \
           --port 8006 \
           --block-size 128 
  • README.md:代码使用说明
    # Vllm-MindIE
    ## 介绍
    昇腾推理引擎对接vLLM开源框架v0.6.2稳定版本补丁
    ## 适配方案
    在昇腾环境中适配vLLM框架的方案如下:
    - 上层维持vLLM框架原生逻辑,包括请求调度、batch构建以及通过Ray分布式框架启动多卡服务等功能;
    - 下层的模型推理与后处理则通过MindIE-LLM提供的`GeneratorTorch`统一接口,接入MindIE模型仓进行统一管理,从而利用整图加速库实现模型推理加速。
    ## 环境准备
    ### 依赖版本
    - Vllm-MindIE适配仓代码配套可运行的硬件型号
      - Atlas 800I A2(32GB/64GB显存)
    - Vllm-MindIE适配仓代码运行相关配套软件
      - 系统OS
      - 驱动(HDK)
      - CANN
      - Python
      - PTA
      - 开源软件依赖
    - 版本配套关系
      - 当前Vllm-MindIE适配仓需基于CANN包8.0版本及以上,Python 3.10/3.11,torch 2.1.0进行环境部署与运行
    ### 环境配置
    环境配置的详细过程参照[MindIE-LLM模型仓](https://gitee.com/ascend/MindIE-LLM)
    ## 安装教程
    确保昇腾推理基础环境安装完成后,执行`install.sh`文件即可完成vllm及昇腾补丁的安装:
    ```sh
    bash install.sh
    ```
    ## 使用说明
    这里提供了vllm离线模式与在线服务的启动demo作为参考。
    **注:请根据实际情况修改运行脚本offline_inference.sh和start_server.sh中的模型路径参数以使用特定的模型进行推理**
    - 离线模式:使用前先设置offline_inference.sh脚本中的model_path参数为推理使用的模型路径
        ```sh
        cd examples
        bash offline_inference.sh
        ```
    - 在线服务:使用前先设置start_server.sh脚本中的model参数为推理使用的模型路径;此外,常用的其他参数配置如下:
        - -tp n:设置模型推理使用的卡数
        - --port port_num:指定推理服务使用的端口号
        - --max-num-seqs bs:设置推理服务支持的最大batch size
        - 配置使用特定的若干张卡进行推理:添加环境变量,如想使用前四卡进行推理,则在脚本中python命令之前添加如下命令:
        ```sh
        export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
        ```
        完成上述参数配置后,运行:
        ```sh
        cd examples
        bash start_server.sh
        ```
  • install.sh: 适配一键式安装脚本,将所有代码文件还原后,可运行该脚本一键安装昇腾适配版的vllm框架,其中会自动拉取源码安装vllm原生框架并覆盖相应修改。
    set -e
    
    if [ -d "./vllm" ]; then
        echo "./vllm directory has already exist!"
        exit 1
    fi
    
    git clone -b v0.6.2 https://github.com/vllm-project/vllm.git vllm
    
    yes | cp -r cover/* vllm/
    
    cd vllm
    pip install -r requirements-npu.txt
    python setup.py install

    安装完成后,通过pip show vllm查看vLLM版本,可能会出现显示的版本并非0.6.2的情况,为vllm原生bug,属正常现象,不影响运行。