下载
中文
注册

TEI框架接入MindIE Torch组件全量适配代码

适配昇腾环境及MindIE Torch组件的TEI代码目录结构与其Github代码仓完全一致(需要修改的代码文件加粗显示在下方代码结构树下,其它未显示文件或文件夹与TEI的Github代码仓保持一致),其目录结构如下所示。

text-embeddings-inference
|____core
|  |____src
|  |  |____ infer.rs
|____backends
|  |____grpc-client
|  |  |____src
|  |  |  |____ client.rs
|  |____proto
|  |  |____ embed.proto
|  |____python
|  |  |____src
|  |  |  |____ lib.rs
|  |  |____server
|  |  |  |____pyproject.toml
|  |  |  |____requirements.txt
|  |  |  |____text_embeddings_server
|  |  |  |  |____ server.py
|  |  |  |  |_____models
|  |  |  |  |  |____ __init__.py
|  |  |  |  |  |____ default_model.py
|  |  |  |  |  |____ model.py
|  |  |  |  |  |____ rerank_model.py
|  |  |  |  |  |____ types.py
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/core/src/infer.rs文件中的Infer类型内将其batching_task数量从2改为1,即注释掉源码中对应位置的加粗代码,仅用一个batchinbg_task性能更佳。

    不做修改也不会影响TEI的运行以及服务化功能,修改后性能会有明显提升。

    use crate::queue::{Entry, Metadata, NextBatch, Queue};
    use crate::tokenization::{EncodingInput, RawEncoding, Tokenization};
    use crate::TextEmbeddingsError;
    use std::sync::Arc;
    use std::time::{Duration, Instant};
    use text_embeddings_backend::{Backend, BackendError, Embedding, ModelType};
    use tokio::sync::{mpsc, oneshot, watch, Notify, OwnedSemaphorePermit, Semaphore};
    use tracing::instrument;
    
    /// Inference struct
    #[derive(Debug, Clone)]
    pub struct Infer {
        tokenization: Tokenization,
        queue: Queue,
        /// Shared notify
        notify_batching_task: Arc<Notify>,
        /// Inference limit
        limit_concurrent_requests: Arc<Semaphore>,
        backend: Backend,
    }
    
    impl Infer {
        #[allow(clippy::too_many_arguments)]
        pub fn new(
            tokenization: Tokenization,
            queue: Queue,
            max_concurrent_requests: usize,
            backend: Backend,
        ) -> Self {
            let notify_batching_task = Arc::new(Notify::new());
    
            let (embed_sender, embed_receiver) = mpsc::unbounded_channel();
    
            // Create only one batching task to prefetch batches
            tokio::spawn(batching_task(
                queue.clone(),
                notify_batching_task.clone(),
                embed_sender.clone(),
            ));
            /* 此处源码使用两个batching_task对请求进行队列管理,将此处注释后仅用一个batching_task性能更佳
            tokio::spawn(batching_task(
                queue.clone(),
                notify_batching_task.clone(),
                embed_sender,
            ));
            */
    
            // Create embed task to communicate with backend
            tokio::spawn(backend_task(backend.clone(), embed_receiver));
    
            // Inference limit with a semaphore
            let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
    
            Self {
                tokenization,
                queue,
                notify_batching_task,
                limit_concurrent_requests: semaphore,
                backend,
            }
        }
    
        #[instrument(skip(self))]
        pub async fn tokenize<I: Into<EncodingInput> + std::fmt::Debug>(
            &self,
            inputs: I,
            add_special_tokens: bool,
        ) -> Result<RawEncoding, TextEmbeddingsError> {
            self.tokenization
                .tokenize(inputs.into(), add_special_tokens)
                .await
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "tokenization");
                    tracing::error!("{err}");
                    err
                })
        }
    
        #[instrument(skip(self))]
        pub async fn decode(
            &self,
            ids: Vec<u32>,
            skip_special_tokens: bool,
        ) -> Result<String, TextEmbeddingsError> {
            self.tokenization
                .decode(ids, skip_special_tokens)
                .await
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "tokenization");
                    tracing::error!("{err}");
                    err
                })
        }
    
        #[instrument(skip(self))]
        pub fn try_acquire_permit(&self) -> Result<OwnedSemaphorePermit, TextEmbeddingsError> {
            // Limit concurrent requests by acquiring a permit from the semaphore
            self.clone()
                .limit_concurrent_requests
                .try_acquire_owned()
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "overloaded");
                    tracing::error!("{err}");
                    TextEmbeddingsError::from(err)
                })
        }
        #[instrument(skip(self))]
        pub async fn acquire_permit(&self) -> OwnedSemaphorePermit {
            // Limit concurrent requests by acquiring a permit from the semaphore
            self.clone()
                .limit_concurrent_requests
                .acquire_owned()
                .await
                .expect("Semaphore has been closed. This is a bug.")
        }
    
        #[instrument(skip(self, permit))]
        pub async fn embed_all<I: Into<EncodingInput> + std::fmt::Debug>(
            &self,
            inputs: I,
            truncate: bool,
            permit: OwnedSemaphorePermit,
        ) -> Result<AllEmbeddingsInferResponse, TextEmbeddingsError> {
            let start_time = Instant::now();
    
            if self.is_splade() {
                metrics::increment_counter!("te_request_failure", "err" => "model_type");
                let message = "`embed_all` is not available for SPLADE models".to_string();
                tracing::error!("{message}");
                return Err(TextEmbeddingsError::Backend(BackendError::Inference(
                    message,
                )));
            }
    
            let results = self
                .embed(inputs, truncate, false, &start_time, permit)
                .await?;
    
            let InferResult::AllEmbedding(response) = results else {
                panic!("unexpected enum variant")
            };
    
            // Timings
            let total_time = start_time.elapsed();
    
            // Metrics
            metrics::increment_counter!("te_embed_success");
            metrics::histogram!("te_embed_duration", total_time.as_secs_f64());
            metrics::histogram!(
                "te_embed_tokenization_duration",
                response.metadata.tokenization.as_secs_f64()
            );
            metrics::histogram!(
                "te_embed_queue_duration",
                response.metadata.queue.as_secs_f64()
            );
            metrics::histogram!(
                "te_embed_inference_duration",
                response.metadata.inference.as_secs_f64()
            );
    
            Ok(response)
        }
    
        #[instrument(skip(self, permit))]
        pub async fn embed_sparse<I: Into<EncodingInput> + std::fmt::Debug>(
            &self,
            inputs: I,
            truncate: bool,
            permit: OwnedSemaphorePermit,
        ) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
            let start_time = Instant::now();
    
            if !self.is_splade() {
                metrics::increment_counter!("te_request_failure", "err" => "model_type");
                let message = "Model is not an embedding model with SPLADE pooling".to_string();
                tracing::error!("{message}");
                return Err(TextEmbeddingsError::Backend(BackendError::Inference(
                    message,
                )));
            }
    
            let results = self
                .embed(inputs, truncate, true, &start_time, permit)
                .await?;
    
            let InferResult::PooledEmbedding(response) = results else {
                panic!("unexpected enum variant")
            };
    
            // Timings
            let total_time = start_time.elapsed();
    
            // Metrics
            metrics::increment_counter!("te_embed_success");
            metrics::histogram!("te_embed_duration", total_time.as_secs_f64());
            metrics::histogram!(
                "te_embed_tokenization_duration",
                response.metadata.tokenization.as_secs_f64()
            );
            metrics::histogram!(
                "te_embed_queue_duration",
                response.metadata.queue.as_secs_f64()
            );
            metrics::histogram!(
                "te_embed_inference_duration",
                response.metadata.inference.as_secs_f64()
            );
    
            Ok(response)
        }
    
        #[instrument(skip(self, permit))]
        pub async fn embed_pooled<I: Into<EncodingInput> + std::fmt::Debug>(
            &self,
            inputs: I,
            truncate: bool,
            normalize: bool,
            permit: OwnedSemaphorePermit,
        ) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
            let start_time = Instant::now();
    
            if self.is_splade() && normalize {
                metrics::increment_counter!("te_request_failure", "err" => "model_type");
                let message = "`normalize` is not available for SPLADE models".to_string();
                tracing::error!("{message}");
                return Err(TextEmbeddingsError::Backend(BackendError::Inference(
                    message,
                )));
            }
    
            let results = self
                .embed(inputs, truncate, true, &start_time, permit)
                .await?;
    
            let InferResult::PooledEmbedding(mut response) = results else {
                panic!("unexpected enum variant")
            };
    
            if normalize {
                // Normalize embedding
                let scale = (1.0
                    / response
                        .results
                        .iter()
                        .map(|v| {
                            let v = *v as f64;
                            v * v
                        })
                        .sum::<f64>()
                        .sqrt()) as f32;
                for v in response.results.iter_mut() {
                    *v *= scale;
                }
            }
    
            // Timings
            let total_time = start_time.elapsed();
    
            // Metrics
            metrics::increment_counter!("te_embed_success");
            metrics::histogram!("te_embed_duration", total_time.as_secs_f64());
            metrics::histogram!(
                "te_embed_tokenization_duration",
                response.metadata.tokenization.as_secs_f64()
            );
            metrics::histogram!(
                "te_embed_queue_duration",
                response.metadata.queue.as_secs_f64()
            );
            metrics::histogram!(
                "te_embed_inference_duration",
                response.metadata.inference.as_secs_f64()
            );
    
            Ok(response)
        }
    
        async fn embed<I: Into<EncodingInput> + std::fmt::Debug>(
            &self,
            inputs: I,
            truncate: bool,
            pooling: bool,
            start_time: &Instant,
            _permit: OwnedSemaphorePermit,
        ) -> Result<InferResult, TextEmbeddingsError> {
            if self.is_classifier() {
                metrics::increment_counter!("te_request_failure", "err" => "model_type");
                let message = "Model is not an embedding model".to_string();
                tracing::error!("{message}");
                return Err(TextEmbeddingsError::Backend(BackendError::Inference(
                    message,
                )));
            }
    
            metrics::increment_counter!("te_embed_count");
    
            // Tokenization
            let encoding = self
                .tokenization
                .encode(inputs.into(), truncate)
                .await
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "tokenization");
                    tracing::error!("{err}");
                    err
                })?;
    
            // MPSC channel to communicate with the background batching task
            let (response_tx, response_rx) = oneshot::channel();
    
            // Append the request to the queue
            self.queue.append(Entry {
                metadata: Metadata {
                    response_tx,
                    tokenization: start_time.elapsed(),
                    queue_time: Instant::now(),
                    prompt_tokens: encoding.input_ids.len(),
                    pooling,
                },
                encoding,
            });
    
            self.notify_batching_task.notify_one();
    
            let response = response_rx
                .await
                .expect(
                    "Infer batching task dropped the sender without sending a response. This is a bug.",
                )
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "inference");
                    tracing::error!("{err}");
                    err
                })?;
    
            Ok(response)
        }
    
        #[instrument(skip(self, _permit))]
        pub async fn predict<I: Into<EncodingInput> + std::fmt::Debug>(
            &self,
            inputs: I,
            truncate: bool,
            raw_scores: bool,
            _permit: OwnedSemaphorePermit,
        ) -> Result<ClassificationInferResponse, TextEmbeddingsError> {
            if !self.is_classifier() {
                metrics::increment_counter!("te_request_failure", "err" => "model_type");
                let message = "Model is not a classifier model".to_string();
                return Err(TextEmbeddingsError::Backend(BackendError::Inference(
                    message,
                )));
            }
    
            let start_time = Instant::now();
            metrics::increment_counter!("te_predict_count");
    
            // Tokenization
            let encoding = self
                .tokenization
                .encode(inputs.into(), truncate)
                .await
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "tokenization");
                    tracing::error!("{err}");
                    err
                })?;
    
            // MPSC channel to communicate with the background batching task
            let (response_tx, response_rx) = oneshot::channel();
    
            // Append the request to the queue
            self.queue.append(Entry {
                metadata: Metadata {
                    response_tx,
                    tokenization: start_time.elapsed(),
                    queue_time: Instant::now(),
                    prompt_tokens: encoding.input_ids.len(),
                    pooling: true,
                },
                encoding,
            });
    
            self.notify_batching_task.notify_one();
    
            let response = response_rx
                .await
                .expect(
                    "Infer batching task dropped the sender without sending a response. This is a bug.",
                )
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "inference");
                    tracing::error!("{err}");
                    err
                })?;
    
            let InferResult::Classification(mut response) = response else {
                panic!("unexpected enum variant")
            };
    
            if !raw_scores {
                // Softmax
                if response.results.len() > 1 {
                    let max = *response
                        .results
                        .iter()
                        .max_by(|x, y| x.abs().partial_cmp(&y.abs()).unwrap())
                        .unwrap();
    
                    let mut den = 0.0;
                    for v in response.results.iter_mut() {
                        *v = (*v - max).exp();
                        den += *v;
                    }
                    for v in response.results.iter_mut() {
                        *v /= den;
                    }
                }
                // Sigmoid
                else {
                    response.results[0] = 1.0 / (1.0 + (-response.results[0]).exp());
                }
            }
    
            // Timings
            let total_time = start_time.elapsed();
    
            // Metrics
            metrics::increment_counter!("te_predict_success");
            metrics::histogram!("te_predict_duration", total_time.as_secs_f64());
            metrics::histogram!(
                "te_predict_tokenization_duration",
                response.metadata.tokenization.as_secs_f64()
            );
            metrics::histogram!(
                "te_predict_queue_duration",
                response.metadata.queue.as_secs_f64()
            );
            metrics::histogram!(
                "te_predict_inference_duration",
                response.metadata.inference.as_secs_f64()
            );
    
            Ok(response)
        }
    
        #[instrument(skip(self))]
        pub fn is_classifier(&self) -> bool {
            matches!(self.backend.model_type, ModelType::Classifier)
        }
    
        #[instrument(skip(self))]
        pub fn is_splade(&self) -> bool {
            matches!(
                self.backend.model_type,
                ModelType::Embedding(text_embeddings_backend::Pool::Splade)
            )
        }
    
        #[instrument(skip(self))]
        pub async fn health(&self) -> bool {
            self.backend.health().await.is_ok()
        }
    
        #[instrument(skip(self))]
        pub fn health_watcher(&self) -> watch::Receiver<bool> {
            self.backend.health_watcher()
        }
    }
    
    #[instrument(skip_all)]
    async fn batching_task(
        queue: Queue,
        notify: Arc<Notify>,
        embed_sender: mpsc::UnboundedSender<(NextBatch, oneshot::Sender<()>)>,
    ) {
        loop {
            notify.notified().await;
    
            while let Some(next_batch) = queue.next_batch().await {
                let (callback_sender, callback_receiver) = oneshot::channel();
                embed_sender
                    .send((next_batch, callback_sender))
                    .expect("embed receiver was dropped. This is a bug.");
                let _ = callback_receiver.await;
            }
        }
    }
    
    #[instrument(skip_all)]
    async fn backend_task(
        backend: Backend,
        mut embed_receiver: mpsc::UnboundedReceiver<(NextBatch, oneshot::Sender<()>)>,
    ) {
        while let Some((batch, _callback)) = embed_receiver.recv().await {
            match &backend.model_type {
                ModelType::Classifier => {
                    let results = backend.predict(batch.1).await;
    
                    // Handle sending responses in another thread to avoid starving the backend
                    std::thread::spawn(move || match results {
                        Ok((mut predictions, inference_duration)) => {
                            batch.0.into_iter().enumerate().for_each(|(i, m)| {
                                let infer_metadata = InferMetadata {
                                    prompt_tokens: m.prompt_tokens,
                                    tokenization: m.tokenization,
                                    queue: m.queue_time.elapsed() - inference_duration,
                                    inference: inference_duration,
                                };
    
                                let _ = m.response_tx.send(Ok(InferResult::Classification(
                                    ClassificationInferResponse {
                                        results: predictions.remove(&i).expect(
                                            "prediction not found in results. This is a backend bug.",
                                        ),
                                        metadata: infer_metadata,
                                    },
                                )));
                            });
                        }
                        Err(err) => {
                            batch.0.into_iter().for_each(|m| {
                                let _ = m.response_tx.send(Err(err.clone()));
                            });
                        }
                    });
                }
                ModelType::Embedding(_) => {
                    let results = backend.embed(batch.1).await;
    
                    // Handle sending responses in another thread to avoid starving the backend
                    std::thread::spawn(move || match results {
                        Ok((mut embeddings, inference_duration)) => {
                            batch.0.into_iter().enumerate().for_each(|(i, m)| {
                                let metadata = InferMetadata {
                                    prompt_tokens: m.prompt_tokens,
                                    tokenization: m.tokenization,
                                    queue: m.queue_time.elapsed() - inference_duration,
                                    inference: inference_duration,
                                };
    
                                let results = match embeddings
                                    .remove(&i)
                                    .expect("embedding not found in results. This is a backend bug.")
                                {
                                    Embedding::Pooled(e) => {
                                        InferResult::PooledEmbedding(PooledEmbeddingsInferResponse {
                                            results: e,
                                            metadata,
                                        })
                                    }
                                    Embedding::All(e) => {
                                        InferResult::AllEmbedding(AllEmbeddingsInferResponse {
                                            results: e,
                                            metadata,
                                        })
                                    }
                                };
    
                                let _ = m.response_tx.send(Ok(results));
                            })
                        }
                        Err(err) => {
                            batch.0.into_iter().for_each(|m| {
                                let _ = m.response_tx.send(Err(err.clone()));
                            });
                        }
                    });
                }
            };
        }
    }
    
    #[derive(Debug)]
    pub struct InferMetadata {
        pub prompt_tokens: usize,
        pub tokenization: Duration,
        pub queue: Duration,
        pub inference: Duration,
    }
    
    #[derive(Debug)]
    pub(crate) enum InferResult {
        Classification(ClassificationInferResponse),
        PooledEmbedding(PooledEmbeddingsInferResponse),
        AllEmbedding(AllEmbeddingsInferResponse),
    }
    
    #[derive(Debug)]
    pub struct ClassificationInferResponse {
        pub results: Vec<f32>,
        pub metadata: InferMetadata,
    }
    
    #[derive(Debug)]
    pub struct PooledEmbeddingsInferResponse {
        pub results: Vec<f32>,
        pub metadata: InferMetadata,
    }
    
    #[derive(Debug)]
    pub struct AllEmbeddingsInferResponse {
        pub results: Vec<Vec<f32>>,
        pub metadata: InferMetadata,
    }
    
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/grpc-client/src/client.rs文件中为后端的grpc_client添加embed_all和predict异步接口(以下代码加粗部分)。
    /// Single shard Client
    use crate::pb::embedding::v1::embedding_service_client::EmbeddingServiceClient;
    use crate::pb::embedding::v1::*;
    use crate::Result;
    use grpc_metadata::InjectTelemetryContext;
    use tonic::transport::{Channel, Uri};
    use tracing::instrument;
    
    /// Text Generation Inference gRPC client
    #[derive(Debug, Clone)]
    pub struct Client {
        stub: EmbeddingServiceClient<Channel>,
    }
    
    impl Client {
        /// Returns a client connected to the given url
        pub async fn connect(uri: Uri) -> Result<Self> {
            let channel = Channel::builder(uri).connect().await?;
    
            Ok(Self {
                stub: EmbeddingServiceClient::new(channel),
            })
        }
    
        /// Returns a client connected to the given unix socket
        pub async fn connect_uds(path: String) -> Result<Self> {
            let channel = Channel::from_shared("http://[::]:50051".to_string())
                .unwrap()
                .connect_with_connector(tower::service_fn(move |_: Uri| {
                    tokio::net::UnixStream::connect(path.clone())
                }))
                .await?;
    
            Ok(Self {
                stub: EmbeddingServiceClient::new(channel),
            })
        }
    
        /// Get backend health
        #[instrument(skip(self))]
        pub async fn health(&mut self) -> Result<HealthResponse> {
            let request = tonic::Request::new(HealthRequest {}).inject_context();
            let response = self.stub.health(request).await?.into_inner();
            Ok(response)
        }
    
        #[instrument(skip_all)]
        pub async fn embed(
            &mut self,
            input_ids: Vec<u32>,
            token_type_ids: Vec<u32>,
            position_ids: Vec<u32>,
            cu_seq_lengths: Vec<u32>,
            max_length: u32,
        ) -> Result<Vec<Embedding>> {
            let request = tonic::Request::new(EmbedRequest {
                input_ids,
                token_type_ids,
                position_ids,
                max_length,
                cu_seq_lengths,
            })
            .inject_context();
            let response = self.stub.embed(request).await?.into_inner();
            Ok(response.embeddings)
        }
    
        #[instrument(skip_all)]
        pub async fn embed_all(
            &mut self,
            input_ids: Vec<u32>,
            token_type_ids: Vec<u32>,
            position_ids: Vec<u32>,
            cu_seq_lengths: Vec<u32>,
            max_length: u32,
        ) -> Result<Vec<TokenEmbedding>> {
            let request = tonic::Request::new(EmbedRequest {
                input_ids,
                token_type_ids,
                position_ids,
                max_length,
                cu_seq_lengths,
            })
            .inject_context();
            let response = self.stub.embed_all(request).await?.into_inner();
            Ok(response.allembeddings)
        }
    
        #[instrument(skip_all)]
        pub async fn predict(
            &mut self,
            input_ids: Vec<u32>,
            token_type_ids: Vec<u32>,
            position_ids: Vec<u32>,
            cu_seq_lengths: Vec<u32>,
            max_length: u32,
        ) -> Result<Vec<Prediction>> {
            let request = tonic::Request::new(PredictRequest {
                input_ids,
                token_type_ids,
                position_ids,
                max_length,
                cu_seq_lengths,
            })
            .inject_context();
            let response = self.stub.predict(request).await?.into_inner();
            Ok(response.predictions)
        }
    }
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/proto/embed.proto文件中为EmbeddingService添加Embed_all和Predict服务(以下代码加粗部分),并设定相应的数据类、请求与响应格式。
    syntax = "proto3";
    
    package embedding.v1;
    
    service EmbeddingService {
        /// Decode token for a list of prefilled batches
        rpc Embed (EmbedRequest) returns (EmbedResponse);
        rpc Embed_all (EmbedRequest) returns (RawEmbedResponse);
        rpc Predict (PredictRequest) returns (PredictResponse);
        /// Health check
        rpc Health (HealthRequest) returns (HealthResponse);
    }
    
    message HealthRequest {}
    message HealthResponse {}
    
    message PredictRequest {
        repeated uint32 input_ids = 1;
        repeated uint32 token_type_ids = 2;
        repeated uint32 position_ids = 3;
        repeated uint32 cu_seq_lengths = 4;
        /// Length of the longest request
        uint32 max_length = 5;
    }
    
    message Prediction {
        repeated float values = 1;
    }
    
    message PredictResponse {
        repeated Prediction predictions = 1;
    }
    
    message EmbedRequest {
        repeated uint32 input_ids = 1;
        repeated uint32 token_type_ids = 2;
        repeated uint32 position_ids = 3;
        repeated uint32 cu_seq_lengths = 4;
        /// Length of the longest request
        uint32 max_length = 5;
    }
    
    
    message Embedding {
        repeated float values = 1;
    }
    
    message EmbedResponse {
        repeated Embedding embeddings = 1;
    }
    
    message TokenEmbedding {
        repeated Embedding embeddings = 1;
    }
    
    message RawEmbedResponse {
        repeated TokenEmbedding allembeddings = 1;
    }
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/src/lib.rs文件中将模型类型校验处的分类模型不支持移除,改为正常拉起模型(以下代码加粗部分);为后端模型增加predict接口,并修改embed接口,添加判断分支,根据batch的成员内容内容决定返回所有token的hidden_states或是池化后的全局hidden_states。
    mod logging;
    mod management;
    
    use backend_grpc_client::Client;
    use nohash_hasher::BuildNoHashHasher;
    use std::collections::HashMap;
    use text_embeddings_backend_core::{
        Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
    };
    use tokio::runtime::Runtime;
    
    pub struct PythonBackend {
        _backend_process: management::BackendProcess,
        tokio_runtime: Runtime,
        backend_client: Client,
    }
    
    impl PythonBackend {
        pub fn new(
            model_path: String,
            dtype: String,
            model_type: ModelType,
            uds_path: String,
            otlp_endpoint: Option<String>,
        ) -> Result<Self, BackendError> {
            match model_type {
                ModelType::Classifier => {
                    let pool = Pool::Cls;
                    pool
                }
                ModelType::Embedding(pool) => {
                    if pool != Pool::Cls {
                        return Err(BackendError::Start(format!("{pool:?} is not supported")));
                    }
                    pool
                }
            };
    
            let backend_process =
                management::BackendProcess::new(model_path, dtype, &uds_path, otlp_endpoint)?;
            let tokio_runtime = tokio::runtime::Builder::new_current_thread()
                .enable_all()
                .build()
                .map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?;
    
            let backend_client = tokio_runtime
                .block_on(Client::connect_uds(uds_path))
                .map_err(|err| {
                    BackendError::Start(format!("Could not connect to backend process: {err}"))
                })?;
    
            Ok(Self {
                _backend_process: backend_process,
                tokio_runtime,
                backend_client,
            })
        }
    }
    
    impl Backend for PythonBackend {
        fn health(&self) -> Result<(), BackendError> {
            if self
                .tokio_runtime
                .block_on(self.backend_client.clone().health())
                .is_err()
            {
                return Err(BackendError::Unhealthy);
            }
            Ok(())
        }
    
        fn is_padded(&self) -> bool {
            false
        }
    
        fn embed(&self, batch: Batch) -> Result<Embeddings, BackendError> {
            /*
            if !batch.raw_indices.is_empty() {
                return Err(BackendError::Inference(
                    "raw embeddings are not supported for the Python backend.".to_string(),
                ));
            }
            */
    
            let batch_size = batch.len();
    
            let mut embeddings =
                HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
    
            if !batch.pooled_indices.is_empty() {
                let results = self
                    .tokio_runtime
                    .block_on(self.backend_client.clone().embed(
                        batch.input_ids,
                        batch.token_type_ids,
                        batch.position_ids,
                        batch.cumulative_seq_lengths,
                        batch.max_length,
                    ))
                    .map_err(|err| BackendError::Inference(err.to_string()))?;
    
                let pooled_embeddings: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();
                for (i, e) in pooled_embeddings.into_iter().enumerate() {
                    embeddings.insert(i, Embedding::Pooled(e));
                }
            }
            else if !batch.raw_indices.is_empty() {
                let results = self
                    .tokio_runtime
                    .block_on(self.backend_client.clone().embed_all(
                        batch.input_ids,
                        batch.token_type_ids,
                        batch.position_ids,
                        batch.cumulative_seq_lengths,
                        batch.max_length,
                    ))
                    .map_err(|err| BackendError::Inference(err.to_string()))?;
    
                let mut raw_embeddings = Vec::new();
                for token_embedding in results {
                    let mut two_dim_list = Vec::new();
                    for embeddings in token_embedding.embeddings {
                        let values = embeddings.values.clone();
                        two_dim_list.push(values);
                    }
                    raw_embeddings.push(two_dim_list);
                }
                for (i, e) in raw_embeddings.into_iter().enumerate() {
                    embeddings.insert(i, Embedding::All(e));
                }
            }
            Ok(embeddings)
        }
    
        fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> {
            let batch_size = batch.len();
    
            let results = self
                .tokio_runtime
                .block_on(self.backend_client.clone().predict(
                    batch.input_ids,
                    batch.token_type_ids,
                    batch.position_ids,
                    batch.cumulative_seq_lengths,
                    batch.max_length,
                ))
                .map_err(|err| BackendError::Inference(err.to_string()))?;
    
            let predictions_result: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();
    
            let mut predictions =
                HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
            for (i, r) in predictions_result.into_iter().enumerate() {
                predictions.insert(i, r);
            }
    
            Ok(predictions)
        }
    }
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/requirements.txt文件中修改“huggingface-hub”“safetensors”“torch”“poetry”的版本信息(以下代码加粗部分),用于适配昇腾推理环境。
    backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
    certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
    charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
    click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
    colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
    deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
    filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
    fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13"
    googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
    grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
    grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
    grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
    grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
    huggingface-hub==0.23.2 ; python_version >= "3.9" and python_version < "3.13"
    idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
    jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
    loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
    markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
    mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
    networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
    packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
    protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
    pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
    requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
    safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
    setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
    sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
    torch==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
    tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
    typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
    typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
    urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
    win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
    wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    poetry==1.8.3 ; python_version >= "3.9" and python_version < "3.13"
  • 基于TEI v1.2.3版本,根据text-embeddings-inference/backends/python/server/requirements.txt文件中修改“safetensors”“torch”的版本信息,请同步修改text-embeddings-inference/backends/python/server/pyproject.toml文件中“safetensors”“torch”版本信息(以下代码加粗部分),用于修改"text-embeddings-server"项目依赖关系。
    [tool.poetry]
    name = "text-embeddings-server"
    version = "0.1.0"
    description = "Text Embeddings Python gRPC Server"
    authors = ["Olivier Dehaene <olivier@huggingface.co>"]
    [tool.poetry.scripts]
    python-text-embeddings-server = 'text_embeddings_server.cli:app'
    [tool.poetry.dependencies]
    python = ">=3.9,<3.13"
    protobuf = "^4.21.7"
    grpcio = "^1.51.1"
    grpcio-status = "^1.51.1"
    grpcio-reflection = "^1.51.1"
    grpc-interceptor = "^0.15.0"
    typer = "^0.6.1"
    safetensors = "^0.4.3"
    loguru = "^0.6.0"
    opentelemetry-api = "^1.15.0"
    opentelemetry-exporter-otlp = "^1.15.0"
    opentelemetry-instrumentation-grpc = "^0.36b0"
    torch = { version = "^2.1.0" }
    [tool.poetry.extras]
    [tool.poetry.group.dev.dependencies]
    grpcio-tools = "^1.51.1"
    pytest = "^7.3.0"
    [[tool.poetry.source]]
    name = "mirrors"
    url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
    priority = "default"
    [tool.pytest.ini_options]
    markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
    [build-system]
    requires = ["poetry-core>=1.0.0"]
    build-backend = "poetry.core.masonry.api"
  • 基于TEI v1.2.3版本,根据embed.proto中定义的格式,请在text-embeddings-inference/backends/python/server/text_embeddings_server/server.py文件中的EmbeddingService类添加Embed_all以及Predict的异步接口(将请求内容转换为对应的batch(PaddedBatch)实例(以下代码加粗部分),将batch传入到模型的对应预测接口,将其输出值转换为对应服务的响应类实例)。
    import asyncio
    import torch
    import mindietorch
    
    from grpc import aio
    from loguru import logger
    
    from grpc_reflection.v1alpha import reflection
    from pathlib import Path
    from typing import Optional
    
    from text_embeddings_server.models import Model, get_model
    from text_embeddings_server.pb import embed_pb2_grpc, embed_pb2
    from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor
    from text_embeddings_server.utils.interceptor import ExceptionInterceptor
    
    
    class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer):
        def __init__(self, model: Model):
            self.model = model
            # Force inference mode for the lifetime of EmbeddingService
            self._inference_mode_raii_guard = torch._C._InferenceMode(True)
    
        async def Health(self, request, context):
            if self.model.device.type == "cuda":
                torch.zeros((2, 2), device="cuda")
            return embed_pb2.HealthResponse()
    
        async def Embed(self, request, context):
            batch = self.model.batch_type.from_pb(request, self.model.device)
    
            embeddings = self.model.embed(batch)
    
            return embed_pb2.EmbedResponse(embeddings=embeddings)
    
        async def Embed_all(self, request, context):
            batch = self.model.batch_type.from_pb(request, self.model.device)
    
            embeddings = self.model.embed_all(batch)
    
            return embed_pb2.RawEmbedResponse(allembeddings=embeddings)
    
        async def Predict(self, request, context):
            batch = self.model.batch_type.from_pb(request, self.model.device)
    
            predictions = self.model.predict(batch)
    
            return embed_pb2.PredictResponse(predictions=predictions)
    
    
    def serve(
        model_path: Path,
        dtype: Optional[str],
        uds_path: Path,
    ):
        async def serve_inner(
            model_path: Path,
            dtype: Optional[str] = None,
        ):
            unix_socket = f"unix://{uds_path}"
    
            try:
                model = get_model(model_path, dtype)
            except Exception:
                logger.exception("Error when initializing model")
                raise
    
            server = aio.server(
                interceptors=[
                    ExceptionInterceptor(),
                    UDSOpenTelemetryAioServerInterceptor(),
                ]
            )
            embed_pb2_grpc.add_EmbeddingServiceServicer_to_server(
                EmbeddingService(model), server
            )
            SERVICE_NAMES = (
                embed_pb2.DESCRIPTOR.services_by_name["EmbeddingService"].full_name,
                reflection.SERVICE_NAME,
            )
            reflection.enable_server_reflection(SERVICE_NAMES, server)
            server.add_insecure_port(unix_socket)
    
            await server.start()
    
            logger.info(f"Server started at {unix_socket}")
    
            try:
                await server.wait_for_termination()
            except KeyboardInterrupt:
                logger.info("Signal received. Shutting down")
                await server.stop(0)
    
        asyncio.run(serve_inner(model_path, dtype))
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/text_embeddings_server/models/__init__.py文件中的get_model方法内添加逻辑判断分支(以下代码以加粗部分);当硬件平台为npu的情况下,根据环境变量TEI_NPU_DEVICE设置torch.device;根据模型的config.json中architectures列表的内容,判定调用default_model(文本嵌入模型类)还是rerank_model (重排序模型类)。
    import os
    import torch
    import mindietorch
    from loguru import logger
    from pathlib import Path
    from typing import Optional
    from transformers import AutoConfig
    
    from text_embeddings_server.models.model import Model
    from text_embeddings_server.models.default_model import DefaultModel
    from text_embeddings_server.models.rerank_model import RerankModel
    
    __all__ = ["Model"]
    
    # Disable gradients
    torch.set_grad_enabled(False)
    
    FLASH_ATTENTION = True
    try:
        from text_embeddings_server.models.flash_bert import FlashBert
    except ImportError as e:
        logger.warning(f"Could not import Flash Attention enabled models: {e}")
        FLASH_ATTENTION = False
    
    if FLASH_ATTENTION:
        __all__.append(FlashBert)
    
    
    def get_model(model_path: Path, dtype: Optional[str]):
        if dtype == "float32":
            dtype = torch.float32
        elif dtype == "float16":
            dtype = torch.float16
        elif dtype == "bfloat16":
            dtype = torch.bfloat16
        else:
            raise RuntimeError(f"Unknown dtype {dtype}")
        deviceIdx = os.environ.get('TEI_NPU_DEVICE', '0')
        if deviceIdx != None and deviceIdx.isdigit() and int(deviceIdx) >= 0 and int(deviceIdx) <= 7:
            mindietorch.set_device(int(deviceIdx))
            device = torch.device(f"npu:{int(deviceIdx)}")
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
        if config.architectures[0].endswith("Classification"):
            return RerankModel(model_path, device, dtype)
        else:
            if (
                config.model_type == "bert"
                and device.type == "cuda"
                and config.position_embedding_type == "absolute"
                and dtype in [torch.float16, torch.bfloat16]
                and FLASH_ATTENTION
            ):
                return FlashBert(model_path, device, dtype)
            else:
                return DefaultModel(model_path, device, dtype)
        raise NotImplementedError
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/text_embeddings_server/models/default_model.py文件中的DefaultModel类增加embed_all接口(以下代码加粗部分),返回所有token的hidden_states。
    import torch
    import mindietorch
    
    from pathlib import Path
    from typing import Type, List
    from opentelemetry import trace
    from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModel, AutoConfig
    from loguru import logger
    
    from text_embeddings_server.models import Model
    from text_embeddings_server.models.types import PaddedBatch, Embedding, Prediction, TokenEmbedding
    
    tracer = trace.get_tracer(__name__)
    
    
    class DefaultModel(Model):
        def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
            mindietorch.set_device(device.index)
            model = torch.jit.load(next(Path(model_path).rglob("*.pt"))).eval().to(device)
            self.model_path = str(model_path)
            self.hidden_size = AutoConfig.from_pretrained(model_path, trust_remote_code=True).hidden_size
            super(DefaultModel, self).__init__(model=model, dtype=dtype, device=device)
    
        @property
        def batch_type(self) -> Type[PaddedBatch]:
            return PaddedBatch
    
        @tracer.start_as_current_span("embed")
        def embed(self, batch: PaddedBatch) -> List[Embedding]:
            kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
            output = self.model(kwargs["input_ids"].to(self.device), kwargs["attention_mask"].to(self.device))
            if isinstance(output, dict):
                embedding = output['last_hidden_state'].to('cpu')
            else:
                embedding = output[0].to('cpu')
            embedding = embedding[:, 0].contiguous()
            cpu_results = embedding.view(-1).tolist()
            return [
                Embedding(
                    values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size]
                )
                for i in range(len(batch))
            ]
    
        @tracer.start_as_current_span("embed_all")
        def embed_all(self, batch: PaddedBatch):
            kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
            output = self.model(kwargs["input_ids"].to(self.device), kwargs["attention_mask"].to(self.device))
            if isinstance(output, dict):
                embedding = output['last_hidden_state'].to('cpu').contiguous()
            else:
                embedding = output[0].to('cpu').contiguous()
            cpu_results = embedding.view(-1).tolist()
    
            embedding_result=[]
            for i in range(len(batch)):
                embedding_tmp=[
                    Embedding(values=cpu_results[(j+i*batch.max_length) * self.hidden_size :
                    (j + 1 + i*batch.max_length) * self.hidden_size])
                    for j in range(batch.input_ids.size()[1])
                    ]
                tokenembeddings=TokenEmbedding(embeddings=embedding_tmp)
                embedding_result.append(tokenembeddings)
    
            return embedding_result
    
        @tracer.start_as_current_span("predict")
        def predict(self, batch: PaddedBatch) -> List[Prediction]:
            print("embedding model does not support predict function")
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/text_embeddings_server/models/model.py文件中的模型父类Model中添加embed_all和predict抽象接口定义(以下代码加粗部分),返回为NotImplementedError。
    import torch
    
    from abc import ABC, abstractmethod
    from typing import List, TypeVar, Type
    
    from text_embeddings_server.models.types import Batch, Embedding, Prediction, TokenEmbedding
    
    B = TypeVar("B", bound=Batch)
    
    
    class Model(ABC):
        def __init__(
            self,
            model,
            dtype: torch.dtype,
            device: torch.device,
        ):
            self.model = model
            self.dtype = dtype
            self.device = device
    
        @property
        @abstractmethod
        def batch_type(self) -> Type[B]:
            raise NotImplementedError
    
        @abstractmethod
        def embed(self, batch: B) -> List[Embedding]:
            raise NotImplementedError
    
        @abstractmethod
        def embed_all(self, batch: B) -> List[TokenEmbedding]:
            raise NotImplementedError
    
        @abstractmethod
        def predict(self, batch: B) -> List[Prediction]:
            raise NotImplementedError
  • 基于TEI v1.2.3版本,请自行创建text-embeddings-inference/backends/python/server/text_embeddings_server/models/rerank_model.py文件。新增该文件实现重排序模型的类,提供rerank接口,返回query和text的匹配分数。
    import torch
    import mindietorch
    from pathlib import Path
    from typing import Type, List
    from opentelemetry import trace
    from loguru import logger
    
    from text_embeddings_server.models import Model
    from text_embeddings_server.models.types import PaddedBatch, Embedding, Prediction, TokenEmbedding
    
    tracer = trace.get_tracer(__name__)
    
    
    class RerankModel(Model):
        def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
            mindietorch.set_device(device.index)
            model = torch.jit.load(next(Path(model_path).rglob("*.pt"))).eval().to(device)
            super(RerankModel, self).__init__(model=model, dtype=dtype, device=device)
    
        @property
        def batch_type(self) -> Type[PaddedBatch]:
            return PaddedBatch
    
        @tracer.start_as_current_span("embed")
        def embed(self, batch: PaddedBatch) -> List[Embedding]:
            print("rerank model does not support embed function")
    
        @tracer.start_as_current_span("embed_all")
        def embed_all(self, batch: PaddedBatch) -> List[TokenEmbedding]:
            print("rerank model does not support embed_all function")
    
        @tracer.start_as_current_span("predict")
        def predict(self, batch: PaddedBatch) -> List[Prediction]:
            kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
    
            scores = self.model(kwargs["input_ids"].to(self.device), kwargs["attention_mask"].to(self.device))[0].tolist()
            return [
                Prediction(
                    values=scores[i]
                )
                for i in range(len(batch))
            ]
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/text_embeddings_server/models/types.py文件中为PaddedBatch类增加max_length字段并修改数据搬运以适应mindietorch(以下代码加粗部分)。
    import torch
    import mindietorch
    
    from abc import ABC, abstractmethod
    from dataclasses import dataclass
    from opentelemetry import trace
    
    from text_embeddings_server.pb import embed_pb2
    from text_embeddings_server.pb.embed_pb2 import Embedding, Prediction, TokenEmbedding
    
    tracer = trace.get_tracer(__name__)
    
    
    class Batch(ABC):
        @classmethod
        @abstractmethod
        def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "Batch":
            raise NotImplementedError
    
        @abstractmethod
        def __len__(self):
            raise NotImplementedError
    
    
    @dataclass
    class PaddedBatch(Batch):
        input_ids: torch.Tensor
        token_type_ids: torch.Tensor
        position_ids: torch.Tensor
        attention_mask: torch.Tensor
        max_length: int
    
        @classmethod
        @tracer.start_as_current_span("from_pb")
        def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "PaddedBatch":
            # Allocate padded tensors all at once
            all_tensors = torch.zeros(
                [4, len(pb.cu_seq_lengths) - 1, pb.max_length], dtype=torch.int32, device='cpu'
            )
            max_length=pb.max_length
    
            for i, start_index in enumerate(pb.cu_seq_lengths[:-1]):
                end_index = pb.cu_seq_lengths[i + 1]
                input_length = end_index - start_index
    
                all_tensors[0, i, :input_length] = torch.tensor(
                    pb.input_ids[start_index:end_index], dtype=torch.int32
                )
                all_tensors[1, i, :input_length] = torch.tensor(
                    pb.token_type_ids[start_index:end_index], dtype=torch.int32
                )
                all_tensors[2, i, :input_length] = torch.tensor(
                    pb.position_ids[start_index:end_index], dtype=torch.int32
                )
                all_tensors[3, i, :input_length] = 1
            """
            # Move padded tensors all at once
            all_tensors = all_tensors.to(device)
            """
            return PaddedBatch(
                input_ids=all_tensors[0],
                token_type_ids=all_tensors[1],
                position_ids=all_tensors[2],
                attention_mask=all_tensors[3],
                max_length=max_length,
            )
    
        def __len__(self):
            return len(self.input_ids)
    
    
    @dataclass
    class FlashBatch(Batch):
        input_ids: torch.Tensor
        token_type_ids: torch.Tensor
        position_ids: torch.Tensor
    
        cu_seqlens: torch.Tensor
        max_s: int
        size: int
    
        @classmethod
        @tracer.start_as_current_span("from_pb")
        def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "FlashBatch":
            if device.type != "cuda":
                raise RuntimeError(f"FlashBatch does not support device {device}")
    
            batch_input_ids = torch.tensor(pb.input_ids, dtype=torch.int32, device=device)
            batch_token_type_ids = torch.tensor(
                pb.token_type_ids, dtype=torch.int32, device=device
            )
            batch_position_ids = torch.tensor(
                pb.position_ids, dtype=torch.int32, device=device
            )
    
            cu_seqlens = torch.tensor(pb.cu_seq_lengths, dtype=torch.int32, device=device)
    
            return FlashBatch(
                input_ids=batch_input_ids,
                token_type_ids=batch_token_type_ids,
                position_ids=batch_position_ids,
                cu_seqlens=cu_seqlens,
                max_s=pb.max_length,
                size=len(cu_seqlens) - 1,
            )
    
        def __len__(self):
            return self.size