TEI框架接入MindIE Torch组件全量适配代码
适配昇腾环境及MindIE Torch组件的TEI代码目录结构与其Github代码仓完全一致(需要修改的代码文件加粗显示在下方代码结构树下,其它未显示文件或文件夹与TEI的Github代码仓保持一致),其目录结构如下所示。
text-embeddings-inference |____core | |____src | | |____ infer.rs |____backends | |____grpc-client | | |____src | | | |____ client.rs | |____proto | | |____ embed.proto | |____python | | |____src | | | |____ lib.rs | | |____server | | | |____pyproject.toml | | | |____requirements.txt | | | |____text_embeddings_server | | | | |____ server.py | | | | |_____models | | | | | |____ __init__.py | | | | | |____ default_model.py | | | | | |____ model.py | | | | | |____ rerank_model.py | | | | | |____ types.py
- 基于TEI v1.2.3版本,请在text-embeddings-inference/core/src/infer.rs文件中的Infer类型内将其batching_task数量从2改为1,即注释掉源码中对应位置的加粗代码,仅用一个batchinbg_task性能更佳。
不做修改也不会影响TEI的运行以及服务化功能,修改后性能会有明显提升。
use crate::queue::{Entry, Metadata, NextBatch, Queue}; use crate::tokenization::{EncodingInput, RawEncoding, Tokenization}; use crate::TextEmbeddingsError; use std::sync::Arc; use std::time::{Duration, Instant}; use text_embeddings_backend::{Backend, BackendError, Embedding, ModelType}; use tokio::sync::{mpsc, oneshot, watch, Notify, OwnedSemaphorePermit, Semaphore}; use tracing::instrument; /// Inference struct #[derive(Debug, Clone)] pub struct Infer { tokenization: Tokenization, queue: Queue, /// Shared notify notify_batching_task: Arc<Notify>, /// Inference limit limit_concurrent_requests: Arc<Semaphore>, backend: Backend, } impl Infer { #[allow(clippy::too_many_arguments)] pub fn new( tokenization: Tokenization, queue: Queue, max_concurrent_requests: usize, backend: Backend, ) -> Self { let notify_batching_task = Arc::new(Notify::new()); let (embed_sender, embed_receiver) = mpsc::unbounded_channel(); // Create only one batching task to prefetch batches tokio::spawn(batching_task( queue.clone(), notify_batching_task.clone(), embed_sender.clone(), )); /* 此处源码使用两个batching_task对请求进行队列管理,将此处注释后仅用一个batching_task性能更佳 tokio::spawn(batching_task( queue.clone(), notify_batching_task.clone(), embed_sender, )); */ // Create embed task to communicate with backend tokio::spawn(backend_task(backend.clone(), embed_receiver)); // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); Self { tokenization, queue, notify_batching_task, limit_concurrent_requests: semaphore, backend, } } #[instrument(skip(self))] pub async fn tokenize<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, add_special_tokens: bool, ) -> Result<RawEncoding, TextEmbeddingsError> { self.tokenization .tokenize(inputs.into(), add_special_tokens) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); tracing::error!("{err}"); err }) } #[instrument(skip(self))] pub async fn decode( &self, ids: Vec<u32>, skip_special_tokens: bool, ) -> Result<String, TextEmbeddingsError> { self.tokenization .decode(ids, skip_special_tokens) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); tracing::error!("{err}"); err }) } #[instrument(skip(self))] pub fn try_acquire_permit(&self) -> Result<OwnedSemaphorePermit, TextEmbeddingsError> { // Limit concurrent requests by acquiring a permit from the semaphore self.clone() .limit_concurrent_requests .try_acquire_owned() .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "overloaded"); tracing::error!("{err}"); TextEmbeddingsError::from(err) }) } #[instrument(skip(self))] pub async fn acquire_permit(&self) -> OwnedSemaphorePermit { // Limit concurrent requests by acquiring a permit from the semaphore self.clone() .limit_concurrent_requests .acquire_owned() .await .expect("Semaphore has been closed. This is a bug.") } #[instrument(skip(self, permit))] pub async fn embed_all<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, permit: OwnedSemaphorePermit, ) -> Result<AllEmbeddingsInferResponse, TextEmbeddingsError> { let start_time = Instant::now(); if self.is_splade() { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "`embed_all` is not available for SPLADE models".to_string(); tracing::error!("{message}"); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } let results = self .embed(inputs, truncate, false, &start_time, permit) .await?; let InferResult::AllEmbedding(response) = results else { panic!("unexpected enum variant") }; // Timings let total_time = start_time.elapsed(); // Metrics metrics::increment_counter!("te_embed_success"); metrics::histogram!("te_embed_duration", total_time.as_secs_f64()); metrics::histogram!( "te_embed_tokenization_duration", response.metadata.tokenization.as_secs_f64() ); metrics::histogram!( "te_embed_queue_duration", response.metadata.queue.as_secs_f64() ); metrics::histogram!( "te_embed_inference_duration", response.metadata.inference.as_secs_f64() ); Ok(response) } #[instrument(skip(self, permit))] pub async fn embed_sparse<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, permit: OwnedSemaphorePermit, ) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> { let start_time = Instant::now(); if !self.is_splade() { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "Model is not an embedding model with SPLADE pooling".to_string(); tracing::error!("{message}"); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } let results = self .embed(inputs, truncate, true, &start_time, permit) .await?; let InferResult::PooledEmbedding(response) = results else { panic!("unexpected enum variant") }; // Timings let total_time = start_time.elapsed(); // Metrics metrics::increment_counter!("te_embed_success"); metrics::histogram!("te_embed_duration", total_time.as_secs_f64()); metrics::histogram!( "te_embed_tokenization_duration", response.metadata.tokenization.as_secs_f64() ); metrics::histogram!( "te_embed_queue_duration", response.metadata.queue.as_secs_f64() ); metrics::histogram!( "te_embed_inference_duration", response.metadata.inference.as_secs_f64() ); Ok(response) } #[instrument(skip(self, permit))] pub async fn embed_pooled<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, normalize: bool, permit: OwnedSemaphorePermit, ) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> { let start_time = Instant::now(); if self.is_splade() && normalize { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "`normalize` is not available for SPLADE models".to_string(); tracing::error!("{message}"); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } let results = self .embed(inputs, truncate, true, &start_time, permit) .await?; let InferResult::PooledEmbedding(mut response) = results else { panic!("unexpected enum variant") }; if normalize { // Normalize embedding let scale = (1.0 / response .results .iter() .map(|v| { let v = *v as f64; v * v }) .sum::<f64>() .sqrt()) as f32; for v in response.results.iter_mut() { *v *= scale; } } // Timings let total_time = start_time.elapsed(); // Metrics metrics::increment_counter!("te_embed_success"); metrics::histogram!("te_embed_duration", total_time.as_secs_f64()); metrics::histogram!( "te_embed_tokenization_duration", response.metadata.tokenization.as_secs_f64() ); metrics::histogram!( "te_embed_queue_duration", response.metadata.queue.as_secs_f64() ); metrics::histogram!( "te_embed_inference_duration", response.metadata.inference.as_secs_f64() ); Ok(response) } async fn embed<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, pooling: bool, start_time: &Instant, _permit: OwnedSemaphorePermit, ) -> Result<InferResult, TextEmbeddingsError> { if self.is_classifier() { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "Model is not an embedding model".to_string(); tracing::error!("{message}"); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } metrics::increment_counter!("te_embed_count"); // Tokenization let encoding = self .tokenization .encode(inputs.into(), truncate) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); tracing::error!("{err}"); err })?; // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = oneshot::channel(); // Append the request to the queue self.queue.append(Entry { metadata: Metadata { response_tx, tokenization: start_time.elapsed(), queue_time: Instant::now(), prompt_tokens: encoding.input_ids.len(), pooling, }, encoding, }); self.notify_batching_task.notify_one(); let response = response_rx .await .expect( "Infer batching task dropped the sender without sending a response. This is a bug.", ) .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "inference"); tracing::error!("{err}"); err })?; Ok(response) } #[instrument(skip(self, _permit))] pub async fn predict<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, raw_scores: bool, _permit: OwnedSemaphorePermit, ) -> Result<ClassificationInferResponse, TextEmbeddingsError> { if !self.is_classifier() { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "Model is not a classifier model".to_string(); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } let start_time = Instant::now(); metrics::increment_counter!("te_predict_count"); // Tokenization let encoding = self .tokenization .encode(inputs.into(), truncate) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); tracing::error!("{err}"); err })?; // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = oneshot::channel(); // Append the request to the queue self.queue.append(Entry { metadata: Metadata { response_tx, tokenization: start_time.elapsed(), queue_time: Instant::now(), prompt_tokens: encoding.input_ids.len(), pooling: true, }, encoding, }); self.notify_batching_task.notify_one(); let response = response_rx .await .expect( "Infer batching task dropped the sender without sending a response. This is a bug.", ) .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "inference"); tracing::error!("{err}"); err })?; let InferResult::Classification(mut response) = response else { panic!("unexpected enum variant") }; if !raw_scores { // Softmax if response.results.len() > 1 { let max = *response .results .iter() .max_by(|x, y| x.abs().partial_cmp(&y.abs()).unwrap()) .unwrap(); let mut den = 0.0; for v in response.results.iter_mut() { *v = (*v - max).exp(); den += *v; } for v in response.results.iter_mut() { *v /= den; } } // Sigmoid else { response.results[0] = 1.0 / (1.0 + (-response.results[0]).exp()); } } // Timings let total_time = start_time.elapsed(); // Metrics metrics::increment_counter!("te_predict_success"); metrics::histogram!("te_predict_duration", total_time.as_secs_f64()); metrics::histogram!( "te_predict_tokenization_duration", response.metadata.tokenization.as_secs_f64() ); metrics::histogram!( "te_predict_queue_duration", response.metadata.queue.as_secs_f64() ); metrics::histogram!( "te_predict_inference_duration", response.metadata.inference.as_secs_f64() ); Ok(response) } #[instrument(skip(self))] pub fn is_classifier(&self) -> bool { matches!(self.backend.model_type, ModelType::Classifier) } #[instrument(skip(self))] pub fn is_splade(&self) -> bool { matches!( self.backend.model_type, ModelType::Embedding(text_embeddings_backend::Pool::Splade) ) } #[instrument(skip(self))] pub async fn health(&self) -> bool { self.backend.health().await.is_ok() } #[instrument(skip(self))] pub fn health_watcher(&self) -> watch::Receiver<bool> { self.backend.health_watcher() } } #[instrument(skip_all)] async fn batching_task( queue: Queue, notify: Arc<Notify>, embed_sender: mpsc::UnboundedSender<(NextBatch, oneshot::Sender<()>)>, ) { loop { notify.notified().await; while let Some(next_batch) = queue.next_batch().await { let (callback_sender, callback_receiver) = oneshot::channel(); embed_sender .send((next_batch, callback_sender)) .expect("embed receiver was dropped. This is a bug."); let _ = callback_receiver.await; } } } #[instrument(skip_all)] async fn backend_task( backend: Backend, mut embed_receiver: mpsc::UnboundedReceiver<(NextBatch, oneshot::Sender<()>)>, ) { while let Some((batch, _callback)) = embed_receiver.recv().await { match &backend.model_type { ModelType::Classifier => { let results = backend.predict(batch.1).await; // Handle sending responses in another thread to avoid starving the backend std::thread::spawn(move || match results { Ok((mut predictions, inference_duration)) => { batch.0.into_iter().enumerate().for_each(|(i, m)| { let infer_metadata = InferMetadata { prompt_tokens: m.prompt_tokens, tokenization: m.tokenization, queue: m.queue_time.elapsed() - inference_duration, inference: inference_duration, }; let _ = m.response_tx.send(Ok(InferResult::Classification( ClassificationInferResponse { results: predictions.remove(&i).expect( "prediction not found in results. This is a backend bug.", ), metadata: infer_metadata, }, ))); }); } Err(err) => { batch.0.into_iter().for_each(|m| { let _ = m.response_tx.send(Err(err.clone())); }); } }); } ModelType::Embedding(_) => { let results = backend.embed(batch.1).await; // Handle sending responses in another thread to avoid starving the backend std::thread::spawn(move || match results { Ok((mut embeddings, inference_duration)) => { batch.0.into_iter().enumerate().for_each(|(i, m)| { let metadata = InferMetadata { prompt_tokens: m.prompt_tokens, tokenization: m.tokenization, queue: m.queue_time.elapsed() - inference_duration, inference: inference_duration, }; let results = match embeddings .remove(&i) .expect("embedding not found in results. This is a backend bug.") { Embedding::Pooled(e) => { InferResult::PooledEmbedding(PooledEmbeddingsInferResponse { results: e, metadata, }) } Embedding::All(e) => { InferResult::AllEmbedding(AllEmbeddingsInferResponse { results: e, metadata, }) } }; let _ = m.response_tx.send(Ok(results)); }) } Err(err) => { batch.0.into_iter().for_each(|m| { let _ = m.response_tx.send(Err(err.clone())); }); } }); } }; } } #[derive(Debug)] pub struct InferMetadata { pub prompt_tokens: usize, pub tokenization: Duration, pub queue: Duration, pub inference: Duration, } #[derive(Debug)] pub(crate) enum InferResult { Classification(ClassificationInferResponse), PooledEmbedding(PooledEmbeddingsInferResponse), AllEmbedding(AllEmbeddingsInferResponse), } #[derive(Debug)] pub struct ClassificationInferResponse { pub results: Vec<f32>, pub metadata: InferMetadata, } #[derive(Debug)] pub struct PooledEmbeddingsInferResponse { pub results: Vec<f32>, pub metadata: InferMetadata, } #[derive(Debug)] pub struct AllEmbeddingsInferResponse { pub results: Vec<Vec<f32>>, pub metadata: InferMetadata, }
- 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/grpc-client/src/client.rs文件中为后端的grpc_client添加embed_all和predict异步接口(以下代码加粗部分)。
/// Single shard Client use crate::pb::embedding::v1::embedding_service_client::EmbeddingServiceClient; use crate::pb::embedding::v1::*; use crate::Result; use grpc_metadata::InjectTelemetryContext; use tonic::transport::{Channel, Uri}; use tracing::instrument; /// Text Generation Inference gRPC client #[derive(Debug, Clone)] pub struct Client { stub: EmbeddingServiceClient<Channel>, } impl Client { /// Returns a client connected to the given url pub async fn connect(uri: Uri) -> Result<Self> { let channel = Channel::builder(uri).connect().await?; Ok(Self { stub: EmbeddingServiceClient::new(channel), }) } /// Returns a client connected to the given unix socket pub async fn connect_uds(path: String) -> Result<Self> { let channel = Channel::from_shared("http://[::]:50051".to_string()) .unwrap() .connect_with_connector(tower::service_fn(move |_: Uri| { tokio::net::UnixStream::connect(path.clone()) })) .await?; Ok(Self { stub: EmbeddingServiceClient::new(channel), }) } /// Get backend health #[instrument(skip(self))] pub async fn health(&mut self) -> Result<HealthResponse> { let request = tonic::Request::new(HealthRequest {}).inject_context(); let response = self.stub.health(request).await?.into_inner(); Ok(response) } #[instrument(skip_all)] pub async fn embed( &mut self, input_ids: Vec<u32>, token_type_ids: Vec<u32>, position_ids: Vec<u32>, cu_seq_lengths: Vec<u32>, max_length: u32, ) -> Result<Vec<Embedding>> { let request = tonic::Request::new(EmbedRequest { input_ids, token_type_ids, position_ids, max_length, cu_seq_lengths, }) .inject_context(); let response = self.stub.embed(request).await?.into_inner(); Ok(response.embeddings) } #[instrument(skip_all)] pub async fn embed_all( &mut self, input_ids: Vec<u32>, token_type_ids: Vec<u32>, position_ids: Vec<u32>, cu_seq_lengths: Vec<u32>, max_length: u32, ) -> Result<Vec<TokenEmbedding>> { let request = tonic::Request::new(EmbedRequest { input_ids, token_type_ids, position_ids, max_length, cu_seq_lengths, }) .inject_context(); let response = self.stub.embed_all(request).await?.into_inner(); Ok(response.allembeddings) } #[instrument(skip_all)] pub async fn predict( &mut self, input_ids: Vec<u32>, token_type_ids: Vec<u32>, position_ids: Vec<u32>, cu_seq_lengths: Vec<u32>, max_length: u32, ) -> Result<Vec<Prediction>> { let request = tonic::Request::new(PredictRequest { input_ids, token_type_ids, position_ids, max_length, cu_seq_lengths, }) .inject_context(); let response = self.stub.predict(request).await?.into_inner(); Ok(response.predictions) } }
- 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/proto/embed.proto文件中为EmbeddingService添加Embed_all和Predict服务(以下代码加粗部分),并设定相应的数据类、请求与响应格式。
syntax = "proto3"; package embedding.v1; service EmbeddingService { /// Decode token for a list of prefilled batches rpc Embed (EmbedRequest) returns (EmbedResponse); rpc Embed_all (EmbedRequest) returns (RawEmbedResponse); rpc Predict (PredictRequest) returns (PredictResponse); /// Health check rpc Health (HealthRequest) returns (HealthResponse); } message HealthRequest {} message HealthResponse {} message PredictRequest { repeated uint32 input_ids = 1; repeated uint32 token_type_ids = 2; repeated uint32 position_ids = 3; repeated uint32 cu_seq_lengths = 4; /// Length of the longest request uint32 max_length = 5; } message Prediction { repeated float values = 1; } message PredictResponse { repeated Prediction predictions = 1; } message EmbedRequest { repeated uint32 input_ids = 1; repeated uint32 token_type_ids = 2; repeated uint32 position_ids = 3; repeated uint32 cu_seq_lengths = 4; /// Length of the longest request uint32 max_length = 5; } message Embedding { repeated float values = 1; } message EmbedResponse { repeated Embedding embeddings = 1; } message TokenEmbedding { repeated Embedding embeddings = 1; } message RawEmbedResponse { repeated TokenEmbedding allembeddings = 1; }
- 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/src/lib.rs文件中将模型类型校验处的分类模型不支持移除,改为正常拉起模型(以下代码加粗部分);为后端模型增加predict接口,并修改embed接口,添加判断分支,根据batch的成员内容内容决定返回所有token的hidden_states或是池化后的全局hidden_states。
mod logging; mod management; use backend_grpc_client::Client; use nohash_hasher::BuildNoHashHasher; use std::collections::HashMap; use text_embeddings_backend_core::{ Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, }; use tokio::runtime::Runtime; pub struct PythonBackend { _backend_process: management::BackendProcess, tokio_runtime: Runtime, backend_client: Client, } impl PythonBackend { pub fn new( model_path: String, dtype: String, model_type: ModelType, uds_path: String, otlp_endpoint: Option<String>, ) -> Result<Self, BackendError> { match model_type { ModelType::Classifier => { let pool = Pool::Cls; pool } ModelType::Embedding(pool) => { if pool != Pool::Cls { return Err(BackendError::Start(format!("{pool:?} is not supported"))); } pool } }; let backend_process = management::BackendProcess::new(model_path, dtype, &uds_path, otlp_endpoint)?; let tokio_runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?; let backend_client = tokio_runtime .block_on(Client::connect_uds(uds_path)) .map_err(|err| { BackendError::Start(format!("Could not connect to backend process: {err}")) })?; Ok(Self { _backend_process: backend_process, tokio_runtime, backend_client, }) } } impl Backend for PythonBackend { fn health(&self) -> Result<(), BackendError> { if self .tokio_runtime .block_on(self.backend_client.clone().health()) .is_err() { return Err(BackendError::Unhealthy); } Ok(()) } fn is_padded(&self) -> bool { false } fn embed(&self, batch: Batch) -> Result<Embeddings, BackendError> { /* if !batch.raw_indices.is_empty() { return Err(BackendError::Inference( "raw embeddings are not supported for the Python backend.".to_string(), )); } */ let batch_size = batch.len(); let mut embeddings = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); if !batch.pooled_indices.is_empty() { let results = self .tokio_runtime .block_on(self.backend_client.clone().embed( batch.input_ids, batch.token_type_ids, batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, )) .map_err(|err| BackendError::Inference(err.to_string()))?; let pooled_embeddings: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect(); for (i, e) in pooled_embeddings.into_iter().enumerate() { embeddings.insert(i, Embedding::Pooled(e)); } } else if !batch.raw_indices.is_empty() { let results = self .tokio_runtime .block_on(self.backend_client.clone().embed_all( batch.input_ids, batch.token_type_ids, batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, )) .map_err(|err| BackendError::Inference(err.to_string()))?; let mut raw_embeddings = Vec::new(); for token_embedding in results { let mut two_dim_list = Vec::new(); for embeddings in token_embedding.embeddings { let values = embeddings.values.clone(); two_dim_list.push(values); } raw_embeddings.push(two_dim_list); } for (i, e) in raw_embeddings.into_iter().enumerate() { embeddings.insert(i, Embedding::All(e)); } } Ok(embeddings) } fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> { let batch_size = batch.len(); let results = self .tokio_runtime .block_on(self.backend_client.clone().predict( batch.input_ids, batch.token_type_ids, batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, )) .map_err(|err| BackendError::Inference(err.to_string()))?; let predictions_result: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect(); let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); for (i, r) in predictions_result.into_iter().enumerate() { predictions.insert(i, r); } Ok(predictions) } }
- 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/requirements.txt文件中修改“huggingface-hub”、“safetensors”、“torch”和“poetry”的版本信息(以下代码加粗部分),用于适配昇腾推理环境。
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.2 ; python_version >= "3.9" and python_version < "3.13" idna==3.4 ; python_version >= "3.9" and python_version < "3.13" jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13" mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" networkx==3.1 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" torch==2.1.0 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" poetry==1.8.3 ; python_version >= "3.9" and python_version < "3.13"
- 基于TEI v1.2.3版本,根据text-embeddings-inference/backends/python/server/requirements.txt文件中修改“safetensors”和“torch”的版本信息,请同步修改text-embeddings-inference/backends/python/server/pyproject.toml文件中“safetensors”和“torch”版本信息(以下代码加粗部分),用于修改"text-embeddings-server"项目依赖关系。
[tool.poetry] name = "text-embeddings-server" version = "0.1.0" description = "Text Embeddings Python gRPC Server" authors = ["Olivier Dehaene <olivier@huggingface.co>"] [tool.poetry.scripts] python-text-embeddings-server = 'text_embeddings_server.cli:app' [tool.poetry.dependencies] python = ">=3.9,<3.13" protobuf = "^4.21.7" grpcio = "^1.51.1" grpcio-status = "^1.51.1" grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" safetensors = "^0.4.3" loguru = "^0.6.0" opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" torch = { version = "^2.1.0" } [tool.poetry.extras] [tool.poetry.group.dev.dependencies] grpcio-tools = "^1.51.1" pytest = "^7.3.0" [[tool.poetry.source]] name = "mirrors" url = "https://pypi.tuna.tsinghua.edu.cn/simple/" priority = "default" [tool.pytest.ini_options] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api"
- 基于TEI v1.2.3版本,根据embed.proto中定义的格式,请在text-embeddings-inference/backends/python/server/text_embeddings_server/server.py文件中的EmbeddingService类添加Embed_all以及Predict的异步接口(将请求内容转换为对应的batch(PaddedBatch)实例(以下代码加粗部分),将batch传入到模型的对应预测接口,将其输出值转换为对应服务的响应类实例)。
import asyncio import torch import mindietorch from grpc import aio from loguru import logger from grpc_reflection.v1alpha import reflection from pathlib import Path from typing import Optional from text_embeddings_server.models import Model, get_model from text_embeddings_server.pb import embed_pb2_grpc, embed_pb2 from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor from text_embeddings_server.utils.interceptor import ExceptionInterceptor class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer): def __init__(self, model: Model): self.model = model # Force inference mode for the lifetime of EmbeddingService self._inference_mode_raii_guard = torch._C._InferenceMode(True) async def Health(self, request, context): if self.model.device.type == "cuda": torch.zeros((2, 2), device="cuda") return embed_pb2.HealthResponse() async def Embed(self, request, context): batch = self.model.batch_type.from_pb(request, self.model.device) embeddings = self.model.embed(batch) return embed_pb2.EmbedResponse(embeddings=embeddings) async def Embed_all(self, request, context): batch = self.model.batch_type.from_pb(request, self.model.device) embeddings = self.model.embed_all(batch) return embed_pb2.RawEmbedResponse(allembeddings=embeddings) async def Predict(self, request, context): batch = self.model.batch_type.from_pb(request, self.model.device) predictions = self.model.predict(batch) return embed_pb2.PredictResponse(predictions=predictions) def serve( model_path: Path, dtype: Optional[str], uds_path: Path, ): async def serve_inner( model_path: Path, dtype: Optional[str] = None, ): unix_socket = f"unix://{uds_path}" try: model = get_model(model_path, dtype) except Exception: logger.exception("Error when initializing model") raise server = aio.server( interceptors=[ ExceptionInterceptor(), UDSOpenTelemetryAioServerInterceptor(), ] ) embed_pb2_grpc.add_EmbeddingServiceServicer_to_server( EmbeddingService(model), server ) SERVICE_NAMES = ( embed_pb2.DESCRIPTOR.services_by_name["EmbeddingService"].full_name, reflection.SERVICE_NAME, ) reflection.enable_server_reflection(SERVICE_NAMES, server) server.add_insecure_port(unix_socket) await server.start() logger.info(f"Server started at {unix_socket}") try: await server.wait_for_termination() except KeyboardInterrupt: logger.info("Signal received. Shutting down") await server.stop(0) asyncio.run(serve_inner(model_path, dtype))
- 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/text_embeddings_server/models/__init__.py文件中的get_model方法内添加逻辑判断分支(以下代码以加粗部分);当硬件平台为npu的情况下,根据环境变量TEI_NPU_DEVICE设置torch.device;根据模型的config.json中architectures列表的内容,判定调用default_model(文本嵌入模型类)还是rerank_model (重排序模型类)。
import os import torch import mindietorch from loguru import logger from pathlib import Path from typing import Optional from transformers import AutoConfig from text_embeddings_server.models.model import Model from text_embeddings_server.models.default_model import DefaultModel from text_embeddings_server.models.rerank_model import RerankModel __all__ = ["Model"] # Disable gradients torch.set_grad_enabled(False) FLASH_ATTENTION = True try: from text_embeddings_server.models.flash_bert import FlashBert except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") FLASH_ATTENTION = False if FLASH_ATTENTION: __all__.append(FlashBert) def get_model(model_path: Path, dtype: Optional[str]): if dtype == "float32": dtype = torch.float32 elif dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16": dtype = torch.bfloat16 else: raise RuntimeError(f"Unknown dtype {dtype}") deviceIdx = os.environ.get('TEI_NPU_DEVICE', '0') if deviceIdx != None and deviceIdx.isdigit() and int(deviceIdx) >= 0 and int(deviceIdx) <= 7: mindietorch.set_device(int(deviceIdx)) device = torch.device(f"npu:{int(deviceIdx)}") config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if config.architectures[0].endswith("Classification"): return RerankModel(model_path, device, dtype) else: if ( config.model_type == "bert" and device.type == "cuda" and config.position_embedding_type == "absolute" and dtype in [torch.float16, torch.bfloat16] and FLASH_ATTENTION ): return FlashBert(model_path, device, dtype) else: return DefaultModel(model_path, device, dtype) raise NotImplementedError
- 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/text_embeddings_server/models/default_model.py文件中的DefaultModel类增加embed_all接口(以下代码加粗部分),返回所有token的hidden_states。
import torch import mindietorch from pathlib import Path from typing import Type, List from opentelemetry import trace from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModel, AutoConfig from loguru import logger from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding, Prediction, TokenEmbedding tracer = trace.get_tracer(__name__) class DefaultModel(Model): def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): mindietorch.set_device(device.index) model = torch.jit.load(next(Path(model_path).rglob("*.pt"))).eval().to(device) self.model_path = str(model_path) self.hidden_size = AutoConfig.from_pretrained(model_path, trust_remote_code=True).hidden_size super(DefaultModel, self).__init__(model=model, dtype=dtype, device=device) @property def batch_type(self) -> Type[PaddedBatch]: return PaddedBatch @tracer.start_as_current_span("embed") def embed(self, batch: PaddedBatch) -> List[Embedding]: kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} output = self.model(kwargs["input_ids"].to(self.device), kwargs["attention_mask"].to(self.device)) if isinstance(output, dict): embedding = output['last_hidden_state'].to('cpu') else: embedding = output[0].to('cpu') embedding = embedding[:, 0].contiguous() cpu_results = embedding.view(-1).tolist() return [ Embedding( values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] ) for i in range(len(batch)) ] @tracer.start_as_current_span("embed_all") def embed_all(self, batch: PaddedBatch): kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} output = self.model(kwargs["input_ids"].to(self.device), kwargs["attention_mask"].to(self.device)) if isinstance(output, dict): embedding = output['last_hidden_state'].to('cpu').contiguous() else: embedding = output[0].to('cpu').contiguous() cpu_results = embedding.view(-1).tolist() embedding_result=[] for i in range(len(batch)): embedding_tmp=[ Embedding(values=cpu_results[(j+i*batch.max_length) * self.hidden_size : (j + 1 + i*batch.max_length) * self.hidden_size]) for j in range(batch.input_ids.size()[1]) ] tokenembeddings=TokenEmbedding(embeddings=embedding_tmp) embedding_result.append(tokenembeddings) return embedding_result @tracer.start_as_current_span("predict") def predict(self, batch: PaddedBatch) -> List[Prediction]: print("embedding model does not support predict function")
- 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/text_embeddings_server/models/model.py文件中的模型父类Model中添加embed_all和predict抽象接口定义(以下代码加粗部分),返回为NotImplementedError。
import torch from abc import ABC, abstractmethod from typing import List, TypeVar, Type from text_embeddings_server.models.types import Batch, Embedding, Prediction, TokenEmbedding B = TypeVar("B", bound=Batch) class Model(ABC): def __init__( self, model, dtype: torch.dtype, device: torch.device, ): self.model = model self.dtype = dtype self.device = device @property @abstractmethod def batch_type(self) -> Type[B]: raise NotImplementedError @abstractmethod def embed(self, batch: B) -> List[Embedding]: raise NotImplementedError @abstractmethod def embed_all(self, batch: B) -> List[TokenEmbedding]: raise NotImplementedError @abstractmethod def predict(self, batch: B) -> List[Prediction]: raise NotImplementedError
- 基于TEI v1.2.3版本,请自行创建text-embeddings-inference/backends/python/server/text_embeddings_server/models/rerank_model.py文件。新增该文件实现重排序模型的类,提供rerank接口,返回query和text的匹配分数。
import torch import mindietorch from pathlib import Path from typing import Type, List from opentelemetry import trace from loguru import logger from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding, Prediction, TokenEmbedding tracer = trace.get_tracer(__name__) class RerankModel(Model): def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): mindietorch.set_device(device.index) model = torch.jit.load(next(Path(model_path).rglob("*.pt"))).eval().to(device) super(RerankModel, self).__init__(model=model, dtype=dtype, device=device) @property def batch_type(self) -> Type[PaddedBatch]: return PaddedBatch @tracer.start_as_current_span("embed") def embed(self, batch: PaddedBatch) -> List[Embedding]: print("rerank model does not support embed function") @tracer.start_as_current_span("embed_all") def embed_all(self, batch: PaddedBatch) -> List[TokenEmbedding]: print("rerank model does not support embed_all function") @tracer.start_as_current_span("predict") def predict(self, batch: PaddedBatch) -> List[Prediction]: kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} scores = self.model(kwargs["input_ids"].to(self.device), kwargs["attention_mask"].to(self.device))[0].tolist() return [ Prediction( values=scores[i] ) for i in range(len(batch)) ]
- 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/text_embeddings_server/models/types.py文件中为PaddedBatch类增加max_length字段并修改数据搬运以适应mindietorch(以下代码加粗部分)。
import torch import mindietorch from abc import ABC, abstractmethod from dataclasses import dataclass from opentelemetry import trace from text_embeddings_server.pb import embed_pb2 from text_embeddings_server.pb.embed_pb2 import Embedding, Prediction, TokenEmbedding tracer = trace.get_tracer(__name__) class Batch(ABC): @classmethod @abstractmethod def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "Batch": raise NotImplementedError @abstractmethod def __len__(self): raise NotImplementedError @dataclass class PaddedBatch(Batch): input_ids: torch.Tensor token_type_ids: torch.Tensor position_ids: torch.Tensor attention_mask: torch.Tensor max_length: int @classmethod @tracer.start_as_current_span("from_pb") def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "PaddedBatch": # Allocate padded tensors all at once all_tensors = torch.zeros( [4, len(pb.cu_seq_lengths) - 1, pb.max_length], dtype=torch.int32, device='cpu' ) max_length=pb.max_length for i, start_index in enumerate(pb.cu_seq_lengths[:-1]): end_index = pb.cu_seq_lengths[i + 1] input_length = end_index - start_index all_tensors[0, i, :input_length] = torch.tensor( pb.input_ids[start_index:end_index], dtype=torch.int32 ) all_tensors[1, i, :input_length] = torch.tensor( pb.token_type_ids[start_index:end_index], dtype=torch.int32 ) all_tensors[2, i, :input_length] = torch.tensor( pb.position_ids[start_index:end_index], dtype=torch.int32 ) all_tensors[3, i, :input_length] = 1 """ # Move padded tensors all at once all_tensors = all_tensors.to(device) """ return PaddedBatch( input_ids=all_tensors[0], token_type_ids=all_tensors[1], position_ids=all_tensors[2], attention_mask=all_tensors[3], max_length=max_length, ) def __len__(self): return len(self.input_ids) @dataclass class FlashBatch(Batch): input_ids: torch.Tensor token_type_ids: torch.Tensor position_ids: torch.Tensor cu_seqlens: torch.Tensor max_s: int size: int @classmethod @tracer.start_as_current_span("from_pb") def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "FlashBatch": if device.type != "cuda": raise RuntimeError(f"FlashBatch does not support device {device}") batch_input_ids = torch.tensor(pb.input_ids, dtype=torch.int32, device=device) batch_token_type_ids = torch.tensor( pb.token_type_ids, dtype=torch.int32, device=device ) batch_position_ids = torch.tensor( pb.position_ids, dtype=torch.int32, device=device ) cu_seqlens = torch.tensor(pb.cu_seq_lengths, dtype=torch.int32, device=device) return FlashBatch( input_ids=batch_input_ids, token_type_ids=batch_token_type_ids, position_ids=batch_position_ids, cu_seqlens=cu_seqlens, max_s=pb.max_length, size=len(cu_seqlens) - 1, ) def __len__(self): return self.size
父主题: 适配样例